try:
from google.colab import output
output.enable_custom_widget_manager()
%pip install medmnist open_clip_torch
except ImportError:
passLab session: Representation learning
Running on colab
When running on colab, we need to install some dependencies.
Preparation
Configuration
Execute the following code blocks to configure the session and import relevant modules.
%config InlineBackend.figure_format ='retina'
%load_ext autoreload
%autoreload 2
%matplotlib inlineimport os
os.environ["HSA_ENABLE_DXG_DETECTION"]="1"import os
import sys
import math
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, SubsetRandomSampler
from torchvision.datasets import ImageFolder
import torchvision.transforms.v2 as transforms
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from sklearn.decomposition import PCA
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import medmnist
import matplotlib.pyplot as plt
import plotly.graph_objects as gotorch.cuda.is_available()device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')Training helper functions
Live plot
This cell defines a simple live plot we’ll use during training
from typing import Optional
import plotly.graph_objects as go
class LivePlot():
def __init__(self):
self.fig = fig = go.FigureWidget()
self.plot_indices = {}
display(self.fig)
self.limits = [0,0]
self.current_x = 0
def report(self, name: str, value: float):
"Report new value for line `name` of the current time step. Use "
try:
plot_index = self.plot_indices[name]
except KeyError:
plot_index = len(self.fig.data)
self.fig.add_scatter(y=[], x=[], name=name)
self.plot_indices[name] = plot_index
self.fig.data[plot_index].y += (value,)
self.fig.data[plot_index].x += (self.current_x,)
def increment(self, n_ticks: int):
"Increment the currently displayed limits with these many ticks"
self.limits[1] += n_ticks
self.fig.update_layout(xaxis_range=self.limits)
def set_limit(self, n_ticks: int):
"Update the currently displayed to exactly these many ticsk"
self.limits[1] = n_ticks
self.fig.update_layout(xaxis_range=self.limits)
def tick(self, n_ticks: Optional[int] = None):
"Update the current time with these many ticks, or 1 tick if no argument is supplied."
if n_ticks is None:
n_ticks = 1
self.current_x += n_ticksYou can test it out running the cells below. It highlights how the figure can be updated interactively without having to create a new one (try to run the second cell with the forloop multiple times).
plot = LivePlot()
i = 0plot.increment(5)
for n in range(5):
plot.tick()
plot.report("dev loss", i**2)
i+=1Training loop
This is a boilerplate training loop for running the training loop a certain number of epochs
import torch
from torch.utils.data import DataLoader
from torch.nn import Module
import torch.nn as nn
def train(*,
model: Module,
train_loader: DataLoader,
dev_loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
max_epochs: int,
device: Optional[torch.device] = None,
liveplot: Optional[LivePlot]=None):
if device is None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in range(max_epochs):
training_loss_acc = 0
training_examples = 0
model.train()
for i, batch in enumerate(train_loader):
optimizer.zero_grad()
x_batch, y_batch = batch
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
y_hat = model(x_batch)
#print("y_hat shape", y_hat.shape, "y_batch shape", y_batch.shape)
loss = criterion(y_hat, y_batch)
loss.backward()
optimizer.step()
training_loss_acc += loss.item()
training_examples += x_batch.size(0)
model.eval()
with torch.no_grad():
dev_loss_acc = 0
dev_examples = 0
for batch in dev_loader:
x_batch, y_batch = batch
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
y_hat = model(x_batch)
#print("y_hat shape", y_hat.shape, "y_batch shape", y_batch.shape)
dev_loss_acc += criterion(y_hat, y_batch).item()
dev_examples += x_batch.size(0)
if liveplot is not None:
liveplot.tick() # Update the liveplot time
liveplot.report("Training loss", training_loss_acc / training_examples)
liveplot.report("Development loss", dev_loss_acc / dev_examples)from sklearn.metrics import confusion_matrix
import plotly.figure_factory as ff
def plot_confusion_matrix(model, dataset, title="Confusion matrix", device=None):
if device is None:
device = next(model.parameters()).device
label_dict = {int(k): v for k, v in dataset.dataset.info['label'].items()}
mapped_labels = dataset.get_mapped_labels()
class_names = [label_dict[orig] for orig in dataset.get_original_labels()]
loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=8)
all_preds, all_targets = [], []
model.eval()
with torch.no_grad():
for x_batch, y_batch in loader:
x_batch = x_batch.to(device)
preds = model(x_batch).argmax(dim=1).cpu().numpy()
all_preds.extend(preds)
all_targets.extend(y_batch.squeeze().numpy())
cm = confusion_matrix(all_targets, all_preds, labels=mapped_labels)
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
fig = ff.create_annotated_heatmap(
z=np.round(cm_norm, 2),
x=class_names,
y=class_names,
colorscale='Blues',
showscale=True,
)
fig.update_layout(
title=title,
xaxis_title="Predicted",
yaxis_title="True",
yaxis_autorange="reversed",
)
fig.show()Representation learning
In this lab we’ll look at the representations neural networks learn. We’ll start by looking at how a neural network learns representations during supervised learning, commonly used in computer vision for transfer learning.
Learning supervised representations
By training a neural network on some dataset, the representation it learns might be useful in some other dataset. We’ll take a look at how this could work with a small dataset.
The data
We’re using data from the MedMNIST dataset, a varied collection of MNIST-style datasets from the medical imaging domain. The data is in the form of RGB images (some only grayscale) with a size of \(28x28\) pixels. In this lab we’ll work a lot with the idea of having small amounts of labeled data, so we will simulate this by setting side some of our data as “unlabeled”. We also want to simulate transfer learning, and we’ll do this by splitting our labeled data into sets depending on the label.
We start by setting up the basic data classes.
# The dataset stores the images with the channel first, in this lab we flatten the images before sending them to the neural network,
# but the matplotlib visualization functions want's the channel axis first, so we add this transform so the
# visualization functions don't have to deal with it
class Transpose(nn.Module):
def __init__(self, order):
super(Transpose, self).__init__()
self.order = order
def forward(self, img):
return img.permute(self.order)
# We separate the transforms in the case that we want to add augmentations to the training data later on. For now, we just resize the images to 64x64 pixels, convert them to grayscale and then to tensors.
train_transforms = transforms.Compose([
transforms.ToTensor(),
#Transpose((1,2,0))
])
eval_transforms = transforms.Compose([
transforms.ToTensor(),
#Transpose((1,2,0))
])
from medmnist import BloodMNIST, PathMNIST, TissueMNIST
# You can easily change what dataset to use by changing which class this points to
DatasetClass = PathMNIST
train_data = DatasetClass(split="train", download=True, transform=train_transforms)
dev_data = DatasetClass(split="val", download=True, transform=train_transforms)
test_data = DatasetClass(split="test", download=True, transform=train_transforms)
img, _ = train_data[0]
IMG_SHAPE = img.shape
N_CHANNELS, HEIGHT, WIDTH = IMG_SHAPE
N_CLASSES = len(train_data.info['label'])
print("Train dataset size:", len(train_data))
print("Development dataset size:", len(dev_data))
print("Test dataset size:", len(test_data))
print("Image shape:", img.shape)
print("Number of channels:", N_CHANNELS)Let’s have a look at the images and the targets in the dataset:
n_per_class = 5
label_dict = {int(k): v for k, v in train_data.info['label'].items()}
labels = sorted(label_dict.keys())
samples_per_class = {label: [] for label in labels}
for idx in range(len(train_data)):
img_tensor, y = train_data[idx]
label = y.item()
if len(samples_per_class[label]) < n_per_class:
samples_per_class[label].append((img_tensor, label))
if all(len(v) == n_per_class for v in samples_per_class.values()):
break
fig, axes = plt.subplots(n_per_class, N_CLASSES, figsize=(N_CLASSES * 1.5, n_per_class * 1.5))
for col, label in enumerate(labels):
for row, (img_tensor, _) in enumerate(samples_per_class[label]):
ax = axes[row, col]
ax.imshow(img_tensor.numpy().transpose(1, 2, 0), cmap='gray')
if row == 0:
ax.set_title(f"{label}", fontsize=8)
ax.axis('off')
plt.tight_layout()
plt.show()Making splits
We will simulate having limited data and new kinds of data but only training on a subset of the labels and samples. We will do this by making pytorch samplers which can be used when creating dataloaders. We first split the data into a set of “labeled” and pretend “unlabeled” images.
from collections import defaultdict
from torch.utils.data import Dataset
class SubsetDataset(Dataset):
"""Custom dataset class which wraps a dataset and exposes only a
subset of its contents. Also performs a label-mapping to
simulate a different label set"""
def __init__(self, dataset, indices, label_map=None):
super().__init__()
self.dataset = dataset
self.indices = indices
self.label_map = label_map
def __len__(self):
return len(self.indices)
def __getitem__(self, item):
real_index = self.indices[item]
x, y = self.dataset[real_index]
if self.label_map is not None:
y = self.label_map[y.item()]
return x, y
def get_mapped_labels(self):
return sorted(self.label_map.values())
def get_original_labels(self):
return sorted(self.label_map.keys())
def get_n_classes(self):
return len(self.label_map)
# These constants control the splitting
pretend_labeled_ratio = 0.2 # We pretend that only 20% of data is labeled
held_out_ratio = 0.3 # We pretend that this ratio of labels belong to the transfer task
label_dict = {int(int_label): str_label for int_label, str_label in train_data.info['label'].items()}
labels = sorted(label_dict.keys())
rng = np.random.default_rng(seed=42)
n_held_out_labels = int(np.ceil(len(labels)*held_out_ratio))
held_out_labels = rng.choice(labels, size=n_held_out_labels, replace=False)
kept_labels = [l for l in labels if l not in held_out_labels]
print(f"Held-out labels: {held_out_labels}, training labels: {kept_labels}")
held_out_label_map = {l: i for i,l in enumerate(sorted(held_out_labels))}
kept_label_map = {l: i for i,l in enumerate(sorted(kept_labels))}
# We start by collecting indices by label. We want to do the
# pretend split based on label so that it's balanced
train_indices_by_label = defaultdict(list)
for i,l in enumerate(train_data.labels):
# The order of the indices per label list will be
# deterministic since they are in the labels attribute of the dataset
train_indices_by_label[l.item()].append(i)
# We now split all the indices into a set of pretend train and unlabeled.
train_pretend_unlabeled_indices = defaultdict(list)
train_pretend_labeled_indices = defaultdict(list)
for label, indices in train_indices_by_label.items():
n_labeled_examples = int(np.ceil(len(indices)*pretend_labeled_ratio))
rng.shuffle(indices) # Shuffle is in-place
train_pretend_labeled_indices[label].extend(indices[:n_labeled_examples])
train_pretend_unlabeled_indices[label].extend(indices[n_labeled_examples:])
# We start by collecting indices by label. We want to do the
# pretend split based on label so that it's balanced
dev_indices_by_label = defaultdict(list)
for i,l in enumerate(dev_data.labels):
# The order of the indices per label list will be
# deterministic since they are in the labels attribute of the dataset
dev_indices_by_label[l.item()].append(i)
# We now split all the indices into a set of pretend dev and unlabeled.
dev_pretend_unlabeled_indices = defaultdict(list)
dev_pretend_labeled_indices = defaultdict(list)
for label, indices in dev_indices_by_label.items():
n_labeled_examples = int(np.ceil(len(indices)*pretend_labeled_ratio))
rng.shuffle(indices) # Shuffle is in-place
dev_pretend_labeled_indices[label].extend(indices[:n_labeled_examples])
dev_pretend_unlabeled_indices[label].extend(indices[n_labeled_examples:])
train_subsample_indices = [i for keep_label in kept_labels for i in train_pretend_labeled_indices[keep_label]]
train_transfer_subsample_indices = [i for held_out_label in held_out_labels for i in train_pretend_labeled_indices[held_out_label]]
train_all_unlabeled = [i for indices in train_pretend_unlabeled_indices.values() for i in indices]
train_data_a = SubsetDataset(train_data, train_subsample_indices, kept_label_map)
train_data_b = SubsetDataset(train_data, train_transfer_subsample_indices, held_out_label_map)
train_unlabeled = SubsetDataset(train_data, train_all_unlabeled)
dev_subsample_indices = [i for keep_label in kept_labels for i in dev_pretend_labeled_indices[keep_label]]
dev_transfer_subsample_indices = [i for held_out_label in held_out_labels for i in dev_pretend_labeled_indices[held_out_label]]
dev_all_unlabeled = [i for indices in dev_pretend_unlabeled_indices.values() for i in indices]
dev_data_a = SubsetDataset(dev_data, dev_subsample_indices, kept_label_map)
dev_data_b = SubsetDataset(dev_data, dev_transfer_subsample_indices, held_out_label_map)
dev_unlabeled = SubsetDataset(dev_data, dev_all_unlabeled)
print("Train subsampler number of examples: ", len(train_data_a))
print("Train transfer number of examples: ", len(train_data_b))
print("Train \"unlabeled\" number of examples: ", len(train_unlabeled))
print("dev subsampler number of examples: ", len(dev_data_a))
print("dev transfer number of examples: ", len(dev_data_b))
print("dev \"unlabeled\" number of examples: ", len(dev_unlabeled))Supervised learning
Now that the data is prepared, we’ll train a simple Convolutional Neural Network to solve the problem.
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, n_input_channels, n_output_channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(n_input_channels, n_output_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(n_output_channels),
nn.ReLU(),
nn.MaxPool2d(2),
)
def forward(self, x):
return self.block(x)
class ConvNet(nn.Module):
def __init__(self, input_channels, n_classes, latent_dim=128):
super().__init__()
self.latent_dim = latent_dim
self.conv_blocks = nn.ModuleList([
ConvBlock(input_channels, 32), # 28x28 -> 14x14
ConvBlock(32, 64), # 14x14 -> 7x7
ConvBlock(64, 128), # 7x7 -> 3x3
])
self.encoder_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 3 * 3, latent_dim),
nn.ReLU(),
)
self.bn = nn.BatchNorm1d(latent_dim, affine=False)
self.classifier = nn.Linear(latent_dim, n_classes)
def encode(self, x):
for block in self.conv_blocks:
x = block(x)
return self.encoder_head(x)
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
embedded = self.encode(x)
batch_normed = self.bn(embedded)
normalized = F.normalize(batch_normed, dim=1)
return self.classifier(normalized)Now we’ll train the model. We’ll use the same training loop throughout the lab, defined below.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
latent_dim = 128
# Note that we use the train_sampler, which only samples examples from the labelsubset
train_loader = DataLoader(train_data_a, shuffle=True, batch_size=batch_size, num_workers=8)
dev_loader = DataLoader(dev_data_a, batch_size=batch_size, num_workers=8)
n_classes = train_data_a.get_n_classes()
cnn_model = ConvNet(input_channels=N_CHANNELS, n_classes=n_classes, latent_dim=latent_dim)
cnn_model.to(device)
optimizer = AdamW(cnn_model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=cnn_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)Looking at the representations
Now that we’ve trained the network we can look at how it maps the images to different representations. A useful tool for creating feature extractors is the create_feature_extractor from the torchvision package. First we inspect what layers our network contain using the get_graph_node_names function.
from torchvision.models.feature_extraction import get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(cnn_model)
print("Eval nodes:")
for node in eval_nodes:
print(node)The function traces the modules in the network, and creates one list for nodes used in training and one for evaluation. We’re interested in the evaluation list because we want to use the module in evaluation mode. The modules with the names encoder_head.[0-2] are the ones we’re interested in, they correspond to the layers we used to construct the model above. The features directly from the conv layers correspond to the Flatten node, which is encoder_head.0. We also might want to inspect the features after the nonlinearity (the ones used as input to the classifier). These are the results after the ReLU, so correspond to encoder_head.2. Knowing the names of the nodes we can now create a feature extractor with our module as a base:
from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(cnn_model, return_nodes={'encoder_head.0': 'flat_conv_features', 'encoder_head.2': 'encoder_features', 'normalize':'norm_features'})We now have created a feature extractor for our model. It will behave like our original model, but instead of returning the result of the forward() method like our base model, it will return a dictionary containing the keys flat_conv_features and encoder_features with the vectors corresponding to the results of those nodes. Let’s use this to embed our test exampels.
test_loader = DataLoader(test_data, batch_size=batch_size)
conv_features = []
encoder_features = []
norm_features = []
labels = []
feature_extractor.eval()
with torch.inference_mode():
for batch in test_loader:
x, y = batch
labels.extend(y.squeeze())
extracted_features = feature_extractor(x.to(device))
conv_features.append(extracted_features['flat_conv_features'].cpu().numpy())
encoder_features.append(extracted_features['encoder_features'].cpu().numpy())
norm_features.append(extracted_features['norm_features'].cpu().numpy())
encoder_features = np.concatenate(encoder_features)
conv_features = np.concatenate(conv_features)
norm_features = np.concatenate(norm_features)A common tool in representation learning is to inspect the representation using nonlinear dimensionality reduction techniques such as t-SNE or UMAP, which try to preserve local structure in the data when projecting it to a 2D space for visualization. Below we visualize the data in the space of the first two principal components, which is a simple linear projection of the data.
import umap
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
# This general visualization function will be used throughout the lab. It assumes a pytorch model with an `encode` method that takes in a batch of data and returns the corresponding latent representations. It then applies UMAP to the latent representations to project them to 2D, and visualizes the result using a scatter plot colored by the class labels.
def visualize_latent_space(features, feature_labels, model_name, n_neighbors=15, min_dist=0.1, samples_per_class=300, cmap_name="tab10", rng=None):
if rng is None:
rng = np.random.default_rng(1729)
if samples_per_class is not None:
indices_per_class = defaultdict(list)
for i,l in enumerate(feature_labels):
indices_per_class[int(l)].append(i)
sampled_indices = []
for l, indices in indices_per_class.items():
if len(indices) > samples_per_class:
indices = rng.choice(indices, size=samples_per_class, replace=False)
sampled_indices.extend(indices.tolist())
features = features[sampled_indices]
feature_labels = [feature_labels[i] for i in sampled_indices]
reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors, min_dist=min_dist, metric='cosine', random_state=rng.integers(0, 2**32-1))
embedding = reducer.fit_transform(features)
plt.figure(figsize=(8, 6))
cmap = plt.get_cmap("tab10", N_CLASSES) # resample to exactly n_classes slots
norm = mcolors.BoundaryNorm(boundaries=np.arange(-0.5, N_CLASSES), ncolors=N_CLASSES)
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=feature_labels, cmap=cmap, norm=norm, s=20, alpha=.8)
plt.colorbar(scatter, ticks=np.arange(N_CLASSES))
plt.title(f"UMAP projection of {model_name} representations with d={features.shape[1]}")
plt.tight_layout()
plt.show()visualize_latent_space(conv_features, labels, f"Conv features (held-out labels {held_out_labels})")visualize_latent_space(encoder_features, labels, f"Encoder features (held-out labels {held_out_labels})")visualize_latent_space(norm_features, labels, f"Normalized features (held-out labels {held_out_labels})")Transfer learning
We will use the model trained above to transfer what it learned to our pretend new task. We do this by replacing the prediction head with a newly initialized one, and train with the transfer sampler instead.
During transfer learning it’s a good idea to first train only the newly initialized layer to make as good use of the already learned encoder as it can (otherwise, we’ll train the encoder to adjust to something which is just noise). We use an optimizer which only takes only the prediction head parameters,
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
# We use the other dataset to simulate transfer learning
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
cnn_model.reset_head(n_classes)
optimizer = AdamW(cnn_model.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()We can see how well this newly created classifier does on our previously unseen classes
plot_confusion_matrix(cnn_model, dev_data_b, title="Confusion matrix — dev set (transfer learning - before training)")Now let’s train it for a bit and see whether it can solve the problem
# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=cnn_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)Let’s inspect the results after training
plot_confusion_matrix(cnn_model, dev_data_b, title="Confusion matrix — dev set (transfer learning - after training)")Exercises
Exercise 1: Effect of labeled data amount
Change pretend_labeled_ratio (defined in the data splitting cell) from 0.2 to 0.05 and then to 1.0. Re-run the CNN training and the transfer learning for each value. How does the amount of labeled data affect:
- The final dev accuracy on the transfer task?
- The UMAP structure of the encoder features — specifically, how cleanly separated are the held-out class clusters?
Exercise 2: Freeze vs. fine-tune during transfer
The transfer training above only optimizes cnn_model.classifier.parameters(), keeping the encoder frozen. Add a second transfer run where you replace that with cnn_model.parameters() (fine-tuning everything). Compare the two on dev_data_b:
- Which reaches higher accuracy?
- Which converges faster?
- Does fine-tuning ever hurt? Think about why that might happen with a small transfer dataset.
Contrastive learning
We’ve seen how representations can be learned using regular supervised learning. We’ll change this a bit and see how we can drop the idea of classification, and instead learn representations using contrastive learning. Instead of attracting examples to a per-class vector, we’ll attract and repel them from each others.
def supcon_loss(z, y, temperature: float = 0.07):
"""
Supervised contrastive (SupCon) loss.
For each anchor i, maximise the average log-probability of sampling a
same-class example j relative to all other examples in the batch:
L_i = -(1/|P(i)|) * sum_{j in P(i)} [ s_ij/T - log sum_{k≠i} exp(s_ik/T) ]
Unlike a margin-based loss this is never geometrically infeasible: it finds
the best attainable arrangement and never collapses classes to share an angle
just to escape an impossible hard constraint. Lower T gives sharper
separation (analogous to a smaller margin).
"""
import torch
# B is the batch size
B = z.size(0)
# Sim is the matrix of dot-products between all pairs of features
# We have a temperature which we can use to scale the features.
# The lower the temperature, the more weight we will assign the contrast between different labels
sim = (z @ z.T) / temperature
# The identity matrix is central, we use it to mask the
# self-pairs which are uninformative (they have the same vector)
eye = torch.eye(B, dtype=torch.bool, device=z.device)
# By setting -inf at the diagonal, the self-pairs will receive zero weight
sim = sim.masked_fill(eye, float("-inf"))
# We create a target matrix by broadcasting y as a row vector against itself as a column vector.
# This creates a matrix of the same shape as sim, which will be True wherever the y_i and
# y_j pair have the same class. We also make sure the diagonal is filled with
# False values since we don't care about that similarity
same = (y.unsqueeze(0) == y.unsqueeze(1)) & ~eye
# Tally the total number of positive pairs. We clamp it to 1 to avoid dividing by zero.
n_pos = same.float().sum(dim=1).clamp(min=1)
# logsumexp first exponentiate all elements of the similarity. This is why we set the diagonal to
# 0. We then take the sum of the entries along the rows, followed by taking the logarithm of these sums.
#
log_denom = torch.logsumexp(sim, dim=1, keepdim=True)
log_prob = sim - log_denom # (B, B) log-probabilities
per_anchor = -(torch.where(same, log_prob, torch.zeros_like(log_prob))
.sum(dim=1)) / n_pos
return per_anchor[same.any(dim=1)].mean()import torch.nn as nn
import torch.nn.functional as F
class ContrastiveConvNet(nn.Module):
def __init__(self, input_channels, n_classes, latent_dim=128):
super().__init__()
self.latent_dim = latent_dim
self.conv_blocks = nn.ModuleList([
ConvBlock(input_channels, 32), # 28x28 -> 14x14
ConvBlock(32, 64), # 14x14 -> 7x7
ConvBlock(64, 128), # 7x7 -> 3x3
])
self.encoder_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 3 * 3, latent_dim),
nn.ReLU(),
)
self.projection_head = nn.Sequential(
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
)
self.bn = nn.BatchNorm1d(latent_dim, affine=False)
self.classifier = None
def encode(self, x):
for block in self.conv_blocks:
x = block(x)
return self.encoder_head(x)
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
embedding = self.encode(x)
projection = self.projection_head(embedding)
batch_normed = self.bn(projection)
return F.normalize(batch_normed, dim=1)from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
latent_dim = 128
# Note that we use the train_sampler, which only samples examples from the labelsubset
train_loader = DataLoader(train_data_a, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_a, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_a.get_n_classes()
contrastive_cnn_model = ContrastiveConvNet(input_channels=N_CHANNELS, n_classes=n_classes, latent_dim=latent_dim)
contrastive_cnn_model = ContrastiveConvNet(input_channels=N_CHANNELS, n_classes=n_classes, latent_dim=latent_dim)
optimizer = AdamW(contrastive_cnn_model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = supcon_loss # We use the Supervised Contrastive loss
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=contrastive_cnn_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)Inspect representations
Like before we can construct a feature extractor from this model
from torchvision.models.feature_extraction import get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(contrastive_cnn_model)
print("Eval nodes:")
for node in eval_nodes:
print(node)The interesting features here are encoder_head.0 and encoder_head.2 which share the same upstreams architecture with our previous model. The normalize feature are interesting since they are the ones we have trained the contrastive loss against. Often with contrastive learning, we don’t use the features from the projection head, as those become too specialized for the specific pairs we contrasted on.
from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(contrastive_cnn_model, return_nodes={'encoder_head.0': 'flat_conv_features', 'encoder_head.2': 'encoder_features', 'normalize': 'contrastive_features'})We now have created a feature extractor for our model. It will behave like our original model, but instead of returning the result of the forward() method like our base model, it will return a dictionary containing the keys flat_conv_features and encoder_features with the vectors corresponding to the results of those nodes. Let’s use this to embed our test exampels.
test_loader = DataLoader(test_data, batch_size=batch_size)
conv_features = []
encoder_features = []
contrastive_features = []
labels = []
feature_extractor.eval()
with torch.inference_mode():
for batch in test_loader:
x, y = batch
labels.extend(y.squeeze())
extracted_features = feature_extractor(x.to(device))
conv_features.append(extracted_features['flat_conv_features'].cpu().numpy())
encoder_features.append(extracted_features['encoder_features'].cpu().numpy())
contrastive_features.append(extracted_features['contrastive_features'].cpu().numpy())
encoder_features = np.concatenate(encoder_features)
conv_features = np.concatenate(conv_features)
contrastive_features = np.concatenate(contrastive_features)visualize_latent_space(conv_features, labels, f"Conv features (held-out labels {held_out_labels})")visualize_latent_space(encoder_features, labels, f"Encoder features (held-out labels {held_out_labels})")visualize_latent_space(contrastive_features, labels, f"Contrastive features (held-out labels {held_out_labels})")Transfer learning
Like the supervised model, we can transfer the learned encoder to a new task by replacing the contrastive projection head with a linear classifier and training only that head on the held-out labels.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
contrastive_cnn_model.reset_head(n_classes)
optimizer = AdamW(contrastive_cnn_model.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=contrastive_cnn_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)Evaluate transferred results
plot_confusion_matrix(contrastive_cnn_model, dev_data_b, title="Confusion matrix — dev set (contrastive transfer learning)")Exercises
Exercise 1: Temperature ablation
The temperature argument in supcon_loss controls how sharply the loss penalises nearby negatives. Try temperature=0.01, 0.07, and 0.5. For each, train a fresh ContrastiveConvNet and inspect the UMAP of contrastive_features. How does temperature affect:
- Training stability (watch the loss curve)?
- Cluster separation in the latent space?
Exercise 2: Batch size and positive availability
The SupCon loss can only use positives that appear in the same batch. Look at the clamp(min=1) in supcon_loss — this prevents division by zero when no same-class example is in the batch. Try batch_size=32 vs batch_size=512. Does a smaller batch produce a noisier loss? Does it hurt the final UMAP structure or transfer accuracy?
Self-supervised learning
We’ve seen how supervised learning is able to learn representations, but relies on labeled data. We will look at how we can use unlabeled data to learn representations using self-supervised learning.
Principal Component Analysis (PCA)
A classical form of representation learning is Principal Component Analysis (PCA). PCA finds the orthogonal directions of maximum variance in the data, and projects the data onto those directions. This can be useful for dimensionality reduction, visualization, and as a preprocessing step for other machine learning algorithms.
Below we illustrate how PCA works on a simple 2D dataset. The first principal component (PC1) is the direction of maximum variance, and the second principal component (PC2) is orthogonal to PC1 and captures the remaining variance. After projecting the data onto these components, we can see that the data is now represented in a new coordinate system defined by PC1 and PC2. In this new space, the axes are uncorrelated and the variance along each axis is maximized.
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
n_samples = 200
cov = np.array([[1, 2], [1, 1]]) # covariance matrix with some correlation
X = np.random.multivariate_normal(mean=[0, 0], cov=cov, size=n_samples)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
projection = PCA(n_components=2, whiten=False).fit(X)
component_1 = projection.components_[0]
component_2 = projection.components_[1]
std1, std2 = np.sqrt(projection.explained_variance_)
axes[0].scatter(X[:, 0], X[:, 1], alpha=0.5, edgecolor='k')
axes[0].set_title("Random 2D data with random covariance")
axes[0].set_xlabel("x1")
axes[0].set_ylabel("x2")
axes[0].axis("equal")
axes[0].arrow(0, 0, std1*component_1[0], std1*component_1[1], color='red', width=0.05, label='PC1')
axes[0].arrow(0, 0, std2*component_2[0], std2*component_2[1], color='blue', width=0.05, label='PC2')
axes[0].legend()
x_pca_projected = projection.transform(X)
axes[1].scatter(x_pca_projected[:, 0], x_pca_projected[:, 1], alpha=0.5, edgecolor='k')
axes[1].set_title("PCA projected data")
axes[1].set_xlabel("PC1")
axes[1].set_ylabel("PC2")
axes[1].axis("equal")
axes[1].arrow(0, 0, std1, 0, color='red', width=0.05, label='PC1')
axes[1].arrow(0, 0, 0, std2, color='blue', width=0.05, label='PC2')
axes[1].legend()
plt.tight_layout()
plt.show()
# plt.arrow(0, 0, component_1[0], component_1[1], color='red', width=0.02, label='PC1')
# plt.arrow(0, 0, component_2[0], component_2[1], color='blue', width=0.02, label='PC2')
# plt.title("Random 2D data with random covariance")
# plt.xlabel("x1")
# plt.ylabel("x2")
# plt.axis("equal")
# plt.show()--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[1], line 6 2 import seaborn as sns 3 from sklearn.decomposition import PCA 4 5 n_samples = 200 ----> 6 cov = np.array([[1, 2], [1, 1]]) # covariance matrix with some correlation 7 X = np.random.multivariate_normal(mean=[0, 0], cov=cov, size=n_samples) 8 9 fig, axes = plt.subplots(1, 2, figsize=(12, 5)) NameError: name 'np' is not defined
We can use this to try to represent the data in a reduced space, but only using the projection onto some reduced number of principal components. Below, you can see examples of projecting the data to either of the principal components.
fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharey='all', sharex='all')
#sns.rugplot(x_pca_projected[:, 0], color='red', label='x1', ax=axes[0])
axes[0].scatter(x_pca_projected[:, 0], np.zeros_like(x_pca_projected[:, 0]), color='red', label='x1', alpha=0.5, edgecolor='k')
axes[0].set_yticks([])
#sns.rugplot(x_pca_projected[:, 1], color='blue', label='x2', ax=axes[1])
axes[1].scatter(x_pca_projected[:, 1], np.zeros_like(x_pca_projected[:, 1]), color='blue', label='x2', alpha=0.5, edgecolor='k')
axes[1].set_yticks([])
plt.tight_layout()
plt.subplots_adjust(hspace=0.5)
plt.show()In this sense, projecting the data onto a set of principal components changes how it’s represented, and is an optimal choice if the goal is to preserve as much variance as possible with a linear projection. Is’s important to emphasize that PCA is a linear method, and may not capture complex nonlinear relationships in the data. We’ll look at how we can use a close analogue for PCA deep learning-based representation learning methods later in this lab.
PCA on our medical dataset
We’ll now see how we can apply PCA to our dataset. We will first flatten the images into vectors, and then apply PCA to find the principal components of the data. We can then visualize the data in the space of the first two principal components to see if there are any interesting patterns or clusters.
X_train = []
y_train = []
for img, y in train_unlabeled:
X_train.append(img.flatten())
y_train.append(y.item())
X_train = np.stack(X_train)
y_train = np.stack(y_train)X_dev = []
y_dev = []
for img, y in dev_unlabeled:
X_dev.append(img.flatten())
y_dev.append(y.item())
X_dev = np.stack(X_dev)
y_dev = np.stack(y_dev)
X_test = []
y_test = []
for img, y in test_data:
X_test.append(img.flatten())
y_test.append(y.item())
X_test = np.stack(X_test)
y_test = np.stack(y_test)pca = PCA(n_components=128).fit(X_train) # n_components here was chosen to be square, so they all fit in the grid belowAs these components are linear combinations of the original pixel values, we can visualize them as images to see what kind of features they capture.
n_components, d = pca.components_.shape
n_rows = int(np.ceil(np.sqrt(n_components)))
n_cols = int(np.ceil(n_components / n_rows))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 12), squeeze=False)
for ax in axes.flatten():
ax.axis('off')
for i, ax in zip(range(n_components), axes.flatten()):
component_img = pca.components_[i].reshape(*IMG_SHAPE)
# Since these are linear bases, they don't necessarily stay in the 0,1 range as
# the input images. We normalize them to make it clearer what kind of patterns they capture
component_img -= component_img.min()
component_img /= component_img.max()
ax.imshow(component_img.transpose(1,2,0), cmap='gray')
ax.set_title(f"PC {i+1}")We can then project the data onto these components, to get a new lower-dimensional representation of it. In a sense we can think of this as a simple lossy compression of the data, using the linear basis which captures the most variance in the full training dataset. Let’s have a look of how the data looks if we “decompress” it back to the original space, but only using a reduced number of principal components.
x_projected = pca.transform(X_train)
x_approx = pca.inverse_transform(x_projected)
n_rows, n_cols = 2, 10
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 6))
# The top row shows the approximated images using the PCA components, and the bottom row shows the original images for comparison.
for col_i in range(n_cols):
ax_a = axes[0, col_i]
ax_o = axes[1, col_i]
img_a = x_approx[col_i].reshape(*IMG_SHAPE).clip(0,1)
img_o = X_train[col_i].reshape(*IMG_SHAPE)
ax_a.imshow(img_a.transpose(1,2,0))
ax_a.set_title(f"PCs used: {pca.n_components_}")
ax_a.axis('off')
ax_o.imshow(img_o.transpose(1,2,0))
ax_o.set_title("Original")
ax_o.axis('off')
plt.tight_layout()
plt.show()Exercise:
Experiment with different number of principal components and see how it affects the quality of the reconstruction. You can also try applying the fitted PCA (e.g. pca.transform(X_dev)) to the development or test set, to see how well the components learned from the training set generalize to new data. At what value do you no longer see a difference between the original and the approximated images? What happens when you set k=1?
Tangent: variance explained
The variance explained by the first k principal components is given by the sum of the first k eigenvalues divided by the sum of all eigenvalues. This can be used to determine how many principal components to keep in order to retain a certain percentage of the variance in the data. You can compute this using the explained_variance_ratio_ attribute of the PCA object, which gives the percentage of variance explained by each component.
x = np.arange(1, len(pca.explained_variance_ratio_) + 1)
y = np.cumsum(pca.explained_variance_ratio_)
plt.figure(figsize=(8, 5))
plt.plot(x, y)
plt.xlabel('Number of principal components')
plt.ylabel('Cumulative explained variance')
plt.title('Variance explained by principal components')Visualzing the latent space
class PCAWrapper(nn.Module):
'''Dummy wrapper around PCA to make it compatible with the visualization function above, which expects a PyTorch model with an `encode` method.'''
def __init__(self, pca):
super(PCAWrapper, self).__init__()
self.pca = pca
self.latent_dim = pca.n_components_
def encode(self, x):
return torch.tensor(self.pca.transform(x), dtype=torch.float32)
#pca_model = PCAWrapper(pca)
pca_features = pca.transform(X_dev)
visualize_latent_space(pca_features, y_dev, "PCA projections")
#visualize_latent_space(X_dev, y_dev, "PCA projections")Experiment with different numbers of principal components (k), and see how it affects the visualization of the latent space. You can also try applying UMAP directly to the original data without PCA, or to the data projected onto a different number of principal components, to see how it affects the structure of the latent space.
PCA as a linear neural network
In the example above, we learned a projection matrix \(W\) that maps the original data to a lower dimensional space. So the latent representation was \(z_i = W x_i\), where \(W\) is a matrix of size \(k \times d\), where \(k\) is the number of principal components we want to keep, and \(d\) is the dimensionality of the original data. The principal components are derived from the eigenvalues of the covariance matrix of the data, for which there are highly optimized algorithms. However, we can also learn the projection matrix \(W\) using gradient descent, by treating it as a linear neural network with no activation function. The loss function we would use is the mean squared error between the original data and the reconstructed data from the latent representation, which is given by:
\[L = \frac{1}{n} \sum_{i=1}^n ||x_i - W^T W x_i||^2\]
Let’s implement this in PyTorch and see if we can learn similar projection matrices as PCA.
from tqdm import tqdm
class LinearAutoencoder(nn.Module):
def __init__(self, input_dim, latent_dim):
super(LinearAutoencoder, self).__init__()
self.latent_dim = latent_dim
# This encoder is a linear layer that maps the input to the latent space. We set bias=False because PCA does not include a bias term, and we want to learn a linear transformation that is centered around the origin.
self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
# We could use different matrices for the "compression" and "decompression" steps, but for PCA they are the same, so we will just use the transpose of the encoder weights for the decoder.
# self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
def forward(self, x):
z = self.encode(x)
x_recon = self.decode(z)
x_recon = x_recon.reshape(x.shape)
return x_recon
def encode(self, x):
# x is an image, so we flatten it first. We need to keep
# the batch dimension though, so we flatten all channels but the last
batch_size, *remaining_shape = x.shape
x = x.reshape((batch_size, -1))
return self.encoder(x)
def decode(self, z):
x_hat = self.encoder.weight.t() @ z.t()
return x_hat.t()from torch.utils.data import Dataset
class AEDataset(Dataset):
"""Simple dataset wrapper to return the image as the label as well"""
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
x, y = self.dataset[item]
return x, xtrain_unlabeled_ae = AEDataset(train_unlabeled)
dev_unlabeled_ae = AEDataset(dev_unlabeled)from torch.nn import MSELoss
from torch.utils.data import TensorDataset
batch_size = 256
train_loader = DataLoader(train_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
latent_dim = 128
x, _ = train_unlabeled_ae[0]
input_dim = x.flatten().shape[0]
linear_model = LinearAutoencoder(input_dim=input_dim, latent_dim=latent_dim)
optimizer = AdamW(linear_model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = MSELoss()
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=linear_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)We’ve now “trained” a “linear neural network” to learn a projection matrix that maps the original data to a lower-dimensional latent space, and then reconstructs the original data from that latent representation. The loss function we used is the mean squared error between the original data and the reconstructed data, which is the same as the loss function for PCA. After training, we can inspect the learned “principal components” and compare it to the principal components learned by PCA to see if they are similar.
# Get the learned projection matrix from the encoder weights
W_learned = linear_model.encoder.weight.detach().cpu().numpy()
n_components, d = W_learned.shape
n_rows = int(np.ceil(np.sqrt(n_components)))
n_cols = int(np.ceil(n_components / n_rows))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 12))
for ax in axes.flatten():
ax.axis('off')
for i, ax in zip(range(n_components), axes.flatten()):
component_img = W_learned[i].reshape(*IMG_SHAPE)
component_img -= component_img.min()
component_img /= component_img.max()
ax.imshow(component_img.transpose(1,2,0), cmap='gray')
ax.set_title(f"PC {i+1}")
ax.axis('off')def visualized_reconstruction(X_train, model, n_examples=5, device=None):
if device is None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
X_train_tensor = torch.tensor(X_train[:n_examples], dtype=torch.float32, device=device)
with torch.no_grad():
x_reconstructed = model(X_train_tensor).detach().cpu().numpy()
n_rows, n_cols = 2, 5
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 6))
# The top row shows the approximated images using the PCA components, and the bottom row shows the original images for comparison.
for col_i in range(n_cols):
ax_a = axes[0, col_i]
ax_o = axes[1, col_i]
img_a = x_reconstructed[col_i].reshape(*IMG_SHAPE)
img_o = X_train[col_i].reshape(*IMG_SHAPE)
ax_a.imshow(img_a.transpose(1,2,0), cmap='gray')
ax_a.axis('off')
ax_o.imshow(img_o.transpose(1,2,0), cmap='gray')
ax_o.set_title("Original")
ax_o.axis('off')
plt.suptitle(f"Reconstruction of original images using {type(model).__name__} with latent dim {model.latent_dim}")
plt.tight_layout()
plt.show()visualized_reconstruction(X_train, linear_model, n_examples=5)Questions
Are the learned components similar to the ones learned by PCA? If there is a difference, why do you think that is? Try experimenting with different learning rates, batch sizes, and numbers of epochs to see how it affects the learned components and the reconstruction loss. You can also try using a different optimization algorithm, such as SGD or RMSprop, to see if it converges to a different solution.
Visualizing the latent space
Just as with regular PCA, we can also visualize the latent space of the linear autoencoder by projecting the data onto the learned components and applying UMAP. Try this out and see how the structure of the latent space compares to the one obtained from PCA. You can also experiment with different numbers of components in the linear autoencoder, and see how it affects the structure of the latent space.
test_loader = DataLoader(test_data, batch_size=batch_size)
linear_ae_features = []
feature_labels = []
with torch.inference_mode():
for batch in test_loader:
x, y = batch
feature_labels.extend(y.squeeze())
embedded = linear_model.encode(x.to(device))
linear_ae_features.append(embedded.cpu().numpy())
linear_ae_features = np.concatenate(linear_ae_features)
visualize_latent_space(linear_ae_features, feature_labels, "Linear AE projections")Exercises
Exercise 1: Orthogonality of learned components
PCA components are orthogonal by construction (they are eigenvectors of the covariance matrix). Does gradient descent also find an orthogonal solution? After training, compute:
W = linear_model.encoder.weight.detach().cpu().numpy()
G = W @ W.T # Should be close to identity if orthogonal
plt.imshow(G, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar()
plt.title("W @ W.T (identity = orthogonal components)")
plt.show()How close to the identity matrix is the result? Why might gradient descent converge to a non-orthogonal solution even though both are valid minimisers of the MSE objective?
Exercise 2: Effect of a bias term
Change bias=False to bias=True in LinearAutoencoder.encoder. Retrain and compare:
- Does the reconstruction loss change?
- Do the learned components (visualised as images) look different?
Classical PCA is a centred projection — it does not include a bias. What does adding a bias allow the model to learn that PCA cannot?
Autoencoders
In the example above, we saw how we could approximate what PCA does using a linear neural network. However, PCA can only learn linear ways of representing the data (as a sum of basis vectors). Let’s see if we can learn a more complex, nonlinear representation of the data using a nonlinear autoencoder. An autoencoder is a neural network that consists of an encoder and a decoder. The encoder maps the input data to a latent representation, and the decoder maps the latent representation back to the original data. The autoencoder is trained to minimize the reconstruction loss between the original data and the reconstructed data from the latent representation.
class AEBlock(nn.Module):
def __init__(self, input_dim, output_dim, activation_function = 'leaky_relu'):
super(AEBlock, self).__init__()
self.dense = nn.Linear(input_dim, output_dim, bias=True)
self.normalization = nn.LayerNorm(output_dim)
self.activation_function = activation_function
def forward(self, x):
pre_activation = self.dense(x)
pre_activation = self.normalization(pre_activation)
if self.activation_function == 'leaky_relu':
activation = nn.functional.leaky_relu(pre_activation)
elif self.activation_function == 'relu':
activation = nn.functional.relu(pre_activation)
else:
activation = pre_activation
return activation
class Autoencoder(nn.Module):
def __init__(self, input_dim, latent_dim=16, encoder_dims=[256, 256], decoder_dims=[256, 256]):
super(Autoencoder, self).__init__()
# We create an explicit list of layers so that we can easily inspecit them and compose them how we'd like in the forward method. We could also have used `nn.Sequential` for the encoder and decoder, but this way we have more flexibility in how we compose the layers in the forward method, and we can also easily add skip connections or other architectural features if we want to experiment with that later on.
self.latent_dim = latent_dim
self.encoder_layers = nn.ModuleList()
dim_from_below = input_dim
for encoder_dim in encoder_dims:
dense_layer = AEBlock(dim_from_below, encoder_dim, activation_function='leaky_relu')
self.encoder_layers.append(dense_layer)
dim_from_below = encoder_dim
self.inner_layer = nn.Linear(dim_from_below, latent_dim, bias=True)
self.decoder_layers = nn.ModuleList()
dim_from_above = latent_dim
for decoder_dim in decoder_dims:
dense_layer = AEBlock(dim_from_above, decoder_dim, activation_function='leaky_relu')
self.decoder_layers.append(dense_layer)
dim_from_above = decoder_dim
self.final_layer = nn.Linear(dim_from_above, input_dim, bias=True)
self.classifier = None
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
z = self.encode(x)
x_hat = self.decode(z).reshape(x.shape)
return x_hat
def encode(self, x):
z = x.reshape((x.shape[0], -1))
for layer in self.encoder_layers:
z = layer(z)
return self.inner_layer(z)
def decode(self, z):
# We apply a nonlinearity after each layer except the last one, which is the output layer that reconstructs the original data. We skip using a non-linearity because we don't want to squash the range of the output values, as we want to be able to reconstruct the original data as accurately as possible.
for layer in self.decoder_layers:
z = layer(z)
z = self.final_layer(z)
return zbatch_size = 256
train_loader = DataLoader(train_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
latent_dim = 128
x, _ = train_unlabeled_ae[0]
input_dim = x.flatten().shape[0]
ae_model = Autoencoder(input_dim=input_dim, latent_dim=latent_dim)
optimizer = AdamW(ae_model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = MSELoss()
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=ae_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)With a nonlinear autoencoder, we no longer have just a matrix projecting the data, and therefore we can’t display the “components” as if they were images as we did with PCA. However, we can still look at the reconstructions as well as visualize the latent space by projecting the data through the encoder and applying UMAP to the resulting latent representations.
visualized_reconstruction(X_train, ae_model, n_examples=5)test_loader = DataLoader(test_data, batch_size=batch_size)
ae_features = []
feature_labels = []
with torch.inference_mode():
for batch in test_loader:
x, y = batch
feature_labels.extend(y.squeeze())
embedded = ae_model.encode(x.to(device))
ae_features.append(embedded.cpu().numpy())
ae_features = np.concatenate(ae_features)
visualize_latent_space(ae_features, feature_labels, "Autoencoder projections")Transfer learning
Like the supervised and contrastive models, we can use the encoder of the trained autoencoder as a feature extractor for a new task. We replace the decoder with a linear classifier and train only that head on the held-out labels. The encoder weights are frozen during this first phase, so the randomly initialized classifier doesn’t corrupt the learned representations before it has had a chance to warm up.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
ae_model.reset_head(n_classes)
optimizer = AdamW(ae_model.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()plot_confusion_matrix(ae_model, dev_data_b, title="Confusion matrix — dev set (autoencoder transfer - before training)")# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=ae_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)plot_confusion_matrix(ae_model, dev_data_b, title="Confusion matrix — dev set (autoencoder transfer - after training)")Exercises
Exercise 1: Bottleneck size sweep
Train three separate autoencoders with latent_dim=8, 32, and 128. For each, compare:
- Reconstruction quality with
visualized_reconstruction - UMAP cluster separation with
visualize_latent_space
What is the smallest latent dimension that still produces visually clean clusters? Is there a point where reducing the bottleneck forces the model to learn more class-discriminative features?
Exercise 2: Self-supervised transfer with kept labels
So far transfer learning has always used the held-out classes (train_data_b). Try using the autoencoder encoder as a feature extractor for the kept labels (train_data_a) instead: call ae_model.reset_head(train_data_a.get_n_classes()), then train only the classifier head. Compare dev accuracy to the supervised cnn_model trained from scratch on the same data. How much performance is lost by never having seen labels during pre-training?
Tweaking the architecture: Convolutional Autoencoder
When we’ve gone from a linear model to a nonlinear one, we have a lot more freedom in exactly how the nonlinear functions should be constructed. A natural choice for image data is to use convolutional layers, which exploit spatial structure and share weights across positions. Below we define a ConvAutoencoder whose encoder is identical to the ConvNet and ContrastiveConvNet used in the supervised and contrastive sections, and whose decoder is its symmetric reverse using transposed convolutions to upsample back to the original resolution.
import torch.nn as nn
import torch.nn.functional as F
class StridedConvBlock(nn.Module):
"""Conv2d with stride=2 instead of MaxPool — downsamples while retaining where information came from."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.block(x)
class UpsampleConvBlock(nn.Module):
"""Upsample then Conv2d — avoids the checkerboard artifacts of ConvTranspose2d."""
def __init__(self, in_channels, out_channels, target_size=None, activate=True):
super().__init__()
if target_size is not None:
self.upsample = nn.Upsample(size=target_size, mode='bilinear', align_corners=False)
else:
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
conv_layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)]
if activate:
conv_layers += [nn.BatchNorm2d(out_channels), nn.ReLU()]
self.conv = nn.Sequential(*conv_layers)
def forward(self, x):
return self.conv(self.upsample(x))
class ConvAutoencoder(nn.Module):
def __init__(self, input_channels, latent_dim=128):
super().__init__()
self.latent_dim = latent_dim
# Encoder: strided conv (stride=2) instead of MaxPool — 28→14→7→4
self.conv_blocks = nn.ModuleList([
StridedConvBlock(input_channels, 32), # 28×28 → 14×14
StridedConvBlock(32, 64), # 14×14 → 7×7
StridedConvBlock(64, 128), # 7×7 → 4×4
])
self.encoder_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 4 * 4, latent_dim),
nn.ReLU(),
)
# Decoder: Upsample + Conv — 4→7→14→28
# The first step uses an explicit target_size=(7,7) because 4*2=8≠7.
self.decoder_head = nn.Sequential(
nn.Linear(latent_dim, 128 * 4 * 4),
nn.ReLU(),
)
self.deconv_blocks = nn.ModuleList([
UpsampleConvBlock(128, 64, target_size=(7, 7)), # 4×4 → 7×7
UpsampleConvBlock(64, 32), # 7×7 → 14×14
UpsampleConvBlock(32, input_channels, activate=False), # 14×14 → 28×28, no activation
])
self.classifier = None
def encode(self, x):
for block in self.conv_blocks:
x = block(x)
return self.encoder_head(x)
def decode(self, z):
z = self.decoder_head(z).view(z.shape[0], 128, 4, 4)
for block in self.deconv_blocks:
z = block(z)
return z
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
return self.decode(self.encode(x))from torch.optim import AdamW
from torch.nn import MSELoss
batch_size = 256
latent_dim = 128
train_loader = DataLoader(train_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_unlabeled_ae, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
conv_ae_model = ConvAutoencoder(input_channels=N_CHANNELS, latent_dim=latent_dim)
conv_ae_model.to(device)
optimizer = AdamW(conv_ae_model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = MSELoss()
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=conv_ae_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)def visualized_reconstruction_dataset(dataset, model, n_examples=5, device=None):
if device is None:
device = next(model.parameters()).device
imgs = torch.stack([dataset[i][0] for i in range(n_examples)]).to(device)
model.eval()
with torch.no_grad():
recons = model(imgs).cpu()
fig, axes = plt.subplots(2, n_examples, figsize=(n_examples * 3, 6))
for col_i in range(n_examples):
axes[0, col_i].imshow(recons[col_i].numpy().transpose(1, 2, 0).clip(0, 1))
axes[0, col_i].axis('off')
axes[1, col_i].imshow(imgs[col_i].cpu().numpy().transpose(1, 2, 0))
axes[1, col_i].set_title("Original")
axes[1, col_i].axis('off')
plt.suptitle(f"Reconstruction — {type(model).__name__} (latent dim {model.latent_dim})")
plt.tight_layout()
plt.show()visualized_reconstruction_dataset(train_unlabeled, conv_ae_model, n_examples=5)test_loader = DataLoader(test_data, batch_size=batch_size)
conv_ae_features = []
feature_labels = []
conv_ae_model.eval()
with torch.inference_mode():
for batch in test_loader:
x, y = batch
feature_labels.extend(y.squeeze())
embedded = conv_ae_model.encode(x.to(device))
conv_ae_features.append(embedded.cpu().numpy())
conv_ae_features = np.concatenate(conv_ae_features)
visualize_latent_space(conv_ae_features, feature_labels, "Conv Autoencoder projections")Transfer learning
Like the other models, we can use the trained convolutional encoder as a feature extractor for the held-out classes by attaching a linear classifier and training only that head.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
conv_ae_model.reset_head(n_classes)
optimizer = AdamW(conv_ae_model.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()plot_confusion_matrix(conv_ae_model, dev_data_b, title="Confusion matrix — dev set (conv AE transfer - before training)")# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=conv_ae_model, train_loader=train_loader, dev_loader=dev_loader, optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)plot_confusion_matrix(conv_ae_model, dev_data_b, title="Confusion matrix — dev set (conv AE transfer - after training)")Exercises
Exercise: Freeze vs. fine-tune for the convolutional AE
The transfer above trains only the linear head while the convolutional encoder stays frozen (linear probing). Add a second run where you pass conv_ae_model.parameters() to the optimizer instead, fine-tuning the full model. With the small train_data_b transfer set:
- Does fine-tuning the encoder improve or hurt accuracy on
dev_data_b? - Does the model converge faster or slower?
This is a practical question in transfer learning: fine-tuning gives more capacity but risks overwriting representations learned during pre-training (catastrophic forgetting). Does that risk materialise here?
SimCLR
SimCLR is a self-supervised contrastive learning method that learns representations without any labels. The key idea is simple: for each image, generate two differently augmented views, then train the network to map both views of the same image close together in embedding space while pushing apart views from different images. The only source of learning signal is the augmentation — the model must learn invariances to crops, flips, color changes, and blur in order to produce consistent representations.
This is in contrast to the supervised contrastive approach we saw earlier, which used label information to define which examples should be attracted. SimCLR defines positives purely from the data itself, making it applicable to the unlabeled portion of our dataset.
Data augmentation
The augmentation strategy is central to SimCLR — too weak and the task is trivially solved, too strong and the positives become unrecognizable. For 28×28 medical images we use random crops, flips, color jitter, and a small Gaussian blur.
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset
simclr_augment = transforms.Compose([
transforms.RandomResizedCrop(28, scale=(0.5, 1.0), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3),
])
class SimCLRDataset(Dataset):
"""Wraps a dataset and returns two independently augmented views of each image."""
def __init__(self, dataset, augment):
self.dataset = dataset
self.augment = augment
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
x, _ = self.dataset[idx]
return self.augment(x), self.augment(x)simclr_train_data = SimCLRDataset(train_unlabeled, simclr_augment)
simclr_dev_data = SimCLRDataset(dev_unlabeled, simclr_augment)NT-Xent loss
The NT-Xent (Normalized Temperature-scaled Cross Entropy) loss treats each pair of views from the same image as a positive pair and all other views in the batch as negatives. With a batch of \(B\) images we have \(2B\) views total, giving \(2(B-1)\) negatives per anchor.
import torch
import torch.nn.functional as F
def nt_xent_loss(z1, z2, temperature=0.5):
"""NT-Xent loss from SimCLR (Chen et al. 2020).
z1, z2: (B, D) L2-normalised embeddings from the two augmented views.
The positive for view z1[i] is z2[i] and vice versa; all other
2(B-1) views within the batch are treated as negatives.
"""
B = z1.shape[0]
# Concatenate both views: rows 0..B-1 are view-1, rows B..2B-1 are view-2
z = torch.cat([z1, z2], dim=0) # (2B, D)
sim = (z @ z.T) / temperature # (2B, 2B) cosine similarities / T
# Mask the diagonal so self-similarity doesn't enter the denominator
eye = torch.eye(2 * B, dtype=torch.bool, device=z.device)
sim = sim.masked_fill(eye, float('-inf'))
# Positive for row i is row i+B, and for row i+B is row i
labels = torch.cat([
torch.arange(B, 2 * B, device=z.device),
torch.arange(0, B, device=z.device),
])
return F.cross_entropy(sim, labels)Model
The architecture mirrors ContrastiveConvNet: a shared convolutional encoder followed by a small projection MLP. The projection head is only used during self-supervised training; for transfer learning it is replaced by a linear classifier via reset_head.
import torch.nn as nn
import torch.nn.functional as F
class SimCLRConvNet(nn.Module):
def __init__(self, input_channels, latent_dim=128, projection_dim=128):
super().__init__()
self.latent_dim = latent_dim
self.conv_blocks = nn.ModuleList([
ConvBlock(input_channels, 32), # 28×28 → 14×14
ConvBlock(32, 64), # 14×14 → 7×7
ConvBlock(64, 128), # 7×7 → 3×3
])
self.encoder_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 3 * 3, latent_dim),
nn.ReLU(),
)
self.projection_head = nn.Sequential(
nn.Linear(latent_dim, latent_dim),
nn.ReLU(),
nn.Linear(latent_dim, projection_dim),
)
self.bn = nn.BatchNorm1d(projection_dim, affine=False)
self.classifier = None
def encode(self, x):
for block in self.conv_blocks:
x = block(x)
return self.encoder_head(x)
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
projection = self.projection_head(self.encode(x))
return F.normalize(self.bn(projection), dim=1)Training
SimCLR’s training loop differs from the supervised case: each batch yields two views rather than (image, label), so we use a dedicated loop here.
def train_simclr(*, model, train_loader, dev_loader, optimizer, temperature=0.5,
max_epochs, device=None, liveplot=None):
if device is None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in range(max_epochs):
model.train()
train_loss_acc, train_n = 0.0, 0
for view1, view2 in train_loader:
optimizer.zero_grad()
view1, view2 = view1.to(device), view2.to(device)
loss = nt_xent_loss(model(view1), model(view2), temperature)
loss.backward()
optimizer.step()
train_loss_acc += loss.item()
train_n += view1.size(0)
model.eval()
with torch.no_grad():
dev_loss_acc, dev_n = 0.0, 0
for view1, view2 in dev_loader:
view1, view2 = view1.to(device), view2.to(device)
dev_loss_acc += nt_xent_loss(model(view1), model(view2), temperature).item()
dev_n += view1.size(0)
if liveplot is not None:
liveplot.tick()
liveplot.report("Training loss", train_loss_acc / train_n)
liveplot.report("Development loss", dev_loss_acc / dev_n)from torch.optim import AdamW
batch_size = 256
latent_dim = 128
train_loader = DataLoader(simclr_train_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(simclr_dev_data, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)
simclr_model = SimCLRConvNet(input_channels=N_CHANNELS, latent_dim=latent_dim)
simclr_model.to(device)
optimizer = AdamW(simclr_model.parameters(), lr=1e-4, weight_decay=1e-5)
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train_simclr(model=simclr_model, train_loader=train_loader, dev_loader=dev_loader,
optimizer=optimizer, temperature=0.5, max_epochs=10, device=device, liveplot=liveplot)Inspect representations
from torchvision.models.feature_extraction import get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(simclr_model)
print("Eval nodes:")
for node in eval_nodes:
print(node)from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(
simclr_model,
return_nodes={
'encoder_head.0': 'flat_conv_features',
'encoder_head.2': 'encoder_features',
}
)test_loader = DataLoader(test_data, batch_size=batch_size)
simclr_conv_features = []
simclr_encoder_features = []
feature_labels = []
feature_extractor.eval()
with torch.inference_mode():
for x, y in test_loader:
feature_labels.extend(y.squeeze())
feats = feature_extractor(x.to(device))
simclr_conv_features.append(feats['flat_conv_features'].cpu().numpy())
simclr_encoder_features.append(feats['encoder_features'].cpu().numpy())
simclr_conv_features = np.concatenate(simclr_conv_features)
simclr_encoder_features = np.concatenate(simclr_encoder_features)visualize_latent_space(simclr_conv_features, feature_labels, f"SimCLR — conv features (held-out labels {held_out_labels})")visualize_latent_space(simclr_encoder_features, feature_labels, f"SimCLR — encoder features (held-out labels {held_out_labels})")Transfer learning
As with the other models, we freeze the encoder and train only a freshly initialised linear head on the held-out labels.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
simclr_model.reset_head(n_classes)
optimizer = AdamW(simclr_model.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()plot_confusion_matrix(simclr_model, dev_data_b, title="Confusion matrix — dev set (SimCLR transfer - before training)")# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=simclr_model, train_loader=train_loader, dev_loader=dev_loader,
optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)plot_confusion_matrix(simclr_model, dev_data_b, title="Confusion matrix — dev set (SimCLR transfer - after training)")Exercises
Exercise 1: Augmentation ablation
The augmentation pipeline is the core design decision in SimCLR — it defines what invariances the model is forced to learn. Create a stripped version of simclr_augment that removes one transform at a time (try removing ColorJitter, then GaussianBlur, then both). For each, train a fresh SimCLRConvNet and compare UMAP structure and transfer accuracy. Which augmentations matter most for this medical imaging dataset?
Exercise 2: Temperature in NT-Xent
Try temperature=0.1, 0.5, and 2.0 in nt_xent_loss (pass it via train_simclr’s temperature argument). Very low temperatures make the loss concentrate on the hardest negatives and can cause instability; very high temperatures flatten gradients. Can you observe these effects in the training loss curve? Which temperature produces the best transfer accuracy?
Dino
DINO (Self-DIstillation with NO labels) is a self-supervised method built around a student–teacher framework. Unlike SimCLR, which learns by contrasting different instances, DINO trains a student network to match the output distribution of a teacher network. The teacher is not trained with gradients; instead it is an exponential moving average (EMA) of the student — always a slightly smoother, more stable snapshot of the student’s weights.
Three ideas work together to prevent representational collapse:
- Multi-crop augmentation: the teacher sees only large (global) crops; the student sees both global and small (local) crops. This forces the student to predict global structure from limited local context.
- Centering: a running mean is subtracted from the teacher output before computing the loss, preventing all outputs from collapsing to a single prototype.
- Sharpening: the teacher uses a lower temperature than the student, providing confident pseudo-targets rather than a flat distribution.
Data augmentation
import copy
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset
dino_global_augment = transforms.Compose([
transforms.RandomResizedCrop(28, scale=(0.4, 1.0), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3),
])
dino_local_augment = transforms.Compose([
transforms.RandomResizedCrop(28, scale=(0.15, 0.4), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1),
transforms.RandomGrayscale(p=0.2),
])
class DINODataset(Dataset):
"""Returns 2 global views (for teacher + student) and n_local local views (student only)."""
def __init__(self, dataset, global_augment, local_augment, n_local_crops=4):
self.dataset = dataset
self.global_augment = global_augment
self.local_augment = local_augment
self.n_local_crops = n_local_crops
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
x, _ = self.dataset[idx]
global_views = torch.stack([self.global_augment(x) for _ in range(2)])
local_views = torch.stack([self.local_augment(x) for _ in range(self.n_local_crops)])
return global_views, local_views # (2,C,H,W), (n_local,C,H,W)dino_train_data = DINODataset(train_unlabeled, dino_global_augment, dino_local_augment, n_local_crops=4)
dino_dev_data = DINODataset(dev_unlabeled, dino_global_augment, dino_local_augment, n_local_crops=4)DINO head
The projection head normalises both the input features and the prototype weight vectors before the final linear layer. This means the output is a cosine similarity between the input and each prototype — only their directions are learned, not their scales. We implement this explicitly rather than using nn.utils.weight_norm, which creates non-leaf tensors that are incompatible with copy.deepcopy.
import torch.nn as nn
import torch.nn.functional as F
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
)
self.last_layer = nn.Linear(hidden_dim, out_dim, bias=False)
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1)
# Normalise prototype vectors so dot-product = cosine similarity
w = F.normalize(self.last_layer.weight, dim=-1)
return F.linear(x, w)Model
class DINOConvNet(nn.Module):
def __init__(self, input_channels, latent_dim=128, out_dim=256):
super().__init__()
self.latent_dim = latent_dim
self.conv_blocks = nn.ModuleList([
ConvBlock(input_channels, 32), # 28×28 → 14×14
ConvBlock(32, 64), # 14×14 → 7×7
ConvBlock(64, 128), # 7×7 → 3×3
])
self.encoder_head = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 3 * 3, latent_dim),
nn.GELU(),
)
self.head = DINOHead(latent_dim, out_dim)
self.classifier = None
def encode(self, x):
for block in self.conv_blocks:
x = block(x)
return self.encoder_head(x)
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
return self.head(self.encode(x))Training
The DINO loss is cross-entropy between the teacher’s centered, sharpened distribution and the student’s log-probabilities. Same-view pairs (student global crop \(i\) vs teacher global crop \(i\)) are skipped — the student shouldn’t just copy the teacher on the identical input.
def dino_loss_step(student_out, teacher_out, center, student_temp, teacher_temp):
"""Cross-entropy loss over all student/teacher view pairs, skipping same-index global pairs."""
teacher_probs = [
F.softmax((t - center) / teacher_temp, dim=-1).detach()
for t in teacher_out
]
loss, n_pairs = 0.0, 0
for s_idx, s_logits in enumerate(student_out):
s_log_probs = F.log_softmax(s_logits / student_temp, dim=-1)
for t_idx, t_probs in enumerate(teacher_probs):
if s_idx == t_idx:
continue # skip same-view pair
loss += -(t_probs * s_log_probs).sum(dim=-1).mean()
n_pairs += 1
return loss / n_pairs
def train_dino(*, student, teacher, train_loader, dev_loader, optimizer,
out_dim, student_temp=0.1, teacher_temp=0.04,
teacher_momentum=0.996, max_epochs, device=None, liveplot=None):
if device is None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
student.to(device)
teacher.to(device)
for p in teacher.parameters():
p.requires_grad = False
center = torch.zeros(1, out_dim, device=device)
for epoch in range(max_epochs):
student.train()
teacher.eval()
train_loss_acc, train_n = 0.0, 0
for global_views, local_views in train_loader:
optimizer.zero_grad()
n_global = global_views.shape[1]
n_local = local_views.shape[1]
# Teacher: global crops only, no gradients
with torch.no_grad():
teacher_out = [teacher(global_views[:, i].to(device)) for i in range(n_global)]
# Student: all crops concatenated into one forward pass
all_views = torch.cat(
[global_views[:, i] for i in range(n_global)] +
[local_views[:, j] for j in range(n_local)],
dim=0,
).to(device)
student_out = student(all_views).chunk(n_global + n_local)
loss = dino_loss_step(student_out, teacher_out, center, student_temp, teacher_temp)
loss.backward()
nn.utils.clip_grad_norm_(student.parameters(), 3.0)
optimizer.step()
# EMA update of teacher weights
with torch.no_grad():
for s_p, t_p in zip(student.parameters(), teacher.parameters()):
t_p.data.mul_(teacher_momentum).add_((1 - teacher_momentum) * s_p.data)
# Center update (EMA of teacher batch outputs)
with torch.no_grad():
batch_center = torch.cat(teacher_out, dim=0).mean(0, keepdim=True)
center.mul_(0.9).add_(0.1 * batch_center)
train_loss_acc += loss.item()
train_n += global_views.shape[0]
student.eval()
with torch.no_grad():
dev_loss_acc, dev_n = 0.0, 0
for global_views, local_views in dev_loader:
n_global = global_views.shape[1]
n_local = local_views.shape[1]
teacher_out = [teacher(global_views[:, i].to(device)) for i in range(n_global)]
all_views = torch.cat(
[global_views[:, i] for i in range(n_global)] +
[local_views[:, j] for j in range(n_local)], dim=0,
).to(device)
student_out = student(all_views).chunk(n_global + n_local)
dev_loss_acc += dino_loss_step(
student_out, teacher_out, center, student_temp, teacher_temp
).item()
dev_n += global_views.shape[0]
if liveplot is not None:
liveplot.tick()
liveplot.report("Training loss", train_loss_acc / train_n)
liveplot.report("Development loss", dev_loss_acc / dev_n)from torch.optim import AdamW
batch_size = 256
latent_dim = 128
out_dim = 256
train_loader = DataLoader(dino_train_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dino_dev_data, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)
dino_student = DINOConvNet(input_channels=N_CHANNELS, latent_dim=latent_dim, out_dim=out_dim)
dino_teacher = copy.deepcopy(dino_student)
dino_student.to(device)
dino_teacher.to(device)
optimizer = AdamW(dino_student.parameters(), lr=1e-4, weight_decay=1e-5)
liveplot = LivePlot()# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train_dino(student=dino_student, teacher=dino_teacher, train_loader=train_loader, dev_loader=dev_loader,
optimizer=optimizer, out_dim=out_dim, student_temp=0.1, teacher_temp=0.04,
teacher_momentum=0.996, max_epochs=10, device=device, liveplot=liveplot)Inspect representations
from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(
dino_student,
return_nodes={
'encoder_head.0': 'flat_conv_features',
'encoder_head.2': 'encoder_features',
}
)test_loader = DataLoader(test_data, batch_size=batch_size)
dino_conv_features = []
dino_encoder_features = []
feature_labels = []
feature_extractor.eval()
with torch.inference_mode():
for x, y in test_loader:
feature_labels.extend(y.squeeze())
feats = feature_extractor(x.to(device))
dino_conv_features.append(feats['flat_conv_features'].cpu().numpy())
dino_encoder_features.append(feats['encoder_features'].cpu().numpy())
dino_conv_features = np.concatenate(dino_conv_features)
dino_encoder_features = np.concatenate(dino_encoder_features)visualize_latent_space(dino_conv_features, feature_labels, f"DINO — conv features (held-out labels {held_out_labels})")visualize_latent_space(dino_encoder_features, feature_labels, f"DINO — encoder features (held-out labels {held_out_labels})")Transfer learning
For transfer learning we use the student encoder. Like the other models, we freeze the encoder and train only a freshly initialised linear classifier on the held-out labels.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
dino_student.reset_head(n_classes)
optimizer = AdamW(dino_student.classifier.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = CrossEntropyLoss()
liveplot = LivePlot()plot_confusion_matrix(dino_student, dev_data_b, title="Confusion matrix — dev set (DINO transfer - before training)")# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=dino_student, train_loader=train_loader, dev_loader=dev_loader,
optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)plot_confusion_matrix(dino_student, dev_data_b, title="Confusion matrix — dev set (DINO transfer - after training)")Exercises
Exercise 1: Teacher momentum
The teacher_momentum parameter controls how quickly the teacher tracks the student. Try teacher_momentum=0.9 (fast tracking) and 0.9999 (very slow tracking), alongside the default 0.996. Lower momentum means the teacher is barely more stable than the student — does this destabilise training? Higher momentum means the teacher barely changes — does the student still get a useful learning signal early in training?
Exercise 2: Student vs. teacher features
After training, the teacher is an EMA-smoothed version of the student. Extract encode features from both dino_student and dino_teacher on test_data and run visualize_latent_space on each. Does the teacher produce cleaner cluster separation than the student? What does this tell you about the role of EMA in representation quality?
Multimodel learning - CLIP
CLIP (Contrastive Language–Image Pre-training) was trained by OpenAI on 400 million image–text pairs scraped from the internet. Like SimCLR, it uses contrastive learning — but across two modalities: images and text. The model learns to pull together the embedding of an image and its matching caption, while pushing apart mismatched pairs. The result is a visual encoder that has seen an enormous breadth of concepts, even if it has never seen a single explicit class label.
We will not train CLIP here. Instead, we import the pretrained ViT-B/32 weights and use the visual encoder directly as a feature extractor. The interesting question is whether representations learned from natural image–text pairs on the internet transfer to our medical imaging domain.
The CLIPWrapper class
The wrapper resizes images to the 224×224 resolution CLIP was trained at, converts grayscale to RGB by repeating the channel, applies CLIP’s expected pixel normalisation, and runs the frozen visual encoder. It otherwise exposes the same interface as every other model in this lab: encode, reset_head, classifier, and forward.
import open_clip
open_clip.list_pretrained()import open_clip
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
class CLIPWrapper(nn.Module):
"""Frozen CLIP visual encoder wrapped in the lab's common interface."""
def __init__(self, model_name='ViT-B-32', pretrained='openai', latent_dim=512):
super().__init__()
self.latent_dim = latent_dim
self.model, _, self.preprocess_val = open_clip.create_model_and_transforms('ViT-B-32', pretrained=pretrained)
self.model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
self.tokenizer = open_clip.get_tokenizer('ViT-B-32')
# Freeze all CLIP parameters — we are only using it as a feature extractor
for p in self.model.parameters():
p.requires_grad = False
self.classifier = None
def _preprocess(self, x):
# The openclip transforms assumes a PIL image, so we'll convert beforehand
preprocessed = [self.preprocess_val(to_pil_image(img)) for img in x]
device = next(self.parameters()).device
imgs = torch.stack(preprocessed).to(device)
return imgs
def encode(self, x):
x = self._preprocess(x)
return self.model.encode_image(x)
def reset_head(self, n_classes):
device = next(self.parameters()).device
self.classifier = nn.Linear(self.latent_dim, n_classes, device=device)
def forward(self, x):
if self.classifier is not None:
return self.classifier(self.encode(x))
return self.encode(x)
def to(self, *args, **kwargs):
self.model.to(*args, **kwargs)
nn.Module.to(self, *args, **kwargs)clip_model = CLIPWrapper(model_name='ViT-B-32', pretrained='openai', latent_dim=512)
clip_model.to(device)
clip_model.eval()
print(f"CLIP encoder ready — latent dim: {clip_model.latent_dim}")model, _, preprocess_val = open_clip.create_model_and_transforms('ViT-B-32', pretrained="openai")
model.to(device)
preprocessed = [preprocess_val(to_pil_image(img)) for img in x]
imgs = torch.stack(preprocessed).to(device)
model.encode_image(imgs)Inspect representations
CLIP was never shown a label for any of our classes. The UMAP below reveals whether the structure it learned from natural image captions accidentally captures structure that is meaningful in our domain.
test_loader = DataLoader(test_data, batch_size=64, num_workers=8)
clip_features = []
clip_labels = []
with torch.inference_mode():
for x, y in test_loader:
clip_features.append(clip_model.encode(x.to(device)).cpu().numpy())
clip_labels.extend(y.squeeze().tolist())
clip_features = np.concatenate(clip_features)visualize_latent_space(clip_features, clip_labels, f"CLIP ViT-B/32 — zero-shot features (held-out labels {held_out_labels})")Transfer learning
Even though CLIP’s encoder is frozen and was never trained on our data, we can attach a small linear classifier on top and train it with the held-out labels. This is called linear probing and is a standard way to evaluate the quality of pretrained representations.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
batch_size = 128
train_loader = DataLoader(train_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
dev_loader = DataLoader(dev_data_b, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8)
n_classes = train_data_b.get_n_classes()
clip_model.reset_head(n_classes)
optimizer = AdamW(clip_model.classifier.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = CrossEntropyLoss()
liveplot = LivePlot()plot_confusion_matrix(clip_model, dev_data_b, title="Confusion matrix — dev set (CLIP linear probe - before training)")# Each time you run this cell, the model will continue training `max_epochs` from its previous state, so you can experiment with training for different numbers of epochs by running this cell multiple times.
train(model=clip_model, train_loader=train_loader, dev_loader=dev_loader,
optimizer=optimizer, criterion=criterion, max_epochs=10, device=device, liveplot=liveplot)plot_confusion_matrix(clip_model, dev_data_b, title="Confusion matrix — dev set (CLIP linear probe - after training)")Discussion topics
- Which self-supervised method comes closest to the supervised CNN on the transfer task?
- Does CLIP’s zero-shot representation outperform models trained on the same domain without labels? Why might that be surprising — or unsurprising?
- Where would you invest effort if you had more unlabeled data but labeling was expensive?
Summary and key takeaways
This lab explored how neural networks learn to represent data, moving from classical methods to modern self-supervised approaches. Here are the main ideas to take away.
Representations are shaped by the training objective
The loss function determines what the encoder learns to preserve. A cross-entropy classifier learns features that separate known classes and nothing else. A reconstruction loss preserves everything needed to rebuild the input. A contrastive loss clusters by similarity. None of these is universally best — the right choice depends on the downstream task.
Supervised representations transfer, but are brittle
A CNN trained on labeled data learns strong features for its training classes. Those features transfer reasonably well to related tasks, but degrade when the transfer domain diverges. The model has no incentive to learn structure beyond what the labels require.
Self-supervised learning can match supervised learning with enough unlabeled data
Methods like SimCLR and DINO learn from data structure alone — no labels required. With sufficient unlabeled data and well-chosen augmentations they produce representations competitive with supervised ones on transfer tasks.
Augmentation design is the core inductive bias in contrastive learning
In SimCLR and DINO, the augmentation pipeline defines what the model is trained to be invariant to. Augmenting with color jitter produces color-invariant features; not augmenting with it leaves color as a free dimension. Choosing augmentations that match the invariances relevant to your downstream task is one of the most important practical decisions.
The projection head is for training; the encoder is for transfer
Both SimCLR and DINO attach a projection head on top of the encoder during self-supervised training, then discard it at transfer time. The projection head specialises to the self-supervised objective; the encoder beneath learns more general features. Evaluating representations at the projection layer consistently underperforms evaluating at the encoder.
Scale and pretraining distribution matter
CLIP shows that representations trained on a very large, diverse dataset can transfer to narrow domains without any domain-specific training. Breadth of pretraining can compensate for source-target mismatch — up to a point. When the target domain is highly specialised (e.g., histopathology), domain-specific self-supervised pretraining often outperforms generic large-scale models.
Linear probing is the standard measure of representation quality
Throughout this lab, we evaluated representations by freezing the encoder and training only a linear classifier head. This linear probe protocol is the standard benchmark: if a linear classifier can solve the task on top of frozen features, those features are genuinely useful. Fine-tuning always achieves higher accuracy, but it can mask a poor encoder by compensating with gradient updates — making it harder to compare representations fairly.