Practitioner's Guide to Two-Tailed Averaging

Tags: ai, Date: 2022-12-06

This is a complement to the Two-Tailed Averaging paper, approached from the direction of what I think is a fairly common technique: averaging checkpoints.

We want to speed up training and improve generalization. One way to do that is by averaging weights from optimization, and that's a big win (e.g. 1, 2, 3). For example, while training a language model for the down-stream task of summarization, we can save checkpoints periodically and average the model weights from the last 10 or so checkpoints to produce the final solution. This is pretty much what Stochastic Weight Averaging (SWA) does.

Problems with SWA

There is a number of problems with SWA:

In summary, working with SWA is tricky because:

These are the issues Two-Tailed Averaging tackles.

Two-Tailed Averaging

The algorithm needs storage for only two sets of weights (constant storage cost) and performance (e.g. of summarization) to be evaluated periodically. In return, it provides a weight average of approximately optimal length at all optimization steps. Now we can start training that language model, periodically evaluating how the averaged weights are doing at summarization. We can stop the training run any time if it's getting worse.

This is how Two-Tailed Averaged (orange) compares to SWA (green) tuned to start averaging at the point that's optimal for final validation loss:

2TA (orange) vs SWA (green)

The Algorithm

The core algorithm maintains two weight averages. Both averages are over the most recent weights weight produced by the optimizer, but they differ in length (i.e. how many weights they average). As the optimizer produces new sets of weights, they are added to both averages. We periodically evaluate the performance of our model with each average. If the short average (the one that currently has fewer weights averaged) does at least as well as the long average according to some arbitrary evaluation function, then we empty the long average, which will now be the short one.

# Initialize the short (s, sw) and long averages (l, lw). s and l are
# the number of weights averaged (the "averaging lengths"). sw and lw
# are the averaged weights.
s, sw, l, lw = 0, 0, 0, 0

# Update the averages with the latest weights from the optimizer.
def update_2ta(w):
  global s, sw, l, lw
  assert s <= l
  s, sw = s+1, (s*sw + w)/(s+1)
  l, lw = l+1, (l*lw + w)/(l+1)

# Evaluate the model with the short-, the long-, and the
# non-averaged weights. Based on the results, adapt the length of
# the averages. Return three values: the best evaluation results,
# the corresponding weights and averaging length.
def evaluate_2ta(w, evaluate):
  global s, sw, l, lw
  # Evaluate the non-averaged weights w, the short and the long average.
  f1, fs, fl = evaluate(w), evaluate(sw), evaluate(lw)
  is_first_eval = (s == l)
  # If the short average is better, then *switch*: empty the long
  # average, which is now the shorter one.
  if fs <= fl:
    s, l, lw, fl = 0, s, sw, fs
  if f1 <= fl:
    # The non-averaged weights performed better. This may happen in
    # the very early stages of training.
    if is_first_eval:
      # If there has never been a switch (s == l), then f1 is probably
      # still improving fast so reset both averages.
      s, l = 0, 0
    return f1, w, 1
    # Return the long average.
    return fl, lw, l

In addition to the core algorithm, the code above has some extra logic to deal with the non-averaged weights being better than the averaged ones.

Let's write a fake a training loop that optimizes $f(x)=x^2$.

import random

def test_2ta_simple():
  def f(w):
    return w**2
  def df_dw(w):
    # Simulate stochasticity due to e.g. minibatching.
    return 2*w + random.uniform(-1.0, 1.0)
  lr = 0.5
  w = 3.14
  for i in range(1, 2001):
    w = w - lr*df_dw(w)
    if i % 100 == 0:
      f_2ta, w_2ta, l_2ta = evaluate_2ta(w, f)
      print(f'i={i:4d}: f(w_i)={f(w):7.3f},'
            f' f(w_2ta)={f_2ta:7.3f}, l={l_2ta:4d}')

