import time
import csv
import math
import os
import argparse
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

from backend import Model
from datasets import Dataset
from utils import *

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmbda', default=1e-3, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=300, type=int)
    parser.add_argument('--backend', required=True, type=str)
    parser.add_argument('--dataset_name', required=True, type=str)

    parser.add_argument('--num_classes', default=10, type=int)

    parser.add_argument('--batch_size', default=128, type=int)
    return parser.parse_args()

class ProxSVRG(Optimizer):
    def __init__(self, params, alpha=required, lmbda=required):
        if alpha is not required and alpha < 0.0:
            raise ValueError("Invalid alpha: {}".format(alpha))
        
        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))
            
            
        defaults = dict(alpha=alpha, lmbda=lmbda)
        super(ProxSVRG, self).__init__(params, defaults)

    
    def __setstate__(self, state):
        super(ProxSVRG, self).__setstate__(state)

        
    def prox_mapping_l1(self, x, grad_f, lmbda, alpha):
        trial_x  = torch.zeros_like(x)
        pos_shrink = x - alpha * grad_f - alpha * lmbda
        neg_shrink = x - alpha * grad_f + alpha * lmbda
        pos_shrink_idx = (pos_shrink > 0)
        neg_shrink_idx = (neg_shrink < 0)
        trial_x[pos_shrink_idx] = pos_shrink[pos_shrink_idx]
        trial_x[neg_shrink_idx] = neg_shrink[neg_shrink_idx]
        d = trial_x - x

        return d

    def prox_mapping_group(self, x, grad_f, lmbda, alpha):
        '''
            Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - alpha * grad_f
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = alpha * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta




    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad_f = p.grad.data
                state = self.state[p]

                p.data.copy_( state['xs_sum'] / state['i'] )
        return loss
    

    def save_weight_and_grad_init_xs(self, num_batches):
        '''
            Revoked at the begining of the epoch
        '''
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['i'] = 0
                if 'hat_v' not in state.keys():
                    state['hat_v'] = torch.zeros_like(p.grad.data)
                state['hat_v'].copy_(p.grad.data)
                state['hat_v'].div_(num_batches)

                if 'hat_x' not in state.keys():
                    state['hat_x'] = torch.zeros_like(p.data)
                state['hat_x'].copy_(p.data)

                if 'xs_end' not in state.keys():
                    state['xs_end'] = torch.zeros_like(state['hat_x'])
                state['xs_end'].copy_(state['hat_x'])
                if 'xs_sum' not in state.keys():
                    state['xs_sum'] = torch.zeros_like(state['hat_x'])
                state['xs_sum'].zero_()

    def save_grad_f(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'grad_f' not in state.keys():
                    state['grad_f'] = torch.zeros_like(p.grad.data)
                state['grad_f'].copy_(p.grad.data)

    def set_weights_from_xs(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['xs_end'])

    def set_weights_from_hat_x(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['hat_x'])

    def save_grad_f_hat(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'grad_f_hat' not in state.keys():
                    state['grad_f_hat'] = torch.zeros_like(p.grad.data)
                state['grad_f_hat'].copy_(p.grad.data)

    def update_xs(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['i'] += 1
                v = state['grad_f'] - state['grad_f_hat'] + state['hat_v']
                if is_conv_weights(p.shape): # weights
                    delta = self.prox_mapping_group(state['xs_end'], v, group['lmbda'], group['alpha'])
                    state['xs_end'].add_(delta)
                    state['xs_sum'].add_(state['xs_end'])
                else:
                    state['xs_end'].add_(-group['alpha'], v)
                    state['xs_sum'].add_(state['xs_end'])
                        
                
    def adjust_learning_rate(self, epoch, decays):
        if epoch in decays:
            for group in self.param_groups:
                group['alpha'] = group['alpha'] / 10
        print('lr:', self.param_groups[0]['alpha'])

if __name__ == "__main__":
    args = ParseArgs()
    lmbda = args.lmbda
    max_epoch = args.max_epoch
    backend = args.backend
    dataset_name = args.dataset_name
    alpha = 0.1
    batch_size = args.batch_size
    num_classes = args.num_classes

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if dataset_name == 'fashion_mnist' and backend == 'vgg16':
        decays = [25, 100, 150]
    else:   
        decays = [75, 130, 150]

    trainloader, testloader = Dataset(dataset_name, batch_size=batch_size)
    model = Model(backend=backend, device=device, num_classes=num_classes, finetune=True)


    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size


    n = num_features
    m = num_samples

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = ProxSVRG(model.parameters(), alpha=alpha, lmbda=lmbda)

    os.makedirs('results', exist_ok=True)
    setting = 'proxsvrg_%s_%s_%E_maxepoch%d'%(backend, dataset_name, lmbda, max_epoch)
    csvname = 'results/' + setting + '.csv'
    print('The csv file is %s'%csvname)
        
    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc1', 'validation_acc5', 'train_time', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    alg_start_time = time.time()
    epoch = 0
    while True:
        optimizer.adjust_learning_rate(epoch, decays)

        epoch_start_time = time.time()
        

        optimizer.zero_grad()
        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)
            y_pred = model.forward(X)
            f1 = criterion(y_pred, y)
            f1.backward() 
        optimizer.save_weight_and_grad_init_xs(len(trainloader))
        optimizer.zero_grad()

        
        if epoch >= max_epoch:
            break


        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)

            # calculate grad_f_i
            optimizer.set_weights_from_xs()
            y_pred = model.forward(X)
            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            optimizer.save_grad_f()
            
            # calculate grad_f_hat_i
            optimizer.set_weights_from_hat_x()
            y_pred = model.forward(X)
            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            optimizer.save_grad_f_hat()
            
            optimizer.update_xs()


        epoch += 1
        optimizer.step()
        
        train_time = time.time() - epoch_start_time
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        sparsity, sparsity_tol, sparsity_group = compute_sparsity(weights)
        accuracy1, accuracy5 = check_accuracy(model, testloader)

        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, \
            'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group,\
            'validation_acc1': accuracy1, 'validation_acc5': accuracy5, 'train_time': train_time, 'remarks': '%s;%s;%E;%f'%(backend, dataset_name, lmbda,optimizer.param_groups[0]['alpha'])})


        csvfile.flush()
        print("epoch {}: {:2f}seconds ...".format(epoch, train_time))
        

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})

    os.makedirs('models', exist_ok=True) 
    torch.save(model, 'models/' + setting+'.pt')
    csvfile.close()

