from __future__ import absolute_import, division, print_function
import numpy as np
import torch
import string
import copy
from tqdm import tqdm
import ot
from math import log
from collections import defaultdict, Counter
import os
from transformers import AutoModelForMaskedLM, AutoTokenizer
import zipfile
import itertools
import utils

################
##### TODO #####
################
# Launch 2016 translation + scripts
# 2017 MOVER
# 2018 BERT
# Explore impact of temperature on small translation
# Launch COCO
# Launch paraphrasing
# run with lambda dic
import logging

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO)


def findsubsets(s, n):
    return list(itertools.combinations(s, n))


class DepthScoreMetric:
    def __init__(self):
        """
        DepthScoreMetric metric

        NOTE:  it  assumes GPU usage
        """

        self.model_names = ["roberta-base", "bert-base-cased"]
        self.temperatures = [1.0]
        self.use_idf_weights = None
        self.invert_support = False
        n = 25 if 'large' in self.model_names[0] else 13
        self.combinations = [[i] for i in range(n) if i > 8]
        self.models_tokenizer = self.load_tokenizer_and_model()
        self.idfs = []
        self.device = torch.device('cpu' if torch.cuda.is_available() else 'cpu')
        self.logger = logging.getLogger(__name__)

    def load_tokenizer_and_model(self):
        models_tokenizer = []
        for model_name in self.model_names:
            print(model_name)

            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForMaskedLM.from_pretrained(model_name)

            model.config.output_hidden_states = True
            model.eval()
            models_tokenizer.append((tokenizer, model, model_name))
        return models_tokenizer

    def evaluate_batch(self, hyps, refs, idf_dict_hyp=None, idf_dict_ref=None, batch_size=256):
        if isinstance(hyps, str):
            hyps = [hyps]
        if isinstance(refs, str):
            refs = [refs]
        l_batch_final = []
        for tokenizer, model, model_name in tqdm(self.models_tokenizer, 'Models'):
            model = model.to(self.device)
            l_combination = []

            for batch_start in range(0, len(refs), batch_size):
                batch_refs = refs[batch_start:batch_start + batch_size]
                batch_hyps = hyps[batch_start:batch_start + batch_size]
                with torch.no_grad():
                    batch_refs = tokenizer(batch_refs, return_tensors='pt', padding=True).to(self.device)
                    batch_refs_embeddings_ = model(**batch_refs)[-1]
                    batch_hyps = tokenizer(batch_hyps, return_tensors='pt', padding=True).to(self.device)
                    batch_hyps_embeddings_ = model(**batch_hyps)[-1]
                    combination_dic = {}
                    for combination in tqdm(self.combinations, "combinations"):
                        batch_refs_embeddings = [batch_refs_embeddings_[i] for i in list(combination)]
                        batch_hyps_embeddings = [batch_hyps_embeddings_[i] for i in list(combination)]

                        batch_refs_embeddings = torch.cat([i.unsqueeze(0) for i in batch_refs_embeddings])
                        batch_refs_embeddings.div_(torch.norm(batch_refs_embeddings, dim=-1).unsqueeze(-1))
                        batch_hyps_embeddings = torch.cat([i.unsqueeze(0) for i in batch_hyps_embeddings])
                        batch_hyps_embeddings.div_(torch.norm(batch_hyps_embeddings, dim=-1).unsqueeze(-1))

                        ref_tokens_id = batch_refs['input_ids'].cpu().tolist()
                        hyp_tokens_id = batch_hyps['input_ids'].cpu().tolist()
                        preds = []
                        for index_sentence in range(len(refs)):
                            dict_score = {}
                            ref_tokens = [i for i in tokenizer.convert_ids_to_tokens(ref_tokens_id[index_sentence],
                                                                                     skip_special_tokens=False) if
                                          i != tokenizer.pad_token]
                            hyp_tokens = [i for i in tokenizer.convert_ids_to_tokens(hyp_tokens_id[index_sentence],
                                                                                     skip_special_tokens=False) if
                                          i != tokenizer.pad_token]

                            ref_ids = [k for k, w in enumerate(ref_tokens) if True]
                            hyp_ids = [k for k, w in enumerate(hyp_tokens) if True]

                            ref_embedding_i = batch_refs_embeddings[:, index_sentence, ref_ids, :]
                            hyp_embedding_i = batch_hyps_embeddings[:, index_sentence, hyp_ids, :]
                            measures_locations_ref = ref_embedding_i.permute(1, 0, 2).cpu().numpy().tolist()
                            measures_locations_ref = [np.array(i) for i in measures_locations_ref]
                            measures_locations_hyps = hyp_embedding_i.permute(1, 0, 2).cpu().numpy().tolist()
                            measures_locations_hyps = [np.array(i) for i in measures_locations_hyps]

                            measures_results = self.compute_measure(measures_locations_ref, measures_locations_hyps,
                                                                    model_name, combination)
                            for key, value in measures_results.items():
                                dict_score[
                                    '{}_{}_depth_{}'.format(model_name, str(combination), key)] = value

                            preds.append(dict_score)
                        dict_bach = {}
                        for key in preds[0].keys():
                            dict_bach[key] = [score[key] for score in preds]
                        combination_dic.update(dict_bach)

                    l_combination.append(combination_dic)

                for key in l_combination[0].keys():
                    dict_bach[key] = [score[key] for score in l_combination]
                l_batch_final.append(dict_bach)

            model = model.cpu()

        final_dic = {}
        for dic_ in l_batch_final:
            final_dic.update(dic_)

        for key, value in final_dic.items():
            final_dic[key] = value[0]
        return final_dic

    def compute_measure(self, measures_locations_ref, measures_locations_hyps, model_name, layer):
        measure_dic = {}
        measures_locations_ref = np.array(measures_locations_ref).squeeze(1)
        measures_locations_hyps = np.array(measures_locations_hyps).squeeze(1)
        for p in [2]:
            for eps in [0.3]:
                self.logger.info("Power {} ---------- Eps {}".format(p, eps))
                for n_alpha in [5]:
                    for data_depth in ["irw", "ai_irw"]:
                        if True:  # not (data_depth in ["irw", "ai_irw"] and model_name in ["bert-base-uncased"] and layer[0] > 8):
                            self.logger.info(data_depth)
                            resutls = dr_distance(measures_locations_ref, measures_locations_hyps, n_alpha=n_alpha,
                                                  n_dirs=10000, data_depth=data_depth, eps_min=eps, eps_max=1, p=p)
                            measure_dic["{}_{}_{}_{}".format(p, eps, n_alpha, data_depth.replace('_', ''))] = resutls
        for data_depth in ["wasserstein", "sliced", "mmd"]:
            resutls = dr_distance(measures_locations_ref, measures_locations_hyps, n_alpha=n_alpha,
                                  n_dirs=10000, data_depth=data_depth, eps_min=eps, eps_max=1, p=p)
            measure_dic["{}_{}_{}_{}".format(p, eps, n_alpha, data_depth.replace('_', ''))] = resutls
        return measure_dic


import numpy as np
from sklearn.preprocessing import normalize
from sklearn.covariance import MinCovDet as MCD
import ot

########################################################
#################### Sampled distribution ########################
########################################################
from sklearn.preprocessing import normalize
from sklearn.covariance import MinCovDet as MCD
from sklearn.decomposition import PCA
import numpy as np


########################################################
#################### Some useful functions ########################
########################################################


def cov_matrix(X, robust=False):
    """ Compute the covariance matrix of X.
    """
    if robust:
        cov = MCD().fit(X)
        sigma = cov.covariance_
    else:
        sigma = np.cov(X.T)

    return sigma


def standardize(X, robust=False):
    """ Compute the square inverse of the covariance matrix of X.
    """

    sigma = cov_matrix(X, robust)
    n_samples, n_features = X.shape
    rank = np.linalg.matrix_rank(X)

    if (rank < n_features):
        pca = PCA(rank)
        pca.fit(X)
        X_transf = pca.fit_transform(X)
        sigma = cov_matrix(X_transf)
    else:
        X_transf = X.copy()

    u, s, _ = np.linalg.svd(sigma)
    square_inv_matrix = u / np.sqrt(s)

    return X_transf @ square_inv_matrix


########################################################
#################### Sampled distributions ########################
########################################################

def sampled_sphere(n_dirs, d):
    """ Produce ndirs samples of d-dimensional uniform distribution on the
        unit sphere
    """

    mean = np.zeros(d)
    identity = np.identity(d)
    U = np.random.multivariate_normal(mean=mean, cov=identity, size=n_dirs)

    return normalize(U)


import ot


def sampled_sphere(ndirs, d):
    mean = np.zeros(d)
    identity = np.identity(d)
    U = np.random.multivariate_normal(mean=mean, cov=identity, size=ndirs)
    return normalize(U)


def Wasserstein(X, Y):
    M = ot.dist(X, Y)
    n = len(X)
    m = len(Y)
    w_X = np.zeros(n) + 1 / n
    w_Y = np.zeros(m) + 1 / m

    return ot.emd2(w_X, w_Y, M)