We added some noise to the gradients in df_dw to make it more like training a neural net with SGD. Anyway, we take 2000 optimization steps, calling update_2ta on most but calling update_and_evaluate_2ta every 100 steps. Running test_2ta_simple, we get something like this:

i= 100: f(w_i)=0.108, f(w_2ta)=0.000, l= 100
i= 200: f(w_i)=0.011, f(w_2ta)=0.000, l= 200
i= 300: f(w_i)=0.098, f(w_2ta)=0.000, l= 200
i= 400: f(w_i)=0.085, f(w_2ta)=0.000, l= 300
i= 500: f(w_i)=0.221, f(w_2ta)=0.000, l= 200
i= 600: f(w_i)=0.185, f(w_2ta)=0.000, l= 300
i= 700: f(w_i)=0.019, f(w_2ta)=0.000, l= 400
i= 800: f(w_i)=0.180, f(w_2ta)=0.000, l= 500
i= 900: f(w_i)=0.161, f(w_2ta)=0.000, l= 600
i=1000: f(w_i)=0.183, f(w_2ta)=0.000, l= 700
i=1100: f(w_i)=0.057, f(w_2ta)=0.000, l= 800
i=1200: f(w_i)=0.045, f(w_2ta)=0.000, l= 900
i=1300: f(w_i)=0.051, f(w_2ta)=0.000, l=1000
i=1400: f(w_i)=0.010, f(w_2ta)=0.000, l= 900
i=1500: f(w_i)=0.012, f(w_2ta)=0.000, l=1000
i=1600: f(w_i)=0.168, f(w_2ta)=0.000, l=1100
i=1700: f(w_i)=0.001, f(w_2ta)=0.000, l=1200
i=1800: f(w_i)=0.020, f(w_2ta)=0.000, l=1300
i=1900: f(w_i)=0.090, f(w_2ta)=0.000, l=1400
i=2000: f(w_i)=0.115, f(w_2ta)=0.000, l=1500

