## Imports

import numpy as np
from sklearn.datasets import load_svmlight_file
import os
import random
import copy
from typing import Iterable
import pickle
from tqdm import tqdm

import torch
from torch import nn


## Setup

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def train(epochs, net, optimizer, trainloader, criterion, device, sps):
    train_los = []
    train_acc = []

    for epoch in tqdm(range(epochs)):
        running_loss = 0.0
        running_corrects = 0
        running_total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            if sps:
                optimizer.step(loss=loss)
            else:
                optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            running_total += labels.size(0)
            running_corrects += (predicted == labels).sum().item()
            
        train_los.append(running_loss / len(trainloader))
        train_acc.append(100 * running_corrects / running_total)

    return train_los, train_acc


## Optimizers

class alr_smag(torch.optim.Optimizer):
    # from [70]
    def __init__(self,
                 params,
                 c=1,
                 weight_decay=0,
                 momentum_decay=0.9,
                 eta_max=None,
                 projection_fn=False,
                 max_norm=500,
                 adapt_flag='constant',
                 fstar_flag=None,
                 eps=1e-5,
                 warmup=False,
                 warmup_tau=0.1,
                 warmup_tau2=1e-4,
                 centralize_grad=False,
                 centralize_grad_norm=False
                ):
        params = list(params)
        super().__init__(params, {})
        self.params = params
        self.c = c
        self.weight_decay = weight_decay
        self.decay = momentum_decay
        self.eta_max = eta_max
        self.projection = projection_fn
        self.max_norm = max_norm
        self.adapt_flag = adapt_flag
        self.fstar_flag = fstar_flag
        self.eps = eps
        self.warmup = warmup
        self.tau = warmup_tau
        self.tau_2 = warmup_tau2
        self.centralize_grad_norm = centralize_grad_norm
        self.centralize_grad = centralize_grad
        self.state['step'] = 0
        self.state['step_size'] = eta_max
        self.state['n_forwards'] = 0
        self.state['n_backwards'] = 0
        
    def getstate(self):
        return self.state['step_size']
    
    @torch.autograd.no_grad()
    def l2_projection(self):
        total_norm = 0
        for group in self.param_groups:
            for p in group['params']:
                total_norm += torch.sum(torch.mul(p, p))
        total_norm = total_norm.clone().detach()
        total_norm = float(torch.sqrt(total_norm))

        if total_norm > self.max_norm:
            ratio = self.max_norm / total_norm
            for group in self.param_groups:
                for p in group['params']:
                    p *= ratio

    def step(self, closure=None, loss=None, batch=None):
        if loss is None and closure is None:
            raise ValueError('please specify either closure or loss')

        if loss is not None:
            if not isinstance(loss, torch.Tensor):
                loss = torch.tensor(loss)
        if loss is None:
            loss = closure()
        else:
            assert closure is None, 'if loss is provided then closure should be None'
        
        # get fstar
        if self.fstar_flag:
            fstar = float(batch['meta']['fstar'].mean())
        else:
            fstar = 0.
        
        # compute current gradients and averaged gradient
        grad_current = self.get_grad_list(self.params, centralize_grad=self.centralize_grad)
        if self.state['step'] == 0:
            avg_grad = copy.deepcopy(grad_current)
        else:
            avg_grad = self.state['avg_grad']
            avg_grad = self.get_avg_grad(grad_current, avg_grad, self.decay, centralize_grad=False)

        # update search direction
        avg_grad_norm = self.compute_grad_norm(avg_grad, centralize_grad_norm=self.centralize_grad_norm)
        if self.adapt_flag in ['constant']:
            # adjust the step size
            step_size = (loss - fstar) / (self.c * avg_grad_norm**2 + self.eps)
            if loss < fstar:
                step_size = 0.
            else:
                if self.eta_max is None:
                    step_size = step_size.item()
                else:
                    if self.warmup is True:
                        self.eta_max = self.tau * min(self.tau_2*self.state['step'], 1)
                    else:
                        self.eta_max = self.eta_max
                    step_size = min(self.eta_max, step_size.item())
                    
            self.sgd_update(self.params, step_size, avg_grad, self.weight_decay)
        
        if self.projection is True:
            self.l2_projection()

        # update state with metrics
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1
        self.state['step'] += 1
        self.state['step_size'] = step_size
        self.state['avg_grad'] = avg_grad
        
        if torch.isnan(self.params[0]).sum() > 0:
            raise ValueError('Got NaNs')

        return float(loss)

    @torch.no_grad()
    def get_grad_list(self, params, centralize_grad=False):
        grad_list = []
        for p in params:
            g = p.grad
            if g is None:
                g = 0.
            else:
                g = p.grad.data
                if len(list(g.size()))>1 and centralize_grad:
                    # centralize grads
                    g.add_(-g.mean(dim = tuple(range(1,len(list(g.size())))), keepdim = True))
                    
            grad_list += [g]
                    
        return grad_list

    @torch.no_grad()
    def get_avg_grad(self, grad, avg_grad, decay, centralize_grad=False):
        avg_grad_list = []
        for v, g in zip(avg_grad, grad):
            if g is None or (isinstance(g, float) and g == 0.):
                continue
            v.mul_(decay).add_(g, alpha=1)
            if len(list(g.size()))>1 and centralize_grad:
            # centralize grads
                v.add_(-v.mean(dim = tuple(range(1,len(list(v.size())))), keepdim = True))
                
            avg_grad_list += [v]
            
        return avg_grad_list

    @torch.no_grad()
    def compute_grad_norm(self, grad_list, centralize_grad_norm=False):
        grad_norm = 0.
        for g in grad_list:
            if g is None or (isinstance(g, float) and g == 0.):
                continue

            if g.dim() > 1 and centralize_grad_norm: 
                # centralize grads 
                g.add_(-g.mean(dim = tuple(range(1,g.dim())), keepdim = True))

            grad_norm += torch.sum(torch.mul(g, g))
        grad_norm = grad_norm.clone().detach()     
        grad_norm = torch.sqrt(grad_norm)
        
        return grad_norm

    @torch.no_grad()
    def sgd_update(self, params, step_size, grad_current, weight_decay):
        for p, g in zip(params, grad_current):
            if isinstance(g, float) and g == 0.:
                continue
            if weight_decay!= 0:
                g.add_(p.data, alpha=weight_decay)
            p.data.add_(other=g, alpha=- step_size)
    
