Build and Train Vision Transformer from Scratch

Mikhail Kravets
Towards AI
Published in
19 min readApr 3, 2023

--

Preview image of the tutorial: two robot with neon eyes
Image by author

A few years ago, it was hard to imagine what a transformer is; today, it is hard to imagine a modern neural network that doesn’t use transformers.

In this tutorial, we’ll be building Vision Transformer using PyTorch and PyTorch Lightning. Along with the ViT model, you will also see how to organize your code in a well-structured and efficient manner.

All the code of the tutorial can be found in the vision_transformer repository.

Overview

Let’s have a quick theory overview before we proceed to the practical part of the tutorial.

Transformer & self-attention

The history of transformers began with the Attention Is All You Need work. Initially, they were used for machine translation but later expanded to solve various tasks. Jay Alammar explains transformers in his pretty detailed article, Illustrated Transformer.

In the diagram below, you may see the architecture of the transformer network for the machine translation task.

Architecture diagram of machine translation transformer model
Fig 1. Transformer architecture for image translation. Image by author.

Transformer has Encoder and Decoder blocks. We only need encoders for the vision transformer model.
Encoder (so also a Decoder) is based on a mechanism called self-attention.

Architecture diagram of one Encoder block
Fig 2. Encoder architecture. Image by author.

Multi-Head Attention block calculates the element importance (or attention) score for each element in a sequence. For example, let’s take a sentence

The animal didn’t cross the street because it was too tired. [5].

One attention score vector for the element `it` may look like

Visualization of self-attention mechanism for a word in a sentence
Fig 3. Self-attention example for one element in a sequence. Illustrated Transformer.

Encoder and Decoder blocks are identical except for a tiny difference. Encoder can attend to all elements in a sequence to calculate attention scores. You may see encoder attention in figure 3. BERT model is an example of encoders-only architecture. Decoders, on the contrary, can attend only to the previous elements in a sequence during the calculation of attention scores. For instance, GPT model is a decoder-only model.

Vision Transformer

As already mentioned above, we can use transformers for image classification tasks. The main difference between Vision Transformer and an NLP transformer is that we should apply a special embedding operation to the images.

Architecture diagram of Vision Transformer model
Fig 4. Vision Transformer architecture. [dosovitsky et al, 2021].

The image embedding begins with image preprocessing. The image should be split into 2D patches as shown in figure 4. The resulting number of patches N can be calculated as

where H is height, W is width and P is patch size.

After we’ve got image patches, we flatten each patch to a 1D vector. The size of the flattened patch vector can be calculated as

where C is the amount of color channels (C = 3).

After patch transformation is done, we have an image represented as a matrix of size N x M. This matrix is the input tensor that is fed to the model. The input tensor then goes through a linear projection which is then concatenated with [class] token parameter and summed with learnable position embedding. The authors of the original paper discuss position embedding in Appendix D.4 of the original paper.

After processing the input tensor with embeddings, there is a standard set of encoders with a classification head at the end.

Let’s move to the code 🎸.

Installation

As was said above, vision_transformer repository contains the full working project, but the installation of dependencies may be tricky.

CPU or MPS

If you plan to train the network on CPU or you’re using Apple computers, you may proceed with the standard installation flow, i.e., install packages from requirements.txt file:

pip install -r requirements.txt

CUDA

If you’re going to train the network on GPU, you should install PyTorch with the command from the official website. Then you have to install the remaining packages from requirements-cuda.txt. The installation instructions may look like this:

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements-cuda.txt

Dataset

We use CIFAR10 dataset, which consists of 60,000 images of 10 classes (50k for training and 10k for validation/testing). The size of a single image is 32 x 32 with 3 RGB color channels.

Full code that manages datasets can be found at src/dataset.py.

The first thing we should do is import the required objects and define constants:

from pathlib import Path
import torch
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms, AutoAugment, AutoAugmentPolicy

BASE_DIR = Path(__file__).parent.parent

PyTorch already has CIFAR10 dataset implemented in its child package torchvision. So, we use the one directly from torchvision.datasets.

Lightning Data Module

We create CIFAR10 dataset in pytorch_lightning.LightningDataModule. Data Module simplifies the usage of datasets and, especially, data loaders during the training phase.

Let’s see the full code of the data module. Then we skim through it step by step.

