## Imports

import numpy as np
import os
import random
import copy
from typing import Iterable
import pickle
from tqdm import tqdm

import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet34
from torchvision.models import densenet121


## Setup

def seed_everything(seed):
    # Set the seed for reproducibility 

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def load_data(dataset, batch_size, num_workers=4):
    # dataset: 'cifar10' or 'cifar100' 

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if os.path.isdir('./data'):
        root_path = './data'
    else:
        root_path = '../data'
    
    if dataset == 'cifar10':
        num_classes = 10
        trainset = torchvision.datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform_test)
    elif dataset == 'cifar100':
        num_classes = 100
        trainset = torchvision.datasets.CIFAR100(root=root_path, train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root=root_path, train=False, download=True, transform=transform_test)
    else:
        raise ValueError('Only cifar 10 and cifar 100 are supported')

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    print('Data Prepared.')

    return trainloader, testloader, num_classes

def accuracy_and_loss(net, dataloader, device, criterion):
    net.eval()
    correct = 0
    total = 0
    loss = 0
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss += criterion(outputs, labels).cpu().item() / len(dataloader)

    return 100 * correct / total, loss

def train(net, optimizer, n_epoch, trainloader, testloader, criterion, checkpoint, device, sps=False):
    train_los = []
    train_acc = []
    test_los = []
    test_acc = []

    for epoch in tqdm(range(n_epoch)):  # loop over the dataset multiple times
        running_loss = 0.0
        running_corrects = 0
        running_total = 0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad(set_to_none=True)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            if sps:
                optimizer.step(loss=loss)
            else:
                optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            running_total += labels.size(0)
            running_corrects += (predicted == labels).sum().item()

            if i % checkpoint == checkpoint - 1:
                train_los.append(running_loss / len(trainloader))
                train_acc.append(100 * running_corrects / running_total)
                test_a, test_l = accuracy_and_loss(net, testloader, device, criterion)
                test_acc.append(test_a)
                test_los.append(test_l)
                running_loss = 0.0
                net.train()
        
        # train_los.append(running_loss / len(trainloader))
        # train_acc.append(100 * running_corrects / running_total)

    return train_los, train_acc, test_los, test_acc


## Optimizers

