import glob
import os
import random
import shutil

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable

from dataset import get_label_name, get_num_class

sns.set()


class DotDict(dict):
    """
    Example:
    m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
    """
    def __init__(self, *args, **kwargs):
        super(DotDict, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v

        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(DotDict, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(DotDict, self).__delitem__(key)
        del self.__dict__[key]


def save_checkpoint(state, is_best, save):
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)


def save_ckpt(model, optimizer, scheduler, epoch, model_path):
    torch.save({'model':model.state_dict(), 
                'epoch': epoch,
                'optimizer': optimizer.state_dict(), 
                'scheduler': scheduler.state_dict()}, model_path)


def restore_ckpt(model, optimizer, scheduler, model_path, location):
    state = torch.load(model_path, map_location=f'cuda:{location}')
    model.load_state_dict(state['model'], strict=True)
    optimizer.load_state_dict(state['optimizer'])
    scheduler.load_state_dict(state['scheduler'])
    epoch = state['epoch']
    return epoch


def save(model, model_path):
    torch.save(model.state_dict(), model_path)


def load(model, model_path, location):
    model.load_state_dict(torch.load(model_path, map_location=f'cuda:{location}'), strict=True)


def parse_genotype(path):
    ops = []
    ms = []
    ws = []
    with open(path, "r") as f:
        for line in f:
            op, m, w = line.split(',')
            ops.append(op)
            ms.append(float(m))
            ws.append(float(w))
    return ops, ms, ws


def load_augmentor(args, augmentor, after_transforms, n_class, sub_policies, aug_model_name):
    # aug_n_class = n_class if args.aug_n_class == 0 else args.args_n_class
    # ugly way to get the model name and get the number of class the augmentor predicts
    # args: dataset, use_cuda, use_parallel, temperature, aug_mode
    # policy_path, k_ops, save, gpu
    resize = True if aug_model_name != 'resnet50' else False

    if args.aug_mode == 'vector':
        ops, ms, ws = parse_genotype(args.policy_path)
        augmentor.add_augment_agent(ops, after_transforms, aug_mode='vector',
                                    search=False, ops_weights=ws, magnitudes=ms, resize=resize)
        for param in augmentor.parameters():
            param.requires_grad = False
    elif args.aug_mode == 'projection' or args.aug_mode == 'cnn':
        ops = sub_policies
        augmentor.add_augment_agent(ops, after_transforms, aug_mode=args.aug_mode,
                                    search=False, k_ops=args.k_ops, sampling='prob',
                                    temperature=args.temperature, save_dir=args.save,
                                    infer_n_class=n_class, resize=resize, n_proj_layer=args.n_proj_layer)
        load(augmentor, f'{args.policy_path}/weights.pt', location=args.gpu)
        for param in augmentor.parameters():
            param.requires_grad = False
    elif args.aug_mode == 'randaug':
        augmentor = getRandAugmentAgent(args.dataset, after_transforms)
        # logging.info(f'RandAugment with n={augmentor.n} m={augmentor.m}')

    elif args.aug_mode in ['dada', 'aa', 'fa']:
        augmentor = getPolicyAugmentAgent(args.dataset, args.aug_mode, after_transforms)
    elif args.aug_mode == 'pba':
        augmentor = getPBAAugmentAgent(args.dataset, args.epochs, after_transforms)
    else:
        raise ValueError('invalid augmentor name=%s' % args.aug_mode)

    return augmentor


class PolicyHistory(object):

    def __init__(self, op_names, save_dir, n_class):
        self.op_names = op_names
        self.save_dir = save_dir
        self._initialize(n_class)

    def _initialize(self, n_class):
        self.history = []
        # [{m:[], w:[]}, {}]
        for i in range(n_class):
            self.history.append({'magnitudes': [],
                                'weights': [],
                                'var_magnitudes': [],
                                'var_weights': []})

    def add(self, class_idx, m_mu, w_mu, m_std, w_std):
        if not isinstance(m_mu, list):  # ugly way to bypass batch with single element
            return
        self.history[class_idx]['magnitudes'].append(m_mu)
        self.history[class_idx]['weights'].append(w_mu)
        self.history[class_idx]['var_magnitudes'].append(m_std)
        self.history[class_idx]['var_weights'].append(w_std)

    def save(self, class2label=None):
        path = os.path.join(self.save_dir, 'policy')
        vis_path = os.path.join(self.save_dir, 'vis_policy')
        os.makedirs(path, exist_ok=True)
        os.makedirs(vis_path, exist_ok=True)
        header = ','.join(self.op_names)
        for i, history in enumerate(self.history):
            k = i if class2label is None else class2label[i]
            np.savetxt(f'{path}/policy{i}({k})_magnitude.csv',
                       history['magnitudes'], delimiter=',', header=header, comments='')
            np.savetxt(f'{path}/policy{i}({k})_weights.csv',
                       history['weights'], delimiter=',', header=header, comments='')
            # np.savetxt(f'{vis_path}/policy{i}({k})_var_magnitude.csv',
            #            history['var_magnitudes'], delimiter=',', header=header, comments='')
            # np.savetxt(f'{vis_path}/policy{i}({k})_var_weights.csv',
            #            history['var_weights'], delimiter=',', header=header, comments='')

    def plot(self):
        PATH = self.save_dir
        mag_file_list = glob.glob(f'{PATH}/policy/*_magnitude.csv')
        weights_file_list = glob.glob(f'{PATH}/policy/*_weights.csv')
        n_class = len(mag_file_list)

        f, axes = plt.subplots(n_class, 2, figsize=(15, 5*n_class))

        for i, file in enumerate(mag_file_list):
            df = pd.read_csv(file).dropna()
            x = range(0, len(df))
            y = df.to_numpy().T
            axes[i][0].stackplot(x, y, labels=df.columns, edgecolor='none')
            axes[i][0].set_title(file.split('/')[-1][:-4])

        for i, file in enumerate(weights_file_list):
            df = pd.read_csv(file).dropna()
            x = range(0, len(df))
            y = df.to_numpy().T
            axes[i][1].stackplot(x, y, labels=df.columns, edgecolor='none')
            axes[i][1].set_title(file.split('/')[-1][:-4])

        axes[-1][-1].legend(loc='upper center', bbox_to_anchor=(-0.1, -0.2), fancybox=True, shadow=True, ncol=10)
        plt.savefig(f'{PATH}/policy/visual.png')
        return f


class AvgrageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []

    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return train_transform, valid_transform


def count_parameters_in_MB(model):
    return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1.-drop_prob
        mask = Variable(torch.cuda.FloatTensor(
            x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def create_exp_dir(path, scripts_to_save=None):
    if not os.path.exists(path):
        os.mkdir(path)
    print('Experiment dir : {}'.format(path))

    if scripts_to_save is not None:
        os.mkdir(os.path.join(path, 'scripts'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)


def reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.autograd.set_detect_anomaly(True)