class CIFAR10DataModule(pl.LightningDataModule):  

def __init__(self, batch_size: int, patch_size: int = 4, val_batch_size: int = 16):
super().__init__()

self.batch_size = batch_size
self.val_batch_size = val_batch_size
self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
AutoAugment(AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
self.val_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
self.patch_size = patch_size
self.ds_train = None
self.ds_val = None

def prepare_data(self) -> None:
CIFAR10(BASE_DIR.joinpath('data/cifar'), train=True, transform=self.train_transform, download=True)
CIFAR10(BASE_DIR.joinpath('data/cifar'), train=False, transform=self.val_transform, download=True)

def setup(self, stage: str) -> None:
self.ds_train = CIFAR10(BASE_DIR.joinpath('data/cifar'), train=True, transform=self.train_transform)
self.ds_val = CIFAR10(BASE_DIR.joinpath('data/cifar'), train=False, transform=self.val_transform)

def train_dataloader(self):
# Due to small dataset we don't need to use multiprocessing
return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True)

def val_dataloader(self):
return DataLoader(self.ds_val, batch_size=self.val_batch_size)

@property
def classes(self):
return 10 # CIFAR10 has 10 possible classes

In the code above, we create a classCIFAR10DataModule that inherits LightningDataModule. The main intent of the data module is to create data loaders, not datasets. This is why we pass batch sizes in the constructor.

LightningDataModule has several methods to override:

  • prepare_data method is called within a single CPU process, meaning that your data will not be corrupt. It is called before the training, so we use it to download CIFAR10 data into the local directory;
  • setup method is called after prepare_data. Here we instantiate train and validation datasets;
  • train_dataloader returns data loader for training;
  • val_dataloader returns data loader for validation;
  • classes is just a property that returns a number of classes.

Transforms

Transform is an operation to apply to an image before we pass it to the model. We define two sets of transforms: for training and validation. Training transform looks like this:

self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
AutoAugment(AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)

CIFAR10 has only 50,000 training images which is a relatively small dataset for neural networks, so we use data augmentation to expand it. The first four transforms do it:

  • RandomHorizontalFlip reflects image horizontally with 50% probability (by default);
  • RandomResizedCrop crops a random portion of an image and resize it to the given size (in our case 32 x 32);
  • RandomRotation rotates an image by a given degree from a range (in our case a range from -30 to 30 degrees);
  • AutoAugment is a pre-trained set of auto-augmentation policies. It is described in AutoAugment: Learning Augmentation Strategies from Data. We use pre-trained auto-augmentation for CIFAR10.

Then we convert our image to a tensor with size 3 x 32 x 32, normalize it with CIFAR10 mean and std values, and split the image tensor on patches.

Pytchify Image

Image patching is, in my opinion, the hardest part of the project. In order to have well-structured code, I move image patching to its own transform class:

class PatchifyTransform:  

def __init__(self, patch_size):
self.patch_size = patch_size

def __call__(self, img: torch.Tensor):
res = img.unfold(1, self.patch_size, self.patch_size) # 3 x 8 x 32 x 4
res = res.unfold(2, self.patch_size, self.patch_size) # 3 x 8 x 8 x 4 x 4

return res.reshape(-1, self.patch_size * self.patch_size * 3) # -1 x 48 == 64 x 48

PyTorch Tensor has a method called unfold(dimension, size, step) that does exactly what we need. It creates a sliding window along the given dimension and unfolds it to a new dimension. Let’s take apart __call__ method. Its first row:

res = img.unfold(1, self.patch_size, self.patch_size)

The tensor img has size 3 x 32 x 32 and self.patch_size equals to 4. unfold method goes through all patches of size 4 with a step 4 along the dimension 1 that has 32 elements and put found patches into a new dimension.

So, now we have a new tensor res with a size 3 x 8 x 32 x 4. For easier understanding, I follow this logic:

  • Discard color dimension for a moment. The tensor is 8 x 32 x 4;
  • This tensor can be seen as an 8 x 32 matrix of four-element vectors, for example, [0.3, 0.01, 0.4, 0.7];
  • Now add color dimension. There is 8 x 32 four-element vectors for each RGB color channel.

In the second row, we unfold the second dimension of the image:

res = res.unfold(2, self.patch_size, self.patch_size)

