import os
import logging
import random
import numpy as np
import torch
import json
import copy

from pathlib import Path
from datetime import datetime
from itertools import repeat
from collections import OrderedDict
#from models.prec_conv import Preconditioned_Conv2d

def CDNV(all_class_dict):
    class_num = len(all_class_dict.keys())
    var_list = []
    mean_list = []
    for cla in all_class_dict:
        this_class_feature = np.vstack(all_class_dict[cla])
        mu_Q = np.mean(this_class_feature, axis = 0)
        class_var_all = np.linalg.norm(this_class_feature - mu_Q[None,:] ** 2, axis = 1)
        class_var = np.mean(class_var_all)
        mean_list.append(mu_Q)
        var_list.append(class_var)
    all_cdnv = []
    for i in range(len(var_list)):
        mean_Q1 = mean_list[i]
        var_Q1 = var_list[i]
        for j in range(i+1, len(var_list)):
            mean_Q2 = mean_list[j]
            var_Q2 = var_list[j]
            cdnv = (var_Q1 + var_Q2) / (2 * np.linalg.norm(mean_Q1 - mean_Q2) ** 2)
            all_cdnv.append(cdnv)
    mean_cdnv = np.mean(all_cdnv)
    
    return mean_cdnv

# Below two functions modified from https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    elif alpha < 0:
        div = 10 ** np.ceil(np.log10(-alpha))
        lam = -alpha / div
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# def mixup_data(x, y, alpha=1.0):
#     '''Returns mixed inputs, pairs of targets, and lambda'''
#     if alpha > 0:
#         lam = np.random.beta(alpha, alpha)
#     elif alpha < 0:
#         div = 10 ** np.ceil(np.log10(-alpha))
#         lam = -alpha / div
#     else:
#         lam = 1

#     batch_size = x.size()[0]
    
#     mixed_x = []
#     y_a_list = []
#     y_b_list = []
#     for i in range(batch_size):
#         index = torch.randperm(batch_size).to(x.device)
#         current_class = y[i]
#         sample = x[i]
#         ite = 0
#         while y[index[ite]] == current_class:
#             ite += 1
#         mixed_sample = lam * sample + (1 - lam) * x[index[ite]]
#         y_a, y_b = y[i], y[index[ite]]
#         mixed_x.append(mixed_sample)
#         y_a_list.append(y_a)
#         y_b_list.append(y_b)
#     mixed_x = torch.stack(mixed_x, 0)
#     y_a = torch.stack(y_a_list, 0)
#     y_b = torch.stack(y_b_list, 0)
#     return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


# Below 2 methods from https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks/blob/14de5261e24544cef78b94c56684d2f1520c1e41/imagenet/utils.py#L34
def deconv_orth_dist(kernel, stride = 2, padding = 1):
    [o_c, i_c, w, h] = kernel.shape
    output = torch.conv2d(kernel, kernel, stride=stride, padding=padding)
    target = torch.zeros((o_c, o_c, output.shape[-2], output.shape[-1])).to(kernel.device)
    ct = int(np.floor(output.shape[-1]/2))
    target[:,:,ct,ct] = torch.eye(o_c).to(kernel.device)
    return torch.norm( output - target )
    
def orth_dist(mat, stride=None):
    mat = mat.reshape( (mat.shape[0], -1) )
    if mat.shape[0] < mat.shape[1]:
        mat = mat.permute(1,0)
    return torch.norm( torch.t(mat)@mat - torch.eye(mat.shape[1]).to(mat.device))
                      
def load_from_state_dict(current_model, checkpoint_state):
    """
    Since running_V in preconditioning layers don't have correct init size,
    we manually load them here
    """
    state = copy.deepcopy(checkpoint_state)
    all_precs = []
    all_kernel_v = []
    for key in checkpoint_state.keys():
        if "running_V" in key:
            all_kernel_v.append(checkpoint_state[key])
            state[key] = torch.zeros(1)
    # Load all other parameters
    current_model.load_state_dict(state) #, strict=False
    # Load running_V 's
    itera = 0
    for mod in current_model.modules():
        if isinstance(mod, Preconditioned_Conv2d): #ConvNorm_2d
            mod.running_V = all_kernel_v[itera]
            itera += 1
            
def load_from_state_dict_without_fc(current_model, checkpoint_state):
    """
    Ignore Linear layer, load others from state dict
    """
    num_classes = current_model.linear.weight.shape[0]
    state = copy.deepcopy(checkpoint_state)
    for key in checkpoint_state.keys():
        if "linear.weight" in key:
            state[key] = torch.randn(num_classes, 512)
        if "linear.bias" in key:
            state[key] = torch.zeros(num_classes)
    # Load all other parameters
    current_model.load_state_dict(state) #, strict=False

def get_logger(name, verbosity=2):
    log_levels = {
        0: logging.WARNING,
        1: logging.INFO,
        2: logging.DEBUG
    }
    msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
                                                                                   log_levels.keys())
    assert verbosity in log_levels, msg_verbosity
    logger = logging.getLogger(name)
    logger.setLevel(log_levels[verbosity])
    return logger

class Timer:
    def __init__(self):
        self.cache = datetime.now()

    def check(self):
        now = datetime.now()
        duration = now - self.cache
        self.cache = now
        return duration.total_seconds()

    def reset(self):
        self.cache = datetime.now()
        
def path_formatter(args):
    # Automatically create a path name for storing checkpoints
    
    args_dict = vars(args)
    check_path = []
    needed_args = ['dataset', 'model', 'lr', 'optimizer', 'batch_size',
                   'lr_scheduler', 'weight_decay', 'norm_method', 'deconv',
                   'seed'] 
    
    for setting in needed_args:
        value = args_dict[setting]
        if value == '':
            value = 'default_scheduler'
        if setting == 'deconv':
            value = 'deconv ' + str(value)
        check_path.append('{}'.format(value))

    timestamp = datetime.datetime.now().strftime("%m-%d-%H.%M")
    check_path.append(timestamp)
    save_path = ','.join(check_path)
    return os.path.join(args.checkpoint_path,save_path).replace("\\","/")

def inf_loop(data_loader):
    ''' wrapper function for endless data loader. '''
    for loader in repeat(data_loader):
        yield from loader

def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)

def ensure_dir(dirname):
    dirname = Path(dirname)
    if not dirname.is_dir():
        dirname.mkdir(parents=True, exist_ok=False)

def read_json(fname):
    with fname.open('rt') as handle:
        return json.load(handle, object_hook=OrderedDict)

def write_json(content, fname):
    with fname.open('wt') as handle:
        json.dump(content, handle, indent=4, sort_keys=False)

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from  2"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def sigmoid_rampdown(current, rampdown_length):
    """Exponential rampdown"""
    if rampdown_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampdown_length)
        phase = 1.0 - (rampdown_length - current) / rampdown_length
        return float(np.exp(-12.5 * phase * phase))

def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length

def linear_rampdown(current, rampdown_length):
    """Linear rampup"""
    assert current >= 0 and rampdown_length >= 0
    if current >= rampdown_length:
        return 1.0
    else:
        return 1.0 - current / rampdown_length


def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    current = np.clip(current, 0.0, rampdown_length)
    return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))


def cosine_rampup(current, rampup_length):
    """Cosine rampup"""
    current = np.clip(current, 0.0, rampup_length)
    return float(-.5 * (np.cos(np.pi * current / rampup_length) - 1))