class SPS_smooth(torch.optim.Optimizer):
    # from [39]
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 n_batches_per_epoch: int = 256,
                 init_step_size: float = 1.0,
                 c: float = 1.0,
                 gamma: float = 2.0,
                 ):
        
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.gamma = gamma
        self.init_step_size = init_step_size
        self.n_batches_per_epoch = n_batches_per_epoch
        self.state['step_size'] = init_step_size

    def step(self, loss):
        grad_norm = self.compute_grad_terms()

        step_size = loss / (self.c * (grad_norm)**2)
        coeff = self.gamma**(1./self.n_batches_per_epoch)
        step_size = min(coeff * self.state['step_size'], step_size.item())

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                p.data.add_(other=p.grad, alpha=- step_size)

        self.state['step_size'] = step_size

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class MomSPS_smooth(torch.optim.Optimizer):
    def __init__(
            self,
            params: Iterable[nn.parameter.Parameter],
            n_batches_per_epoch: int = 256,
            init_step_size: float = 1.0,
            c: float = 1.0,
            beta: float = 0.9,
            gamma: float = 2.0,
            ):
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "beta": beta, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.beta = beta
        self.gamma = gamma
        self.coeff = gamma**(1./n_batches_per_epoch)

        self.number_steps = 0
        self.state["p"] = 0
        self.state["momsps"] = init_step_size

    @torch.no_grad()
    def step(self, loss=None):
        self.number_steps += 1
        c = self.c
        beta = self.beta
        coeff = self.coeff
        grad_norm = self.compute_grad_terms()

        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]
                
                momsps = (1-beta) * (loss/(c*grad_norm**2))
                momsps = min(momsps.item(), coeff * self.state["momsps"])
                
                # momsps = (loss/(c*grad_norm**2))
                # momsps = (1-beta) * min(momsps.item(), coeff * self.state["momsps"])

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad + beta * (p-old_p)

                with torch.no_grad():
                    p.copy_(new_p)
                
                self.state["momsps"] = momsps
        
        return loss
    
    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class Naive_MomSPS_smooth(torch.optim.Optimizer):
    def __init__(
            self,
            params: Iterable[nn.parameter.Parameter],
            n_batches_per_epoch: int = 256,
            init_step_size: float = 1.0,
            c: float = 1.0,
            beta: float = 0.9,
            gamma: float = 2.0,
            ):
        defaults = {"n_batches_per_epoch": n_batches_per_epoch, "init_step_size": init_step_size, "c": c, "beta": beta, "gamma": gamma}
        super().__init__(params, defaults)

        self.c = c
        self.beta = beta
        self.gamma = gamma
        self.coeff = gamma**(1./n_batches_per_epoch)

        self.number_steps = 0
        self.state["p"] = 0
        self.state["momsps"] = init_step_size

    @torch.no_grad()
    def step(self, loss=None):
        self.number_steps += 1
        c = self.c
        beta = self.beta
        coeff = self.coeff
        grad_norm = self.compute_grad_terms()

        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]
                
                momsps = (loss/(c*grad_norm**2))
                momsps = min(momsps.item(), coeff * self.state["momsps"])
                
                # momsps = (loss/(c*grad_norm**2))
                # momsps = (1-beta) * min(momsps.item(), coeff * self.state["momsps"])

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momsps * grad + beta * (p-old_p)

                with torch.no_grad():
                    p.copy_(new_p)
                
                self.state["momsps"] = momsps
        
        return loss
    
    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