Now, the size of res is 3 x 8 x 8 x 4 x 4 which may seem insane but it isn't. You may understand it as follows:

  • Discard color dimension, for now, having 8 x 8 x 4 x 4 tensor;
  • Each element in 8 x 8 matrix contains 4 x 4 patch;
  • Bring back color dimension and you have three 8 x 8 matrices of 4 x 4 patches.
Visual representation of unfolded image tensor
Fig 5. Visual representation of two dimensions unfolded. Each element in 8 x 8 matrix is 4 x 4 matrix. Image by author.

What is remained to do is to reshape tensor res back to 2D matrix. It is achieved with the reshape method:

res.reshape(-1, self.patch_size * self.patch_size * 3)

After reshape operation is done, res is a matrix of size64 x 48. If you worked with NLP tasks you may notice how similar it is to a sentence embedding: the first dimension corresponds to a word and the second dimension corresponds to a context vector of a given word.

Model

We build a model as a set of standard PyTorch modules, except the main ViT module. ViT inherits LightningModule. The model diagram is displayed in the figure below.

The full code of the model can be found at src/basic.py.

PlantUML diagram of model classes
Fig 6. Vision Transformer diagram. Image by author.

Let’s take a detailed look at each module.

InputEmbedding

Input Embedding accepts a batch of patchified images and returns a full embedding of patches with the [class] token prepended.

class ImageEmbedding(nn.Module):  

def __init__(self, size: int, hidden_size: int, num_patches: int, dropout: float = 0.2):
super().__init__()

self.projection = nn.Linear(size, hidden_size)
self.class_token = nn.Parameter(torch.rand(1, hidden_size))
self.position = nn.Parameter(torch.rand(1, num_patches + 1, hidden_size))

self.dropout = nn.Dropout(dropout)

def forward(self, inp: torch.Tensor):
res = self.projection(inp)

class_token = self.class_token.repeat(res.size(0), 1, 1) # batch_size x 1 x output_size
res = torch.concat([class_token, res], dim=1)

position = self.position.repeat(res.size(0), 1, 1)
return self.dropout(res + position)

An interesting thing to look at is how we create class_token and position embedding parameters:

self.class_token = nn.Parameter(torch.rand(1, hidden_size))
self.position = nn.Parameter(torch.rand(1, num_patches + 1, hidden_size))

Tensors that are created via nn.Parameter are added to the graph and trained during the fit process.

class_token is a tensor of size 1 x hidden_size which we later repeat for each batch. position tensor has size 1 x num_patches + 1 x hidden_size and is also repeated for each batch. The first dimension of position is num_patches + 1 because [class] token is taken into consideration by position as well.

Let’s take a closer look at the forward method:

def forward(self, inp: torch.Tensor):
res = self.projection(inp)

First of all, we accept inp tensor that has size batch_size x 64 x 48. Then we pass it through a linear projection layer.

Context size 48 of the input tensor is too small. Our model will barely be able to catch dependencies between the input and the target. So, we should expand it to a hidden_size. Projectionres has size batch_size x 64 x hidden_size.

In the next operation, we repeat class_token parameter for each element in a batch and concatenate it with res

class_token = self.class_token.repeat(res.size(0), 1, 1)
res = torch.concat([class_token, res], dim=1)

The size of class_token tensor is batch_size x 1 x hidden_size.

position = self.position.repeat(res.size(0), 1, 1)  
return self.dropout(res + position)

Above operations repeat position tensor for each image in batch, sum res and position tensor, and pass the result to the dropout layer.

AttentionHead

After we’ve got input embeddings, they are sent to each AttentionHead. AttentionHead module looks like this:

class AttentionHead(nn.Module):  

def __init__(self, size: int): # size is hidden size
super(AttentionHead, self).__init__()

self.query = nn.Linear(size, size)
self.key = nn.Linear(size, size)
self.value = nn.Linear(size, size)

def forward(self, input_tensor: torch.Tensor):
q, k, v = self.query(input_tensor), self.key(input_tensor), self.value(input_tensor)

scale = q.size(1) ** 0.5
scores = torch.bmm(q, k.transpose(1, 2)) / scale

scores = F.softmax(scores, dim=-1)

# 8 x 64 x 64 @ 8 x 64 x 48 = 8 x 64 x 48
output = torch.bmm(scores, v)
return output

size argument here is our hidden size.

Let’s set values of our parameters to which we’ll refer later as an example:

