import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import IPython.display as ipd
import matplotlib.pyplot as plt
from import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import time
#Converting data to torch.FloatTensor and padding to 32x32
transform = transforms.Compose([transforms.Pad(2), transforms.ToTensor()])
data_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
data_test = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Use both datasets to maximize info
data =[data_train, data_test])
DIGIT_RES = data_train[0][0].shape[-1]
X, y = data_train[0]
# num channels x height x width
plt.imshow(X[0, :, :], cmap='gray')
torch.Size([1, 32, 32])
<matplotlib.image.AxesImage at 0x7f6489ed52d0>
layer1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1)
X1 = layer1(X)
print("X1.shape", X1.shape)
layer2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
X2 = layer2(X1)
print("X2.shape", X2.shape)
for i, weights in enumerate(layer2.parameters()):
print("layer 2 weight set {}".format(i), weights.shape)
X1.shape torch.Size([16, 16, 16]) X2.shape torch.Size([32, 8, 8]) layer 2 weight set 0 torch.Size([32, 16, 3, 3]) layer 2 weight set 1 torch.Size([32])
class ConvAutoencoder(nn.Module):
def __init__(self, digit_res, depth=4, dim_latent=2, dim_img=32, in_channels=1):
digit_res: int
Resolution of digit
depth: int
How many convolutional layers there are in the encoder/decoder
dim_latent: int
Dimension of the latent space
dim_digit: int
Width/height of input image
in_channels: int
Number of channels of input image
self.dim_latent = dim_latent
## Step 1: Create convolutional encoder
in_orig = in_channels
layers = []
out_channels = 16
for i in range(depth):
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1))
in_channels = out_channels
out_channels *= 2
# Create a dummy input to get the shape right
X = torch.zeros(1, in_orig, dim_img, dim_img)
XOut = nn.Sequential(*layers)(X)
shape = XOut.shape[1::]
layers += [nn.Flatten(), nn.Linear(, dim_latent), nn.Sigmoid()]
self.encoder = nn.Sequential(*layers)
## Step 2: Setup convolutional decoder
layers = [nn.Linear(dim_latent,, nn.LeakyReLU(), nn.Unflatten(1, shape)]
in_channels = out_channels//2
for i in range(depth):
out_channels = 1
if i < depth-2:
out_channels = in_channels // 2
# Use upsampling with bilinear interpolation instead of ConvTranspose
# to avoid checkerboard artifacts
# See this link for more info:
layers.append(nn.Upsample(scale_factor=2, mode='bilinear'))
layers.append(nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1))
layers.append(nn.LeakyReLU()) # I forgot this in the video!
in_channels = out_channels
self.decoder = nn.Sequential(*layers)
def forward(self, X):
z = self.encoder(X) # Encoding in latent space
XOut = self.decoder(z) # Decoding
loss = torch.sum((X-XOut)**2)
return z, XOut, loss
def scatter_digits(model, data, device, n_scatter=1000):
Scatter a subset of digits in their latent representation
model: nn.Module
Autoencoder model
data: torch dataset
Digits dataset
device: str
Device on which to run the model
n_scatter: int
Number of example digits to scatter
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
ax = plt.gca()
encoded = []
# Convert a grayscale digit to one with a background color chosen from
# the tab10 colorcycle to indicate its class
c = plt.get_cmap("tab10")
jump = len(data)//n_scatter
for k in range(n_scatter):
tidx = k*jump
label = data[tidx][1]
img = data[tidx][0].to(device)
z, _, _ = model(img.unsqueeze(0))
img = img.detach().cpu()[0, :, :].numpy()
x, y = z[0, :].detach().cpu()
encoded.append([x, y])
C = c([label]).flatten()[0:3]
img_disp = np.zeros((img.shape[0], img.shape[1], 4))
img_disp[:, :, 0:3] = img[:, :, None]*C[None, None, :]
img_disp[:, :, 3] = img
img_disp = OffsetImage(img_disp, zoom=0.7)
ab = AnnotationBbox(img_disp, (x, y), xycoords='data', frameon=False)
encoded = np.array(encoded)
def plot_digits_dimreduced_examples(model, data, device, n_examples=20):
Plot examples of encoded digits, as well as a scatter of some digits
in their latent representation
model: nn.Module
Autoencoder model
data: torch dataset
Digits dataset
device: str
Device on which to run the model
n_examples: int
Number of example encodings to show
## Step 1: Plot examples of encodings
jump = len(data)//n_examples
for k in range(n_examples):
tidx = k*jump
x = data[tidx][0].to(device)
z, xenc, _ = model(x.unsqueeze(0))
x = x.detach().cpu()[0, :, :]
xenc = xenc.detach().cpu()[0, 0, :, :]
plt.subplot(n_examples, n_examples, k+1)
plt.imshow(x, vmin=0, vmax=1, cmap='gray')
plt.subplot(n_examples, n_examples, n_examples+k+1)
plt.imshow(xenc, vmin=0, vmax=1, cmap='gray')
## Step 2: Do a scatterplot of a subset of the digits in their latent space
plt.subplot2grid((n_examples, n_examples), (2, 0), colspan=n_examples, rowspan=n_examples-2)
scatter_digits(model, data, device)
device = 'cuda'
model = ConvAutoencoder(32)
model =
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
n_epochs = 20
batch_size = 16
train_losses = []
plt.figure(figsize=(10, 10))
for epoch in range(n_epochs):
loader = DataLoader(data, batch_size=batch_size, shuffle=True)
train_loss = 0
for i, (X, Y) in enumerate(loader): # Go through each mini batch
X =
# Reset the optimizer's gradients
# Run the sequential model on all inputs
_, _, loss = model(X)
# Compute the gradients of the loss function with respect
# to all of the parameters of the model
# Update the parameters based on the gradient and
# the optimization scheme
train_loss += loss.item()
if i%100 == 0:
print("Epoch {} batch {}: loss {:.3f}".format(epoch, i, train_loss/((i+1)*batch_size)))
plot_digits_dimreduced_examples(model, data_train, device)
Epoch 19 batch 4300: loss 33.857
layer = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
for p in layer.parameters():
torch.Size([32, 16, 5, 5]) torch.Size([32])