Practitioner's Guide to Two-Tailed Averaging

Tags: ai, Date: 2022-12-06

This is a complement to the Two-Tailed Averaging paper, approached from the direction of what I think is a fairly common technique: averaging checkpoints.

We want to speed up training and improve generalization. One way to do that is by averaging weights from optimization, and that's a big win (e.g. 1, 2, 3). For example, while training a language model for the down-stream task of summarization, we can save checkpoints periodically and average the model weights from the last 10 or so checkpoints to produce the final solution. This is pretty much what Stochastic Weight Averaging (SWA) does.

Problems with SWA

There is a number of problems with SWA:

In summary, working with SWA is tricky because:

Two-Tailed Averaging

These are the issues Two-Tailed Averaging tackles. The algorithm needs storage for only two sets of weights (constant storage cost) and performance (e.g. summarization) to be evaluated periodically. In return, it provides a weight average of approximately optimal length at all optimization steps. Now we can start training that language model, periodically evaluating how the averaged weights are doing at summarization. We can stop the training run anytime if it's getting worse.

This is how Two-Tailed Averaged (orange) compares to SWA (green) tuned to start averaging at the point that's optimal for final validation loss:

TTA (orange) vs SWA (green)

The Algorithm

The core algorithm is quite simple: as the optimizer produces new weights, we add those to two moving averages. When the short moving average (to which we added fewer or equal number of weights) does at least as well as the long average according to some arbitrary evaluation function, then we empty the long average, which will now be the short one.

# Initialize the short (s, sw) and long averages (l, lw). s and l are
# the number of weights averaged (the "averaging lengths"). sw and lw
# are the averaged weights.
s, sw, l, lw = 0, 0, 0, 0

# Update the moving averages with the latest weights from the optimizer.
def update_tta(w):
  global s, sw, l, lw
  assert s <= l
  s, sw = s+1, (s*sw + w)/(s+1)
  l, lw = l+1, (l*lw + w)/(l+1)

# Like update_tta but also evaluate the model and use that to adapt
# the length of the averages. Return three values: the best evaluation
# results, the corresponding weights and averaging length.
def evaluate_tta(w, evaluate):
  global s, sw, l, lw
  # Evaluate the non-averaged weights w, the short and the long average.
  f1, fs, fl = evaluate(w), evaluate(sw), evaluate(lw)
  is_first_eval = (s == l)
  # If the short average is better, then *switch*: empty the long
  # average, which is now the shorter one.
  if fs <= fl:
    s, l, lw, fl = 0, s, sw, fs
  if f1 <= fl:
    # The non-averaged weights performed better. This may happen in
    # the very early stages of training.
    if is_first_eval:
      # If there has never been a switch (s == l), then f1 is probably
      # still improving fast so reset both averages.
      s, l = 0, 0
    return f1, w, 1
  else:
    # Return the long average.
    return fl, lw, l

In addition to the core algorithm, the code above has some extra logic to deal with the non-averaged weights being better than the averaged ones.

Let's write a fake a training loop that optimizes $f(x)=x^2$.

import random

def test_tta_simple():
  def f(w):
    return w**2
  def df_dw(w):
    # Simulate stochasticity due to e.g. minibatching.
    return 2*w + random.uniform(-1.0, 1.0)
  lr = 0.5
  w = 3.14
  for i in range(1, 2001):
    w = w - lr*df_dw(w)
    update_tta(w)
    if i % 100 == 0:
      tta_f, tta_w, tta_l = evaluate_tta(w, f)
      print(f'i={i:4d}: f(w_i)={f(w):7.3f},'
            f' f(w_tta)={tta_f:7.3f}, l={tta_l:4d}')

We added some noise to the gradients in df_dw to make it more like training with a neural net with SGD. Anyway, we take 2000 optimization steps, calling update_tta on most but calling update_and_evaluate_tta on every 100 steps. Running test_tta_simple, we get something like this:

i= 100: f(w_i)=  0.108, f(w_tta)=  0.000, l= 100
i= 200: f(w_i)=  0.011, f(w_tta)=  0.000, l= 200
i= 300: f(w_i)=  0.098, f(w_tta)=  0.000, l= 200
i= 400: f(w_i)=  0.085, f(w_tta)=  0.000, l= 300
i= 500: f(w_i)=  0.221, f(w_tta)=  0.000, l= 200
i= 600: f(w_i)=  0.185, f(w_tta)=  0.000, l= 300
i= 700: f(w_i)=  0.019, f(w_tta)=  0.000, l= 400
i= 800: f(w_i)=  0.180, f(w_tta)=  0.000, l= 500
i= 900: f(w_i)=  0.161, f(w_tta)=  0.000, l= 600
i=1000: f(w_i)=  0.183, f(w_tta)=  0.000, l= 700
i=1100: f(w_i)=  0.057, f(w_tta)=  0.000, l= 800
i=1200: f(w_i)=  0.045, f(w_tta)=  0.000, l= 900
i=1300: f(w_i)=  0.051, f(w_tta)=  0.000, l=1000
i=1400: f(w_i)=  0.010, f(w_tta)=  0.000, l= 900
i=1500: f(w_i)=  0.012, f(w_tta)=  0.000, l=1000
i=1600: f(w_i)=  0.168, f(w_tta)=  0.000, l=1100
i=1700: f(w_i)=  0.001, f(w_tta)=  0.000, l=1200
i=1800: f(w_i)=  0.020, f(w_tta)=  0.000, l=1300
i=1900: f(w_i)=  0.090, f(w_tta)=  0.000, l=1400
i=2000: f(w_i)=  0.115, f(w_tta)=  0.000, l=1500

In the above, f(w_i) is the loss with the non-averaged weights, f(w_tta) is the loss with the weights provided by TTA, and l is the number of weights averaged. We see that with the high constant learning rate, SGD keeps jumping around the optimum, and while TTA does the same, its jitter is way smaller (it's beyond the three significant digits printed here). Also, the length of the average increases almost monotonically but not quite due to the switching logic.

OK, that was easy. Let's now do something a bit more involved, where the function being optimized changes. We will change the loss function to $f(x) = (x-m)^2$ where $m$ is set randomly every 400 steps. We will deal with this non-stationarity by resetting the long average if it has not improved for a while.

def reset_tta_long_average():
  global s, sw, l, lw
  s, sw, l, lw = 0, 0, s, sw

def test_tta_non_stationary():
  optimum = 0
  def f(w):
    return (w-optimum)**2
  def df_dw(w):
    # Simulate stochasticity due to e.g. minibatching.
    return 2*w - 2*optimum + random.uniform(-1.0, 1.0)
  lr = 0.5
  w = 3.14
  best_f = float("inf")
  best_iteration = 0
  for i in range(1, 2001):
    w = w - lr*df_dw(w)
    update_tta(w)
    if i % 400 == 0:
      optimum = random.uniform(-10.0, 10.0)
      print(f'setting optimum={optimum:.3f}')
    if i % 100 == 0:
      tta_f, tta_w, tta_l = evaluate_tta(w, f)
      print(f'i={i:4d}: f(w_i)={f(w):7.3f},'
            f' f(w_tta)={tta_f:7.3f}, l={tta_l:4d}',
            end='')
      if tta_l > 1 and tta_f < best_f:
        best_f = tta_f
        best_iteration = i
        print()
      elif best_iteration + 1 <= i:
        # Reset heuristic: the results of the long average have not
        # improved for a while, let's reset it so that it may adapt
        # quicker.
        print(' Reset!')
        reset_tta_long_average()
        best_f = float("inf")
        best_iteration = 0

We can see that TTA adapts to the non-stationarity in a reasonable way although the reset heuristic gets triggered spuriously a couple of times:

i= 100: f(w_i)=  0.008, f(w_tta)=  0.005, l= 100
i= 200: f(w_i)=  0.060, f(w_tta)=  0.000, l= 100
i= 300: f(w_i)=  0.004, f(w_tta)=  0.000, l= 100
setting optimum=9.691
i= 400: f(w_i)= 87.194, f(w_tta)= 87.194, l=   1 Reset!
i= 500: f(w_i)=  0.002, f(w_tta)=  0.000, l= 100
i= 600: f(w_i)=  0.033, f(w_tta)=  0.000, l= 200 Reset!
i= 700: f(w_i)=  0.126, f(w_tta)=  0.000, l= 200
setting optimum=9.899
i= 800: f(w_i)=  0.022, f(w_tta)=  0.022, l=   1 Reset!
i= 900: f(w_i)=  0.004, f(w_tta)=  0.003, l= 100
i=1000: f(w_i)=  0.094, f(w_tta)=  0.000, l= 100
i=1100: f(w_i)=  0.146, f(w_tta)=  0.000, l= 100
setting optimum=3.601
i=1200: f(w_i)= 35.623, f(w_tta)= 35.623, l=   1 Reset!
i=1300: f(w_i)=  0.113, f(w_tta)=  0.001, l= 100
i=1400: f(w_i)=  0.166, f(w_tta)=  0.000, l= 200
i=1500: f(w_i)=  0.112, f(w_tta)=  0.000, l= 200
setting optimum=6.662
i=1600: f(w_i)= 11.692, f(w_tta)=  9.409, l= 300 Reset!
i=1700: f(w_i)=  0.075, f(w_tta)=  0.000, l= 100
i=1800: f(w_i)=  0.229, f(w_tta)=  0.000, l= 200 Reset!
i=1900: f(w_i)=  0.217, f(w_tta)=  0.000, l= 100
setting optimum=-8.930
i=2000: f(w_i)=242.481, f(w_tta)=242.481, l=   1 Reset!

Note that that in these examples the evaluation function in TTA was the training loss, but TTA is mainly intended for when the evaluation function measures performance on the validation set or on a down-stream task (e.g. summarization).

Downsampling weights

In its proposed form, Two-Tailed Averaging incorporates every set of weights produced by the optimizer in both averages it maintains. This is good because Tail Averaging, also known as Suffix Averaging, theory has nice things to say about convergence to a local optimum in this setting. However, in a memory constrained situation, these averages will not fit on the GPU/TPU, so we must move the weights off the device to add them to the averages (which may be in RAM or on disk). Moving stuff off the device can be slow, so we might want to do that, say, every 20 optimization steps. Obviously, downsampling the weights too much will affect the convergence rate, so there is a tradeoff.

Learning rate

Note that in our experiments with Two-Tailed Averaging, we used a constant learning rate motivated by the fact that the closely related method of Tail Averaging guarantees optimal convergence rate learning rate in such a setting. The algorithm should work with decreasing learning rates but would require modification for cyclical schedules.

Related works

Adaptivity: SWA and LAWA have hyperparameters that directly control the averaging length; NT-ASGD still has one, but its effect is more indirect. Anytimeness: LAWA provides an average at all times, SWA and NT-ASGD don't. Optimality: The final averages of SWA and LAWA are optimal if their hyperparameters are well-tuned; intermediate results of LAWA are unlikely to be optimal; NT-ASGD can miss the right time to start averaging.

Summary

Two-Tailed Averaging can be thought of as online SWA with no hyperparameters. It is a great option when training runs take a long (or even an a priori unknown amount of) time, and when we could do without optimizing yet another hyperparameter.

Comment on Twitter or Mastodon.