No learning rates needed: Introducing SALSA - Stable Armijo Line Search Adaptation
๐ Abstract
The paper introduces SALSA (Stable Armijo Line Search Adaptation), an optimization method that extends the Armijo line search algorithm to improve its performance and stability, particularly in the mini-batch setting. The key contributions are:
- Introducing a momentum term to the Armijo criterion to mitigate the effect of noise in mini-batch gradients.
- Limiting the frequency of line search computations to reduce the overall computational overhead.
- Extensive experiments evaluating SALSA across different datasets, model architectures, and optimization methods (SGD, Adam).
- Demonstrating that SALSA outperforms previous line search methods and tuned optimizers, with an average 1.5% advantage in accuracy and 50% lower average log loss at the end of training.
๐ Q&A
[01] Addressing Mini-batch Noise
1. What is the problem with the original Armijo line search method in the mini-batch setting? The original Armijo line search criterion is checked for every mini-batch, which is problematic due to the inherent noise in mini-batch data. This can lead to frequent changes in the step size, as the criterion is violated by some noisy mini-batches even when the step size is appropriate for most of the mini-batches.
2. How does the proposed SaLSa criterion address this issue? The SaLSa criterion introduces an exponential smoothing (momentum) term for the factors in the Armijo equation that depend on the mini-batch. This makes the left-hand side of the criterion less affected by the current mini-batch, allowing the step size to be reduced more slowly in response to noisy mini-batches.
3. What is the theoretical basis for the convergence of the SaLSa criterion? The authors provide a convergence theorem for the SaLSa criterion with SGD, extending the original Armijo convergence proof. The key assumptions are the existence of a unique minimizer and that every found learning rate yields an improved loss.
[02] Addressing Computational Costs
1. What is the problem with performing a line search at every training step? Performing a line search at every step increases the overall training compute cost by roughly 30%, as it requires additional forward passes.
2. How does the proposed method address this computational overhead? The authors introduce a method to perform the line search less frequently, by monitoring the rate of change in the step size and adjusting the line search frequency accordingly. This reduces the extra compute needed from 30% to approximately 3% for longer training runs.
[03] Experimental Evaluation
1. What are the key findings from the natural language processing (NLP) experiments? In the NLP experiments, the authors found that ADAM + SaLSa achieves a lower final loss compared to ADAM, ADAM + SLS, and SGD + SLS, although the accuracy improvements are not always significant. ADAM + SLS and ADAM + SaLSa perform similarly in terms of accuracy, but both outperform ADAM and SGD + SLS on average.
2. What are the key findings from the image classification experiments? In the image classification experiments, the authors found that ADAM + SLS or SGD + SLS yield good results for CIFAR10 and CIFAR100, but perform poorly for ImageNet, likely due to stability issues. In contrast, the ADAM + SaLSa and SGD + SaLSa approaches do not encounter these problems and deliver the best average performance.
3. What are the overall performance improvements observed with the SaLSa methods? The authors report that the SaLSa methods have on average a 1.5% advantage in accuracy and a 50% lower average log loss at the end of training, compared to the other optimization methods evaluated.