batch_size = 64
sequence_size = 64
hidden_size = 512
num_heads = 8

Now let’s take a look at the code of forward method:

def forward(self, input_tensor):  
q, k, v = self.query(input_tensor), self.key(input_tensor), self.value(input_tensor)

At first, we create query, key, value projections for input_tensor. input_tensor has size batch_size x sequence_size x hidden_size or 64 x 64 x 512.

In the next set of operations, we calculate attention scores using the famous formula:

scale = q.size(1) ** 0.5  
scores = torch.bmm(q, k.transpose(1, 2)) / scale
scores = F.softmax(scores, dim=-1)

The size of scores is sequence_size x sequence_size | 64 x 64 meaning that each element of the sequence has an attention score to each other element in a sequence.

Note, that we do not apply any masking. After scores are calculated, we multiply them with value tensor.

output = torch.bmm(scores, v)

The size of output is the same as the size of the input_tensorbatch_size x sequence_size x hidden_size | 64 x 64 x 512.

MultiHeadAttention

The intent of MultiHeadAttention module is to unite attention heads.

class MultiHeadAttention(nn.Module):  

def __init__(self, size: int, num_heads: int):
super().__init__()

self.heads = nn.ModuleList([AttentionHead(size) for _ in range(num_heads)])
self.linear = nn.Linear(size * num_heads, size)

def forward(self, input_tensor: torch.Tensor):
s = [head(input_tensor) for head in self.heads]
s = torch.cat(s, dim=-1)

output = self.linear(s)
return output

size argument in this module is the hidden size.

We calculate the output of each attention head and concatenate them into dimension 2.

def forward(self, input_tensor: torch.Tensor):  
s = [head(input_tensor) for head in self.heads]
s = torch.cat(s, dim=-1)

The resulting size of s is batch_size x sequence_size x num_heads * hidden_size. Regarding the example above, the size is 64 x 64 x 4096.

output = self.linear(s)  
return output

Then we pass s through the linear layer. The size of output is batch_size x sequence_size x hidden_size | 64 x 64 x 512, the same as the input size.

Encoder

The encoder module contains multi-head attention and a feed-forward network. Also, it provides normalization to the data.

class Encoder(nn.Module):  

def __init__(self, size: int, num_heads: int, dropout: float = 0.1):
super().__init__()

self.attention = MultiHeadAttention(size, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(size, 4 * size),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(4 * size, size),
nn.Dropout(dropout)
)
self.norm_attention = nn.LayerNorm(size)
self.norm_feed_forward = nn.LayerNorm(size)

def forward(self, input_tensor):
attn = input_tensor + self.attention(self.norm_attention(input_tensor))
output = attn + self.feed_forward(self.norm_feed_forward(attn))
return output

Feed-forward network is created as a Sequential module:

self.feed_forward = nn.Sequential(  
nn.Linear(size, 4 * size),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(4 * size, size),
nn.Dropout(dropout)
)

We make the feed-forward network four times bigger than attention to make it more expressive and to capture more complex dependencies between input and the target. It helps mitigate the vanishing gradients problem as well.

We use Gaussian Error Linear Units (GELU) activation function.

Visual representation of Gaussian Error Linear Unit function
Fig 7. Visualization of GELU function. Image by author.

forward method is more than expressive:

def forward(self, input_tensor):  
attn = input_tensor + self.attention(self.norm_attention(input_tensor))
output = attn + self.feed_forward(self.norm_feed_forward(attn))
return output

Note, that we apply the normalization layer before we pass tensors further to the network in contrast to Attention Is All You Need. This process is called pre-normalization. The paper Understanding the Difficulty of Training Transformers analyzes both approaches.

ViT

Finally, we are ready to proceed with the main module of the Vision Transformer ViT. This class is kind of big to include the full code at once. I split it into several parts and cover them all individually.

ViT not only embeds all parts of the model but also provides training functionality. Note, that ViT inherits pl.LightningModule, not nn.Module.

class ViT(pl.LightningModule):

def __init__(self, size: int, hidden_size: int, num_patches: int, num_classes: int, num_heads: int,
num_encoders: int, emb_dropout: float = 0.1, dropout: float = 0.1,
lr: float = 1e-4, min_lr: float = 4e-5,
weight_decay: float = 0.1, epochs: int = 200):
super().__init__()
self.save_hyperparameters()

