import torch
import copy
import os
import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm


class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0.05):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.epochs_no_improve = 0
        self.early_stop = False

    def __call__(self, val_loss):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.epochs_no_improve += 1
            if self.epochs_no_improve >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.epochs_no_improve = 0


def load_datasets(dataset_name, model_type='resnet18'):
    if dataset_name == 'cifar10':
        if model_type in ['resnet18', 'vgg']:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        elif model_type in ['vit']:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        train_dataset = torchvision.datasets.CIFAR10(root='datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR10(root='datas', train=False, download=True, transform=transform)
        num_classes = 10
    elif dataset_name == 'cifar100':
        if model_type in ['resnet18', 'vgg']:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
            ])
        elif model_type in ['vit']:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
            ])
        train_dataset = torchvision.datasets.CIFAR100(root='datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR100(root='datas', train=False, download=True, transform=transform)
        num_classes = 100
    elif dataset_name == 'mini-fashion':
        if model_type in ['resnet18', 'vgg']:
            transform = transforms.Compose([
                transforms.Resize(32),
                transforms.Grayscale(num_output_channels=3),  # Convert 1 channel image to 3 channel
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        elif model_type in ['vit']:
            transform = transforms.Compose([
                transforms.Resize(224),
                transforms.Grayscale(num_output_channels=3),  # Convert 1 channel image to 3 channel
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        train_dataset = torchvision.datasets.FashionMNIST(root='datas', train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.FashionMNIST(root='datas', train=False, download=True, transform=transform)
        num_classes = 10
    elif dataset_name == 'celeba':
        transform = transforms.Compose([
            transforms.CenterCrop(178),
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = torchvision.datasets.CelebA(root='datas', split='train', download=True, transform=transform)
        test_dataset = torchvision.datasets.CelebA(root='datas', split='test', download=True, transform=transform)
        
        # Assuming binary classification for CelebA attributes
        num_classes = 2  
    else:
        raise ValueError("Unsupported dataset")
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
    return train_dataset, train_loader, test_loader, num_classes

def create_pseudo_labels(num_classes, batch_size):
    """
    Create pseudo labels for a batch of samples, distributed evenly across the given number of classes.

    Args:
    num_classes (int): Number of classes in the dataset.
    batch_size (int): Number of samples in the batch.

    Returns:
    torch.Tensor: A tensor of shape (batch_size, num_classes) where each row contains uniform
                   probability distribution across all classes.
    """
    return torch.ones(batch_size, num_classes) / num_classes

def evaluate_accuracy(model, dataloader):
    was_training = model.training
    model.eval()
    correct = 0
    total = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    model.train(mode=was_training)
    return correct / total

def clone_and_freeze_model(model):
    model_copy = copy.deepcopy(model)

    for param in model_copy.parameters():
        param.requires_grad = False
    
    return model_copy

######################## Compute first order and second order grad ###########################

def compute_epoch_grad(loader, model, loss_fn, loss_type = 'sum', pseudo_labels = None):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = next(model.parameters()).device
    was_training = model.training
    
    total_loss = torch.tensor(0.0).to(device)
    total_samples = torch.tensor(0.0).to(device)

    if pseudo_labels is None:
        for inputs, targets in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            batch_loss_sum = loss_fn(outputs, targets) * inputs.size(0)
            total_loss += batch_loss_sum
            total_samples += inputs.size(0)
    else:
        for inputs, _ in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            batch_loss_sum = loss_fn(outputs, pseudo_labels[:outputs.size(0)]) * inputs.size(0)
            total_loss += batch_loss_sum
            total_samples += inputs.size(0)
    
    mean_loss = total_loss / total_samples
    
    model.train()
    if loss_type == 'sum':
        grad = compute_grad(total_loss, list(model.parameters()))

    else:
        grad = compute_grad(mean_loss, list(model.parameters()))

    model.train(mode=was_training)
    return grad, total_loss, total_samples

def compute_epoch_hessian_vector_product(model, dataloader, loss_fn, v, loss_type = 'sum', num_samples=100):
    was_training = model.training

    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = next(model.parameters()).device

    count = 0
    total_loss = 0.0
    batch_count = 0
    grad_grad_v = [torch.zeros_like(param, device="cpu") for param in list(model.parameters())]

    for inputs, targets in dataloader:
        if count >= num_samples:
            break
        batch_count = batch_count + 1
        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        # total_loss += loss_fn(outputs, targets) * inputs.size(0)
        count += inputs.size(0)
        loss = loss_fn(outputs, targets) * inputs.size(0)
        mean_loss = loss / inputs.size(0)

        model.train()
        if loss_type == 'sum':
            batch_grad_grad_v = compute_hessian_vector_product(loss, list(model.parameters()), v)
        else:
            batch_grad_grad_v = compute_hessian_vector_product(mean_loss, list(model.parameters()), v)

        grad_grad_v = [(ori_grad + batch_grad.detach().cpu()) for ori_grad, batch_grad in zip(grad_grad_v, batch_grad_grad_v)]

    # mean_loss = total_loss / count

    # model.train()
    # if loss_type == 'sum':
    #     grad_grad_v = compute_hessian_vector_product(total_loss, list(model.parameters()), v)
    # else:
    #     grad_grad_v = compute_hessian_vector_product(mean_loss, list(model.parameters()), v)

    grad_grad_v = [(grad.to(device) / batch_count) for grad in grad_grad_v]
    
    model.train(mode=was_training)
    return grad_grad_v

def compute_epoch_hessian(model, dataloader, loss_fn, loss_type = 'sum', num_samples=100):
    was_training = model.training
    model.eval()
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = next(model.parameters()).device

    count = 0
    total_loss = 0.0

    for inputs, targets in dataloader:
        if count >= num_samples:
            break
        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        total_loss += loss_fn(outputs, targets) * inputs.size(0)
        count += inputs.size(0)
    
    mean_loss = total_loss / count
    
    if loss_type == 'sum':
        hessiain = compute_hessian(total_loss, model.parameters())
    else:
        hessiain = compute_hessian(mean_loss, model.parameters())

    model.train(mode=was_training)
    return hessiain

def compute_grad(loss, params):
    """
    Compute the gradients of the specified loss with respect to the given parameters, handling parameters
    that do not influence the loss by assigning them zero gradients.

    Args:
    loss (torch.Tensor): The loss tensor for which gradients are to be computed.
    params (iterable of torch.Tensor): A list or tuple of parameters with respect to which the gradients
                                       of the loss will be computed.

    Returns:
    list of torch.Tensor: A list of gradient tensors corresponding to each parameter in `params`. If a gradient
                          is not defined (i.e., the parameter does not contribute to the loss), a zero tensor
                          of the same shape as the parameter is returned in its place.

    Description:
    This function utilizes `torch.autograd.grad` to compute gradients, allowing for the creation of higher-order
    gradients by setting `create_graph=True`. It ensures that all parameters have a gradient tensor, substituting
    a zero tensor for parameters that do not affect the loss, which is indicated by a `None` gradient.
    """
    # Compute gradients using torch.autograd.grad, this function returns a tuple containing the gradient for each parameter
    gradients = torch.autograd.grad(loss, params, allow_unused=True, create_graph=True)
    
    # Replace None gradients with zero tensors of the same shape as the corresponding parameters
    gradients = [torch.zeros_like(param) if grad is None else grad for grad, param in zip(gradients, params)]
    
    return gradients

def compute_hessian_vector_product(loss, params, v):
    """
    Compute the product of the Hessian matrix of the given loss function with respect to the parameters
    and a vector 'v'. This function is useful for efficiently calculating second-order derivatives in
    optimization algorithms, particularly for implementing Newton's method or quasi-Newton methods
    in deep learning.

    Args:
    loss (torch.Tensor): The scalar loss tensor for which the Hessian is to be computed.
    params (iterable of torch.Tensor): A list or tuple of tensors representing the parameters with
                                       respect to which the Hessian is to be computed.
    v (iterable of torch.Tensor): A list or tuple of tensors representing the vector with respect
                                  to which the Hessian-vector product is computed.

    Returns:
    list of torch.Tensor: A list of tensors representing the Hessian-vector product for each parameter.

    Description:
    The function first computes the gradients of the loss with respect to the parameters. If any parameter
    does not contribute to the loss (i.e., its gradient is None), a zero tensor of the same shape as the parameter
    is used in its place. These gradients are then used to compute the dot product with the vector 'v', which
    is a scalar. Finally, the gradients of this scalar with respect to the original parameters are computed,
    resulting in the Hessian-vector product. If a gradient of this final computation is None, it indicates that
    the Hessian-vector product does not depend on that parameter, and a zero tensor of the same shape is returned.
    """
    # First compute the first order gradient of loss with respect to params
    first_grads = torch.autograd.grad(loss, params, allow_unused=True, create_graph=True,  retain_graph=True)
    first_grads = [torch.zeros_like(param) if grad is None else grad for grad, param in zip(first_grads, params)]

    # Calculate the dot product of first_grads and v
    grad_v_product = torch.sum(torch.stack([torch.sum(g * v_i) for g, v_i in zip(first_grads, v)]))

    # Calculate gradient of grad_v_product with respect to params
    hessian_vector_product = torch.autograd.grad(grad_v_product, params, allow_unused=True,  retain_graph=True)
    hessian_vector_product = [torch.zeros_like(param) if grad is None else grad for grad, param in zip(hessian_vector_product, params)]

    return hessian_vector_product

def compute_hessian(loss, parameters):
    hessians = []
    for param in parameters:
        # Calculate only the parameters that require a gradient
        if param.requires_grad:

            # First calculate the gradient of the loss with respect to the current parameter
            grad = torch.autograd.grad(loss, param, create_graph=True, retain_graph=True)[0]

            # Initialize the Hessian matrix for the current argument
            hessian = torch.zeros_like(param).reshape(-1, param.numel())

            # Calculate the Hessian matrix for the current parameter: derive again for each element of the gradient
            for idx in range(param.numel()):
                grad_element = grad.reshape(-1)[idx]
                grad_grad = torch.autograd.grad(grad_element, param, retain_graph=True)[0]
                hessian[idx] = grad_grad.reshape(-1)
            
            # Add the Hessian matrix of the current argument to the list
            hessians.append(hessian.reshape(param.shape + param.shape))

    return hessians

######################## Split sample indices ###########################

def split_indices(ori_indices, T, split_type, num_per_subset=0, num_classes=None, train_dataset=None, max_per=None):
    """
    Split a list of original indices into T subsets according to various splitting strategies.

    Args:
    ori_indices (list): List of original indices to be split.
    T (int): Total number of subsets to split the indices into.
    split_type (str): Type of splitting strategy ('uniform', 'random', or 'class').
    num_per_subset (int, optional): Number of elements per subset. If set to 0, split evenly. Defaults to 0.
    num_classes (int, optional): Number of classes in the dataset. Required if split_type is 'class'.
    train_dataset (list or Dataset, optional): Dataset object, required if split_type is 'class' to fetch class labels.
    max_per (int, optional): Maximum number of elements from each class to include in any subset. Only used with class-based split.

    Returns:
    tuple:
        list: The original shuffled indices.
        list of list: List of subsets, each a list of indices according to the split strategy.

    Raises:
    ValueError: If the inputs do not conform to the expectations based on the split_type.

    Description:
    Depending on the split_type, the function divides the ori_indices into T subsets. The 'uniform' type attempts
    to distribute indices as evenly as possible, the 'random' type assigns indices randomly into subsets, and 
    the 'class' type distributes indices based on class labels ensuring class-wise distribution.
    """
    # Create copies of indices
    indices = copy.deepcopy(ori_indices)
    np.random.shuffle(indices)
    n = len(indices)

    if split_type == 'uniform':  
        if num_per_subset <= 0:
            # Calculate the number of bases each subset should be divided into and the excess
            base_size = n // T
            remainder = n % T
            sizes = [base_size + 1 if i < remainder else base_size for i in range(T)]
            subsets = []
            start_index = 0
            for size in sizes:
                subsets.append(indices[start_index:start_index + size])
                start_index += size
        elif num_per_subset * T <= n:
            # Randomly and evenly indexed
            start_indices = np.random.choice(range(n - num_per_subset * T + 1), T, replace=False)
            start_indices.sort()
            subsets = [indices[start:start + num_per_subset] for start in start_indices]
        else:
            raise ValueError("Not enough indices to distribute among subsets.")
        return indices, subsets
    elif split_type == 'random':
        # Randomly select indices for different subsets
        sizes = np.random.randint(1, len(indices), T - 1)
        sizes = np.sort(sizes)
        sizes = np.diff(np.concatenate(([0], sizes, [len(indices)])))
        subsets = [indices[sum(sizes[:i]):sum(sizes[:i+1])] for i in range(T)]
        return indices, subsets
    elif split_type == 'class':
        if num_classes is None or train_dataset is None:
            raise ValueError("num_classes and train_dataset must be provided for class based split type.")
        
        # Collect indices for each class
        class_indices = {i: [] for i in range(num_classes)}
        for idx in indices:
            # Assuming the second item is the label in the dataset
            _, label = train_dataset[idx]
            class_indices[label].append(idx)

        subsets = [[] for _ in range(T)]

        if num_per_subset <= 0:
            # Calculate the basic number of classes per subset and the remainder
            base_classes_per_subset = num_classes // T
            extra_classes = num_classes % T

            # random available class
            available_classes = list(range(num_classes))
            random.shuffle(available_classes)

            # Distribute the classes
            for i in range(T):
                num_classes_this_subset = base_classes_per_subset + (1 if i < extra_classes else 0)
                selected_classes = available_classes[:num_classes_this_subset]
                available_classes = np.setdiff1d(available_classes, selected_classes)

                for cls in selected_classes:
                    if max_per is None:
                        subset_indices = class_indices[cls]
                    else:
                        subset_indices = class_indices[cls][:max_per]
                    subsets[i].extend(subset_indices)
            return indices, [np.array(subset) for subset in subsets]
        elif num_classes >= num_per_subset * T:
            T_n_random_class = random.sample(range(num_classes), T * num_per_subset)
            for i in range(T):
                for j in range(num_per_subset):
                    class_index = (i * num_per_subset + j) % num_classes
                    if max_per is None:
                        subset_indices = class_indices[T_n_random_class[class_index]]
                    else:
                        subset_indices = class_indices[T_n_random_class[class_index]][:max_per]
                    subsets[i].extend(subset_indices)
            return indices, [np.array(subset) for subset in subsets]
        else:
            raise ValueError("Not enough class to fill all subsets.")
    else:
        raise ValueError("Unknown split_type provided.")

######################## Laryer-wise CKA ###########################

def adjust_vector_lengths(X, Y):
    """
    Adjust the length of the two vectors so that they have the same dimension.
    """
    X = np.atleast_2d(X)
    Y = np.atleast_2d(Y)
    max_len = max(X.shape[1], Y.shape[1])
    if X.shape[1] < max_len:
        padding = np.zeros((X.shape[0], max_len - X.shape[1]))
        X = np.hstack((X, padding))
    if Y.shape[1] < max_len:
        padding = np.zeros((Y.shape[0], max_len - Y.shape[1]))
        Y = np.hstack((Y, padding))
    return X, Y

def rbf_kernel(X, Y, sigma=None):
    """
    Calculate the similarity of two vectors using RBF (Gaussian Radial Basis) kernel dealing with vectors of different lengths.
    """
    X, Y = adjust_vector_lengths(X, Y)
    if sigma is None:
        sigma = (np.linalg.norm(X - X.mean(axis=0)) + np.linalg.norm(Y - Y.mean(axis=0))) / 2  # heuristic for sigma
    gamma = 1 / (2 * sigma ** 2)
    diff = X - Y
    return np.exp(-gamma * np.linalg.norm(diff)**2)

def linear_kernel(X, Y):
    """
    Compute the similarity of two vectors using a linear kernel to handle vectors of different lengths.
    """
    X, Y = adjust_vector_lengths(X, Y)
    return np.dot(X, Y.T)

def centered_kernel_alignment(X, Y, kernel_type='linear'):
    """
    Calculate the CKA similarity between two feature sets, allowing for feature vectors of different lengths.
    """
    X = X - X.mean(axis=0, keepdims=True)
    Y = Y - Y.mean(axis=0, keepdims=True)

    if kernel_type == 'rbf':
        cka = rbf_kernel(X, Y)
    elif kernel_type == 'linear':
        hsic = linear_kernel(X, Y)
        var1 = np.sqrt(linear_kernel(X, X))
        var2 = np.sqrt(linear_kernel(Y, Y))
        cka = hsic / (var1 * var2)
    return cka

def extract_features(model):
    """
    Extracts the parameters of each layer of the model and combines the parameters of the same layer into one feature vector.
    Since the structure of the model is encapsulated in the model member of CustomModel, it needs to be accessed through model.model.
    """
    features = []

    # Make sure we're accessing an internally encapsulated model structure
    internal_model = model.model if hasattr(model, "model") else model

    for layer in internal_model.children():
        layer_features = []
        for param in layer.parameters():
            # Decouple and convert arguments to numpy arrays, then flatten
            layer_features.append(param.detach().cpu().numpy().ravel())

        # If the layer has parameters, link them together
        if layer_features:
            features.append(np.concatenate(layer_features))
    return features

def compare_param(model_a, model_b, path):
    """
    Compare the hierarchical similarity of the two models and draw a heat map.
    """
    # Extract features
    features_a = extract_features(model_a)
    features_b = extract_features(model_b)
    
    # Calculate CKA similarity
    n_layers_a = len(features_a)
    n_layers_b = len(features_b)
    cka_matrix = np.zeros((n_layers_a, n_layers_b))
    for i in range(n_layers_a):
        for j in range(n_layers_b):
            cka_matrix[i, j] = centered_kernel_alignment(np.array(features_a[i]), np.array(features_b[j]))
    
    # Heat mapping
    plt.figure(figsize=(10, 8))
    sns.heatmap(cka_matrix, annot=False, cmap='viridis', fmt=".2f")
    plt.title("Layer-wise CKA Similarity between Models")
    plt.xlabel("Layers of Model A")
    plt.ylabel("Layers of Model B")
    plt.savefig(path)
    plt.close()

def plt_param_compare(ma_list, mb_list, path, baseline_name):
    """
    Compares parameters between models and saves the comparison plots.

    Args:
    ma_list (list): List of model parameters for the first set of models.
    mb_list (list): List of model parameters for the second set of models.
    path (str): Base directory where the comparison plots will be saved.
    baseline_name (str): Subdirectory name under the base path to save plots.
    
    Description:
    This function takes two lists of model parameters, `ma_list` and `mb_list`, and compares each corresponding pair of parameters.
    The comparison is done via a function `compare_param` which is assumed to generate a plot comparing the parameters.
    Each plot is then saved in a structured directory format based on the provided `path` and `baseline_name`.
    The function ensures that the number of elements in both `ma_list` and `mb_list` is the same, raising an assertion error otherwise.
    It creates a subdirectory under the given `path` named `baseline_name` if it does not exist and saves each generated plot within this directory
    with a filename indicating the order in the list.
    """
    assert len(ma_list) == len(mb_list), "ma_list and mb_list must have the same length"
    
    save_path = os.path.join(path, baseline_name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for idx, (ma, mb) in enumerate(zip(ma_list, mb_list)):
        compare_param(ma, mb, os.path.join(save_path, f'time-{idx}.png'))
    print('Complete models compare !')