def SW(X, Y, ndirs, p=2, max_sliced=False):
    n, d = X.shape
    U = sampled_sphere(ndirs, d)
    Z = np.matmul(X, U.T)
    Z2 = np.matmul(Y, U.T)
    Sliced = np.zeros(ndirs)
    for k in range(ndirs):
        Sliced[k] = ot.emd2_1d(Z[:, k], Z2[:, k], p=2)
    if (max_sliced == True):
        return (np.max(Sliced)) ** (1 / p)
    else:
        return (np.mean(Sliced)) ** (1 / p)


def gaussian_kernel(x1, x2, sigma=1.0):
    # r = x1.dimshuffle(0,'x',1)
    return np.exp(-np.linalg.norm(x1 - x2, axis=1) / (2 * sigma ** 2))


import geomloss


def MMD(x1, x2, sigma=1):
    # x1x1 = gaussian_kernel(x1, x1, sigma)
    # x1x2 = gaussian_kernel(x1, x2, sigma)
    # x2x2 = gaussian_kernel(x2, x2, sigma)
    # diff = x1x1.mean() - 2 * x1x2.mean() + x2x2.mean()
    # return diff
    return geomloss.SamplesLoss("gaussian")(torch.tensor(x1), torch.tensor(x2)).item()


########################################################
#################### Data Depths ########################
########################################################

def tukey_depth(X, n_dirs=None):
    """ Compute the score of the classical tukey depth of X w.r.t. X

    Parameters
    ----------
    X : Array of shape (n_samples, n_features)
            The training set.

    ndirs : int | None
        The number of random directions to compute the score.
        If None, the number of directions is chosen as
        n_features * 100.

    Return
    -------
    tukey_score: Array of float
        Depth score of each delement in X.
    """

    if n_dirs is None:
        n_dirs = n_features * 100

    n_samples, n_features = X.shape

    # Simulated random directions on the unit sphere.
    U = sampled_sphere(n_dirs, n_features)

    sequence = np.arange(1, n_samples + 1)
    depth = np.zeros((n_samples, n_dirs))

    # Compute projections
    proj = np.matmul(X, U.T)

    rank_matrix = np.matrix.argsort(proj, axis=0)

    for k in range(n_dirs):
        depth[rank_matrix[:, k], k] = sequence

    depth = depth / (n_samples * 1.)

    depth_score = np.minimum(depth, 1 - depth)
    tukey_score = np.amin(depth, axis=1)

    return tukey_score


def projection_depth(X, n_dirs=None):
    """ Compute the score of the projection depth of X w.r.t. X

    Parameters
    ----------
    X : Array of shape (n_samples, n_features)
        The training set.

    ndirs : int | None
        The number of random directions to compute the score.
        If None, the number of directions is chosen as
        n_features * 100.

    Return
    -------
    projection score: Array of float
        Depth score of each delement of X.
    """
    n, d = X.shape
    if n_dirs is None:
        n_dirs = n_features * 100

    n_samples, n_features = X.shape

    # Simulated random directions on the unit sphere.
    U = sampled_sphere(n_dirs, n_features)

    # Compute projections
    proj = np.matmul(X, U.T)

    depth = np.zeros((n_samples, n_dirs))
    MAD = np.zeros(n_dirs)

    # Compute stahel-Donoho outlyingness on projections
    med_proj = np.median(proj, axis=0)
    MAD = np.median(np.absolute(proj - med_proj.reshape(1, -1)), axis=0)
    depth = np.absolute(proj - med_proj.reshape(1, -1)) / MAD
    outlyingness = np.amax(depth, axis=1)

    projection_score = 1 / (1 + outlyingness)

    return projection_score


def ai_irw(X, AI=True, robust=False, n_dirs=None, random_state=None):
    """ Compute the score of the (Affine-invariant-) integrated rank
        weighted depth of X_test w.r.t. X

    Parameters
    ----------

    X : Array of shape (n_samples, n_features)
            The training set.

    AI: bool
        if True, the affine-invariant version of irw is computed.
        If False, the original irw is computed.

    robust: bool, default=False
        if robust is true, the MCD estimator of the covariance matrix
        is performed.

    n_dirs : int | None
        The number of random directions needed to approximate
        the integral over the unit sphere.
        If None, n_dirs is set as 100* n_features.

    random_state : int | None
        The random state.

    Returns
    -------
    ai_irw_score: Array
        Depth score of each element in X_test.
    """

    if random_state is None:
        random_state = 0

    np.random.seed(random_state)

    if AI:
        X_reduced = standardize(X, robust)
    else:
        X_reduced = X.copy()

    n_samples, n_features = X_reduced.shape

    if n_dirs is None:
        n_dirs = n_features * 100

    # Simulated random directions on the unit sphere.
    U = sampled_sphere(n_dirs, n_features)

    sequence = np.arange(1, n_samples + 1)
    depth = np.zeros((n_samples, n_dirs))

    proj = np.matmul(X_reduced, U.T)
    rank_matrix = np.matrix.argsort(proj, axis=0)

    for k in range(n_dirs):
        depth[rank_matrix[:, k], k] = sequence

    depth = depth / (n_samples * 1.)
    depth_score = np.minimum(depth, 1 - depth)
    ai_irw_score = np.mean(depth_score, axis=1)

    return ai_irw_score


