import os

import numpy as np
import torch
from joblib import Parallel, delayed
# import torch.nn.functional as F
from scipy.stats import gaussian_kde
from sklearn.metrics import brier_score_loss, log_loss
from torchmetrics import AUROC, Accuracy, CalibrationError, MeanSquaredError
from torchmetrics.functional import kl_divergence
from tqdm import tqdm


def compute_classif_metrics(y_scores, y_labels, y_true_probas=None):
    y_scores = torch.from_numpy(np.array(y_scores))
    y_labels = torch.from_numpy(np.array(y_labels))
    # y_scores_2d = torch.stack([1-y_scores, y_scores], axis=1)

    metrics = {
        'acc': Accuracy()(y_scores, y_labels).item(),
        'auroc': AUROC(compute_on_step=True).forward(y_scores, y_labels).item(),
        'ece': CalibrationError(norm='l1', compute_on_step=True).forward(y_scores, y_labels).item(),
        'mce': CalibrationError(norm='max', compute_on_step=True).forward(y_scores, y_labels).item(),
        'rmsce': CalibrationError(norm='l2', compute_on_step=True).forward(y_scores, y_labels).item(),
        'brier': brier_score_loss(y_labels, y_scores),
        # 'nll': F.nll_loss(torch.log(y_scores_2d), y_labels),
        'nll': log_loss(y_labels, y_scores),
    }
    metrics['msce'] = np.square(metrics['rmsce'])

    if y_true_probas is not None:
        y_true_probas = torch.from_numpy(np.array(y_true_probas))

        # Estimate distributions for KL divergence
        # XX = np.linspace(0, 1, 100)
        # y_scores_dist = torch.from_numpy(gaussian_kde(y_scores)(XX))
        # y_true_probas_dist = torch.from_numpy(gaussian_kde(y_true_probas)(XX))

        metrics.update({
            # 'kl': kl_divergence(torch.atleast_2d(y_scores_dist), torch.atleast_2d(y_true_probas_dist)),
            'mse': MeanSquaredError(compute_on_step=True).forward(y_scores, y_true_probas).item(),
            'acc_bayes': Accuracy()(y_true_probas, y_labels).item(),
            'brier_bayes': brier_score_loss(y_labels, y_true_probas),
        })

    return metrics


def compute_IPW(T_labels, T_scores, y, y_hat):
    # weights = np.divide(T_labels, T_scores) + np.divide(1 - T_labels, 1 - T_scores)
    # weights = np.zeros(T_labels.shape, dtype=float)
    weights = np.zeros_like(T_labels, dtype=float)
    weights[T_labels == 0] = np.divide(1, 1 - T_scores[T_labels == 0])
    weights[T_labels == 1] = np.divide(1, T_scores[T_labels == 1])
    IPW = np.vdot(weights, np.square(y - y_hat))
    return IPW


def save_path(dirpath, ext, order=[], **kwargs):
    os.makedirs(dirpath, exist_ok=True)
    keys = sorted(list(kwargs.keys()))
    if not set(order).issubset(keys):
        raise ValueError(f'Given order {order} should be a subset of {keys}.')

    for key in order:
        keys.remove(key)

    keys = order + keys

    def replace(x):
        if x is True:
            return 'T'
        if x is False:
            return 'F'
        if x is None:
            return 'N'
        return x

    filename = ':'.join(f'{k}={replace(kwargs[k])}' for k in keys)
    if not filename:
        filename = 'fig'

    filename = filename.replace('(', ':')
    filename = filename.replace(')', '')
    filename = filename.replace(' ', '_')
    filename = filename.replace(',', ':')
    # filename = filename.replace(':', '_')
    filename = filename.replace('@', '_')
    # filename = filename.replace('=', '_')
    filename = filename.replace('.', '_')

    filename += f'.{ext}'
    filepath = os.path.join(dirpath, filename)

    return filepath


def save_fig(fig, dirpath, ext='pdf', order=[], pad_inches=0.1, **kwargs):
    filepath = save_path(dirpath, ext=ext, order=order, **kwargs)
    fig.savefig(filepath, bbox_inches='tight', transparent=True, pad_inches=pad_inches)
    return filepath


def list_list_to_array(L, fill_value=None, dtype=None):
    """Convert a list of list of varying size into a numpy array with
    smaller shape possible.

    Parameters
    ----------
    L : list of lists.

    fill_value : any
        Value to fill the blank with.

    Returns
    -------
    a : array

    """
    max_length = max(map(len, L))
    L = [Li + [fill_value]*(max_length - len(Li)) for Li in L]
    return np.array(L, dtype=dtype)


def pad_array(a, shape, fill_value=0):
    """Pad a numpy array to the desired shape by appending values to axes.

    Parameters
    ----------
    a : array

    shape : tuple
        Desired shape. If one array has a smaller shape, an error is raised.

    Returns
    -------
    b : array
        Padded array with shape shape.
    """
    a_shape = np.array(a.shape, dtype=int)
    b_shape = np.array(shape, dtype=int)

    if len(a_shape) != len(b_shape):
        raise ValueError(f'Desired shape and array shape must have same '
                         f'dimension. Array is {len(a_shape)}D, desired shape '
                         f'is {len(b_shape)}D.')

    if (b_shape < a_shape).any():
        raise ValueError(f'Desired shape must have all its dimension at least '
                         f'as large as input array. Asked shape {b_shape} on '
                         f'array of shape {a_shape}.')

    pad_width = tuple((0, c) for c in b_shape - a_shape)
    return np.pad(a, pad_width, mode='constant', constant_values=fill_value)


def pad_arrays(L, shape=None, fill_value=0):
    """"Pad a list of array to the desired shape by appending values to axes.

    Parameters
    ----------
    L : list of arrays.

    shape : tuple
        Desired shape. If one array has a smaller shape, an error is raised.

    fill_value : any
        Value to fill the blank with.

    """

    if shape is None:
        # Find the largest shape
        shapes = [np.array(a.shape) for a in L]
        shape = np.maximum.reduce(shapes)

    return [pad_array(a, shape, fill_value) for a in L]


def pairwise_call(L, f, symmetric=True, n_jobs=1, verbose=0):
    """Compute pairwise call of f on object of list L

    Parameters
    ----------
    L : list
        List of objects of shape n.

    f : callable (obj1, obj2) returning a float

    symmetric : bool
        Whether f(x, y) = f(y, x). If so, avoid half the computation.

    n_jobs : int

    Returns
    -------
    D : (n, n) array

    """
    n = len(L)
    D = np.full((n, n), np.nan)

    if symmetric:
        indexes = [(i, j) for i in range(n) for j in range(i, n)]
    else:
        indexes = [(i, j) for i in range(n) for j in range(n)]

    disable = verbose <= 0
    res = Parallel(n_jobs=n_jobs, require='sharedmem')(
        delayed(f)(L[i], L[j]) for i, j in tqdm(indexes, disable=disable))

    for k, (i, j) in enumerate(tqdm(indexes, disable=disable)):
        D[i, j] = res[k]
        if symmetric:
            D[j, i] = res[k]

    return D
