Research Forum | Episode 4 - abstract chalkboard background with colorful network nodes and circular icons

Research Forum Brief | September 2024

Direct Nash Optimization: Teaching language models to self-improve with general preferences

Published

Corby Rosset

“The traditional way to fine-tune an LLM for post-training … basically tells the model to emulate good behaviors, but it does not target or correct any mistakes or bad behaviors that it makes explicitly. … Self-improving post-training explicitly identifies and tries to correct bad behaviors or mistakes that the model makes.”

Corby Rosset, Senior Researcher, Microsoft Research AI Frontiers

Transcript: Lightning Talk

Direct Nash Optimization: Teaching language models to self-improve with general preferences

Corby Rosset, Senior Researcher, Microsoft Research AI Frontiers

This talk discusses teaching language models to self-improve using a preference oracle like GPT-4, framing it as a two-player game to find an optimal policy at a Nash equilibrium, and achieving state-of-the-art win rates against GPT-4 Turbo on benchmarks such as AlpacaEval and MT-Bench.

Microsoft Research Forum, September 3, 2024

CORBY ROSSET: Hi, I’m Corby. I’m a scientist in Microsoft Research. Today, we’re going to be talking about Direct Nash Optimization, which is a technique to help language models self-improve.

We all know that there are two main ways to improve language models. One is to scale up the number of parameters or to scale up the amount of training data. Both of these approaches are costly even for the post-training techniques. The traditional way to fine-tune an LLM for post-training is using SFT. SFT basically tells the model to emulate good behaviors, but it does not target or correct any mistakes or bad behaviors that it makes explicitly. More advanced post-training techniques such as RLHF use a fixed reward model, which can be easily hacked or go stale during training and involves much more complex reinforcement learning, which can be unstable. Self-improving post-training explicitly identifies and tries to correct bad behaviors or mistakes that the model makes.

Before we move on, we want to give a concrete example of what we mean by self-improving behavior. Here’s a simple geometry problem where a base model that was already SFTed makes a simple arithmetic error on the left-hand side. After our self-improving technique, the model is able to correct this mistake.

Here we give a simple overview of how Direct Nash Optimization works. One of the properties of generative LLMs is that you can sample multiple outputs from them. This is advantageous because what we can do is, given an input, we can take our language model and sample, in this case, two outputs—answer A and answer B—and we can have them scored or rated by a preference function oracle, which tells us which response is better. Then we can use a contrastive training mechanism, such as DPO or IPO or others to update the parameters of the language model to hopefully improve it. In the next iteration, timestep t+1, we repeat the process over again. The key insight of this technique is how we define reward. Typically, in the RLHF framework, we want to maximize the reward of a language model policy against some given external reward model. Here, we redefine “reward” as the expected win rate against your own behavior as judged by a preference function P. What this means is that for a given response y to an input x, the reward of that response is defined as the expected win rate against y primes sampled from the policy itself. Hence, rewards are maximized by responses that are preferred over other responses.

When you start comparing the y primes, or the model’s own outputs to each other, this incentivizes a self-improving behavior because you’re basically competing against yourself. You can formulate this in a game theoretic manner where, in this game, you have a single player which is competing against itself, and the payoffs are given by the preference function. In this game, a Nash equilibrium is achieved by the best possible π* whose responses are preferred over any other competing policy in its class.

At a high level, Direct Nash Optimization has many advantages. Firstly, it optimizes towards a more general preference function directly rather than a point-wise reward model, which is limited in its expressibility since it can’t model transitive preferences. Secondly, it is an iterative algorithm, meaning it is much simpler to implement. We use a contrastive update as the loss, which does not involve any policy gradients or heavy reinforcement learning machinery. We also sample on policy outputs from the model and compare them to each other in a self-play framework. We use a powerful preference annotator—in this case, GPT-4—to rank or judge the best response among them. This approach is also flexible since we can compare the responses to each other but also to outputs from a more powerful teacher such as GPT-4, which provides even bigger improvements. Most importantly, this algorithm is theoretically guaranteed to monotonically approach the Nash equilibrium, hence the name Direct Nash Optimization.

If you implement this algorithm correctly, you will find state-of-the-art results on several benchmarks, including this one, which is AlpacaEval2. This benchmark basically measures how well language models follow instructions and align with human expectations. This benchmark computes a win rate of the language model’s outputs versus a powerful reference—in this case, GPT-4—in a side-by-side comparison. The y-axis is the win rate, and the x-axis is the amount of iterations of training. We see that the dark blue line, which is DNO, the vanilla implementation, outperforms two important baselines. The red line is SFT, and the orange and yellow lines are offline contrastive algorithms, such as DPO and KTO. Hence, we see that self-improving post-training is better than offline contrastive training and SFT. Notably, DNO is also able to outperform similar training techniques from other models, which were 10 times as large, namely the gray line, which was a 70 billion parameter Llama model. We are also encouraged to see that these results do not saturate, and with more training in the purple line over more iterations, we see even better results.

We hope this work inspires other researchers to continue to investigate self-improving post-training as an effective method for aligning language models with human expectations. Thank you for watching.