import copy
import time

import numpy as np
import torch
import tqdm
from open_clip.transformer import ResidualAttentionBlock

import json
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.datasets.registry import get_dataset
from src.models.heads import get_classification_head
from src.models.modeling import ImageClassifier
from src.models.task_vectors import _Checkpoint, _TaskVector
from src.utils import utils
from torch.cuda.amp import autocast
from torch.nn import MultiheadAttention
from scipy.stats import kurtosis, skew

def average_cosine_similarity(matrix1, matrix2):
    """
    Compute the average column cosine similarity
    between two matrices.
    """

    # Do not divide if the column norm is zero.
    norms1 = np.where(np.linalg.norm(matrix1, axis=0) == 0, 1, np.linalg.norm(matrix1, axis=0))
    norms2 = np.where(np.linalg.norm(matrix2, axis=0) == 0, 1, np.linalg.norm(matrix2, axis=0))

    normalized_matrix1 = matrix1 / norms1
    normalized_matrix2 = matrix2 / norms2

    cosine_similarity_matrix = np.dot(normalized_matrix1.T, normalized_matrix2)

    # Extract the upper triangle and exclude diagonal if the matrices are the same.
    is_equal = np.array_equal(matrix1, matrix2)
    upper_triangle_indices = np.triu_indices_from(cosine_similarity_matrix, k=1 if is_equal else 0)
    upper_triangle_values = cosine_similarity_matrix[upper_triangle_indices]

    # average cosine similarity
    return float(np.mean(upper_triangle_values))

def compute_tensor_stats(tensor):
    """Compute tensor statistics."""
    
    return {
        "output_dimension": tensor.shape,
        "norm": tensor.norm().item(),
        "variance": torch.var(tensor).item(),
        "std": torch.std(tensor).item(),
        "mean": torch.mean(tensor).item()
    }

def compute_spectral_properties(tensor):
    """Compute spectral properties."""
    rank = torch.linalg.matrix_rank(tensor).item()
    singular_vals = torch.linalg.svdvals(tensor)

    largest_sv = singular_vals.max().item()
    smallest_sv = singular_vals.min().item()
    condition_number = largest_sv / smallest_sv if smallest_sv != 0 else float("inf")
    
    # Calculate cosine similarity between weight columns.
    column_similarity = average_cosine_similarity(
        tensor.detach().cpu().numpy(), tensor.detach().cpu().numpy()
    )

    return {
        "rank": rank,
        "largest_singular_value": largest_sv,
        "condition_number_singular_value": condition_number,
        "column_cosine_similarity": column_similarity
    }

def compute_distribution_stats(tensor):
    """
    Compute distribution statistics (kurtosis and skewness) 
    for a tensor.
    """
    flattened = tensor.detach().cpu().numpy().flatten()
    
    return {
        "kurtosis": kurtosis(flattened),
        "skewness": skew(flattened)
    }

def compute_parameter_metrics(param, eval_weight_delta=False, is_weight=False):
    """
    Compute a set of metrics for a parameter tensor.
    If is_weight is True, also compute spectral properties.
    """
    # Calculate basic statistics
    stats = compute_tensor_stats(param)

    # Calculate spectral properties for weight matrices
    if is_weight:
        stats.update(compute_spectral_properties(param))
    else: # Exclude bias vectors
        stats.update({
            "rank": None,
            "largest_singular_value": None,
            "condition_number_singular_value": None,
            "column_cosine_similarity": None
        })

    # Calculate skew and kurtosis
    stats.update(compute_distribution_stats(param))
    
    # Evaluate weight delta (tau_merged - tau_baseline) between the merged and baseline models.
    if eval_weight_delta:
        num_params = param.numel()
        num_params_nonzero = torch.count_nonzero(param).item()
        stats.update({
            "num_params": num_params,
            "num_params_nonzero": num_params_nonzero,
            "percentage_params_nonzero": num_params_nonzero / num_params
        })
    return stats

