import os
import json
import re
import numpy as np

import torch
import torch.nn as nn

from scipy.stats import entropy
from scipy.special import softmax
from scipy.spatial.distance import pdist, squareform, jensenshannon
from scipy.integrate import trapezoid

from sklearn.model_selection import StratifiedShuffleSplit

from torchattacks import AutoAttack

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TESTSETS = ['data', 'random', 'awgn', 'adv', 'union']


def flatten_weights(model):
    """
    Flatten the weights of a model into a single array.

    Args:
        model (torch.nn.Module): The model.

    Returns:
        np.ndarray: The flattened weights.
    """
    return np.concatenate([p.detach().cpu().numpy().flatten() for p in model.parameters()])

def kl_divergence(model1, model2):
    """
    Compute the Kullback-Leibler divergence between two models predictions on the same dataset.

    Args:
        model1 (np.ndarray): The predictions of the first model.
        model2 (np.ndarray): The predictions of the second model.

    Returns:
        float: The Kullback-Leibler divergence between the two models.
    """
    return np.mean(entropy(model1, model2, axis=1))

def js_distance(model1, model2):
    """
    Compute the Jensen-Shannon distance between two models predictions on the same dataset.

    Args:
        model1 (np.ndarray): The flattened predictions of the first model, where the first element indicates the number of queries to reconstruct the matrix.
        model2 (np.ndarray): The flattened predictions of the second model, where the first element indicates the number of queries to reconstruct the matrix.

    Returns:
        float: The Kullback-Leibler divergence between the two models.
    """

    assert isinstance(model1, np.ndarray) and isinstance(model2, np.ndarray), "Expected two arrays of model predictions"

    n_queries = int(model1[0])
    assert n_queries == model2[0], f"Expected the same number of queries, got {n_queries} and {model2[0]}"

    model1 = model1[1:].reshape(n_queries, -1)
    model2 = model2[1:].reshape(n_queries, -1)
    
    return np.mean(jensenshannon(model1, model2, axis=1))

def cka_distance(X, Y):
    """
    Compute the linear Centred Kernel Alignment (CKA) distance between two models, which is equal to 1-CKA.

    Args:
        X (np.ndarray): Flattened array of model predictions for the first model, which must be centred. The first element of the flattened array indicates the number of queries to reconstruct the matrix.
        Y (np.ndarray): Flattened array of model predictions for the second model, which must be centred. The first element of the flattened array indicates the number of queries to reconstruct the matrix.

    Returns:
        float: The CKA score between the two models.

    Raises:
        AssertionError: If X and Y are not numpy arrays.
        AssertionError: If the number of queries in X and Y are not the same.
        AssertionError: If X and Y are not centred.
    """
    assert isinstance(X, np.ndarray) and isinstance(Y, np.ndarray), "Expected two arrays of model predictions"

    n_queries = int(X[0])

    assert n_queries == int(Y[0]), f"Expected the same number of queries, got {n_queries} and {int(np.sqrt(len(Y)))}"

    X = X[1:].reshape(n_queries, -1)
    Y = Y[1:].reshape(n_queries, -1)

    assert np.allclose(np.mean(X, axis=0), 0.), f"Expected centred predictions for the first model, got {np.mean(X, axis=0)}"
    assert np.allclose(np.mean(Y, axis=0), 0.), f"Expected centred predictions for the second model, got {np.mean(Y, axis=0)}"

    denominator = np.linalg.norm(X.T @ X, ord='fro') * np.linalg.norm(Y.T @ Y, ord='fro')
    
    return 1 - np.linalg.norm(Y.T @ X, ord='fro')**2 / denominator if (np.isfinite(denominator) and denominator != 0.) else np.nan