self.lr = lr
self.min_lr = min_lr
self.weight_decay = weight_decay
self.epochs = epochs

self.embedding = ImageEmbedding(size, hidden_size, num_patches, dropout=emb_dropout)

self.encoders = nn.Sequential(
*[Encoder(hidden_size, num_heads, dropout=dropout) for _ in range(num_encoders)],
)
self.mlp_head = nn.Linear(hidden_size, num_classes)

We create the model modules in the constructor:

  1. Input Embedding layer;
  2. Set of encoders;
  3. MLP head that does the final classification.

The forward step is seen in forward method:

def forward(self, input_tensor: torch.Tensor):  
emb = self.embedding(input_tensor)
attn = self.encoders(emb)

return self.mlp_head(attn[:, 0, :])

At first, we convert the input tensor into the model's inner state with added [class] token and positional encoding.

emb tensor has size batch_size x sequence_size x hidden_size. Regarding the values defined above, its size is 64 x 64 x 512. Then we pass emb into a sequential set of encoders.

attn tensor has the same size as emb: batch_size x sequence_size x hidden_size or 64 x 64 x 512. The first element in the sequence attn corresponds to the [class] token. So, we pass only this element to the mlp_head and return the value from the function.

attn[:, 0, :] has size batch_size x hidden_size | 64 x 512. The function outputs a tensor of logits with a size batch_size x num_classes | 64 x 10.

Now we can use the output of the model to organize the training process.

Training

The training process starts in ViT class. There are a few more methods that participate in the training:

  • configure_optimizers
  • configure_parameters
  • training_step
  • validation_step

Let’s run through each of them.

def configure_optimizers(self):  
optimizer = AdamW(self.configure_parameters(), lr=self.lr)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.epochs, eta_min=self.min_lr)

return {"optimizer": optimizer, "lr_scheduler": scheduler}

We should create our optimizers and schedulers in configure_optimizers method. Read more at Lightning Optimization docs. We use AdamW with the CosineAnnealingLR scheduler. There is also a nice article that visualizes various learning rate schedulers A Visual Guide to Learning Rate Schedulers in PyTorch.

Note that we don’t pass all parameters of the model to AdamW. We configure them in our custom method configure_parameters:

def configure_parameters(self):  
no_decay_modules = (nn.LayerNorm,)
decay_modules = (nn.Linear,)

decay = set()
no_decay = set()

for module_name, module in self.named_modules():
if module is self:
continue
for param_name, value in module.named_parameters():
full_name = f"{module_name}.{param_name}" if module_name else param_name
if param_name.endswith('bias'):
no_decay.add(full_name)
elif param_name.endswith('weight') and isinstance(module, no_decay_modules):
no_decay.add(full_name)
elif param_name.endswith('weight') and isinstance(module, decay_modules):
decay.add(full_name)

optim_groups = [
{"params": [v for name, v in self.named_parameters() if name in decay],
"weight_decay": self.weight_decay},
{"params": [v for name, v in self.named_parameters() if name in no_decay],
"weight_decay": 0}
]
return optim_groups

The code of the above method is taken from Andrej Karpathy’s minGPT. LayerNorm module has its own regularization so, we should disable weight decay for this module. configure_parameters prepares two groups of parameters: the ones with weight decay enabled and the ones with weight decay disabled.

def training_step(self, batch, batch_idx):  
input_batch, target = batch

logits = self(input_batch)
loss = F.cross_entropy(logits, target)

if batch_idx % 5 == 0:
self.log("train_acc", logit_accuracy(logits, target), prog_bar=True)

self.log("train_loss", loss)
return loss

training_step should return the loss of the particular training step. We use cross_entropy loss function. While logits tensor has size batch_size x hidden_size, the target tensor size is (batch_size,). These are exact arguments that cross_entropy function expects from us.

Also, for each period of time, we log loss and model accuracy to tensorboard.

The code of validation_step is the same as of training_step:

def validation_step(self, batch, batch_idx):  
input_batch, target = batch
output = self(input_batch)

loss = F.cross_entropy(output, target)

self.log("val_loss", loss, prog_bar=True)
self.log("val_accuracy", logit_accuracy(output, target), prog_bar=True)

return loss

A function that calculates accuracy looks like this:

