# Trainable Quantum Convolution

In [None]:
import torch
from torch import nn

import torchvision

import pennylane as qml
from pennylane import numpy as np
from pennylane.templates import RandomLayers

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class QonvLayer(nn.Module):
    def __init__(self, stride=2, device="default.qubit", wires=4, circuit_layers=4, n_rotations=8, out_channels=4, seed=None):
        super(QonvLayer, self).__init__()
        
        # init device
        self.wires = wires
        self.dev = qml.device(device, wires=self.wires)
        
        self.stride = stride
        self.out_channels = min(out_channels, wires)
        
        if seed is None:
            seed = np.random.randint(low=0, high=10e6)
            
        print("Initializing Circuit with random seed", seed)
        
        # random circuits
        @qml.qnode(device=self.dev, interface="torch")
        def my_circuit(inputs, weights):
            n_inputs=4
            # Encoding of 4 classical input values
            for j in range(n_inputs):
                qml.RY(inputs[j], wires=j)
            # Random quantum circuit
            RandomLayers(weights, wires=list(range(self.wires)), seed=seed)
            
            # Measurement producing 4 classical output values
            return [qml.expval(qml.PauliZ(j)) for j in range(self.out_channels)]
        
        weight_shapes = {"weights": [circuit_layers, n_rotations]}
        self.circuit = qml.qnn.TorchLayer(my_circuit, weight_shapes=weight_shapes)
        
        # Just to demonstrate the circuit
        inputs = torch.from_numpy(np.zeros(4))
        weights = torch.from_numpy(np.zeros((4,8)))
        print(qml.draw_mpl(my_circuit)(inputs, weights))
    
    
    def my_draw(self):
        # build circuit by sending dummy data through it
        _ = self.circuit(inputs=torch.from_numpy(np.zeros(4)))
        print(self.circuit.qnode.draw())
        self.circuit.zero_grad()
        
    
    def forward(self, img):
        # TODO : add code to print the initial images
        # img is a 4-dims tensor, img[number of imgs, height, width, channels]

        
        
        bs, h, w, ch = img.size()
        if ch > 1:
            img = img.mean(axis=-1).reshape(bs, h, w, 1)
                        
        kernel_size = 2        
        h_out = (h-kernel_size) // self.stride + 1
        w_out = (w-kernel_size) // self.stride + 1
        
        out = torch.zeros((bs, h_out, w_out, self.out_channels))
        
        for b in range(bs):
            for j in range(0, h - kernel_size + 1, self.stride):
                for k in range(0, w - kernel_size + 1, self.stride):
                    # Process a squared 2x2 region of the image with a quantum circuit
                    q_results = self.circuit(
                        inputs=torch.Tensor([
                            img[b, j, k, 0],
                            img[b, j, k + 1, 0],
                            img[b, j + 1, k, 0],
                            img[b, j + 1, k + 1, 0]
                        ])
                    )
                    for c in range(self.out_channels):
                        out[b, j // kernel_size, k // kernel_size, c] = q_results[c]
        
        # TODO : add code here to print the convoluted output image (out)
        
        
        
        return out

In [None]:
# Test QonvLayer
qonv = QonvLayer(circuit_layers=1, n_rotations=8, out_channels=4, stride=2)

In [None]:
def transform(x):
    # normalizing the values inside the image
    x = np.array(x)
    x = x/255.0
    return torch.from_numpy(x).float()

# Model training

In [None]:
def train(model, train_loader, epochs=50):
    print("Starting Training for {} epochs".format(epochs))

    model.train()

    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
    
    losses = np.array([])
    accs = np.array([])

    for epoch in range(epochs):
        global_epoch = epoch
        for i, (x, y) in enumerate(train_loader):
            global_step = i           

            # prepare inputs and labels
            x = x.view(-1, 28, 28, 1)
            y = y.long()

            # reset optimizer
            optimizer.zero_grad()
            
            # engage
            y_pred = model(x)

            # error, gradients and optimization
            loss = criterion(y_pred, y)  
            loss.backward()
            optimizer.step()

            # output
            acc = accuracy_score(y, y_pred.argmax(-1).numpy())
            
            accs = np.append(accs, acc)
            losses = np.append(losses, loss.item())
            
            #if i == 50: 
                # This loop can stop training at i steps (and not go through
                # the entire dataset), uncomment the "if" and the "break" to do so
                
            print("Epoch:", epoch, 
                  "\tStep:", i, 
                  "\tAcc:", round(acc, 3), 
                  "\tLoss:", round(loss.item(),3),
                  "\tMean Loss:", round(float(losses[-30:].mean()), 3),
                  "\tMean Acc:", round(float(accs[-30:].mean()), 3)
                 )
            print("---------------------------------------\n")
                
            #break
            
    return model, losses, accs

In [None]:
if __name__ == "__main__":
    # prepare dataset
    train_set = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=4)
    
    # build the model
    model = torch.nn.Sequential(
        QonvLayer(stride=2, circuit_layers=2, n_rotations=4, out_channels=4, seed=9321727),
        torch.nn.Flatten(),
        torch.nn.Linear(in_features=14*14*4, out_features=10)
    )
    
    # start training
    model, losses, accs = train(model, train_loader, epochs=10)
    
    # plot losses and accuracies
    fig, (ax1, ax2) = plt.subplots(1,2, figsize=(16, 4))
    ax1.plot(losses)
    ax1.set_title("Loss")
    ax1.set_xlabel("Steps")
    ax1.set_ylabel("Loss")

    ax2.plot(accs)
    ax2.set_title("Accuracy")
    ax2.set_xlabel("Steps")
    ax2.set_ylabel("Accuracy")