Lab demonstration: War and Peace

War and peace demonstration
Author
Affiliation

Per Unneberg

NBIS

Published

May 6, 2026

About

This notebook is based on the output generated by the Claude LLM model claude-opus-4-7. The LLM prompt was

Could you provide a complete PyTorch example of a generative RNN that is trained on “War and peace” and that generates sequences? Ideally, the example should use an LSTM and word embeddings.

The output has been edited and reformatted for increased clarity. If you want to test this notebook, you will need a GPU. Completing 5 training epochs takes approximately 15 minutes on my computer GPU; to get reasonable outputs you probably need to run 50 epochs or more which would take several hours to complete. You can try it out on Google Colab:

  1. Download the notebook from this link.
  2. Go to colab.research.google.com, then File → Open notebook → Upload and select the downloaded file.
  3. Make a copy to your own Google Drive before starting your work, so that your changes are saved.

Claude also provided some comments on the code, which I include here quoted as-is:

==============================================

Key design choices

Choice Reason
Word-level (not char-level) Faster convergence, more coherent output
Embedding dim 128 Standard for ~10k vocab; can go up to 256
2-layer LSTM, hidden 256 Good balance of capacity vs. training time
Vocabulary cap at 10k Tail words become ; keeps softmax tractable
Orthogonal init + forget bias = 1 Matches Keras defaults that often work better
Top-k + temperature sampling Avoids both repetition (greedy) and gibberish (pure sampling)
Gradient clipping at 1.0 Standard safeguard for RNNs

Tuning for better output

  • Train longer (20–50 epochs) — 5 is just to verify it runs
  • Increase hidden_size to 512 if you have GPU memory
  • Lower temperature (0.6–0.7) for more conservative text
  • Increase SEQ_LEN to 64 or 128 for longer-range context

One known limitation

This trains on disjoint windows rather than maintaining hidden state across an epoch (truncated BPTT). For a more sophisticated implementation, you’d batch contiguous chunks and carry hidden state forward — but the above is the standard introductory pattern and works well.

==============================================

0. Initialization

Load relevant libraries and define liveplotting class, training loop, and function to save model checkpoints.

Code
import re
import requests
from collections import Counter
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
Code
from typing import Optional
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import torchmetrics

class LivePlot():
    def __init__(self, left_label="Loss", right_label="Accuracy"):
        self.fig = go.FigureWidget(
            make_subplots(specs=[[{"secondary_y": True}]])
        )
        self.fig.update_yaxes(title_text=left_label,  secondary_y=False)
        self.fig.update_yaxes(title_text=right_label, secondary_y=True)

        self.plot_indices = {}
        self.trace_secondary = {}
        display(self.fig)
        self.limits = [0, 0]
        self.current_x = 0

    def report(self, name: str, value: float, secondary_y: bool = False):
        try:
            plot_index = self.plot_indices[name]
        except KeyError:
            plot_index = len(self.fig.data)
            self.fig.add_scatter(
                y=[], x=[], name=name,
                secondary_y=secondary_y
            )
            self.plot_indices[name] = plot_index
            self.trace_secondary[name] = secondary_y
        self.fig.data[plot_index].y += (value,)
        self.fig.data[plot_index].x += (self.current_x,)

    def increment(self, n_ticks: int):
        self.limits[1] += n_ticks
        self.fig.update_layout(xaxis_range=self.limits)

    def set_limit(self, n_ticks: int):
        self.limits[1] = n_ticks
        self.fig.update_layout(xaxis_range=self.limits)

    def tick(self, n_ticks: Optional[int] = None):
        if n_ticks is None:
            n_ticks = 1
        self.current_x += n_ticks

def save_model(*, epoch, model, optimizer, total_loss, n, val_loss, vn, word2idx):
    checkpoint = {
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_loss": total_loss / n,
        "val_loss": val_loss / vn,
        "vocab": word2idx,
    }
    torch.save(checkpoint, f"checkpoint_epoch_{epoch+1}.pt")
        
def train(*,
          model: torch.nn.Module,
          train_loader: DataLoader,
          dev_loader: DataLoader,
          optimizer: torch.optim.Optimizer,
          criterion: torch.nn.Module,
          max_epochs: int,
          metric: Optional[torchmetrics.metric] = None,
          device: Optional[torch.device] = None,
          liveplot: Optional[LivePlot]=None):
    if device is None:
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    model.to(device)

    for epoch in tqdm(range(max_epochs)):
        training_loss_acc = 0
        training_examples = 0
        model.train()

        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            x_batch, y_batch = batch
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            logits, _ = model(x_batch)

            loss = criterion(logits.reshape(-1, logits.size(-1)), y_batch.reshape(-1))
            loss.backward()

            # Clip gradients at 1.0 to prevent exploding gradients,
            # which can be a common issue with RNNs
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            training_loss_acc += loss.item()
            training_examples += x_batch.size(0)

        model.eval()
        with torch.no_grad():
            dev_loss_acc = 0
            dev_examples = 0
            dev_accuracy = 0
            for i, batch in enumerate(dev_loader):
                x_batch, y_batch = batch
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                logits, _ = model(x_batch)
                dev_loss_acc += criterion(logits.reshape(-1, logits.size(-1)),  y_batch.reshape(-1)).item()
                dev_examples += x_batch.size(0)
                if metric:
                    dev_accuracy += metric(torch.argmax(y_hat, -1), y_batch)

        if liveplot is not None:
            liveplot.tick() # Update the liveplot time
            liveplot.report("Training loss", training_loss_acc / training_examples)
            liveplot.report("Development loss", dev_loss_acc / dev_examples)
            if metric:
                liveplot.report("Development accuracy", dev_accuracy / (i+1), secondary_y=True)

        # Save model
        save_model(
            epoch=epoch,
            model=model,
            optimizer=optimizer,
            total_loss=training_loss_acc,
            n=training_examples,
            val_loss=dev_loss_acc,
            vn=dev_examples,
            word2idx=word2idx
        )