def l2_worst_distance(model1, model2):
    """
    Compute the worst-case L2 distance between two models predictions on the same dataset, ie. the L2 distance between two predictions that is the largest across all queries.

    Args:
        model1 (np.ndarray): The flattened predictions of the first model, where the first element indicates the number of queries to reconstruct the matrix.
        model2 (np.ndarray): The flattened predictions of the second model, where the first element indicates the number of queries to reconstruct the matrix.

    Returns:
        float: The worst-case L2 distance between the two models.

    Raises:
        AssertionError: If X and Y are not numpy arrays.
        AssertionError: If the number of queries in X and Y are not the same.
    """
    assert isinstance(model1, np.ndarray) and isinstance(model2, np.ndarray), "Expected two arrays of model predictions"

    n_queries = int(model1[0])
    assert n_queries == model2[0], f"Expected the same number of queries, got {n_queries} and {model2[0]}"

    model1 = model1[1:].reshape(n_queries, -1)
    model2 = model2[1:].reshape(n_queries, -1)

    return np.max(np.linalg.norm(model1 - model2, axis=1))

def epsilon_delta_distance(model1, model2, eps_min=0., eps_max=1.):
    """
    Compute the epsilon-gamma distance between two models predictions on the same dataset.

    We define the epsilon-delta curve as a function of epsilon, which returns the fraction of queries for which the L2 distance between the predictions is less than epsilon.
    The epsilon-delta distance is then the integral of the epsilon-delta curve between eps_min and eps_max, scaled with the length of the interval.

    Args:
        model1 (np.ndarray): The flattened predictions of the first model, where the first element indicates the number of queries to reconstruct the matrix.
        model2 (np.ndarray): The flattened predictions of the second model, where the first element indicates the number of queries to reconstruct the matrix.
        epsilon (float, optional): The epsilon parameter to use in the distance. Defaults to 1e-6.

    Returns:
        float: The epsilon-delta distance between the two models.
    """
    assert isinstance(model1, np.ndarray) and isinstance(model2, np.ndarray), "Expected two arrays of model predictions"
    assert eps_min < eps_max, f"Expected eps_min < eps_max, got {eps_min} and {eps_max}"
    assert eps_min >= 0., f"Expected eps_min >= 0, got {eps_min}"
    assert eps_max <= 1., f"Expected eps_max <= 1, got {eps_max}"

    n_queries = int(model1[0])
    assert n_queries == model2[0], f"Expected the same number of queries, got {n_queries} and {model2[0]}"

    model1 = model1[1:].reshape(n_queries, -1)
    model2 = model2[1:].reshape(n_queries, -1)

    dist = np.abs(model1 - model2).mean(axis=1)

    x = np.arange(eps_min, eps_max, 1e-3)
    y = np.array([np.mean(dist < eps) for eps in x])

    return trapezoid(y, x) / (eps_max - eps_min)

def compute_predictions(model, X_test, softmax_flag=False, centre_flag=False):
    """
    Compute predictions using the given model on the test dataset.
    Predictions are flattened and returned as a single array, with its first element the number of queries.

    Args:
        model (torch.nn.Module): The trained model.
        X_test (torch.Tensor): The test dataset as a single tensor on the right device.
        softmax_flag (bool, optional): Whether to apply the softmax function to the predictions. Defaults to False.
        centre_flag (bool, optional): Whether to centre the predictions. Defaults to False.

    Returns:
        numpy.ndarray: The flattened predictions.
    """
    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        predictions = model(X_test.to(DEVICE)).cpu().numpy().astype(np.float64)

    if softmax_flag:
        predictions = softmax(predictions, axis=1)

    if centre_flag:
        predictions -= np.mean(predictions, axis=0)

    return np.concatenate((
        np.array([predictions.shape[0]]),
        predictions.flatten()
    ), dtype=np.float64)

