'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
'''
import argparse
import scipy.sparse
import logging
import math
import random
from collections import OrderedDict

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score, roc_auc_score)
from torch.optim.lr_scheduler import LambdaLR

from utils.ramps import cosine_rampdown

__all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter']


def get_raw_dict(args):
    """
    return the dicf contained in args.

    e.g:
        >>> with open(path, 'w') as f:
                json.dump(get_raw_dict(args), f, indent=2)
    """
    if isinstance(args, argparse.Namespace):
        return vars(args)
    else:
        raise NotImplementedError("Unknown type {}".format(type(args)))


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4)

    mean = torch.zeros(3)
    std = torch.zeros(3)
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


def zero_one_loss(h, t, is_logistic=False):
    positive = 1
    negative = 0 if is_logistic else -1

    if is_logistic:
        t[t == -1] = 0
        h[h == -1] = 0
    else:
        t[t == 0] = -1
        h[h == 0] = -1

    n_p = (t == positive).sum()
    n_n = (t == negative).sum()
    size = n_p + n_n

    n_pp = (h == positive).sum()
    t_p = ((h == positive) * (t == positive)).sum()
    t_n = ((h == negative) * (t == negative)).sum()
    f_p = n_n - t_n
    f_n = n_p - t_p

    presicion = (0.0 if t_p == 0 else t_p / (t_p + f_p))
    recall = (0.0 if t_p == 0 else t_p / (t_p + f_n))

    return presicion, recall, 1 - (t_p + t_n) / size, n_pp


def accuracy(prob, predicted, target):
    """Computes the precision@k for the specified values of k"""

    target[target == 0] = -1

    acc = accuracy_score(target, predicted)
    auc = roc_auc_score(target, prob)
    # auc = -1
    f1 = f1_score(target, predicted)

    precision, recall, erate, npp = zero_one_loss(predicted, target)

    return acc, auc, f1, precision, recall, erate, npp


def accuracy1(output, target):
    """Computes the precision@k for the specified values of k"""
    batch_size = target.size(0)
    _, predicted = torch.max(output.data, 1)
    prob = torch.clamp(F.softmax(output.detach(), dim=1), 1e-10, 1 - 1e-10)
    prob = np.array(prob.cpu().detach())
    predicted = np.array(predicted.cpu().detach())
    target = np.array(target.cpu().detach())
    acc = accuracy_score(target, predicted)
    auc = roc_auc_score(target, prob[:, 1])
    # auc = 0
    f1 = f1_score(target, predicted)
    prec = precision_score(target, predicted)
    recall = recall_score(target, predicted)
    return acc, auc, f1, prec, recall


def multi_class_accuracy(prob, predicted, target, is_logistic=False):
    """
    计算多分类的各项评估指标，包括总体指标和每个类别的指标

    参数:
    prob: shape (n_samples, n_classes) 每个样本在各个类别上的概率
    predicted: shape (n_samples, n_classes) 预测的类别标签（稀疏矩阵或普通数组）
    target: shape (n_samples, n_classes) 真实的类别标签（稀疏矩阵或普通数组）

    返回:
    返回两组指标：总体指标和每个类别的指标
    总体指标: acc, auc, f1_macro, f1_micro, precision, recall, erate
    每类指标: class_f1, class_precision, class_recall, class_npp
    """
    predicted_labels = predicted
    target_labels = target

    # 总体指标计算
    acc = accuracy_score(target_labels, predicted_labels)

    try:
        auc = roc_auc_score(target, prob, multi_class='ovr')
    except ValueError:
        # 如果AUC计算失败，返回-1
        auc = -1

    # 计算不同平均方式的F1
    f1_macro = f1_score(target_labels, predicted_labels, average='macro')
    f1_micro = f1_score(target_labels, predicted_labels, average='micro')

    precision = precision_score(target_labels, predicted_labels, average='weighted')
    recall = recall_score(target_labels, predicted_labels, average='macro')
    erate = 1 - acc

    # 每个类别的指标计算
    class_f1 = [f1_score(target_labels, predicted_labels, pos_label=-1), f1_score(target_labels, predicted_labels, pos_label=1)]
    class_precision = [ precision_score(target_labels, predicted_labels, pos_label=-1), precision_score(target_labels, predicted_labels, pos_label=1)]
    class_recall = [recall_score(target_labels, predicted_labels, pos_label=-1), recall_score(target_labels, predicted_labels, pos_label=1)]
    class_npp =[(predicted_labels == -1).sum(), (predicted_labels == 1).sum()]

    # 返回两组值：总体指标和分类指标
    overall_metrics = (acc, auc, f1_macro, f1_micro, precision, recall, erate)
    class_metrics = (class_f1, class_precision, class_recall, class_npp)

    return overall_metrics, class_metrics


# def set_seed(args):
#     random.seed(args.seed)
#     np.random.seed(args.seed)
#     torch.manual_seed(args.seed)
#     if args.n_gpu > 0:
#         torch.cuda.manual_seed_all(args.seed)


def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7. / 16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        # if current_step < num_warmup_steps:
        #     return float(current_step) / float(max(1, num_warmup_steps))
        # no_progress = float(current_step - num_warmup_steps) / \
        #     float(max(1, num_training_steps - num_warmup_steps))
        # return max(0., math.cos(math.pi * num_cycles * no_progress))

        return cosine_rampdown(current_step, num_training_steps - num_warmup_steps)

    return LambdaLR(optimizer, _lr_lambda, last_epoch)


def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])


def de_interleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])


def three_sigma(x):

    # idx = np.where(x < 0.2/14)
    idx = np.where(x < 0.2 / 9)
    return x[idx]


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
