import os
from datetime import datetime
import torch
from torch.autograd import grad
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict

#==============================================
# utility functions: evaluate loss and accuracy
#==============================================


def eval_loss_and_acc_on_batch(net,x,y,loss_fn, require_acc = True, mode = 'train'):

    assert mode in ['train','eval']

    if mode is 'train':
        net.train()
        pred = net(x)
        loss = loss_fn(pred,y)
        if require_acc:
            _,predicted =  torch.max(pred.data, 1)
    else:
        with torch.no_grad():
            net.eval()
            pred = net(x)
            loss = loss_fn(pred,y)
            if require_acc:
                _,predicted =  torch.max(pred.data, 1)


    if require_acc:
        acc = (predicted == y).sum().item() / y.size()[0]


    return loss, acc if require_acc else None

def eval_loss_and_acc_on_valid_set(net,test_dataloader,loss_fn,device = 'cpu'):
    list_loss_value = []
    list_acc_value = []

    for x,y in test_dataloader:
        x,y = x.to(device), y.to(device)
        loss,acc = eval_loss_and_acc_on_batch(net,x,y,loss_fn = loss_fn, require_acc = True, mode = 'eval')
        list_loss_value.append(loss.item())
        list_acc_value.append(acc)

    return np.mean(list_loss_value), np.mean(list_acc_value)

#=================================
# utilites: evaluate sharpness
#=================================
def perturb_model_weights_with_gaussian_noise(net,state_dict, noise_std, device):
    net.load_state_dict(state_dict)
    for para in net.parameters():
        para.data += torch.randn(para.size()).to(device)*noise_std

def loop_evaluate_expected_sharpness(net,state_dict, noise_std, repeat, dataloader, loss_fn, device):
    # under OG weights
    lossT0, accT0 = eval_loss_and_acc_on_valid_set(net,dataloader,loss_fn,device = device)
    print((lossT0,accT0))

    dLoss, dAcc = 0,0
    for idx_repeat in range(repeat):
        if idx_repeat % 50 == 0:
            print(idx_repeat)
        # perturb model weights
        perturb_model_weights_with_gaussian_noise(net,state_dict, noise_std, device)
        # eval loss and acc
        lossT, accT = eval_loss_and_acc_on_valid_set(net,dataloader,loss_fn,device = device)
        #print((lossT,accT))
        # calculate difference, update dLoss and dAcc
        dLoss += abs(lossT - lossT0)/repeat
        dAcc += abs(accT - accT0)/repeat

    return dLoss, dAcc


#======================
# generic utils
#======================

def get_model_weights(model):
    out = []
    for p in model.parameters():
        out.append(p.data.view(-1))
    return torch.cat(out)

def clip_model_weights(model,model_weight_clip, eps = 0.1):
    current_norm = get_model_weights(model).norm()
    if current_norm > model_weight_clip + eps:
        for p in model.parameters():
            p.data *= (model_weight_clip/current_norm)


def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

#=================================
# utilities for heavy tailed noise
#=================================

def get_grads(model):
    out = []
    for p in model.parameters():
        if p.requires_grad and p.grad is not None:
            out.append( p.grad.view(-1))
    return torch.cat(out)

def get_grads_dict(model):
    out = OrderedDict()
    for name,p in model.named_parameters():
        out[name] = p.grad.detach().clone()
    return out

def get_dict_differnce(dictA,dictB):
    out = OrderedDict()
    for key in dictA.keys():
        out[key] = dictA[key] - dictB[key]
    return out

def modify_model_noise(model, gradient_dict, noise_dict, noise_multiplier):
    for key,p in model.named_parameters():
        p.grad = gradient_dict[key] + noise_multiplier * noise_dict[key]