def distance_matrix(models, metric='l2', testset='random', ds_test=None, n_queries=128, seed=2020, **kwargs):
    """
    Compute the distance matrix between a list of models.

    Args:
        models (list): A list of models.
        metric (str, optional): The distance metric to use -- one of 'l2', 'l2_preds', 'l2_worst', 'js', 'cka', 'precomputed'. Defaults to 'l2'.
        testset (str, optional): The dataset to evaluate the models on for the 'l2_preds', 'l2_worst', 'js' and 'cka' metrics -- one of 'data', 'random', 'awgn', 'adv', 'union'. Defaults to 'random'.
        ds_test (torch.utils.data.Dataset, optional): A dataset to evaluate the models on for the 'l2_preds', 'l2_worst', 'js' and 'cka' metrics. Defaults to None.
        n_queries (int, optional): The number of queries to use for the 'l2_preds', 'l2_worst', 'js' and 'cka' metrics. Defaults to 128.
        seed (int, optional): The random seed to use to sample the datasets. Defaults to None.
        **kwargs: Additional arguments to pass to the distance function.

    Returns:
        np.ndarray: The distance matrix.
    """
    metric_dict = {
        'l2': 'euclidean',
        'l2_preds': 'euclidean',
        'l2_worst': l2_worst_distance,
        'js': js_distance,
        'cka': cka_distance,
        'precomputed': 'euclidean',
    }

    if metric == 'l2':
        models = np.vstack([flatten_weights(model) for model in models])
    elif metric == 'precomputed':
        pass
    elif metric in ['js', 'l2_preds', 'l2_worst', 'cka']:
        X = torch.stack([x for x, _ in ds_test], dim=0)
        y = torch.tensor([y for _, y in ds_test], dtype=torch.long)
        sss = StratifiedShuffleSplit(n_splits=3, train_size=n_queries, random_state=seed).split(X, y)

        rdm_nbr_generator = torch.manual_seed(seed) if seed else torch.Generator()

        if testset == 'data' or testset == 'union':
            X_test_data = X[next(sss)[0]] if n_queries < len(X) else X
            if testset == 'data':
                X_test = X_test_data
        if testset == 'random' or testset == 'union':
            X_test_rdm = torch.rand([n_queries] + list(X[0].shape), generator=rdm_nbr_generator)
            if testset == 'random':
                X_test = X_test_rdm
        if testset == 'awgn' or testset == 'union':
            X_test_awgn = X[next(sss)[0]] if n_queries < len(X) else X
            X_test_awgn = torch.clamp(X_test_awgn + 0.1 * torch.randn(X_test_awgn.shape, generator=rdm_nbr_generator), 0, 1)
            if testset == 'awgn':
                X_test = X_test_awgn
        if testset == 'adv' or testset == 'union':
            idx = next(sss)[0] if n_queries < len(X) else np.arange(len(X), dtype=int)
            X_test_adv = AutoAttack(models[0], norm='Linf', eps=8/255, version='standard', verbose=False)(X[idx], y[idx]).detach().cpu()
            if testset == 'adv':
                X_test = X_test_adv
        if testset not in TESTSETS:
            raise ValueError(f"Unknown testset {testset}, possible values are {TESTSETS}")
        
        if testset == 'union':
            n_ts = len(TESTSETS) - 1
            X_test = torch.cat([X_test_data, X_test_rdm, X_test_awgn, X_test_adv], dim=0)
            assert X_test.shape[0] == n_ts * n_queries, f"Expected {n_ts * n_queries} queries, got {X_test.shape[0]}"
        else:
            assert X_test.shape[0] == n_queries, f"Expected {n_queries} queries, got {X_test.shape[0]}"
        
        softmax_flag = metric in ['js', 'l2_preds', 'l2_worst']
        centre_flag = metric == 'cka'

        models = np.vstack([compute_predictions(model, X_test.to(DEVICE), softmax_flag=softmax_flag, centre_flag=centre_flag) for model in models])

        if metric == 'l2_preds':
            models = models[:, 1:]
    else:
        raise ValueError(f"Unknown metric {metric}, possible values are {list(metric_dict.keys())}")

    return squareform(pdist(models, metric=metric_dict[metric], **kwargs))