import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
from functorch import make_functional
import torch.optim as optim
import numpy as np
from torch.utils.data import Subset, DataLoader
# from pf import derivatives
from derivatives import derivatives
from cifar_data import load_cifar
from pyhessian import hessian
import matplotlib.pyplot as plt
import os
import argparse
# from vgg import vgg11, vgg11_bn

# trainset = datasets.CIFAR10(root='../data', train=True, download=True, transform = transforms_train)
# testset = datasets.CIFAR10(root='../data', train=False, download=True, transform = transforms_train)

# loader = torch.utils.data.dataloader.DataLoader(trainset, 50000, shuffle=False, num_workers = 2)
# testloader = torch.utils.data.dataloader.DataLoader(trainset, 10000, shuffle=False, num_workers = 2)

# device = 'cuda:0'

# for data, labels in loader:
#     data = data.flatten(1).to(device)
#     data = data - data.mean(0)
#     data = data / data.std(0)
#     data = torch.reshape(data, [50000, 3, 32, 32])
#     labels = labels.to(device)

# # for tdata, tlabels in testloader:
# #     tdata = tdata.flatten(1).to(device)
# #     tdata = tdata - tdata.mean(0)
# #     tdata = tdata / tdata.std(0)
# #     tlabels = tlabels.to(device)

trainset, _ = load_cifar('ce')
indices = [i for i in range(5000)]

subset = Subset(trainset, indices)
loader = DataLoader(subset, 5000, shuffle = False)

device = 'cuda:0'


# import pdb; pdb.set_trace()

# tindices = [i for i in range(1000)]
# tdata = tdata[tindices]
# tlabels = tlabels[tindices]
# tonehotlabels = F.one_hot(tlabels).type(torch.float32)
# tdata = tdata, tonehotlabels

class Net(nn.Module):
    def __init__(self, activation):
        super().__init__()
        if activation == 'Tanh':
            self.lin1 = nn.Linear(3072, 200)
            self.relu1 = nn.Tanh()
            self.lin2 = nn.Linear(200, 200)
            self.relu2 = nn.Tanh()
            self.lin3 = nn.Linear(200, 10)
        elif activation == 'ReLU':
            self.lin1 = nn.Linear(3072, 200)
            self.relu1 = nn.ReLU()
            self.lin2 = nn.Linear(200, 200)
            self.relu2 = nn.ReLU()
            self.lin3 = nn.Linear(200, 10)

    def features(self, x):
        with torch.no_grad():
            z1 = self.relu1(self.lin1(x))
            z2 = self.relu2(self.lin2(z1))
        return [torch.linalg.matrix_norm(z1, ord = 2).cpu().numpy(), torch.linalg.matrix_norm(z2, ord = 2).cpu().numpy()]

    def forward(self, x):
        z1 = self.relu1(self.lin1(x))
        z2 = self.relu2(self.lin2(z1))
        z3 = self.lin3(z2)
        return z3

lrs = [2/10]
scalars = [1.0, 0.5, 1.5]

def main(activation):

    for lr in lrs:

        os.makedirs(f'./sharpening/{activation}/', exist_ok=True)
        epochs = 0

        for sc in scalars:

            for inputs, labels in loader:
                inputs = sc * inputs.flatten(1).to(device)
                labels = labels.to(device)

            data = (inputs, labels)

            for n in range(5):

                lossvals = []
                hessvals = []
                jac1eval = []
                featurevals = []

                torch.manual_seed(n)
                net = Net(activation).to(device)

                criterion = nn.CrossEntropyLoss()
                pd = derivatives(net, criterion, [5000, 3072], [5000, 10], device)

                optimizer = torch.optim.SGD(net.parameters(), lr)

                acc = 0.
                loss = 0

                if n == 0 and sc == 1.0:
                    while acc < .99:
                        optimizer.zero_grad()
                        print('Epoch: ', epochs)

                        pd.update(data)
                        h = pd.power('H')
                        jac = pd.power('jac1eval')
                        hessvals.append(h)
                        jac1eval.append(jac)
                        features = net.features(data[0])
                        featurevals.append(features)

                        net.train()

                        outputs = net(data[0])

                        _, predicted = torch.max(outputs, 1)
                        correct = (predicted == labels).sum().item()
                        acc = correct / 5000
                        print('Scalar: ', sc, 'Acc: ', acc, 'Sharp: ', h)

                        loss = criterion(outputs, labels)
                        lossvals.append(loss.detach().cpu().numpy())

                        loss.backward()
                        optimizer.step()
                        epochs+=1       
                else:
                    for e in range(epochs):
                        optimizer.zero_grad()
                        print('Epoch: ', e)

                        pd.update(data)
                        h = pd.power('H')
                        jac = pd.power('jac1eval')
                        hessvals.append(h)
                        jac1eval.append(jac)
                        features = net.features(data[0])
                        featurevals.append(features)

                        net.train()

                        outputs = net(data[0])

                        _, predicted = torch.max(outputs, 1)
                        correct = (predicted == labels).sum().item()
                        acc = correct / 5000
                        print('Scalar: ', sc, 'Acc: ', acc, 'Sharp: ', h)

                        loss = criterion(outputs, labels)
                        lossvals.append(loss.detach().cpu().numpy())
                        # print('Loss: ', loss)

                        loss.backward()
                        optimizer.step()

                np.save(f'./sharpening/{activation}/loss_{sc:g}_{n:g}.npy', np.array(lossvals))
                np.save(f'./sharpening/{activation}/hess_{sc:g}_{n:g}.npy', np.array(hessvals))
                np.save(f'./sharpening/{activation}/jac_{sc:g}_{n:g}.npy', np.array(jac1eval))
                np.save(f'./sharpening/{activation}/feat_{sc:g}_{n:g}.npy', np.array(featurevals))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument('activation', type = str, choices = ['Tanh', 'ReLU'])
    args = parser.parse_args()

    main(activation = args.activation)