import numpy as np
import scipy.stats
from scipy.stats import gamma
import torch
import itertools
import torch
import os
import copy
from datetime import datetime
import math
import numpy as np
import tqdm

import torch.nn.functional as F


def get_scf_idxes(ents, p=0.05, fit_method='MM'):
    """
    Get indices of the spurious-cue-free samples from predictive entropy array.
    The function fits the entropy array to the Gamma distribution, and takes
    the indices of the samples which have lower p-value than the specific cut-off value (Default to 0.05).

    Args:
        ents (np.array[Float]): 1-D numpy array of the predictive entropy of the model. 
        p (float, optional): Cut-off p-value. Defaults to 0.05.

    Returns:
        np.array[Integer]: Indices of the spurious-cue-free samples 
    """
    fit_alpha, fit_loc, fit_beta = gamma.fit(
        ents, method=fit_method)  # Assume the entropy follows Gamma distribution.
    p_vals = 1 - gamma.cdf(ents, fit_alpha, loc=fit_loc, scale=fit_beta)
    scf_idxes = np.where(p_vals < p)[0]
    return scf_idxes


class EMA:
    """
    Exponential moving average for LfF
    """

    def __init__(self, label, alpha=0.7):
        self.label = label
        self.alpha = alpha
        self.parameter = torch.zeros(label.size(0))
        self.updated = torch.zeros(label.size(0))

    def update(self, data, index):
        """
        Update the moving average tensor (self.parameter) with its 'index'.

        Args:
            data ([torch.Tensor]): The current values (losses) to be updated. 
            index ([torch.LongTensor]): The indices to be updated.
        """
        data, index = data.cpu(), index.cpu()
        self.parameter[index] = self.alpha * self.parameter[index] + \
            (1-self.alpha*self.updated[index]) * data
        self.updated[index] = 1

    def max_loss(self, label):
        label_index = np.where(self.label == label)[0]
        return self.parameter[label_index].max()


def init_grad(model):
    for param in model.parameters():
        param.grad = None


def ece_score(py, y_test, n_bins=10):
    py = np.array(py)
    y_test = np.array(y_test)
    if y_test.ndim > 1:
        y_test = np.argmax(y_test, axis=1)
    py_index = np.argmax(py, axis=1)
    py_value = []
    for i in range(py.shape[0]):
        py_value.append(py[i, py_index[i]])
    py_value = np.array(py_value)
    acc, conf = np.zeros(n_bins), np.zeros(n_bins)
    Bm = np.zeros(n_bins)
    for m in range(n_bins):
        a, b = m / n_bins, (m + 1) / n_bins
        for i in range(py.shape[0]):
            if py_value[i] > a and py_value[i] <= b:
                Bm[m] += 1
                if py_index[i] == y_test[i]:
                    acc[m] += 1
                conf[m] += py_value[i]
        if Bm[m] != 0:
            acc[m] = acc[m] / Bm[m]
            conf[m] = conf[m] / Bm[m]
    ece = 0
    for m in range(n_bins):
        ece += Bm[m] * np.abs((acc[m] - conf[m]))
    return ece / sum(Bm)


def confidence_interval(mu, std, confidence=0.9):
    h = std * scipy.stats.norm.ppf((1 + confidence) / 2)
    low_interval = mu - h
    up_interval = mu + h
    return mu, low_interval, up_interval


def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a, axis=0), scipy.stats.sem(a, axis=0)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h


def get_hard_index(model, criterion, p=0.1):
    trainloader_ordered, _ = dataset.get_loader(args, shuffle_train=False)
    model.eval()
    score_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets, metadata) in enumerate(trainloader_ordered):
            if use_cuda:
                inputs, targets = inputs.cuda(
                    device_id), targets.cuda(device_id)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            score_list.append(loss.cpu().numpy())
    score_list = np.concatenate(score_list)
    score_dict = dict(zip(range(len(score_list)), score_list))
    # * get top k proportion from score_dict
    large_index = heapq.nlargest(
        int(len(score_list)*p), score_dict, key=scire_dict.get)
    return large_index


def moving_average(net1, net2, alpha=1):
    for param1, param2 in zip(net1.parameters(), net2.parameters()):
        param1.data *= 1.0 - alpha
        param1.data += param2.data * alpha


def _check_bn(module, flag):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        flag[0] = True


def check_bn(model):
    flag = [False]
    model.apply(lambda module: _check_bn(module, flag))
    return flag[0]


def reset_bn(module):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.running_mean = torch.zeros_like(module.running_mean)
        module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        momenta[module] = module.momentum


def _set_momenta(module, momenta):
    if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
        module.momentum = momenta[module]


def bn_update(loader, model, verbose=False, subset=None, **kwargs):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.
        :param loader: train dataset loader for buffers average estimation.
        :param model: model being update
        :return: None
    """
    if not check_bn(model):
        return
    model.train()
    momenta = {}
    model.apply(reset_bn)
    model.apply(lambda module: _get_momenta(module, momenta))
    n = 0
    num_batches = len(loader)

    with torch.no_grad():
        if subset is not None:
            num_batches = int(num_batches * subset)
            loader = itertools.islice(loader, num_batches)
        if verbose:

            loader = tqdm.tqdm(loader, total=num_batches)
        for input, _, _ in loader:
            input = input.cuda(non_blocking=True)
            input_var = torch.autograd.Variable(input)
            b = input_var.data.size(0)

            momentum = b / (n + b)
            for module in momenta.keys():
                module.momentum = momentum

            model(input_var, **kwargs)
            n += b

    model.apply(lambda module: _set_momenta(module, momenta))
