Convolutional Autoencoders¶

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import IPython.display as ipd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import time

Dataset Definitions / Example¶

In [2]:
torch.manual_seed(1)
#Converting data to torch.FloatTensor and padding to 32x32
transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()])
data_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
data_test = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Use both datasets to maximize info
data = torch.utils.data.ConcatDataset([data_train, data_test]) 

DIGIT_RES = data_train[0][0].shape[-1]
In [3]:
X, y = data_train[0]
# num channels x height x width
print(X.shape)
plt.imshow(X[0, :, :], cmap='gray')
torch.Size([1, 32, 32])
Out[3]:
<matplotlib.image.AxesImage at 0x7f6489ed52d0>
In [4]:
layer1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1)
X1 = layer1(X)
print("X1.shape", X1.shape)
layer2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
X2 = layer2(X1)
print("X2.shape", X2.shape)
for i, weights in enumerate(layer2.parameters()):
    print("layer 2 weight set {}".format(i), weights.shape)
X1.shape torch.Size([16, 16, 16])
X2.shape torch.Size([32, 8, 8])
layer 2 weight set 0 torch.Size([32, 16, 3, 3])
layer 2 weight set 1 torch.Size([32])

Model Definition¶

In [5]:
class ConvAutoencoder(nn.Module):
    def __init__(self, digit_res, depth=4, dim_latent=2, dim_img=32, in_channels=1):
        """
        digit_res: int
            Resolution of digit
        depth: int
            How many convolutional layers there are in the encoder/decoder
        dim_latent: int
            Dimension of the latent space
        dim_digit: int
            Width/height of input image
        in_channels: int
            Number of channels of input image
        """
        super().__init__()
        self.dim_latent = dim_latent
        
        ## Step 1: Create convolutional encoder
        in_orig = in_channels
        layers = []
        out_channels = 16
        for i in range(depth):
            layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1))
            layers.append(nn.LeakyReLU())
            in_channels = out_channels
            out_channels *= 2
        # Create a dummy input to get the shape right
        X = torch.zeros(1, in_orig, dim_img, dim_img)
        XOut = nn.Sequential(*layers)(X)
        shape = XOut.shape[1::]
        layers += [nn.Flatten(), nn.Linear(np.prod(shape), dim_latent), nn.Sigmoid()]
        self.encoder = nn.Sequential(*layers)
        
        ## Step 2: Setup convolutional decoder
        layers = [nn.Linear(dim_latent, np.prod(shape)), nn.LeakyReLU(), nn.Unflatten(1, shape)]
        in_channels = out_channels//2
        for i in range(depth):
            out_channels = 1
            if i < depth-2:
                out_channels = in_channels // 2
            # Use upsampling with bilinear interpolation instead of ConvTranspose
            # to avoid checkerboard artifacts
            # See this link for more info: https://distill.pub/2016/deconv-checkerboard/
            layers.append(nn.Upsample(scale_factor=2, mode='bilinear'))
            layers.append(nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1))
            layers.append(nn.LeakyReLU()) # I forgot this in the video!
            in_channels = out_channels

        self.decoder = nn.Sequential(*layers)
    
    def forward(self, X):
        z = self.encoder(X) # Encoding in latent space
        XOut = self.decoder(z) # Decoding
        loss = torch.sum((X-XOut)**2)
        return z, XOut, loss

Plotting Code for Training¶

In [6]:
def scatter_digits(model, data, device, n_scatter=1000):
    """
    Scatter a subset of digits in their latent representation
    
    Parameters
    ----------
    model: nn.Module
        Autoencoder model
    data: torch dataset
        Digits dataset
    device: str
        Device on which to run the model
    n_scatter: int
        Number of example digits to scatter
    """
    from matplotlib.offsetbox import OffsetImage, AnnotationBbox
    ax = plt.gca()
    encoded = []
    # Convert a grayscale digit to one with a background color chosen from
    # the tab10 colorcycle to indicate its class
    c = plt.get_cmap("tab10")
    jump = len(data)//n_scatter
    for k in range(n_scatter):
        tidx = k*jump
        label = data[tidx][1]
        img = data[tidx][0].to(device)
        z, _, _ = model(img.unsqueeze(0))
        img = img.detach().cpu()[0, :, :].numpy()
        x, y = z[0, :].detach().cpu()
        encoded.append([x, y])
        C = c([label]).flatten()[0:3]
        img_disp = np.zeros((img.shape[0], img.shape[1], 4))
        img_disp[:, :, 0:3] = img[:, :, None]*C[None, None, :]
        img_disp[:, :, 3] = img
        img_disp = OffsetImage(img_disp, zoom=0.7)
        ab = AnnotationBbox(img_disp, (x, y), xycoords='data', frameon=False)
        ax.add_artist(ab)
    encoded = np.array(encoded)
    ax.update_datalim(encoded)
    ax.autoscale()

def plot_digits_dimreduced_examples(model, data, device, n_examples=20):
    """
    Plot examples of encoded digits, as well as a scatter of some digits
    in their latent representation
    
    Parameters
    ----------
    model: nn.Module
        Autoencoder model
    data: torch dataset
        Digits dataset
    device: str
        Device on which to run the model
    n_examples: int
        Number of example encodings to show
    """
    ## Step 1: Plot examples of encodings
    jump = len(data)//n_examples
    for k in range(n_examples):
        tidx = k*jump
        x = data[tidx][0].to(device)
        z, xenc, _ = model(x.unsqueeze(0))
        x = x.detach().cpu()[0, :, :]
        xenc = xenc.detach().cpu()[0, 0, :, :]
        
        plt.subplot(n_examples, n_examples, k+1)
        plt.imshow(x, vmin=0, vmax=1, cmap='gray')
        plt.axis("off")
        plt.subplot(n_examples, n_examples, n_examples+k+1)
        plt.imshow(xenc, vmin=0, vmax=1, cmap='gray')
        plt.axis("off")

    ## Step 2: Do a scatterplot of a subset of the digits in their latent space
    plt.subplot2grid((n_examples, n_examples), (2, 0), colspan=n_examples, rowspan=n_examples-2)
    scatter_digits(model, data, device)
In [7]:
device = 'cuda'
model = ConvAutoencoder(32)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

n_epochs = 20
batch_size = 16
train_losses = []

plt.figure(figsize=(10, 10))

for epoch in range(n_epochs):
    model.train()
    loader = DataLoader(data, batch_size=batch_size, shuffle=True)
    train_loss = 0
    for i, (X, Y) in enumerate(loader): # Go through each mini batch
        X = X.to(device)
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        # Run the sequential model on all inputs
        _, _, loss = model(X)
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        # Update the parameters based on the gradient and
        # the optimization scheme
        optimizer.step()
        train_loss += loss.item()
        
        if i%100 == 0:
            ipd.clear_output()
            print("Epoch {} batch {}: loss {:.3f}".format(epoch, i, train_loss/((i+1)*batch_size)))
    plt.clf()
    model.eval()
    plot_digits_dimreduced_examples(model, data_train, device)
    plt.savefig("Epoch{}.png".format(epoch))
Epoch 19 batch 4300: loss 33.857
In [8]:
layer = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
for p in layer.parameters():
    print(p.shape)
torch.Size([32, 16, 5, 5])
torch.Size([32])
In [ ]: