import json
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm

from Utils.logger import *
from Utils.metrics import hessian_stats

def train(model, loss, optimizer, dataloader, device, epoch, verbose, log_interval=10, args=None):
    model.train()
    total = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        train_loss = loss(output, target)
        total += train_loss.item() * data.size(0)
        train_loss.backward()
        optimizer.step()

        if args is not None and args.wb is not None:
            with torch.no_grad():
                
                mask_dict = dict(model.named_buffers())
                grad_norm = 0.0

                for name, param in model.named_parameters():
                    if 'bn' in name or 'bias' in name:
                        continue
                    mask = mask_dict[name + '_mask']
                    if param.grad is not None:
                        grad_norm += (param.grad.data * mask).norm(2).item() ** 2
                grad_norm = grad_norm ** 0.5

                try:
                    delta = optimizer.get_delta()
                except:
                    delta = 0
                try:
                    args.wb.log({
                        # "P_norm": p_norm,
                        "running_delta": delta,
                        "grad_norm": grad_norm,
                        "train_loss": train_loss.item(),
                        "Steps": args.training_step,
                        }
                    )
                except:
                    pass
        args.training_step+=1


        if verbose & (batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader), train_loss.item()))
            # break
    return total / len(dataloader.dataset)


def train_osgm(model, loss, optimizer, dataloader, device, epoch, verbose, log_interval=10, args=None):
    model.train()
    total = 0
    lst_norm = []
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        def closure():
            optimizer.zero_grad()
            output = model(data)
            train_loss = loss(output, target)
            train_loss.backward()
            return train_loss
        
        train_loss = optimizer.step(closure)

        if args is not None and args.wb is not None:
            with torch.no_grad():
                p_norm = 0.0
                for p_group in optimizer.param_groups:
                    for p in p_group['P']:
                        p_norm += p.data.norm(2).item() ** 2
                p_norm = p_norm ** 0.5

                grad_norm = 0.0
                for param in model.parameters():
                    if param.grad is not None:
                        grad_norm += param.grad.data.norm(2).item() ** 2
                grad_norm = grad_norm ** 0.5
                try:
                    args.wb.log({
                        "P_norm": p_norm,
                        "grad_norm": grad_norm,
                        "Steps": args.training_step,
                        }
                    )
                except:
                    pass
        args.training_step+=1


        if verbose & (batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader), train_loss.item()))
            # break

        
    return total / len(dataloader.dataset)



def eval(model, loss, dataloader, device, verbose):
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:,:1].sum().item()
            correct5 += correct[:,:5].sum().item()
    average_loss = total / len(dataloader.dataset)
    accuracy1 = 100. * correct1 / len(dataloader.dataset)
    accuracy5 = 100. * correct5 / len(dataloader.dataset)
    if verbose:
        print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
            average_loss, correct1, len(dataloader.dataset), accuracy1))
    return average_loss, accuracy1, accuracy5


import copy
def modify_subnet(model, args, flipping_type='rand', flipping_ratio=0.05):
    print_and_log(args.logger, '='*60)
    print_and_log(args.logger, 'FLIPPING SIGN OF PARAMETERS')
    flipping_scores = {}
    mask_dict = dict(model.named_buffers())

    with torch.no_grad():
        for name, param in model.named_parameters():
                if 'bn' in name or 'bias' in name:
                    continue

                # a[name] = copy.deepcopy(param)
                    
                mask = mask_dict[name + '_mask']

                # IF type == 'rand'
                if args.flipping_type == 'rand':
                    flipping_scores[name] = mask * (torch.rand_like(mask) + 10000) # ensuring masked params are not considered
                elif args.flipping_type == 'inv_mag':
                    flipping_scores[name] = mask * (torch.abs(param) - torch.max(torch.abs(param))) * -1
                elif args.flipping_type == 'mag':
                    flipping_scores[name] = mask * (torch.abs(param)) 

        global_flipping_scores = torch.cat([torch.flatten(v) for v in flipping_scores.values()])
        # k = int((1.0 - flipping_ratio) * global_flipping_scores.nonzero().size(0))
        n_flipped = int(flipping_ratio * global_flipping_scores.nonzero().size(0))
        k = global_flipping_scores.numel() - n_flipped
        if not k < 1:
            threshold, _ = torch.kthvalue(global_flipping_scores, k)
            for name, param in model.named_parameters():
                if 'bn' in name or 'bias' in name:
                    continue
                score = flipping_scores[name] 
                flip = torch.tensor([-1.]).to(param.device)
                one = torch.tensor([1.]).to(param.device)
                flipping_mask = torch.where(score <= threshold, one, flip).to(param.device)
                # b[name] = torch.nonzero(flipping_mask-1).size(0)
                param.data = param.data * flipping_mask

        print_and_log(args.logger, f'Sign Flips of {n_flipped} parameters in total of {global_flipping_scores.numel()}')
        print_and_log(args.logger, '='*60)

    return model




def train_eval_loop(model, loss, optimizer, scheduler, train_loader, test_loader, device, epochs, verbose, args):
    test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device, verbose)
    rows = [[np.nan, test_loss, accuracy1, accuracy5]]
    args.training_step = 0
    for epoch in tqdm(range(epochs)):
        # model = modify_subnet(model, flipping_ratio=0.05)
        if args.optimizer == 'osgmrx':
            train_loss = train_osgm(model, loss, optimizer, train_loader, device, epoch, verbose, args=args)
        else:
            train_loss = train(model, loss, optimizer, train_loader, device, epoch, verbose, args=args)

        test_loss, accuracy1, accuracy5 = eval(model, loss, test_loader, device, verbose)
        row = [train_loss, test_loss, accuracy1, accuracy5]
        scheduler.step()
        rows.append(row)
        print_and_log(args.logger, "==> Epoch {} \t | \t Loss: {:.2f} \t | \t Accuracy: {:.2f}".format(epoch, test_loss, accuracy1))
        if args.wb is not None:
            args.wb.log({"Accuracy": accuracy1, "Loss": test_loss, "Steps": epoch})
        # tqdm.set_postfix(tqdm,
        #                 loss="{:.3f}".format(test_loss),
        #                 accuracy="{:.2f}".format(accuracy1),
        #                 )
        # Flipping
        if args.is_stat_eigenvalues:
            try:
                hessian_stats(train_loader, model, loss, optimizer, epoch, args)
            except Exception as ex:
                print_and_log(args.logger, f'Fail when tracing hessian with ex: {ex}')

        if epoch % 10 == 0:
            try:
                for param_group in optimizer.param_groups:
                    param_group['delta'] *= 0.9
            except:
                pass

        if args is not None and args.is_flipping:
            try:
                if epoch % args.flipping_freq == 0 and epochs - epoch >= 50:
                    model = modify_subnet(model, args,
                                        flipping_type=args.flipping_type, 
                                        flipping_ratio=args.flipping_ratio
                                        )
            except Exception as ex:
                print_and_log(args.logger, f'Fail when flipping with ex: {ex}')

    columns = ['train_loss', 'test_loss', 'top1_accuracy', 'top5_accuracy']
    
    return pd.DataFrame(rows, columns=columns)