def logit_accuracy(logits: torch.Tensor, target: torch.Tensor) -> float:  
idx = logits.max(1).indices
acc = (idx == target).int()
return acc.sum() / torch.numel(acc)

logit_accuracy function takes two tensors as arguments:

  • logits tensor is the output of the model. It has a size of batch_size x 10 (10 is because we have 10 possible classes);
  • target is the target tensor for the batch. It has a size of batch_size,.

At the first row of the logit_accuracy, we take the class index with the maximum logit value, i.e. the element with the highest value along the 1 dimension:

idx = logits.max(1).indices

idx tensor now has the same size as targetbatch_size,.

Then we create a tensor acc whose elements are 0 or 1, where 1 shows that the model output and target are equal for the specific element in a batch. 0, obviously, means that the prediction isn’t correct.

At the return statement, we calculate the ratio of correctly predicted classes to the total amount of elements.

Lightning Trainer

The script that runs model training is located in train.py.

First of all, there we import all required objects:

import torch
import pytorch_lightning as pl

from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

from src.dataset import CIFAR10DataModule
from src.models.basic import ViT

Then we set constants and hyperparameters of the model:

BASE_DIR = Path(__file__).parent
LIGHTNING_DIR = BASE_DIR.joinpath("data/lightning")
MODELS_DIR = LIGHTNING_DIR.joinpath("models")

LOG_EVERY_N_STEPS = 50
MAX_EPOCHS = 200

BATCH_SIZE = 512
VAL_BATCH_SIZE = 512
PATCH_SIZE = 4

SIZE = PATCH_SIZE * PATCH_SIZE * 3
HIDDEN_SIZE = 512
NUM_PATCHES = int(32 * 32 / PATCH_SIZE ** 2) # 32 x 32 is the size of image in CIFAR10

NUM_HEADS = 8
NUM_ENCODERS = 6

DROPOUT = 0.1
EMB_DROPOUT = 0.1

LEARNING_RATE = 1e-4
MIN_LEARNING_RATE = 2.5e-5
WEIGHT_DECAY = 1e-6

where

  • BASE_DIR is the base directory of the project;
  • LIGHTNING_DIR is a directory where lightning stores models and logs;
  • MODELS_DIR is a directory where lightning stores models;
  • LOG_EVERY_N_STEPS defines how often to log statistics into tensorboard;
  • MAX_EPOCHS is the maximum amount of epochs to run;
  • SIZE is the size of the context vector of each element in the input sequence;
  • HIDDEN_SIZE is the size of the context after embedding is applied;
  • NUM_PATCHES is the total amount of patches of an image;
  • NUM_HEADS is the number of attention heads;
  • NUM_ENCODERS is the amount of sequential encoders;
  • DROPOUT is dropout percentage to apply in encoders;
  • EMB_DROPOUT is dropout percentage to apply in embedding.

Such as we’re training the model with a GPU device, we can speed up the training process by using mixed precision floating points. We do it by the command:

torch.set_float32_matmul_precision('medium')

Under if __name__ == '__main__' section, we create a data module and instantiate the model:

data = CIFAR10DataModule(batch_size=BATCH_SIZE, patch_size=PATCH_SIZE)

model = ViT(
size=SIZE,
hidden_size=HIDDEN_SIZE,
num_patches=NUM_PATCHES,
num_classes=data.classes,
num_heads=NUM_HEADS,
num_encoders=NUM_ENCODERS,
emb_dropout=EMB_DROPOUT,
dropout=DROPOUT,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
epochs=MAX_EPOCHS
)

Then we create several useful callbacks:

checkpoint_callback = ModelCheckpoint(  
dirpath=MODELS_DIR,
monitor="val_loss",
save_last=True,
verbose=True
)
es = EarlyStopping(monitor="val_loss", mode="min", patience=10)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
  • ModelCheckpoint callback saves a model after each epoch into MODELS_DIR directory. So, if the training process breaks we'll not lose the progress;
  • EarlyStopping monitors validation loss and stops training if it wasn't improved in the last 10 epochs;
  • LearningRateMonitor adds the visualization of the learning rate to tensorboard.

Eventually, we see the trainer:

trainer = pl.Trainer(
accelerator="cuda",
precision="bf16",
default_root_dir=LIGHTNING_DIR,
log_every_n_steps=LOG_EVERY_N_STEPS,
max_epochs=MAX_EPOCHS,
callbacks=[checkpoint_callback, es, lr_monitor],
resume_from_checkpoint=MODELS_DIR.joinpath("last.ckpt")
)
trainer.fit(model, data)

