import argparse
from asyncio import wrap_future
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import torch.optim as optim
from derivatives import derivatives
from cifar_data import load_cifar
from vgg import VGG11, VGG13, VGG16, VGG19
from resnet import resnet32
from resnet_kuang import ResNet18
import numpy as np

deriv_size = 2500
train_size = 5000
device = 'cuda:0'
num_trials = 5
smoothings = [0.0, 0.5, 0.75]

traindata, _ = load_cifar('ce')
loader = DataLoader(traindata, batch_size = train_size, num_workers = 1)

def GD_step(net, criterion, lr):
    correct = 0
    total = 0

    loss = 0.

    grads = [torch.zeros_like(p) for p in net.parameters()]

    lval = 0.

    for data in loader:
        net.zero_grad()
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        loss = criterion(outputs, labels)
        loss.backward()
        grad = [(train_size / 50000)*p.grad for p in list(net.parameters())]
        grads = [g1 + g2 for g1, g2 in zip(grads, grad)]
        lval += (train_size / 50000)*loss.detach().cpu().numpy()

    new_params = [p - lr * g for p, g in zip(list(net.parameters()), grads)]
    with torch.no_grad():
        for p, newp in zip(list(net.parameters()), new_params):
            p.copy_(newp)

    return correct / total, lval

def warmup(net, loader):
    for data, _ in loader:
        data = data.to(device)
            
    for i in range(100):
        net(data)

def compute_derivs(derivative, loader):
    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

    data = (inputs, targets)
    derivative.update(data)
    return derivative.power('H'), derivative.power('jac1train'), derivative.power('jac1eval'), derivative.power('jac1trainsoft'), derivative.power('jac1evalsoft')

def main(model):

    if model == 'vgg':
        lrs = [0.08, 0.04, 0.02]
    elif model == 'resnet':
        lrs = [0.08, 0.04, 0.02]

    for lr in lrs:

        epochs = 0

        loss = []
        hess =[]
        j1t = [] # jac1 train
        j1e = [] # jac1 eval
        j1ts = [] # jac1 train with softmax
        j1es = [] # jac1 eval with softmax

        for n in range(num_trials):

            loss_ = []
            hess_ = []
            j1t_ = []
            j1e_ = []
            j1ts_ = []
            j1es_ = []

            torch.manual_seed(n)
            if model == 'vgg':
                eig_freq = 10
                net = VGG11(10).to(device)
            elif model == 'resnet':
                eig_freq = 5
                net = ResNet18(10).to(device)

            os.makedirs(f'./fullbatch/{model}/{lr:g}/', exist_ok=True)

            torch.save(net.state_dict(), f'./fullbatch/{model}/{lr:g}/params_{n}.npy')
            deriv_subset = Subset(traindata, torch.randint(0, 50000, (deriv_size,)))
            deriv_loader = DataLoader(deriv_subset, deriv_size)

            for s in smoothings:

                loss__ = []
                hess__ = []
                j1t__ = []
                j1e__ = []
                j1ts__ = []
                j1es__ = []

                criterion = nn.CrossEntropyLoss(label_smoothing = s)
                d = derivatives(net, criterion, [deriv_size, 3, 32, 32], [deriv_size, 10], device)

                acc = 0.

                if n == 0 and s == 0.0:
                    while acc < 0.99:

                        if epochs == 0:
                            warmup(net, deriv_loader)

                        if epochs % eig_freq == 0:
                            # print('yep')
                            derivs = compute_derivs(d, deriv_loader)
                            # print(derivs)
                            # print('hess: ', derivs[0])
                            hess__.append(derivs[0])
                            j1t__.append(derivs[1])
                            j1e__.append(derivs[2])
                            j1ts__.append(derivs[3])
                            j1es__.append(derivs[4])

                        net.train()
                        acc, lossval = GD_step(net, criterion, lr)
                        print(f'LR {lr}, trial {n}, smoothing {s:g}')
                        print(f'Acc: {acc}')
                        loss__.append(lossval)

                        epochs += 1
                        if epochs > 999:
                            break
                
                else:
                    for e in range(epochs):

                        if e == 0:
                            warmup(net, deriv_loader)

                        if e % eig_freq == 0:
                            derivs = compute_derivs(d, deriv_loader)
                            hess__.append(derivs[0])
                            j1t__.append(derivs[1])
                            j1e__.append(derivs[2])
                            j1ts__.append(derivs[3])
                            j1es__.append(derivs[4])
    
                        net.train()
                        acc, lossval = GD_step(net, criterion, lr)
                        print(f'LR {lr}, trial {n}, smoothing {s:g}')
                        print(f'Acc: {acc}')
                        loss__.append(lossval)



                loss_.append(loss__)
                hess_.append(hess__)
                j1t_.append(j1t__)
                j1e_.append(j1e__)
                j1ts_.append(j1ts__)
                j1es_.append(j1es__)

            loss.append(loss_)
            hess.append(hess_)
            j1t.append(j1t_)
            j1e.append(j1e_)
            j1ts.append(j1ts_)
            j1es.append(j1es_)

        np.save(f'./fullbatch/{model}/{lr}/loss.npy', loss)
        np.save(f'./fullbatch/{model}/{lr}/hess.npy', hess)
        np.save(f'./fullbatch/{model}/{lr}/j1t.npy', j1t)
        np.save(f'./fullbatch/{model}/{lr}/j1e.npy', j1e)
        np.save(f'./fullbatch/{model}/{lr}/j1ts.npy', j1ts)
        np.save(f'./fullbatch/{model}/{lr}/j1es.npy', j1es)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument('model', type = str, choices = ['vgg', 'resnet'])
    args = parser.parse_args()

    main(model = args.model)



                                



            

    