Table of Contents
This is the initial draft of Reward Models for the GPU Poor, written in under an hour to capitalize on a bounty. You'll probably find the rewritten version more helpful but this might be good for a laugh.
Chai Research is holding a neat little large language model competition. This season they've introduced the ability to package a reward model with your submission, to be used with best-of-4 sampling. Looks like there haven't been any custom models submitted yet though.
Let's train a reward model! Really quickly. Go go go.
Base Model
We can use either gpt2 or Phi as the base for our reward model. Phi is definitely more capable, but larger and slower both to train and to evaluate. Well phooey to that. I'm going to go the opposite direction, and train an even smaller reward model. I can't afford 137M parameters in this economy. Let's chop a few layers off of gpt2 and start from there.
Let's pop open mergekit and make the aerodynamic, streamlined base model of our dreams. A nice, simple config:
slices:
- sources:
- model: gpt2
layer_range: [0, 8]
merge_method: passthrough
dtype: float16
mergekit-yml gpt2-small.yml ./gpt2-small
And now we have gpt2-small
, weighing in at 96M parameters. Much better!
Training
Now it's time to cram a whole bunch of data in there. Chai has provided a nice dataset of real feedback from their users, which will serve us just fine. It's a little unusual in that it provides a binary 'thumbs up'/'thumbs down' label on single conversations, as opposed to the accept/reject pair used in typical RLHF schemes. That just means we can approach this as simple sequence classification instead of a pairwise rating objective.
Let's crank out a tiny bit of code to get transformers.Trainer
to work for us:
import transformers, datasets
model = transformers.AutoModelForSequenceClassification.from_pretrained("./gpt2-small")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
tokenizer.truncation_side = "right"
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id # nice! super elegant, love that we need this
def prepare_row(row):
tokenized = tokenizer(
row["input_text"],
padding=True,
max_length=model.config.n_positions,
truncation=True,
)
return {
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
"labels": row["labels"],
}
ds = datasets.load_dataset("ChaiML/20231012_chai_prize_reward_model_data")["train"]
ds = ds.map(prepare_row)
And for the sake of propriety, let's put aside an evaluation split as well:
splits = ds.train_test_split(test_size=0.01)
ds_train = splits["train"]
ds_eval = splits["test"]
Great! Now we can train the model. For finding optimal hyperparameters, I used the well-known "gut feeling" theorem. My finely-honed instincts informed me that eight is a nice number, 1e-4 is pretty safe usually, and one epoch sounds like plenty of waiting. Go go go.
train_args = transformers.TrainingArguments(
output_dir="reward-model-out",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
bf16=True,
lr_scheduler_type="cosine",
optim="adamw_torch_fused",
learning_rate=0.0001,
report_to="wandb",
logging_steps=1,
num_train_epochs=1,
)
trainer = transformers.Trainer(
model,
args=train_args,
tokenizer=tokenizer,
train_dataset=ds_train,
)
trainer.train()
Turns out a batch size of 8 fits quite nicely on a 3090. Some 45 minutes later, we have a trained model. Let's see how it does!
def good_prob(example: str):
with torch.no_grad():
input_ids = tokenizer(example, return_tensors="pt")["input_ids"]
logits = model(input_ids.to(model.device))[0].softmax(-1)
return logits[..., 1]
>>> good_prob(ds_eval[0]["input_text"]), ds[0]["labels"]
(tensor([0.4884], device='cuda:0'), 0)
>>> good_prob(ds_eval[1]["input_text"]), ds[1]["labels"]
(tensor([0.5567], device='cuda:0'), 1)
Alright, that's two points that are on the correct side of average! And two points make a line. That line's slope? Success.
For my next pass at this I'm going to train two models and compare them - one on the Chai binary labelled dataset, and another on Anthropic's public rlhf data. I'll be curious to see how the two differ. To facilitate this I've written a script that can train a model on either objective, which is hideous but available here.