import time
import csv
import os
import argparse
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

from backend import Model
from datasets import Dataset
from utils import *

torch.manual_seed(0)

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmbda', default=1e-3, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=300, type=int)
    parser.add_argument('--backend', required=True, type=str) 
    parser.add_argument('--dataset_name', required=True, type=str) # cifar10 | mnist
    parser.add_argument('--num_classes', default=10, type=int)

    parser.add_argument('--batch_size', default=128, type=int)
    return parser.parse_args()

class SProx(Optimizer):
    def __init__(self, params, alpha=required, lmbda = required):
        if alpha is not required and alpha < 0.0:
            raise ValueError("Invalid learning rate: {}".format(alpha))

        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))

        defaults = dict(alpha=alpha, lmbda=lmbda)
        super(SProx, self).__init__(params, defaults)


    def prox_mapping_group(self, x, grad_f, lmbda, alpha):
        '''
            Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - alpha * grad_f
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = alpha * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta

    
    def __setstate__(self, state):
        super(SProx, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                
                if is_conv_weights(p.shape): # weights
                    delta = self.prox_mapping_group(p.data, grad_f, group['lmbda'], group['alpha'])
                    p.data.add_(1, delta)
                else: # bias
                    p.data.add_(-group['alpha'], grad_f) 
                    
        return loss
    
    def adjust_learning_rate(self, epoch, decays):
        if epoch in decays:
            for group in self.param_groups:
                group['alpha'] = group['alpha'] / 10
        print('lr:', self.param_groups[0]['alpha'])


if __name__ == "__main__":

    args = ParseArgs()
    lmbda = args.lmbda
    max_epoch = args.max_epoch
    backend = args.backend
    dataset_name = args.dataset_name
    alpha = 0.1
    batch_size = args.batch_size
    num_classes = args.num_classes

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if dataset_name == 'fashion_mnist' and backend == 'vgg16':
        decays = [25, 100, 150]
    else:   
        decays = [75, 130, 150]


    trainloader, testloader = Dataset(dataset_name, batch_size=batch_size)
    model = Model(backend=backend, device=device, num_classes=num_classes, finetune=True)

    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size

    n = num_features
    m = num_samples

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = SProx(model.parameters(), alpha=alpha, lmbda=lmbda)

    os.makedirs('results', exist_ok=True)
    setting = 'proxsg_%s_%s_%E_maxepoch%d'%(backend, dataset_name, lmbda, max_epoch)
    csvname = 'results/' + setting + '.csv'
    print('The csv file is %s'%csvname)
        
    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc1', 'validation_acc5', 'train_time', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    alg_start_time = time.time()

    epoch = 0
    while True:
        optimizer.adjust_learning_rate(epoch, decays)
        epoch_start_time = time.time()

        if epoch >= max_epoch:
            break

        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)
            y_pred = model.forward(X)

            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            optimizer.step()

        epoch += 1
        
        train_time = time.time() - epoch_start_time
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        sparsity, sparsity_tol, sparsity_group = compute_sparsity(weights)
        accuracy1, accuracy5 = check_accuracy(model, testloader)

        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, 'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group, 'validation_acc1': accuracy1, 'validation_acc5': accuracy5, 'train_time': train_time, 'remarks': '%s;%s;%E;%f'%(backend, dataset_name, lmbda, optimizer.param_groups[0]['alpha'])})



        csvfile.flush()
        print("epoch {}: {:2f}seconds ...".format(epoch, train_time))

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})

    os.makedirs('models', exist_ok=True) 
    torch.save(model, 'models/' + setting+'.pt')
    csvfile.close()