import numpy as np


def dr_distance(X, Y, n_alpha=10, n_dirs=100, data_depth='tukey', eps_min=0,
                eps_max=1, p=2, random_state=None):
    """The Deph-based pseudo-metric between two probability distributions.

    Parameters
    ----------

    X: array of shape (n_samples, n_features)
        The first sample.

    Y: array of shape (n_samples, n_features)
        The second sample.

    n_alpha: int
        The Monte-Carlo parameter for the approximation of the integral
        over alpha.

    n_dirs: int
        The number of directions for approximating the supremum over
        the unit sphere.

    data_depth: str in {'tukey', 'projection', 'irw', 'ai_irw'}

    eps_min: float in [0,eps_max]
        the lowest level set.

    eps_max: float in [eps_min,1]
        the highest level set.

    p: int
        the power of the ground cost.

    random_state : int | None
        The random state.

    Return
    ------

    dr_score: float
        the computed pseudo-metric score.
    """

    if random_state is None:
        random_state = 0

    np.random.seed(random_state)

    if data_depth not in {'tukey', 'projection', 'irw', 'ai_irw', 'wasserstein', 'mmd', 'sliced'}:
        raise NotImplementedError('This data depth is not implemented')

    if eps_min > eps_max:
        raise ValueError('eps_min must be lower than eps_max')

    if eps_min < 0 or eps_min > 1:
        raise ValueError('eps_min must be in [0,eps_max]')

    if eps_max < 0 or eps_max > 1:
        raise ValueError('eps_min must be in [eps_min,1]')

    _, n_features = X.shape

    if data_depth == "tukey":
        depth_X = tukey_depth(X, n_dirs=n_dirs)
        depth_Y = tukey_depth(Y, n_dirs=n_dirs)
    elif data_depth == "projection":
        depth_X = projection_depth(X, n_dirs=n_dirs)
        depth_Y = projection_depth(Y, n_dirs=n_dirs)
    elif data_depth == "irw":
        depth_X = ai_irw(X, AI=False, n_dirs=n_dirs)
        depth_Y = ai_irw(Y, AI=False, n_dirs=n_dirs)
    elif data_depth == "ai_irw":
        depth_X = ai_irw(X, AI=True, n_dirs=n_dirs)
        depth_Y = ai_irw(Y, AI=True, n_dirs=n_dirs)
    elif data_depth == 'wasserstein':
        return Wasserstein(X, Y)
    elif data_depth == 'sliced':
        return SW(X, Y, ndirs=10000)
    elif data_depth == 'mmd':
        return MMD(X, Y)

        # draw n_dirs vectors of the unit sphere in dimension n_features.
    U = sampled_sphere(n_dirs, n_features)
    proj_X = np.matmul(X, U.T)
    proj_Y = np.matmul(Y, U.T)

    liste_alpha = np.linspace(int(eps_min * 100), int(eps_max * 100), n_alpha)
    quantiles_DX = [np.percentile(depth_X, j) for j in liste_alpha]
    quantiles_DY = [np.percentile(depth_Y, j) for j in liste_alpha]

    dr_score = 0
    for i in range(n_alpha):
        d_alpha_X = np.where(depth_X >= quantiles_DX[i])[0]
        d_alpha_Y = np.where(depth_Y >= quantiles_DY[i])[0]
        supp_X = np.max(proj_X[d_alpha_X], axis=0)
        supp_Y = np.max(proj_Y[d_alpha_Y], axis=0)
        dr_score += np.max((supp_X - supp_Y) ** p)

    return (dr_score / n_alpha) ** (1 / p)


if __name__ == '__main__':

    class Args:
        def __init__(self):
            self.temperature = 1
            self.device = torch.device('cpu')
            self.compute_js2 = False
            self.reverse_idfw = False
            self.use_idf_weights = True
            self.use_rao_weights = False


    args = Args()
    model_name = "bert-base-uncased"  # 'distilbert-base-uncased'
    folder = "/gpfswork/rech/qsq/uwi62ct/transformers_models/{}".format(model_name)
    if os.path.exists(folder):
        model_name = folder
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)

    metric_call = DepthScoreMetric()

    ref = [
        'I like my cakes very much']
    hypothesis = ['I like my cakes very much']

    final_preds = metric_call.evaluate_batch(ref, hypothesis, None, None, batch_size=256)
    print(final_preds)