# Decreasing

class DecSPS(torch.optim.Optimizer):
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 c: float = 1.0,
                 gamma_max: float = 10.0,
                 ):
        
        defaults = {"c": c, "gamma_max": gamma_max}
        super().__init__(params, defaults)

        self.c = c
        self.gamma_max = gamma_max
        self.number_steps = 0
        self.ss = 0

    def step(self, loss):
        self.number_steps += 1
        grad_norm = self.compute_grad_terms()
        sps = loss/grad_norm**2

        if self.number_steps == 1:
            step_size = min(sps.item()/self.c, self.gamma_max)
        else:
            step_size = min(sps.item()/self.c, np.sqrt(self.number_steps) * self.ss) / np.sqrt(self.number_steps + 1)

        self.ss = step_size

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                p.data.add_(other=p.grad, alpha=- step_size)

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class AdaSPS(torch.optim.Optimizer):
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 c: float = 1.0,
                 ):
        
        defaults = {"c": c}
        super().__init__(params, defaults)

        self.c = c
        self.number_steps = 0
        self.ss = 0
        self.sum = 0
        

    def step(self, loss):
        self.number_steps += 1
        self.sum += loss.item()
        grad_norm = self.compute_grad_terms()
        adasps = (loss/grad_norm**2)/np.sqrt(self.sum)

        if self.number_steps == 1:
            step_size = adasps.item()/self.c
        else:
            step_size = min(adasps.item()/self.c, self.ss)
        
        self.ss = step_size

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                p.data.add_(other=p.grad, alpha=- step_size)

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class MomDecSPS(torch.optim.Optimizer):
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 c: float = 1.0,
                 gamma_max: float = 10.0,
                 beta: float = 0.9,
                 ):
        
        defaults = {"c": c, "gamma_max": gamma_max, "beta": beta}
        super().__init__(params, defaults)

        self.c = c
        self.gamma_max = gamma_max
        self.beta = beta
        self.number_steps = 0
        self.ss = 0

    def step(self, loss):
        self.number_steps += 1
        grad_norm = self.compute_grad_terms()
        sps = loss/grad_norm**2

        if self.number_steps == 1:
            step_size = min((1 - self.beta) * sps.item()/self.c, self.gamma_max)
        else:
            step_size = min((1 - self.beta) * sps.item()/self.c, np.sqrt(self.number_steps) * self.ss) / np.sqrt(self.number_steps + 1)

        self.ss = step_size

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - step_size * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - step_size * grad + self.beta * (p-old_p)

                with torch.no_grad():
                        p.copy_(new_p)

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm

class MomAdaSPS(torch.optim.Optimizer):
    def __init__(self,
                 params: Iterable[nn.parameter.Parameter],
                 c: float = 1.0,
                 beta: float = 0.9,
                 ):
        
        defaults = {"c": c, "beta": beta}
        super().__init__(params, defaults)

        self.c = c
        self.beta = beta
        self.number_steps = 0
        self.ss = 0
        self.sum = 0

    def step(self, loss):
        self.number_steps += 1
        self.sum += loss.item()
        grad_norm = self.compute_grad_terms()
        adasps = (loss/grad_norm**2)/np.sqrt(self.sum)

        if self.number_steps == 1:
            step_size = (1 - self.beta) * adasps.item()/self.c
        else:
            step_size = min((1 - self.beta) * adasps.item()/self.c, self.ss)

        self.ss = step_size

        # update with step size
        for group in self.param_groups:
            for p in group['params']:
                # new_p = x^{t+1}
                # p = x^t
                # old_p = x^{t-1}
                grad = p.grad.data.detach()
                state = self.state[p]

                if self.number_steps == 1:
                    state["p"] = p.detach().clone()
                    new_p = p - step_size * grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - step_size * grad + self.beta * (p-old_p)

                with torch.no_grad():
                        p.copy_(new_p)

        return loss

    @torch.no_grad()   
    def compute_grad_terms(self):
        grad_norm = 0.
        for group in self.param_groups:
            for p in group['params']:
                g = p.grad.data.detach()
                grad_norm += torch.sum(torch.mul(g, g))
          
        grad_norm = torch.sqrt(grad_norm)
        return grad_norm


## Tests

def opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=False):
    net = nn.Linear(X.shape[1], max(y)+1)
    net.to(device)
    net.train()

    seed_everything(seed)
    # optimizer = opt_func(net.parameters())
    if name == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
    elif name == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    elif name == 'shb':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    elif name == 'alr_smag_warmup':
        optimizer = opt_func(net.parameters(), warmup=True)
    else:
        optimizer = opt_func(net.parameters())

    trainlos, trainacc = train(epochs, net, optimizer, trainloader, criterion, device, need_loss)

    pickle.dump(trainlos, open(path+'m_logreg_set_'+dset_name+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_trainlos_'+name+'.p', "wb"))
    pickle.dump(trainacc, open(path+'m_logreg_set_'+dset_name+'_bs_'+str(batch_size)+'_e_'+str(epochs)+'_seed_'+str(seed)+'_trainacc_'+name+'.p', "wb"))

    print(name+' completed')

## Datasets shapes
# 'aloi': (108000, 128)
# 'dna': (2000, 180)
# 'glass': (214, 9)
# 'iris': (150, 4)
# 'letter': (15000, 16)
# 'pendigits': (7494, 16)
# 'sensorless': (58509, 48)
# 'smallNORB': (24300, 18432)
# 'usps': (7291, 256)
# 'vehicle': (846, 18)
# 'vowel': (528, 10)
# 'wine': (178, 13)

dset_list = ['aloi', 'dna', 'glass', 'iris', 'letter', 'pendigits', 'sensorless', 'smallNORB', 'usps', 'vehicle', 'vowel', 'wine']

bs_dict = {'aloi': [16], 
           'dna': [10, 16, 20, 128, 200], 
           'glass': [10, 16, 20, 25, 32], 
           'iris': [10, 16, 20], 
           'letter': [10, 16, 20, 128, 256, 1500], 
           'pendigits': [10, 16, 20, 128, 256, 750], 
           'sensorless': [16], 
           'smallNORB': [16], 
           'usps': [10, 16, 20, 128, 256, 730], 
           'vehicle': [10, 16, 20, 85], 
           'vowel': [10, 16, 20, 52], 
           'wine': [10, 16, 20]}
# 10, 16, 20, 128, 256, 10%, except the huge ones

opt_list = ['sgd', 'shb', 'adam', 
            'sps_smooth', 'momsps_smooth', 'naive_momsps_smooth', 'alr_smag_warmup', 
            'decsps', 'momdecsps', 'adasps', 'momadasps', 'adagrad']

seed_list = [42, 123, 1029, 1234, 2048]
criterion = torch.nn.CrossEntropyLoss()
epochs = 100
path = 'saved_data/convex_results/'


for dset_name in dset_list:
    data = load_svmlight_file('data/libsvm/'+dset_name+'.scale')
    X = data[0].toarray()
    y = data[1].astype(int)
    y = y - min(y) # classes start from 0
    print(dset_name)

    for batch_size in bs_dict[dset_name]:
        dset = torch.utils.data.TensorDataset(torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long))
        trainloader = torch.utils.data.DataLoader(dset, batch_size=batch_size)

        for seed in seed_list:
            seed_everything(seed)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if device == 'cpu':
                print("Running on CPU - BREAK")
                break

            # SGD
            name = 'sgd'
            if name in opt_list:
                opt_func = torch.optim.SGD
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=False)

            # SHB
            name = 'shb'
            if name in opt_list:
                opt_func = torch.optim.SGD
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=False)

            # ADAM
            name = 'adam'
            if name in opt_list:
                opt_func = torch.optim.Adam
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=False)

            # SPS smooth
            name = 'sps_smooth'
            if name in opt_list:
                opt_func = SPS_smooth
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)

            # MomSPS smooth
            name = 'momsps_smooth'
            if name in opt_list:
                opt_func = MomSPS_smooth
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)

            # Naive MomSPS smooth
            name = 'naive_momsps_smooth'
            if name in opt_list:
                opt_func = Naive_MomSPS_smooth
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)

            # ALR SMAG
            name = 'alr_smag_warmup'
            if name in opt_list:
                opt_func = alr_smag
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)
            
            # DecSPS
            name = 'decsps'
            if name in opt_list:
                opt_func = DecSPS
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)
            
            # MomDecSPS
            name = 'momdecsps'
            if name in opt_list:
                opt_func = MomDecSPS
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)
            
            # AdaSPS
            name = 'adasps'
            if name in opt_list:
                opt_func = AdaSPS
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)
            
            # MomAdaSPS
            name = 'momadasps'
            if name in opt_list:
                opt_func = MomAdaSPS
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=True)

            # AdaGrad
            name = 'adagrad'
            if name in opt_list:
                opt_func = torch.optim.Adagrad
                opt_test(name, device, opt_func, seed, trainloader, criterion, need_loss=False)
            

            print('Dataset: '+dset_name+', BS: '+str(batch_size)+', Epochs: '+str(epochs)+', Seed: '+str(seed)+', Finished.')

print("Script FINISHED!!!")
