import time
import os
import csv
import torch
import numpy as np
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import argparse

from backend import Model
from datasets import Dataset
from utils import *

class HSPG(Optimizer):

    def __init__(self, params, alpha=required, lmbda=required, epsilon=required):
        if alpha is not required and alpha < 0.0:
            raise ValueError("Invalid learning rate: {}".format(alpha))

        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))

        if epsilon is not required and epsilon < 0.0:
            raise ValueError("Invalid lambda: {}".format(epsilon))
        
        defaults = dict(alpha=alpha, lmbda=lmbda, epsilon=epsilon)
        super(HSPG, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(HSPG, self).__setstate__(state)

    def prox_mapping_group(self, x, grad_f, lmbda, alpha):
        '''
        Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - alpha * grad_f
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = alpha * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta

    def proxsg_step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                if is_conv_weights(p.shape):
                    delta = self.prox_mapping_group(p.data, grad_f, group['lmbda'], group['alpha'])
                    p.data.add_(1.0, delta)
                else:
                    p.data.add_(-group['alpha'], grad_f)
        return loss


    def half_space_step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                
                if is_conv_weights(p.shape):
                    num_kernels, channels, height, width = p.shape

                    hat_x = self.gradient_step(p.data, grad_f, group['alpha'], group['lmbda'])
                    hat_x[state['zero_idx'], :, :, :] = 0.0

                    # do half space projection
                    proj_x = hat_x
                    idx = (torch.bmm(proj_x.view(proj_x.shape[0], 1, -1), p.data.view(p.data.shape[0], -1, 1)).squeeze() \
                        < group['epsilon'] * torch.norm(p.data.view(p.data.shape[0], -1), dim=1))
                    proj_x[idx, ...] = 0.0
                    
                    # fixed non_free variables
                    proj_x.data[state['zero_idx'], ...] = 0.0
                    
                    p.data.copy_(proj_x)
                    
                else:
                    p.data.add_(-group['alpha'], grad_f)
       
        return loss

    def gradient_step(self, x, grad_f, alpha, lmbda):
        norms = torch.norm(x.view(x.shape[0], -1), p=2, dim=1)
        return x - alpha * grad_f - alpha * lmbda * x / norms.unsqueeze(1).unsqueeze(1).unsqueeze(1)

    def adjust_learning_rate(self, epoch, decays):
        if epoch in decays:
            for group in self.param_groups:
                group['alpha'] = group['alpha'] / 10
        print('lr:', self.param_groups[0]['alpha'])

    def init_polyhedron(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]                
                if is_conv_weights(p.shape):
                    num_kernels, channels, height, width = p.shape                    
                    if 'zero_idx' not in state.keys():
                        state['zero_idx'] = torch.zeros(num_kernels)   
                    state['zero_idx'] = ( torch.norm(p.data.view(num_kernels, -1), p=2, dim=1) == 0.0 )

 
def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--lmbda', default=1e-3, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=300, type=int)
    parser.add_argument('--backend', required=True, type=str)
    parser.add_argument('--dataset_name', required=True, type=str)
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--epsilon', default=0.1, type=float, help="Halfspace Threshold") 
    parser.add_argument('--n_p', default=150, type=int, help="N_P") 
    parser.add_argument('--lr', default=0.1, type=float, help="learning rate") 

    return parser.parse_args()

if __name__ == "__main__":
    
    args = ParseArgs()
    lmbda = args.lmbda
    max_epoch = args.max_epoch
    backend = args.backend
    dataset_name = args.dataset_name
    num_classes = args.num_classes
    N_P = args.n_p
    alpha = args.lr
    batch_size = args.batch_size

    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if dataset_name == 'fashion_mnist' and backend == 'vgg16':
        decays = [25, 100, 150]
    else:   
        decays = [75, 130, 150]

    trainloader, testloader = Dataset(dataset_name, batch_size=batch_size)
    model = Model(backend=backend, device=device, num_classes=num_classes, finetune=True)

    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size
    
    n = num_features
    m = num_samples

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = HSPG(model.parameters(), alpha=alpha, lmbda=lmbda, epsilon=args.epsilon)

    os.makedirs('results', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    setting = 'hspg_np_%d_%s_%s_lmbda%.1E_eps%.2f_lr%.1E_maxepoch%d' % (N_P, backend, dataset_name, lmbda, args.epsilon, alpha, max_epoch)

    csvname = 'results/' + setting + '.csv'
    print('The csv file is %s'%csvname)

    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc1', 'validation_acc5', 'train_time', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    alg_start_time = time.time()

    epoch = 0

    while True:
        optimizer.adjust_learning_rate(epoch, decays)

        epoch_start_time = time.time()
        print("epoch {}".format(epoch), end = '...')
        
        if epoch >= max_epoch:
            break

        if epoch < N_P:
            for index, (X, y) in enumerate(trainloader):
                X = X.to(device)
                y = y.to(device)
                y_pred = model.forward(X)

                f = criterion(y_pred, y)
                optimizer.zero_grad()
                f.backward()
                optimizer.proxsg_step()
        else:
            for index, (X, y) in enumerate(trainloader):
                optimizer.init_polyhedron()
                X = X.to(device)
                y = y.to(device)

                y_pred = model.forward(X)
                f = criterion(y_pred, y)
                optimizer.zero_grad()
                f.backward()
                optimizer.half_space_step()

        epoch += 1

        train_time = time.time() - epoch_start_time
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        sparsity, sparsity_tol, sparsity_group = compute_sparsity(weights)
        accuracy1, accuracy5 = check_accuracy(model, testloader)
        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, \
            'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group,\
            'validation_acc1': accuracy1, 'validation_acc5':accuracy5, 'train_time': train_time, 'remarks': '%s;%s;%E;%s;%f'%(backend, dataset_name, lmbda, ('proxsg_step' if (epoch - 1) < N_P else 'halfspace_step'), optimizer.param_groups[0]['alpha'])})
        csvfile.flush()
        print("Epoch time: {:2f}seconds".format(train_time), end='...')

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})
    torch.save(model, 'models/' + setting+'.pt')
    csvfile.close()

