How to use TorchMetrics
A guide for using TorchMetrics with PyTorch or PyTorch Lightning

A critical part of good training is to properly manage validation metrics. This is required in order to monitor which state of the model is performing the best. Different tasks require different metrics to evaluate the accuracy of the model and implementing them generally requires writing boilerplate code. TorchMetrics addresses this problem by providing a modular approach to define and track all the evaluation metrics.
Table of content
- How the Metric class works
- What is a MetricCollection
- How to implement a custom metric
- Using TorchMetrics with PyTorch Lightning
To install the latest version of TorchMetrics from PyPI use:
pip install torchmetrics
1. How the Metric class work
TorchMetrics provides many ready-to-use metrics such as Accuracy, Dice, F1 Score, Recall, Mean Absolute Error, and more. All of them inherit from Metric
which is the father class of each metric:

Before using a metric, it must be instantiated. After that, for every batch read from the dataloader, the target and the prediction are passed to the metric object. Then, it computes the metric result for the current batch and saves it in its internal state which keeps track of the data seen so far. When batches are finished, it’s possible to withdraw from the metric object the final result.
Each metric inherits the following methods from the Metric
class:
metric.forward(preds, target)
— computes the metric usingpreds
andtarget
which are the prediction and target of the current batch. Then, it updates the metric state and returns the metric result. As a shortcut, it is possible to usemetric(preds, target)
, there is no difference between the two syntaxes.metric.update(preds, target)
— the same as forward but without returning the metric result for efficiency. If logging or printing the metric result for each batch is not required, this is the method that should be used because it’s faster.metric.compute()
— returns the result computed over all the data seen so far. It should be called after the end of all the batches.metric.reset()
— clears the metric state. It has to be called at the end of each validation phase.
Summing up:
- For each batch, call
forward
orupdate
. - Outside the validation loader loop, call
compute
to get the final result. - In the end, call
reset
to clear the metric state.
The following code shows the complete process:
If printing batch metrics is not needed, just replace row 12 with metric.update(preds, target)
(and delete row 13).
2. What is a Metric Collection?
In the example above, the validation uses only a single metric but most validation will likely use more than one. TorchMetrics provides MetricsCollection
which wraps a list or a dictionary of metrics into a single callable metric class with the same interface as any other metric (Composite pattern). This way forward
, update
, compute
and reset
are called once from the collection and not for each metric.

The code below shows how MetricCollection
works:
3. How to implement a custom metric
Although TorchMetrics offers a wide range of metrics for different tasks, it might be necessary to implement custom ones. To implement a metric, create a class that inherits from Metric
. Then you should extend the __init__
method and override update
and compute
methods:
__init__()
is the method called every time an object is created. Inside it, by using the methodself.add_state(state_name, default)
, add the required internal states:state_name
is a string of your choice, anddefault
is the initial value (and the value after a reset).- Inside the
update(preds, target)
method, update the states usingpreds
andtarget
. - The
compute()
method must return the final result using state values.
For example, Accuracy is a metric that returns the percentage of identical elements between the prediction and the target.

The procedure to implement this metric from scratch is:
- Create a state
total
to memorize the total values. - Create a state
corrects
to memorize the values equal to the target. update(preds, target)
must increment thetotal
state with the number of values and thecorrect
state with the number of values frompreds
whose value is equal to thetarget
.compute()
returnscorrects
divided bytotal
and multiplied by 100.
This is the complete implementation:
4. Using TorchMetrics with PyTorch Lightning
TorchMetrics is a good combination with PyTorch Lightning to further reduce the boilerplate code. If you have never heard of PyTorch Lightning, it’s a framework to simplify model coding. For further information, refer to their website. If you don’t use PyTorch Lightning, just skip this section.
Through PL method self.log_dict(collection, on_step, on_epoch)
it’s possible to log the metric collection object. By logging the metric object, PyTorch Lightning takes care of when to compute or reset the metric.
Set on_step=True
to log the metrics for each batch. Set on_epoch=True
to log the epoch metrics (thus, the results computed over all the batches). If they are both set to True step metrics and epoch metrics are both logged. For more information about PyTorch Lightning Logging, check their documentation.
The following code shows a PyTorch Lightning module that uses TorchMetrics to handle the metrics:
This way during the training stage, the metrics are logged for each batch. While during the validation stage, step metrics are accumulated to log only the final metrics at the end of the epoch. In this case, compute
and reset
methods are not called because PyTorch Lightning handles them by itself.
To print the metrics at the end of each validation, this piece of code has to be added:
def validation_epoch_end(self, outputs):
results = metric_collection.compute()
print(results)
self.metric_collection.reset()
But in most cases, logging the metrics is enough.
Concluding Remarks
You have discovered the power of TorchMetrics. I think this library is a very clean way to manage metrics and significantly improves the quality of the code. For further information, check their documentation.
Thanks for reading, I hope you have found this useful.
If you enjoyed reading my story and want to support me as a writer, consider signing up to become a Medium member. It’s 5$ a month, which gives you unlimited access to all the stories. If you sign up by using my link, I’ll earn a small commission, and it costs you the same. https://mattiagatti.medium.com/membership