1. Download and tokenize War and Peace

Download and tokenize War and Peace from https://www.gutenberg.org.

Code
URL = "https://www.gutenberg.org/files/2600/2600-0.txt"

def load_text():
    text = requests.get(URL).text
    # Strip Gutenberg header/footer
    start = text.find("CHAPTER I")
    end = text.rfind("End of the Project Gutenberg")
    text = text[start:end]
    # Lowercase and split off punctuation as separate tokens
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s\.\,\;\:\?\!\'\-]", " ", text)
    tokens = re.findall(r"[a-z0-9\-']+|[\.\,\;\:\?\!]", text)
    return tokens

tokens = load_text()
print(f"Total tokens: {len(tokens):,}")

2. Build vocabulary (keep top-N most frequent words)

Build a vocabulary size 10000 of the most common words, and add special tokens for unknown and padding characters.

Code
VOCAB_SIZE = 10_000
counter = Counter(tokens)
most_common = counter.most_common(VOCAB_SIZE - 2)  # leave room for <unk>, <pad>

word2idx = {"<pad>": 0, "<unk>": 1}
for word, _ in most_common:
    word2idx[word] = len(word2idx)
idx2word = {i: w for w, i in word2idx.items()}

def encode(tok):
    return word2idx.get(tok, word2idx["<unk>"])

data = torch.tensor([encode(t) for t in tokens], dtype=torch.long)
print(f"Vocabulary size: {len(word2idx):,}")

3. Dataset of (sequence, next-word) windows

Instantiate a TextDataset class to hold input/output-pairs.

Code
SEQ_LEN = 32

class TextDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len - 1

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + self.seq_len + 1]  # shifted by 1
        return x, y

# Train/val split
split = int(0.95 * len(data))
train_ds = TextDataset(data[:split], SEQ_LEN)
val_ds   = TextDataset(data[split:], SEQ_LEN)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=128, shuffle=False, num_workers=2)

4. Model

Define the LSTM model, consisting of an embedding layer, an LSTM layer, and a linear output layer.

Code
class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_size=256,
                 num_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_size, vocab_size)

        # Better recurrent weight init; improves training performance
        for name, p in self.lstm.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(p)
            elif "weight_ih" in name:
                nn.init.xavier_uniform_(p)
            elif "bias" in name:
                nn.init.zeros_(p)
                # Forget gate bias = 1
                n = p.size(0)
                p.data[n // 4 : n // 2].fill_(1.0)

    def forward(self, x, hidden=None):
        emb = self.embed(x)                    # (batch, tokens, embedding_size)
        out, hidden = self.lstm(emb, hidden)   # (batch, tokens, hidden_size)
        logits = self.fc(out)                  # (batch, tokens, vocab_size)
        return logits, hidden

5. Training loop

Train the model, saving the model for each epoch.

Code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CharLSTM(len(word2idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
criterion = nn.CrossEntropyLoss(ignore_index=0)
epochs = 5
liveplot = LivePlot()
liveplot.fig.update_layout(width=1200, height=800, font_size=18)
liveplot.increment(epochs)

train(model=model,
      train_loader=train_loader,
      dev_loader=val_loader,
      optimizer=optimizer,
      criterion=criterion,
      max_epochs=epochs,
      liveplot=liveplot,
      device=device)

6. Generation with temperature + top-k sampling

Define function to generate sentences based on a prompt.

Code
@torch.no_grad()
def generate(model, prompt, n_words=100, temperature=0.8, top_k=20):
    model.eval()
    tokens = re.findall(r"[a-z0-9\-']+|[\.\,\;\:\?\!]", prompt.lower())
    idx = torch.tensor([encode(t) for t in tokens],
                       dtype=torch.long, device=device).unsqueeze(0)

    # Warm up the hidden state on the prompt
    logits, hidden = model(idx)
    output_tokens = list(tokens)

    next_input = idx[:, -1:].clone()
    for _ in range(n_words):
        logits, hidden = model(next_input, hidden)
        logits = logits[:, -1, :] / temperature

        # Top-k filtering ~ beam search
        if top_k:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        output_tokens.append(idx2word[next_id.item()])
        next_input = next_id

    # Re-attach punctuation nicely
    text = ""
    for tok in output_tokens:
        if tok in ".,;:?!":
            text = text.rstrip() + tok + " "
        else:
            text += tok + " "
    return text.strip()

After completing training, we can load models from different epochs to see how (if) the model has improved.

Code
prompt = "prince andrew looked at"
n_words = 80
temperature = 0.8

m1 = CharLSTM(len(word2idx)).to(device)
ckpt = torch.load("checkpoint_epoch_1.pt", map_location=device)
m1.load_state_dict(ckpt["model_state_dict"])

m5 = CharLSTM(len(word2idx)).to(device)
ckpt = torch.load("checkpoint_epoch_5.pt", map_location=device)
m5.load_state_dict(ckpt["model_state_dict"])

print("\n--- Sample, m1 ---")
print(generate(m1, prompt, n_words=n_words, temperature=temperature))

print("\n--- Sample, m5 ---")
print(generate(m5, prompt, n_words=n_words, temperature=temperature))