def compute_model_output_metrics(output):
    """Compute statistics for the final model logits."""
    probabilities = torch.nn.functional.softmax(output, dim=-1)
    epsilon = 1e-12
    entropy = -torch.sum(probabilities * torch.log(probabilities + epsilon), dim=-1)
    
    return {
        "output_dimension": output.shape,
        "probabilities_dimension": probabilities.shape,
        "probabilities": probabilities.detach().mean(dim=0).cpu().numpy(),
        "logit_variance": torch.var(output).item(),
        "probability_variance": torch.var(probabilities).item(),
        "logit_std": torch.std(output).item(),
        "probability_std": torch.std(probabilities).item(),
        "logit_mean": torch.mean(output).item(),
        "probability_mean": torch.mean(probabilities).item(),
        "entropy_mean": torch.mean(entropy).item()
    }

def register_layer_hook(metrics, layer_name):
    """
    Return a hook function that computes and stores statistics for a layer's output.
    This is used to collect activation metrics during the forward pass.
    """
    def hook(module, input, output):
        if layer_name not in metrics:
            # For MultiHeadAttention, disregard the attention weights and consider
            # only the output tensor.
            tensor = output[0] if isinstance(module, MultiheadAttention) else output
            metrics[layer_name] = compute_tensor_stats(tensor)
    
    return hook

def evaluate_activations(image_encoder, dataset_name, args):
    """
    Evaluate model activations on a dataset and compute statistics for each layer.
    """
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model.to(args.device)
    model.eval()

    dataset = get_dataset(
        dataset_name, model.val_preprocess,
        location=args.data_location,
        batch_size=args.model_stats_batch_size
    )
    dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
    device = args.device

    metrics = {}
    hooks = []

    # Register hooks on ResidualAttentionBlock modules.
    for name, module in model.image_encoder.model.visual.named_modules():
        if isinstance(module, ResidualAttentionBlock):
            # Register hooks for MLP layer
            hooks.append(module.mlp.register_forward_hook(register_layer_hook(metrics, f"{name}.mlp")))
            hooks.append(module.mlp.gelu.register_forward_hook(register_layer_hook(metrics, f"{name}.mlp.gelu")))
            
            # Register hook for attention layer
            hooks.append(module.attn.register_forward_hook(register_layer_hook(metrics, f"{name}.attn")))
            print(f"Hooks registered for {name}")

    with torch.no_grad():
        data = next(iter(dataloader))
        data = maybe_dictionarize(data)
        x = data["images"].to(device)

        output = model(x)

        metrics["model_output"] = compute_model_output_metrics(output)

    # Remove hooks once done.
    for hook in hooks:
        hook.remove()

    torch.cuda.empty_cache()
    return metrics

