import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import torch.optim as optim
from derivatives import derivatives
import matplotlib.pyplot as plt
from vgg import VGG11, VGG13, VGG16, VGG19
import numpy as np
import os

data_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = datasets.CIFAR10('./data', train = True, transform = data_transform)

device = 'cuda:0'

loader = DataLoader(trainset, 128, shuffle = True)

size = 2000

device = 'cuda:0'

num_trials = 5

scalars = [0.0, 0.5, 0.75]

def compute_derivs_and_norm(net, derivative, loader):
    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

    data = (inputs, targets)
    derivative.update(data)

    outs = net(inputs)

    return derivative.power('H'), derivative.power('jac1train'), derivative.power('jac1eval'), derivative.power('jac1trainsoft'), derivative.power('jac1evalsoft'), torch.linalg.vector_norm(outs).detach().cpu().numpy()

def main(trial):

    os.makedirs(f'./vgg_sgd/', exist_ok=True)

    torch.manual_seed(trial)
    net = VGG11(10).to(device)
    torch.save(net.state_dict(), f'./vgg_sgd/params_{trial}.pt')

    epochs = 90
    optimizer = optim.SGD(net.parameters(), lr = 0.1)

    indices = list(torch.randperm(50000)[:size])
    subset = Subset(trainset, indices)
    subloader = DataLoader(subset, size)

    for sc in scalars:

        h = []
        j1t = []
        j1e = []
        j1ts = []
        j1es = []
        outnorm = []

        criterion = nn.CrossEntropyLoss(label_smoothing = sc)

        net.load_state_dict(torch.load(f'./vgg_sgd/params_{trial}.pt'))

        d = derivatives(net, criterion, [size, 3, 32, 32], [size, 10], device)

        for e in range(epochs):

            if e == 0:
                j = 0
                for batchdata, batchtargets in loader:

                    net.zero_grad()

                    if j % 5 == 0:

                        derivs = compute_derivs_and_norm(net, d, subloader)
                        h.append(derivs[0])
                        j1t.append(derivs[1])
                        j1e.append(derivs[2])
                        j1ts.append(derivs[3])
                        j1es.append(derivs[4])
                        outnorm.append(derivs[5])

                    net.train()
                    net.zero_grad()

                    batchdata = batchdata.to(device)
                    batchtargets = batchtargets.to(device)

                    outs = net(batchdata)
                    loss = criterion(outs, batchtargets)
                    loss.backward()
                    optimizer.step()
                    
                    j += 1

            else:
                if e % 2 == 0:
                    derivs = compute_derivs_and_norm(net, d, subloader)
                    h.append(derivs[0])
                    j1t.append(derivs[1])
                    j1e.append(derivs[2])
                    j1ts.append(derivs[3])
                    j1es.append(derivs[4])
                    outnorm.append(derivs[5])
    
                for batchdata, batchtargets in loader:
                    net.zero_grad()
                    net.train()
                    batchdata = batchdata.to(device)
                    batchtargets = batchtargets.to(device)

                    outs = net(batchdata)
                    loss = criterion(outs, batchtargets)
                    loss.backward()
                    optimizer.step()

        h = np.array(h)
        j1t = np.array(j1t)
        j1e = np.array(j1e)
        j1ts = np.array(j1ts)
        j1es = np.array(j1es)
        outnorm = np.array(outnorm)

        np.save(f'./vgg_sgd/hess_{sc:g}_{trial}.npy', h)
        np.save(f'./vgg_sgd/j1t_{sc:g}_{trial}.npy', j1t)
        np.save(f'./vgg_sgd/j1e_{sc:g}_{trial}.npy', j1e)
        np.save(f'./vgg_sgd/j1ts_{sc:g}_{trial}.npy', j1ts)
        np.save(f'./vgg_sgd/j1es_{sc:g}_{trial}.npy', j1es)
        np.save(f'./vgg_sgd/outnorm_{sc:g}_{trial}.npy', outnorm)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="SGD")
    parser.add_argument('trial', type = int, choices = [0, 1, 2, 3, 4])
    args = parser.parse_args()

    main(trial = args.trial)