class alr_smag(torch.optim.Optimizer):
    # from [70]
    def __init__(self,
                 params,
                 c=1,
                 weight_decay=0,
                 momentum_decay=0.9,
                 eta_max=None,
                 projection_fn=False,
                 max_norm=500,
                 adapt_flag='constant',
                 fstar_flag=None,
                 eps=1e-5,
                 warmup=False,
                 warmup_tau=0.1,
                 warmup_tau2=1e-4,
                 centralize_grad=False,
                 centralize_grad_norm=False
                ):
        params = list(params)
        super().__init__(params, {})
        self.params = params
        self.c = c
        self.weight_decay = weight_decay
        self.decay = momentum_decay
        self.eta_max = eta_max
        self.projection = projection_fn
        self.max_norm = max_norm
        self.adapt_flag = adapt_flag
        self.fstar_flag = fstar_flag
        self.eps = eps
        self.warmup = warmup
        self.tau = warmup_tau
        self.tau_2 = warmup_tau2
        self.centralize_grad_norm = centralize_grad_norm
        self.centralize_grad = centralize_grad
        self.state['step'] = 0
        self.state['step_size'] = eta_max
        self.state['n_forwards'] = 0
        self.state['n_backwards'] = 0
        
    def getstate(self):
        return self.state['step_size']
    
    @torch.autograd.no_grad()
    def l2_projection(self):
        total_norm = 0
        for group in self.param_groups:
            for p in group['params']:
                total_norm += torch.sum(torch.mul(p, p))
        total_norm = total_norm.clone().detach()
        total_norm = float(torch.sqrt(total_norm))

        if total_norm > self.max_norm:
            ratio = self.max_norm / total_norm
            for group in self.param_groups:
                for p in group['params']:
                    p *= ratio

    def step(self, closure=None, loss=None, batch=None):
        if loss is None and closure is None:
            raise ValueError('please specify either closure or loss')

        if loss is not None:
            if not isinstance(loss, torch.Tensor):
                loss = torch.tensor(loss)
        if loss is None:
            loss = closure()
        else:
            assert closure is None, 'if loss is provided then closure should be None'
        
        # get fstar
        if self.fstar_flag:
            fstar = float(batch['meta']['fstar'].mean())
        else:
            fstar = 0.
        
        # compute current gradients and averaged gradient
        grad_current = self.get_grad_list(self.params, centralize_grad=self.centralize_grad)
        if self.state['step'] == 0:
            avg_grad = copy.deepcopy(grad_current)
        else:
            avg_grad = self.state['avg_grad']
            avg_grad = self.get_avg_grad(grad_current, avg_grad, self.decay, centralize_grad=False)

        # update search direction
        avg_grad_norm = self.compute_grad_norm(avg_grad, centralize_grad_norm=self.centralize_grad_norm)
        if self.adapt_flag in ['constant']:
            # adjust the step size
            step_size = (loss - fstar) / (self.c * avg_grad_norm**2 + self.eps)
            if loss < fstar:
                step_size = 0.
            else:
                if self.eta_max is None:
                    step_size = step_size.item()
                else:
                    if self.warmup is True:
                        self.eta_max = self.tau * min(self.tau_2*self.state['step'], 1)
                    else:
                        self.eta_max = self.eta_max
                    step_size = min(self.eta_max, step_size.item())
                    
            self.sgd_update(self.params, step_size, avg_grad, self.weight_decay)
        
        if self.projection is True:
            self.l2_projection()

        # update state with metrics
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1
        self.state['step'] += 1
        self.state['step_size'] = step_size
        self.state['avg_grad'] = avg_grad
        
        if torch.isnan(self.params[0]).sum() > 0:
            raise ValueError('Got NaNs')

        return float(loss)

    @torch.no_grad()
    def get_grad_list(self, params, centralize_grad=False):
        grad_list = []
        for p in params:
            g = p.grad
            if g is None:
                g = 0.
            else:
                g = p.grad.data
                if len(list(g.size()))>1 and centralize_grad:
                    # centralize grads
                    g.add_(-g.mean(dim = tuple(range(1,len(list(g.size())))), keepdim = True))
                    
            grad_list += [g]
                    
        return grad_list

    @torch.no_grad()
    def get_avg_grad(self, grad, avg_grad, decay, centralize_grad=False):
        avg_grad_list = []
        for v, g in zip(avg_grad, grad):
            if g is None or (isinstance(g, float) and g == 0.):
                continue
            v.mul_(decay).add_(g, alpha=1)
            if len(list(g.size()))>1 and centralize_grad:
            # centralize grads
                v.add_(-v.mean(dim = tuple(range(1,len(list(v.size())))), keepdim = True))
                
            avg_grad_list += [v]
            
        return avg_grad_list

    @torch.no_grad()
    def compute_grad_norm(self, grad_list, centralize_grad_norm=False):
        grad_norm = 0.
        for g in grad_list:
            if g is None or (isinstance(g, float) and g == 0.):
                continue

            if g.dim() > 1 and centralize_grad_norm: 
                # centralize grads 
                g.add_(-g.mean(dim = tuple(range(1,g.dim())), keepdim = True))

            grad_norm += torch.sum(torch.mul(g, g))
        grad_norm = grad_norm.clone().detach()     
        grad_norm = torch.sqrt(grad_norm)
        
        return grad_norm

    @torch.no_grad()
    def sgd_update(self, params, step_size, grad_current, weight_decay):
        for p, g in zip(params, grad_current):
            if isinstance(g, float) and g == 0.:
                continue
            if weight_decay!= 0:
                g.add_(p.data, alpha=weight_decay)
            p.data.add_(other=g, alpha=- step_size)

class SPS_smooth(torch.optim.Optimizer):
    # from [39]
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 n_batches_per_epoch: int = 256,
                 init_step_size: float = 1.0,
                 c: float = 1.0,
                 gamma: float = 2.0,
                 ):
        
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.gamma = gamma
        self.init_step_size = init_step_size
        self.n_batches_per_epoch = n_batches_per_epoch
        self.state['step_size'] = init_step_size

    def step(self, loss):
        grad_norm = self.compute_grad_terms()

        step_size = loss / (self.c * (grad_norm)**2)
        coeff = self.gamma**(1./self.n_batches_per_epoch)
        step_size = min(coeff * self.state['step_size'], step_size.item())

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                p.data.add_(other=p.grad, alpha=- step_size)

        self.state['step_size'] = step_size

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class MomSPS_smooth(torch.optim.Optimizer):
    def __init__(
            self,
            params: Iterable[nn.parameter.Parameter],
            n_batches_per_epoch: int = 256,
            init_step_size: float = 1.0,
            c: float = 1.0,
            beta: float = 0.9,
            gamma: float = 2.0,
            ):
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "beta": beta, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.beta = beta
        self.gamma = gamma
        self.coeff = gamma**(1./n_batches_per_epoch)

        self.number_steps = 0
        self.state["p"] = 0
        self.state["momsps"] = init_step_size

    @torch.no_grad()
    def step(self, loss=None):
        self.number_steps += 1
        c = self.c
        beta = self.beta
        coeff = self.coeff
        grad_norm = self.compute_grad_terms()

        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]
                
                momsps = (1-beta) * (loss/(c*grad_norm**2))
                momsps = min(momsps.item(), coeff * self.state["momsps"])
                
                # momsps = (loss/(c*grad_norm**2))
                # momsps = (1-beta) * min(momsps.item(), coeff * self.state["momsps"])

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad + beta * (p-old_p)

                with torch.no_grad():
                    p.copy_(new_p)
                
                self.state["momsps"] = momsps
        
        return loss
    
    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class Naive_MomSPS_smooth(torch.optim.Optimizer):
    def __init__(
            self,
            params: Iterable[nn.parameter.Parameter],
            n_batches_per_epoch: int = 256,
            init_step_size: float = 1.0,
            c: float = 1.0,
            beta: float = 0.9,
            gamma: float = 2.0,
            ):
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "beta": beta, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.beta = beta
        self.gamma = gamma
        self.coeff = gamma**(1./n_batches_per_epoch)

        self.number_steps = 0
        self.state["p"] = 0
        self.state["momsps"] = init_step_size

    @torch.no_grad()
    def step(self, loss=None):
        self.number_steps += 1
        c = self.c
        beta = self.beta
        coeff = self.coeff
        grad_norm = self.compute_grad_terms()

        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]
                
                momsps = (loss/(c*grad_norm**2))
                momsps = min(momsps.item(), coeff * self.state["momsps"])
                
                # momsps = (loss/(c*grad_norm**2))
                # momsps = (1-beta) * min(momsps.item(), coeff * self.state["momsps"])

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad + beta * (p-old_p)

                with torch.no_grad():
                    p.copy_(new_p)
                
                self.state["momsps"] = momsps
        
        return loss
    
    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm


## Tests

def opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=False):
    net = model(weights=None)
    net.to(device)
    net.train()

    seed_everything(seed)
    optimizer = opt_func(net.parameters())
    if name == 'shb':
        optimizer = torch.optim.SGD(net.parameters(), momentum=0.9)
    if name == 'alr_smag_warmup':
        optimizer = opt_func(net.parameters(), warmup=True)

    trainlos, trainacc, testlos, testacc = train(net, 
                                                 optimizer, 
                                                 n_epoch=epochs, 
                                                 trainloader=trainloader, 
                                                 testloader=testloader, 
                                                 criterion=criterion, 
                                                 checkpoint=checkpoint, 
                                                 device=device, 
                                                 sps=need_loss)

    pickle.dump(trainlos, open(path+'m_'+model.__name__+'_set_'+dset+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_trainlos_'+name+'.p', "wb"))
    pickle.dump(trainacc, open(path+'m_'+model.__name__+'_set_'+dset+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_trainacc_'+name+'.p', "wb"))
    pickle.dump(testlos, open(path+'m_'+model.__name__+'_set_'+dset+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_testlos_'+name+'.p', "wb"))
    pickle.dump(testacc, open(path+'m_'+model.__name__+'_set_'+dset+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_testacc_'+name+'.p', "wb"))

    print(name+' completed')

seed_list = [42, 123, 1029, 1234, 2048]
criterion = torch.nn.CrossEntropyLoss()

dset_list = ['cifar10', 'cifar100']
model_list = [resnet18, resnet34, densenet121]
bs_list = [128, 256]

epochs = 200
path = 'saved_results/'

opt_list = ['sgd', 'shb', 'adam', 
            'sps_smooth', 'momsps_smooth', 'naive_momsps_smooth', 
            'alr_smag_warmup', 
            'decsps', 'adasps', 'momdecsps', 'momadasps', 'adagrad']

for dset in dset_list:
    for batch_size in bs_list:
        trainloader, testloader, num_classes = load_data(dataset=dset, batch_size=batch_size)
        checkpoint = len(trainloader) // 3 + 1

        for model in model_list:
            for seed in seed_list:
                seed_everything(seed)
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                if device == 'cpu':
                    print("Running on CPU - BREAK")
                    break

                # SGD
                name = 'sgd'
                if name in opt_list:
                    opt_func = torch.optim.SGD
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=False)

                # SHB
                name = 'shb'
                if name in opt_list:
                    opt_func = torch.optim.SGD
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=False)

                # ADAM
                name = 'adam'
                if name in opt_list:
                    opt_func = torch.optim.Adam
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=False)

                # SPS smooth
                name = 'sps_smooth'
                if name in opt_list:
                    opt_func = SPS_smooth
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=True)

                # MomSPS smooth
                name = 'momsps_smooth'
                if name in opt_list:
                    opt_func = MomSPS_smooth
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=True)

                # Naive MomSPS smooth
                name = 'naive_momsps_smooth'
                if name in opt_list:
                    opt_func = Naive_MomSPS_smooth
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=True)

                # Wang ALR SMAG with warmup
                name = 'alr_smag_warmup'
                if name in opt_list:
                    opt_func = alr_smag
                    opt_test(name, model, device, opt_func, seed, trainloader, testloader, criterion, checkpoint, need_loss=True)
                

                print('Model: '+model.__name__+', Dataset: '+dset+', BS: '+str(batch_size)+', Epochs: '+str(epochs)+', Seed: '+str(seed)+', Finished.')

print("Script FINISHED!!!")
