import os
import random
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lrs

import torch
import numpy as np

from pdb import set_trace

def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)

def make_optimizer(args, my_model):

    if args.bias_only_decrease:
        bias_params = [p for n, p in my_model.named_parameters() if (n.endswith('bias') and p.requires_grad and "fc" not in n)]
        other_params = [p for n, p in my_model.named_parameters() if (p.requires_grad and p not in set(bias_params))]
        trainable = [{'params': bias_params, 'weight_decay': 0},
                     {'params': other_params}]
    else:
        trainable = filter(lambda x: x.requires_grad, my_model.parameters())

    if args.sep_decay:
        wd_term = 0
    else:
        wd_term = args.weight_decay

    if args.optimizer == 'SGD':
        optimizer_function = optim.SGD
        kwargs = {'momentum': 0.9,
                  'lr': args.lr,
                  'weight_decay': wd_term#args.weight_decay
        }
    elif args.optimizer == 'Adam':
        optimizer_function = optim.Adam
        kwargs = {
            'betas': (0.9, 0.999),
            'eps': 1e-08,
            'lr': args.lr,
            'weight_decay': wd_term#args.weight_decay
        }
    elif args.optimizer == 'LBFGS':
        optimizer_function = optim.LBFGS
        kwargs = {'lr': args.lr,
                  'history_size': args.history_size,
                  'line_search_fn': 'strong_wolfe'
        }

    return optimizer_function(trainable, **kwargs)


def make_scheduler(args, my_optimizer):
    if args.decay_type == 'step':
        scheduler = lrs.StepLR(
            my_optimizer,
            step_size=args.patience,
            gamma=args.gamma
        )
    elif args.decay_type.find('step') >= 0:
        milestones = args.decay_type.split('_')
        milestones.pop(0)
        milestones = list(map(lambda x: int(x), milestones))
        scheduler = lrs.MultiStepLR(
            my_optimizer,
            milestones=milestones,
            gamma=args.gamma
        )

    return scheduler


def make_criterion(args):
    if args.loss == 'CrossEntropy':
        criterion = nn.CrossEntropyLoss()
    elif args.loss == 'MSE':
        criterion = nn.MSELoss()

    return criterion


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class AverageDict(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = dict()
        self.avg = dict()
        self.sum = dict()
        self.count = 0

    def update(self, val, n=1):
        self.val = val

        for key, item in val.items():
            if key in self.sum:
                self.sum[key] += item * n
            else:
                self.sum[key] = item * n

        self.count += n

        for key, item in val.items():
            self.avg[key] = self.sum[key] / self.count


def count_network_parameters(model):
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in parameters])


def print_and_save(text_str, file_stream):
    print(text_str)
    print(text_str, file=file_stream)


def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class SparsityFlipCal(object):
    def __init__(self):
        self.compare_dict = dict()
        self.feature_dict = dict()

    def update(self, dict, indexs, set_compare=False):
        for cnt, index in enumerate(indexs):
            if int(index) < 10:
                if set_compare:
                    self.compare_dict[int(index)] = {key: item[cnt].detach() for key, item in dict.items()}
                else:
                    self.feature_dict[int(index)] = {key: item[cnt].detach() for key, item in dict.items()}

    def compare(self, feature_as_compare=False):
        compare_summary = dict()

        for key in self.feature_dict:
            assert key in self.compare_dict

            for module in self.compare_dict[key]:
                if module not in compare_summary:
                    compare_summary[module] = []

                compare_summary[module].append(torch.abs((self.compare_dict[key][module] < 1e-6).float() - (self.feature_dict[key][module] < 1e-6).float()).mean().cpu().item())

        compare_summary = {key: np.mean(item) for key, item in compare_summary.items()}

        if feature_as_compare:
            self.compare_dict = self.feature_dict
            self.feature_dict = dict()

        return compare_summary
