Code
import re
import requests
from collections import Counter
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderThis notebook is based on the output generated by the Claude LLM model claude-opus-4-7. The LLM prompt was
Could you provide a complete PyTorch example of a generative RNN that is trained on “War and peace” and that generates sequences? Ideally, the example should use an LSTM and word embeddings.
The output has been edited and reformatted for increased clarity. If you want to test this notebook, you will need a GPU. Completing 5 training epochs takes approximately 15 minutes on my computer GPU; to get reasonable outputs you probably need to run 50 epochs or more which would take several hours to complete. You can try it out on Google Colab:
Claude also provided some comments on the code, which I include here quoted as-is:
==============================================
| Choice | Reason |
|---|---|
| Word-level (not char-level) | Faster convergence, more coherent output |
| Embedding dim 128 | Standard for ~10k vocab; can go up to 256 |
| 2-layer LSTM, hidden 256 | Good balance of capacity vs. training time |
| Vocabulary cap at 10k | Tail words become |
| Orthogonal init + forget bias = 1 | Matches Keras defaults that often work better |
| Top-k + temperature sampling | Avoids both repetition (greedy) and gibberish (pure sampling) |
| Gradient clipping at 1.0 | Standard safeguard for RNNs |
This trains on disjoint windows rather than maintaining hidden state across an epoch (truncated BPTT). For a more sophisticated implementation, you’d batch contiguous chunks and carry hidden state forward — but the above is the standard introductory pattern and works well.
==============================================
Load relevant libraries and define liveplotting class, training loop, and function to save model checkpoints.
import re
import requests
from collections import Counter
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoaderfrom typing import Optional
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torchmetrics
class LivePlot():
def __init__(self, left_label="Loss", right_label="Accuracy"):
self.fig = go.FigureWidget(
make_subplots(specs=[[{"secondary_y": True}]])
)
self.fig.update_yaxes(title_text=left_label, secondary_y=False)
self.fig.update_yaxes(title_text=right_label, secondary_y=True)
self.plot_indices = {}
self.trace_secondary = {}
display(self.fig)
self.limits = [0, 0]
self.current_x = 0
def report(self, name: str, value: float, secondary_y: bool = False):
try:
plot_index = self.plot_indices[name]
except KeyError:
plot_index = len(self.fig.data)
self.fig.add_scatter(
y=[], x=[], name=name,
secondary_y=secondary_y
)
self.plot_indices[name] = plot_index
self.trace_secondary[name] = secondary_y
self.fig.data[plot_index].y += (value,)
self.fig.data[plot_index].x += (self.current_x,)
def increment(self, n_ticks: int):
self.limits[1] += n_ticks
self.fig.update_layout(xaxis_range=self.limits)
def set_limit(self, n_ticks: int):
self.limits[1] = n_ticks
self.fig.update_layout(xaxis_range=self.limits)
def tick(self, n_ticks: Optional[int] = None):
if n_ticks is None:
n_ticks = 1
self.current_x += n_ticks
def save_model(*, epoch, model, optimizer, total_loss, n, val_loss, vn, word2idx):
checkpoint = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"train_loss": total_loss / n,
"val_loss": val_loss / vn,
"vocab": word2idx,
}
torch.save(checkpoint, f"checkpoint_epoch_{epoch+1}.pt")
def train(*,
model: torch.nn.Module,
train_loader: DataLoader,
dev_loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
max_epochs: int,
metric: Optional[torchmetrics.metric] = None,
device: Optional[torch.device] = None,
liveplot: Optional[LivePlot]=None):
if device is None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in tqdm(range(max_epochs)):
training_loss_acc = 0
training_examples = 0
model.train()
for i, batch in enumerate(train_loader):
optimizer.zero_grad()
x_batch, y_batch = batch
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
logits, _ = model(x_batch)
loss = criterion(logits.reshape(-1, logits.size(-1)), y_batch.reshape(-1))
loss.backward()
# Clip gradients at 1.0 to prevent exploding gradients,
# which can be a common issue with RNNs
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
training_loss_acc += loss.item()
training_examples += x_batch.size(0)
model.eval()
with torch.no_grad():
dev_loss_acc = 0
dev_examples = 0
dev_accuracy = 0
for i, batch in enumerate(dev_loader):
x_batch, y_batch = batch
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
logits, _ = model(x_batch)
dev_loss_acc += criterion(logits.reshape(-1, logits.size(-1)), y_batch.reshape(-1)).item()
dev_examples += x_batch.size(0)
if metric:
dev_accuracy += metric(torch.argmax(y_hat, -1), y_batch)
if liveplot is not None:
liveplot.tick() # Update the liveplot time
liveplot.report("Training loss", training_loss_acc / training_examples)
liveplot.report("Development loss", dev_loss_acc / dev_examples)
if metric:
liveplot.report("Development accuracy", dev_accuracy / (i+1), secondary_y=True)
# Save model
save_model(
epoch=epoch,
model=model,
optimizer=optimizer,
total_loss=training_loss_acc,
n=training_examples,
val_loss=dev_loss_acc,
vn=dev_examples,
word2idx=word2idx
)Download and tokenize War and Peace from https://www.gutenberg.org.
URL = "https://www.gutenberg.org/files/2600/2600-0.txt"
def load_text():
text = requests.get(URL).text
# Strip Gutenberg header/footer
start = text.find("CHAPTER I")
end = text.rfind("End of the Project Gutenberg")
text = text[start:end]
# Lowercase and split off punctuation as separate tokens
text = text.lower()
text = re.sub(r"[^a-z0-9\s\.\,\;\:\?\!\'\-]", " ", text)
tokens = re.findall(r"[a-z0-9\-']+|[\.\,\;\:\?\!]", text)
return tokens
tokens = load_text()
print(f"Total tokens: {len(tokens):,}")Build a vocabulary size 10000 of the most common words, and add special tokens for unknown and padding characters.
VOCAB_SIZE = 10_000
counter = Counter(tokens)
most_common = counter.most_common(VOCAB_SIZE - 2) # leave room for <unk>, <pad>
word2idx = {"<pad>": 0, "<unk>": 1}
for word, _ in most_common:
word2idx[word] = len(word2idx)
idx2word = {i: w for w, i in word2idx.items()}
def encode(tok):
return word2idx.get(tok, word2idx["<unk>"])
data = torch.tensor([encode(t) for t in tokens], dtype=torch.long)
print(f"Vocabulary size: {len(word2idx):,}")Instantiate a TextDataset class to hold input/output-pairs.
SEQ_LEN = 32
class TextDataset(Dataset):
def __init__(self, data, seq_len):
self.data = data
self.seq_len = seq_len
def __len__(self):
return len(self.data) - self.seq_len - 1
def __getitem__(self, idx):
x = self.data[idx : idx + self.seq_len]
y = self.data[idx + 1 : idx + self.seq_len + 1] # shifted by 1
return x, y
# Train/val split
split = int(0.95 * len(data))
train_ds = TextDataset(data[:split], SEQ_LEN)
val_ds = TextDataset(data[split:], SEQ_LEN)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=128, shuffle=False, num_workers=2)Define the LSTM model, consisting of an embedding layer, an LSTM layer, and a linear output layer.
class CharLSTM(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_size=256,
num_layers=2, dropout=0.3):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
self.fc = nn.Linear(hidden_size, vocab_size)
# Better recurrent weight init; improves training performance
for name, p in self.lstm.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(p)
elif "weight_ih" in name:
nn.init.xavier_uniform_(p)
elif "bias" in name:
nn.init.zeros_(p)
# Forget gate bias = 1
n = p.size(0)
p.data[n // 4 : n // 2].fill_(1.0)
def forward(self, x, hidden=None):
emb = self.embed(x) # (batch, tokens, embedding_size)
out, hidden = self.lstm(emb, hidden) # (batch, tokens, hidden_size)
logits = self.fc(out) # (batch, tokens, vocab_size)
return logits, hiddenTrain the model, saving the model for each epoch.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CharLSTM(len(word2idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
epochs = 5
liveplot = LivePlot()
liveplot.fig.update_layout(width=1200, height=800, font_size=18)
liveplot.increment(epochs)
train(model=model,
train_loader=train_loader,
dev_loader=val_loader,
optimizer=optimizer,
criterion=criterion,
max_epochs=epochs,
liveplot=liveplot,
device=device)Define function to generate sentences based on a prompt.
@torch.no_grad()
def generate(model, prompt, n_words=100, temperature=0.8, top_k=20):
model.eval()
tokens = re.findall(r"[a-z0-9\-']+|[\.\,\;\:\?\!]", prompt.lower())
idx = torch.tensor([encode(t) for t in tokens],
dtype=torch.long, device=device).unsqueeze(0)
# Warm up the hidden state on the prompt
logits, hidden = model(idx)
output_tokens = list(tokens)
next_input = idx[:, -1:].clone()
for _ in range(n_words):
logits, hidden = model(next_input, hidden)
logits = logits[:, -1, :] / temperature
# Top-k filtering ~ beam search
if top_k:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float("inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
output_tokens.append(idx2word[next_id.item()])
next_input = next_id
# Re-attach punctuation nicely
text = ""
for tok in output_tokens:
if tok in ".,;:?!":
text = text.rstrip() + tok + " "
else:
text += tok + " "
return text.strip()After completing training, we can load models from different epochs to see how (if) the model has improved.
prompt = "prince andrew looked at"
n_words = 80
temperature = 0.8
m1 = CharLSTM(len(word2idx)).to(device)
ckpt = torch.load("checkpoint_epoch_1.pt", map_location=device)
m1.load_state_dict(ckpt["model_state_dict"])
m5 = CharLSTM(len(word2idx)).to(device)
ckpt = torch.load("checkpoint_epoch_5.pt", map_location=device)
m5.load_state_dict(ckpt["model_state_dict"])
print("\n--- Sample, m1 ---")
print(generate(m1, prompt, n_words=n_words, temperature=temperature))
print("\n--- Sample, m5 ---")
print(generate(m5, prompt, n_words=n_words, temperature=temperature))