Prioritized Training - Regression vs Classification with RHO Loss
tl;dr
This is a blog post exploring how the loss function (classification vs regression losses) affects the feasability of using this prioritized training technique and in which regimes the overhead justifies the cost.
Original Paper: “Prioritized Training on Points that are Learnable, Worth Learning, and Not Yet Learnt” My experiment: (https://github.com/timholds/prioritized-training)[https://github.com/timholds/prioritized-training]
Background
In a previous post, Prioritized Training with Reducible Holdout Loss, we built an intuition for why RHO loss is a good way to select training points. In a nutshell, points that are learnable by a weak model and not yet learned by a target model make for the perfect training points! Doing this lets us train on way smaller subsets of the training data on any given epoch with similar performance to training on the full dataset.
However, there is no free lunch here. We incur a one time cost to train the holdout model, even though its underfit and we can reuse it with many different candidate models. We also incur some cost every epoch to calculate the RHO losses, either in memory/synchronization if we use a copy of the target model and calculate the losses in parallel, or a compute cost if we use our target model directly. We also have to sort the potential batch based on loss values are get the top batch_size datapoints. The hope is that the overhead is outweighed by the benefits of training on ~10-50% of the training data every epoch.
A crucial detail is that in the original paper, they only train classification models! In this post, we will explore how well this method works for regression problems, and how the dimensionality of the regression problem affects the performance of the RHO loss. Ultimately, the point of this post is to explore the limits of the Prioritized Training approach and when the overhead outweighs the benefits.
Experiment Notes
Challenges with comparison We care about seeing what subsample rate is needed to achieve comparable performance with PT models compared to a baseline trained without any curriculum learning. To do this, we need to determine which metric and at what thrershold we consider the model to be “trained”.
For the classification models, accuracy is a good metric to use, since it is easy to interpret and compare across models. However, for regression problems, it’s unclear
The only thing we can compare between the regression and classification models is the subsample rate needed to get comparable performance. We could use some loss threshold, but since the losses like MSE and CCE have a different numerical nature, it’s better to stick to a downstream metric like accuracy for classification and MSE for regression.
What are regrression datasets I can use like QMNIST and CIFAR10/100? Regression Datasets:
Since this is not a method to generate SoTA models but rather train models up to some level of performance, we will use the accuracy at which the model trained on the full dataset reaches. This is a bit arbitrary, but it is a good enough proxy for our purposes.
Batch Size: Since we are sampling data from every batch, we need to increase the “batch size” inversely to the sampling rate in the RHO loss models. More concretely, if we are training our baseline / null hypothesis models with a batch size of 64, then training a RHO loss model with a subsample rate of 10% calls for an increase in batch size to 640.
This lets us fix the effective batch size in attempt to run a clean experiment, hopefully controlling for the hyperpameters / learning dynamic so that any changes in the learning dynamic we observe are due to the RHO loss and not the batch size.
In Summary
We can think about Prioritized Training as a form of distillation, but instead of getting the full prediction distribution over the labels from a teacher model, we are getting datapoints that will induce better gradients in the target model. However, the nature of some loss functions and the dimensionality of the output space change the efficiacy of this method, and YMMV depending on the task at hand and the overhead to train the holdout model and compute the rho losses every epoch.
we are using an undertrained model as a proxy to evaluate how learnable a point is, and then using the target model to evaluate how well the point has been learned.
Assume a candidate batch size B and an actual batch size of b, where the subsample rate r = b / B. The authors assume a forward pass on a candidate batch takes time T, forward pass for an training batch t, and backward pass for training batch 2t, where t < T
Since the forward pass takes t and backward pass takes 2t time, the total time per batch is 3t. Depending on the ratio between t and T Let’s think about how the needed subsample rate affects the total time per batch.
When does T + 3t < >
We know the equation for the total amount BT (candidate forward pass) + bT (subsampled forward pass) + 2bT (subsampled backward pass) = BT + 3b*T
If we can do some gradient checkpointing However, we can do even better by avoiding the recomputation of the forward pass. Assuming we have enough memory to hold an extended batch in memory Instead, we can just mask the forward pass from the candidate batch so that we only see the selected dataponts. This means we only need to do one forward pass T per batch, and 2t backwards pass. Now, when is T + 2t < 3t. When t < 1
If we are training on a subsample rate of 10%, then we are only doing 30% of the forward passes and 60% of the backward passes, leading to a total time of 90% of the original time. Basically, if the needed subsample rate is actually like 95%, then you are doing
The authors state “Recall that our setting assumes training time is a bottleneck but data is abundant—more than we can train on (see Bottou & LeCun (2004)). This is common e.g. for web-scraped data where state-of-the-art performance is often reached in less than half of one epoch (Komatsuzaki, 2019; Brown et al., 2020).”
They also talk of using a copy of the target model to calculate the irreducible losses in parallel, but increasingly large models memory and synchronization costs become a bottleneck. In that case, we will be thinking about the time penalty
We also incur some cost every epoch to calculate the RHO losses, either in memory/synchronization if we use a copy of the target model and calculate the losses in parallel, or a compute cost if we use our target model directly.
Do I even really need to rerun the forward pass with the candidate models, or can i just mask the loss for the indices that correspond to the argmax of the rho loss aka the candidate batch?
When the model is very underfit, the rho loss can help select points that are more likely to improve the model. In this case, the holdout model is also very underfit, and thus the rho loss is a good measure of learnability.
then the holdout model will not be able to learn anything useful about the data distribution, and thus the irreducible loss will not be a good measure of learnability. In this case, it would be better to train on all the points in the training set.
If the training data is very small, then the holdout model will not be able to learn anything useful about the data distribution, and thus the irreducible loss will not be a good measure of learnability. In this case, it would be better to train on all the points in the training set.