We create an instance of pytorch_lightning.Trainer and run it for model and data. For GPU training we should set cuda as an accelerator. Also, we want to set the precision to bf16. Read more about PyTorch Lightning precision management in N-bit Precision.

One of the biggest advantages of pytorch_lightning (I suppose) is that you don't need to pass .to(device) to every tensor you have, you just pass accelerator to the trainer.

We run the model training with the command:

python train.py

In the terminal, you may see the training progress. Note, that if you don’t have last.ckpt model saved, you should remove resume_from_checkpoint argument from Trainer creation.

Run this command to visualize training in Tensorboard:

tensorboard --logdir=data/lightning/lightning_logs

CPU Training

If there is only a CPU or MPS device available, you can train a smaller model.

Set the following hyperparameters in train.py:

BATCH_SIZE = 256
VAL_BATCH_SIZE = 256
PATCH_SIZE = 4

SIZE = PATCH_SIZE * PATCH_SIZE * 3 # 4 * 4 * 3 (RGB colors)
HIDDEN_SIZE = 512
NUM_PATCHES = int(32 * 32 / PATCH_SIZE ** 2) # 32 x 32 is the size of image in CIFAR10

NUM_HEADS = 8
NUM_ENCODERS = 4

DROPOUT = 0.1
EMB_DROPOUT = 0.16

LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-6

update the accelerator value in the Trainer:

trainer = pl.Trainer(
accelerator="cpu",
default_root_dir=LIGHTNING_DIR,
log_every_n_steps=LOG_EVERY_N_STEPS,
max_epochs=MAX_EPOCHS,
callbacks=[checkpoint_callback, es, lr_monitor],
resume_from_checkpoint=MODELS_DIR.joinpath("last.ckpt")
)
trainer.fit(model, data)

and remove AutoAugment transform in src/dataset.py, having:

self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)

The model with the above parameters can be trained to ~80% accuracy for around 6–8 hours on an MPS device.

Training Results

I trained the model on RTX 4090. 200 epochs of training took almost 3 hours. At the end of the 200th epoch, the model is ~83% accurate on validation data.

Below you may see charts of the training progress.

Validation accuracy change over training
Fig 8. Validation accuracy during the training process. Image by author.
Validation loss change over training
Fig 9. Validation loss during the training process. Image by author.

The model is neither on plateau nor overfitting. We can continue training and get even higher accuracy.

Evaluate Model

Now, it’s time to evaluate the model by running classify.py. Follow the code on GitHub to see the full implementation of the script.

First of all, we create CIFAR10 validation dataset and load the trained model by command:

model = ViT.load_from_checkpoint(MODELS_DIR.joinpath('last.ckpt'))  
model.eval()

Then we can use the model to classify the images.

If you run the script, you should see a nice 32 x 32 frog.

A picture of frog from CIFAR10
Fig 10. A nice 32 x 32 frog. CIFAR10.

And the output in the terminal:

Predicted class: 6 - frog
Target class: 6 - frog

Summary

As you see there is nothing complicated in the vision transformer. It utilizes the same self-attention mechanism as any other transformer model. Despite the production models should be pre-trained on a huge set of data, it was shown that the vision transformer can be trained even with the CIFAR10 dataset. However powerful computation device is required: 200 epochs of training of the above model were running for almost 3 hours on RTX 4090.

Also, we fulfilled our secondary goal of the tutorial and showed how well the code is organized when we use a bundle of pytorch and pytorch_lightning. This approach can be used to, actually, prepare any production-ready model.

References

[1] Mikhail Kravets, vision_transformer (2023). GitHub repository containing the full code of the tutorial;

[2] Ashish Vaswani, Llion Jones, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin, Attention Is All You Need (2017);

[3] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby, An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale (2021);

[4] Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, Jiawei Han, Understanding the Difficulty of Training Transformers (2020);

[5] Jay Alammar, The Illustrated Transformer (2018);

[6] Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, Quoc V. Le, AutoAugment: Learning Augmentation Strategies from Data (2019);

[7] Leonie Monigatti, A Visual Guide to Learning Rate Schedulers in PyTorch (2022).

--

--