def evaluate_params(image_encoder, dataset_name, args):
    """
    Evaluate model parameter statistics (e.g., weight norms, spectral properties,
    distribution metrics, and (optionally) weight delta metrics).
    """
    classification_head = get_classification_head(args, dataset_name)
    model = ImageClassifier(image_encoder, classification_head)
    model.to(args.device)
    model.eval()

    metrics = {}
    eval_weight_delta = args.eval_weight_delta

    with torch.no_grad():
        for name, module in model.image_encoder.model.visual.named_modules():
            if isinstance(module, ResidualAttentionBlock):
                # Process MLP parameters
                for param_name, param in module.mlp.named_parameters():
                    is_weight = "weight" in param_name
                    metrics[f"{name}.{param_name}"] = compute_parameter_metrics(param, eval_weight_delta, is_weight)

                # Process attention layer parameters
                embed_dim = module.attn.in_proj_weight.shape[1]
                num_heads = 12  # Adjust if needed.
                head_dim = embed_dim // num_heads
                
                W_Q = module.attn.in_proj_weight[:embed_dim,]
                W_K = module.attn.in_proj_weight[embed_dim: 2*embed_dim,]
                W_V = module.attn.in_proj_weight[2*embed_dim:,]

                # Compute metrics for query, key, and value weight matrices.
                for mat_name, mat in zip(["query", "key", "value"], [W_Q, W_K, W_V]):
                    metrics[f"{name}.attn.weight_{mat_name}"] = compute_parameter_metrics(mat, eval_weight_delta, True)

                # Attention projection weights.
                W_out = module.attn.out_proj.weight
                metrics[f"{name}.attn.weight_out_proj"] = compute_parameter_metrics(W_out, eval_weight_delta, True)

                # Per-head metrics for query, key, and value weight matrices.
                W_Q_heads = W_Q.reshape(num_heads, head_dim, embed_dim)
                W_K_heads = W_K.reshape(num_heads, head_dim, embed_dim)
                W_V_heads = W_V.reshape(num_heads, head_dim, embed_dim)

                matrices = {"query": W_Q_heads, "key": W_K_heads, "value": W_V_heads}

                for mat_name, mat in matrices.items():
                    head_metrics = {}
                    cosine_similarity_sum = 0
                    num_params_sum = 0
                    num_params_nonzero_sum = 0
                    percentage_params_nonzero_sum = 0

                    for head_index, head_matrix in enumerate(mat):
                        m = compute_parameter_metrics(head_matrix, eval_weight_delta, True)
                        # Set metrices for each head.
                        metrics[f"{name}.attn.weight_{mat_name}_head_{head_index}"] = m

                        if mat_name == "query":
                            # Compute cosine similarity between corresponding query and key heads.
                            query_head = head_matrix.detach().cpu().numpy()
                            key_head = W_K_heads[head_index].detach().cpu().numpy()

                            cosine_similarity_sum += average_cosine_similarity(query_head, key_head)
                        
                        if eval_weight_delta:
                            num_params_sum += head_matrix.numel()
                            num_params_nonzero_sum += torch.count_nonzero(head_matrix).item()
                            percentage_params_nonzero_sum += torch.count_nonzero(head_matrix).item() / head_matrix.numel()

                        # Calculate cumulative metrics.
                        for key in ["norm", "variance", "std", "mean", "rank",
                                    "largest_singular_value", "condition_number_singular_value",
                                    "kurtosis", "skewness", "column_cosine_similarity"]:
                            head_metrics.setdefault(key, 0)
                            # Accumulate the metric values over all heads.
                            head_metrics[key] += m.get(key, 0)
                    
                    # Average the per-head metrics.
                    num_heads_float = float(num_heads)    

                    summary = {
                        "norm": head_metrics["norm"] / num_heads_float,
                        "variance": head_metrics["variance"] / num_heads_float,
                        "std": head_metrics["std"] / num_heads_float,
                        "mean": head_metrics["mean"] / num_heads_float,
                        "Q_K_cosine_similarity": cosine_similarity_sum / num_heads_float,
                        "rank": head_metrics["rank"] / num_heads_float,
                        "largest_singular_value": head_metrics["largest_singular_value"] / num_heads_float,
                        "condition_number_singular_value": head_metrics["condition_number_singular_value"] / num_heads_float,
                        "kurtosis": head_metrics["kurtosis"] / num_heads_float,
                        "skewness": head_metrics["skewness"] / num_heads_float,
                        "column_cosine_similarity": head_metrics["column_cosine_similarity"] / num_heads_float
                    }

                    # Add num parameters for the weight delta (tau_merged - tau_baseline).
                    if eval_weight_delta:
                        summary.update({
                            "num_params": num_params_sum / num_heads_float,
                            "num_params_nonzero": num_params_nonzero_sum / num_heads_float,
                            "percentage_params_nonzero": percentage_params_nonzero_sum / num_heads_float
                        })
                    metrics[f"{name}.attn.weight_{mat_name}_summary"] = summary

                # Evaluate task vector, i.e., tau_merged - tau_baseline.
                if eval_weight_delta:
                    # Per resblock statistics.
                    total_params = sum(param.numel() for param in module.parameters())
                    nonzero_params = sum(torch.count_nonzero(param).item() for param in module.parameters())
                    metrics[f"{name}.resblock_summary"] = {
                        "total_params": total_params,
                        "nonzero_params": nonzero_params,
                        "percentage_nonzero": nonzero_params / total_params
                    }

                torch.cuda.empty_cache()

    return metrics