import numpy as np
import pandas as pd
import torch

from scipy.special import softmax

from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler


def describe_array(x):
    """
    Return a pandas Series describing the input array.
    
    Args:
        x (np.ndarray): The input array.
    
    Returns:
        pd.Series: A pandas Series describing the input array.
    """
    return pd.Series({
        'mean': np.mean(x),
        'std': np.std(x),
        'min': np.min(x),
        '25%': np.percentile(x, 25),
        '50%': np.percentile(x, 50),
        '75%': np.percentile(x, 75),
        'max': np.max(x)
    })


def test_model(model, dataloader, device='cuda'):
    """
    Evaluate the performance of a given PyTorch model on a provided dataloader.

    Args:
        model (torch.nn.Module): A PyTorch model that has been trained and is to be evaluated.
        dataloader (torch.utils.data.DataLoader): A PyTorch DataLoader object that provides the data for the model to be evaluated on.
        device (str, optional): The device to use for testing the model. Default is 'cuda'.

    Returns:
        float: The accuracy of the model.
    """
    model.eval()  # Set the model to evaluation mode
    model.to(device)  # Move the model to the specified device
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient calculation to save memory
        for data in dataloader:
            images, labels = [tensor.to(device) for tensor in data]
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total


def compute_predictions(model, X_test, softmax_flag=False, centre_flag=False, device='cuda'):
    """
    Compute predictions using the given model on the test dataset.

    Args:
        model (torch.nn.Module): The trained model.
        X_test (torch.Tensor): The test dataset as a single tensor on the right device.
        softmax_flag (bool, optional): Whether to apply the softmax function to the predictions. Defaults to False.
        centre_flag (bool, optional): Whether to centre the predictions. Defaults to False.
        device (str, optional): The device to use for computing predictions. Default is 'cuda'.

    Returns:
        numpy.ndarray: The predictions.
    """
    model.to(device)
    model.eval()

    with torch.no_grad():
        predictions = model(X_test.to(device)).cpu().numpy().astype(np.float64)

    if softmax_flag:
        predictions = softmax(predictions, axis=1)

    if centre_flag:
        predictions -= np.mean(predictions, axis=0)

    return predictions


def compare_models_prediction(model1, model2, dataloader, n_batches=32, device='cuda'):
    """Computes the average L2 distance, Mean Squared Error (MSE), and agreement between the softmaxed predictions of two models.

    Args:
        model1 (torch.nn.Module): The first model to evaluate.
        model2 (torch.nn.Module): The second model to evaluate.
        dataloader (torch.utils.data.DataLoader): DataLoader containing the test data.
        n_batches (int, optional): Number of batches to process. If None, all batches of the testset are processed. Defaults to 32.
        device (str, optional): The device to use for computing predictions. Default is 'cuda'.

    Returns:
        tuple: A tuple containing:
            - float: The average L2 distance between the softmaxed predictions of model1 and model2.
            - float: The average Mean Squared Error (MSE) between the softmaxed predictions of model1 and model2.
            - float: The agreement between the softmaxed predictions of model1 and model2.
    """
    l2 = list()
    mse = list()
    agreement = 0
    total_predictions = 0

    for bdx, (X_test, _) in enumerate(dataloader):
        if n_batches is not None and bdx == n_batches:
            break

        y_preds_1 = compute_predictions(model1, X_test, softmax_flag=True, device=device)
        y_preds_2 = compute_predictions(model2, X_test, softmax_flag=True, device=device)

        for idx in range(y_preds_1.shape[0]):
            y1 = y_preds_1[idx]
            y2 = y_preds_2[idx]
            l2.append(np.linalg.norm(y1 - y2))
            mse.append(np.mean((y1 - y2) ** 2))

            # Check if the predictions are the same
            if np.argmax(y1) == np.argmax(y2):
                agreement += 1

            total_predictions += 1

    return np.mean(l2), np.mean(mse), agreement / total_predictions


def generate_z_samples(z_anchors, n_samples=None):
    """
    Generates samples from the distribution of the given z_anchors using PCA and KDE.

    The process involves several steps:
    1. Standardize the input data using StandardScaler.
    2. Apply Principal Component Analysis (PCA) to reduce the dimensionality to 32 components.
    3. Perform Kernel Density Estimation (KDE) on the PCA-transformed data to estimate the underlying distribution.
    4. Sample from the estimated distribution to generate new samples.
    5. Inverse transform the samples back to the original feature space using the inverse PCA and scaler transformations.

    Args:
        z_anchors (numpy.ndarray): A 2D array of shape (n_samples, n_features) containing the anchor points.
        n_samples (int, optional): The number of samples to generate. If None, defaults to the number of rows in z_anchors.

    Returns:
        numpy.ndarray: A 2D array of shape (n_samples, n_features) containing the generated samples.

    Raises:
        AssertionError: If the input z_anchors does not have exactly 2 dimensions.

    Example:
        >>> import numpy as np
        >>> z_anchors = np.random.rand(100, 50)
        >>> samples = generate_z_samples(z_anchors, n_samples=50)
        >>> print(samples.shape)
        (50, 50)
    """
    assert len(z_anchors.shape) == 2
    if n_samples is None:
        n_samples = z_anchors.shape[0]

    # Scale
    scaler = StandardScaler()
    z_std = scaler.fit_transform(z_anchors)

    # Apply PCA
    pca = PCA(n_components=32)
    z_pca = pca.fit_transform(z_std)

    # Perform KDE
    z_kde = KernelDensity(kernel='gaussian', bandwidth=0.1).fit(z_pca).sample(n_samples, random_state=2020)

    # Inverse transform and return
    return scaler.inverse_transform(pca.inverse_transform(z_kde))