In the above, f(w_i) is the loss with the non-averaged weights, f(w_2ta) is the loss with the weights provided by 2TA, and l is the number of weights averaged. We see that with the high, constant learning rate, SGD keeps jumping around the optimum, and while 2TA does the same, its jitter is way smaller (it's beyond the three significant digits printed here). Also, the length of the average increases almost monotonically but not quite due to the switching logic.

OK, that was easy. Let's now do something a bit more involved, where the function being optimized changes. We will change the loss function to $f(x) = (x-m)^2$ where $m$ is set randomly every 400 steps. We will deal with this non-stationarity by resetting the long average if it has not improved for a while.

def reset_2ta_long_average():
  global s, sw, l, lw
  s, sw, l, lw = 0, 0, s, sw

def test_2ta_non_stationary():
  optimum = 0
  def f(w):
    return (w-optimum)**2
  def df_dw(w):
    # Simulate stochasticity due to e.g. minibatching.
    return 2*w - 2*optimum + random.uniform(-1.0, 1.0)
  lr = 0.5
  w = 3.14
  best_f = float("inf")
  best_iteration = 0
  for i in range(1, 2001):
    w = w - lr*df_dw(w)
    if i % 400 == 0:
      optimum = random.uniform(-10.0, 10.0)
      print(f'setting optimum={optimum:.3f}')
    if i % 100 == 0:
      f_2ta, w_2ta, l_2ta = evaluate_2ta(w, f)
      print(f'i={i:4d}: f(w_i)={f(w):7.3f},'
            f' f(w_2ta)={f_2ta:7.3f}, l={l_2ta:4d}',
      if l_2ta > 1 and f_2ta < best_f:
        best_f = f_2ta
        best_iteration = i
      elif best_iteration + 1 <= i:
        # Reset heuristic: the results of the long average have not
        # improved for a while, let's reset it so that it may adapt
        # quicker.
        print(' Reset!')
        best_f = float("inf")
        best_iteration = 0

We can see that 2TA adapts to the non-stationarity in a reasonable way although the reset heuristic gets triggered spuriously a couple of times:

i= 100: f(w_i)=  0.008, f(w_2ta)=  0.005, l= 100
i= 200: f(w_i)=  0.060, f(w_2ta)=  0.000, l= 100
i= 300: f(w_i)=  0.004, f(w_2ta)=  0.000, l= 100
setting optimum=9.691
i= 400: f(w_i)= 87.194, f(w_2ta)= 87.194, l=   1 Reset!
i= 500: f(w_i)=  0.002, f(w_2ta)=  0.000, l= 100
i= 600: f(w_i)=  0.033, f(w_2ta)=  0.000, l= 200 Reset!
i= 700: f(w_i)=  0.126, f(w_2ta)=  0.000, l= 200
setting optimum=9.899
i= 800: f(w_i)=  0.022, f(w_2ta)=  0.022, l=   1 Reset!
i= 900: f(w_i)=  0.004, f(w_2ta)=  0.003, l= 100
i=1000: f(w_i)=  0.094, f(w_2ta)=  0.000, l= 100
i=1100: f(w_i)=  0.146, f(w_2ta)=  0.000, l= 100
setting optimum=3.601
i=1200: f(w_i)= 35.623, f(w_2ta)= 35.623, l=   1 Reset!
i=1300: f(w_i)=  0.113, f(w_2ta)=  0.001, l= 100
i=1400: f(w_i)=  0.166, f(w_2ta)=  0.000, l= 200
i=1500: f(w_i)=  0.112, f(w_2ta)=  0.000, l= 200
setting optimum=6.662
i=1600: f(w_i)= 11.692, f(w_2ta)=  9.409, l= 300 Reset!
i=1700: f(w_i)=  0.075, f(w_2ta)=  0.000, l= 100
i=1800: f(w_i)=  0.229, f(w_2ta)=  0.000, l= 200 Reset!
i=1900: f(w_i)=  0.217, f(w_2ta)=  0.000, l= 100
setting optimum=-8.930
i=2000: f(w_i)=242.481, f(w_2ta)=242.481, l=   1 Reset!

Note that in these examples, the evaluation function in 2TA was the training loss, but 2TA is intended for when the evaluation function measures performance on the validation set or on a down-stream task (e.g. summarization).

Scaling to Large Models

In its proposed form, Two-Tailed Averaging incorporates every set of weights produced by the optimizer in both averages it maintains. This is good because Tail Averaging, also known as Suffix Averaging, theory has nice things to say about convergence to a local optimum of the training loss in this setting. However, in a memory constrained situation, these averages will not fit on the GPU/TPU, so we must move the weights off the device to add them to the averages (which may be in RAM or on disk). Moving stuff off the device can be slow, so we might want to do that, say, every 20 optimization steps. Obviously, downsampling the weights too much will affect the convergence rate, so there is a tradeoff.

Learning Rate

Note that in our experiments with Two-Tailed Averaging, we used a constant learning rate motivated by the fact that the closely related method of Tail Averaging guarantees optimal convergence rate learning rate in such a setting. The algorithm should work with decreasing learning rates but would require modification for cyclical schedules.

Related Works

Adaptivity: SWA and LAWA have hyperparameters that directly control the averaging length; NT-ASGD still has one, but its effect is more indirect. Anytimeness: LAWA provides an average at all times, SWA and NT-ASGD don't. Optimality: The final averages of SWA and LAWA are optimal if their hyperparameters are well-tuned; intermediate results of LAWA are unlikely to be optimal; NT-ASGD can miss the right time to start averaging.


Two-Tailed Averaging can be thought of as online SWA with no hyperparameters. It is a great option when training runs take a long (or even an a priori unknown amount of) time, and when we could do without optimizing yet another hyperparameter.

Comment on Twitter or Mastodon.