import torch
import numpy as np
import os 
from typing import List, Tuple
import csv
import math
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns  # optional, for prettier plots
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests
from config import MODELS

##############################################################################
# Helper: load files
##############################################################################
def load_files(model, directory_path) -> Tuple[List[str], List[str]]:
    clean = []
    poisoned = []
    for filename in os.listdir(directory_path):
        full_path = os.path.join(directory_path, filename)
        if filename.startswith('clean') and filename.endswith('.pt'):
            clean.append(full_path)
        elif filename.startswith('poisoned') and filename.endswith('.pt'):
            poisoned.append(full_path)
        else:
            print(f'File type {full_path} not recognised. Skipping.')
    print(f'{model}: Processed {len(clean)} clean files and {len(poisoned)} poisoned files.')
    return clean, poisoned

##############################################################################
# Plotting
##############################################################################
def plot_curvature_distribution(c1, c2, k, bins=30, title="Curvature Distribution"):
    """
    c1, c2: arrays of curvature values for two groups (e.g., group1 vs group2).
    """
    plt.figure(figsize=(8, 4))
    # Histogram
    plt.hist(c1, bins=bins, alpha=0.5, color='blue', label='Group 1')
    plt.hist(c2, bins=bins, alpha=0.5, color='orange', label='Group 2')
    plt.xlabel("Curvature Value")
    plt.ylabel("Count")
    plt.title(title + " (Histogram)")
    plt.legend()
    plt.show()
    
    # KDE
    plt.figure(figsize=(8, 4))
    sns.kdeplot(c1, color='blue', label='Group 1', shade=True)
    sns.kdeplot(c2, color='orange', label='Group 2', shade=True)
    plt.xlabel("Curvature Value")
    plt.ylabel("Density")
    plt.title(title + " (KDE)")
    plt.legend()
    file = f"cp_dispersion_test_results_parametric_{k}_ablation.pdf"
    plt.savefig(file)
    plt.show()

##############################################################################
# Local Curvature Estimation
##############################################################################
def estimate_local_curvatures(activations: np.ndarray, k: int = 10, verbose: bool = False) -> np.ndarray:
    """
    Estimate a local curvature measure for each point in `activations` using
    a PCA-based approach on its k-nearest neighbors.
    curvature = sum_of_rest_eigenvalues / largest_eigenvalue
    """
    N, D = activations.shape
    if verbose:
        print(f"[Curvature] Starting local curvature computation for N={N}, D={D}, k={k}")

    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(activations)
    _, indices = nbrs.kneighbors(activations)

    curvatures = np.zeros(N, dtype=np.float32)
    epsilon = 1e-12

    for i in range(N):
        if verbose and i % 100 == 0:
            print(f"[Curvature] Processing sample {i}/{N}...")

        neighbor_idx = indices[i]  # shape (k,)
        neighbor_points = activations[neighbor_idx]  # shape (k, D)

        pca = PCA(n_components=None)
        pca.fit(neighbor_points)
        eigenvalues = pca.explained_variance_

        if eigenvalues.size == 0:
            # Edge case: if k=1 or degenerate data
            curvatures[i] = 0.0
            continue

        eigenvalues_sorted = np.sort(eigenvalues)[::-1]
        lambda_1 = eigenvalues_sorted[0]
        lambda_rest = np.sum(eigenvalues_sorted[1:])

        curvature = lambda_rest / (lambda_1 + epsilon)
        curvatures[i] = curvature

    return curvatures

##############################################################################
# Cohen's d, CIs
##############################################################################
def cohen_d(x: np.ndarray, y: np.ndarray) -> float:
    """
    d = (mean_x - mean_y) / pooled_std
    """
    nx = len(x)
    ny = len(y)
    dof = nx + ny - 2

    var_x = np.var(x, ddof=1)
    var_y = np.var(y, ddof=1)

    pooled_std = math.sqrt(
        ((nx - 1)*var_x + (ny - 1)*var_y) / (dof + 1e-12)
    )
    return (np.mean(x) - np.mean(y)) / (pooled_std + 1e-12)

def normal_95_ci(data: np.ndarray) -> (float, float):
    """
    95% CI for the mean of `data` assuming large-sample normal approx.
    """
    if len(data) < 2:
        mean_val = float(np.mean(data)) if len(data) > 0 else 0.0
        return (mean_val, mean_val)

    z_value = 1.96
    mean_val = np.mean(data)
    std_dev = np.std(data, ddof=1)
    sem = std_dev / math.sqrt(len(data))

    halfwidth = z_value * sem
    return (mean_val - halfwidth, mean_val + halfwidth)

##############################################################################
# Main Comparison Function
##############################################################################
def compare_curvature_layerwise(
    pt_files_1: List[str],
    pt_files_2: List[str],
    model: str,
    csv_file: str = "curvature_results.csv",
    k: int = 10,
    normalize: bool = True,
    alpha: float = 0.05,
    correction_method: str = "fdr_bh",
    layers_to_check: List[int] = None,
    verbose: bool = False
):
    """
    Loads two lists of .pt files (shape (2, N, L, D)),
    computes local curvature on the differences, does t-tests,
    then writes results to CSV.

    By default, tests layers [0, 15, 23, 31] if layers_to_check is None.
    """
    if verbose:
        print(f"[compare_curvature_layerwise] Loading pt_files_1={pt_files_1}")
        print(f"[compare_curvature_layerwise] Loading pt_files_2={pt_files_2}")

    # 1) Load
    data_1_torch = torch.cat([torch.load(fn) for fn in pt_files_1], dim=1)
    data_2_torch = torch.cat([torch.load(fn) for fn in pt_files_2], dim=1)

    # 2) Compute difference (pooled[1] - pooled[0])
    data_1_diff = data_1_torch[1] - data_1_torch[0]  # shape (N1, L, D)
    data_2_diff = data_2_torch[1] - data_2_torch[0]  # shape (N2, L, D)

    data_1 = data_1_diff.float().cpu().numpy()
    data_2 = data_2_diff.float().cpu().numpy()

    N1, L1, D1 = data_1.shape
    N2, L2, D2 = data_2.shape
    if L1 != L2 or D1 != D2:
        raise ValueError(f"Mismatched shapes: data1 {data_1.shape}, data2 {data_2.shape}")


    pvals = []
    results_layerwise = []

    for layer_idx in layers_to_check:
        if layer_idx >= L1:
            if verbose:
                print(f"Skipping layer {layer_idx}, out of range for data shape {data_1.shape}")
            continue

        if verbose:
            print(f"\n[compare_curvature_layerwise] Processing layer={layer_idx}")

        layer_data_1 = data_1[:, layer_idx, :]  # shape (N1, D)
        layer_data_2 = data_2[:, layer_idx, :]  # shape (N2, D)

        # (Optional) Z-score
        if normalize:
            m1, s1 = layer_data_1.mean(axis=0), layer_data_1.std(axis=0) + 1e-12
            layer_data_1 = (layer_data_1 - m1) / s1

            m2, s2 = layer_data_2.mean(axis=0), layer_data_2.std(axis=0) + 1e-12
            layer_data_2 = (layer_data_2 - m2) / s2

        # 3) Compute local curvature
        c1 = estimate_local_curvatures(layer_data_1, k=k, verbose=verbose)
        c2 = estimate_local_curvatures(layer_data_2, k=k, verbose=verbose)

        # 4) (Optional) plot
        plot_curvature_distribution(c1, c2, k=k, bins=30,
                                    title=f"Group1 vs Group2 Curvature Layer {layer_idx}")

        # 5) Stats
        c1_lower, c1_upper = normal_95_ci(c1)
        c2_lower, c2_upper = normal_95_ci(c2)
        ttest_res = ttest_ind(c1, c2, equal_var=False)
        p_val = ttest_res.pvalue
        pvals.append(p_val)

        d_val = cohen_d(c1, c2)

        if verbose:
            print(f"Layer {layer_idx}: p_val_raw={p_val:.3e}, cohen_d={d_val:.3f}")

        results_layerwise.append({
            "layer": layer_idx,
            "mean_c1": float(np.mean(c1)),
            "ci_c1_lower": float(c1_lower),
            "ci_c1_upper": float(c1_upper),
            "mean_c2": float(np.mean(c2)),
            "ci_c2_lower": float(c2_lower),
            "ci_c2_upper": float(c2_upper),
            "p_val_raw": float(p_val),
            "cohen_d": float(d_val)
        })

    # 6) Correct for multiple comparisons
    if len(pvals) > 0:
        rej, pvals_corr, _, _ = multipletests(pvals, alpha=alpha, method=correction_method)
    else:
        rej, pvals_corr = [], []

    # 7) Write results to CSV
    file_exists = os.path.isfile(csv_file)
    with open(csv_file, "a", newline="") as out_f:
        writer = csv.writer(out_f)
        if not file_exists:
            writer.writerow([
                "model",
                "layer",
                "mean_c1", "ci_c1_lower", "ci_c1_upper",
                "mean_c2", "ci_c2_lower", "ci_c2_upper",
                "raw_p", "corrected_p",
                "reject_H0",
                "cohen_d"
            ])

        for i, res in enumerate(results_layerwise):
            corr_p = pvals_corr[i] if i < len(pvals_corr) else 1.0
            rj = "Yes" if (i < len(rej) and rej[i]) else "No"
            writer.writerow([
                model,
                res["layer"],
                f"{res['mean_c1']:.6f}",
                f"{res['ci_c1_lower']:.6f}",
                f"{res['ci_c1_upper']:.6f}",
                f"{res['mean_c2']:.6f}",
                f"{res['ci_c2_lower']:.6f}",
                f"{res['ci_c2_upper']:.6f}",
                f"{res['p_val_raw']:.3e}",
                f"{corr_p:.3e}",
                rj,
                f"{res['cohen_d']:.3f}"
            ])

    print(f"[compare_curvature_layerwise] Done. Results appended to {csv_file} for model={model}.")

##############################################################################
# MAIN SCRIPT
##############################################################################

# 1) Load activations files for each model
for model, data in MODELS.items():
    clean_files, poisoned_files = load_files(model, directory_path=data['disk_path'])
    data['clean'] = clean_files
    data['poisoned'] = poisoned_files

# 2) For each model, compare:
#    a) Poisoned vs Poisoned
#    b) Clean vs Clean
#    c) Mixed vs Mixed (i.e. c + p vs c + p)

CSV_OUTPUT = "cp_dispersion_test_results_parametric_ablation.csv"
K = 30  # neighbors for local curvature

for model, model_info in MODELS.items():
    print(f"\n\n=== Analyzing Model: {model} ===")
    pt_files_clean = model_info["clean"]
    pt_files_poisoned = model_info["poisoned"]
    layers_to_check = model_info["layers_to_check"]

    if model not in ['mistral_7B', 'llama_8B']:
        # a) Poisoned vs Poisoned (split in half)
        p_half = 2
        pt_files_poisoned_1 = pt_files_poisoned[:p_half]
        pt_files_poisoned_2 = pt_files_poisoned[p_half:4]

        compare_curvature_layerwise(
            pt_files_1=pt_files_poisoned_1,
            pt_files_2=pt_files_poisoned_2,
            model=f"{model}_p_vs_p",
            csv_file=CSV_OUTPUT,
            k=K,
            layers_to_check=layers_to_check,
            normalize=True,
            alpha=0.05,
            correction_method="fdr_bh",
            verbose=True
        )
    if model not in ['mistral_7B']:
    # b) Clean vs Clean (split in half)
        c_half = len(pt_files_clean)//2
        pt_files_clean_1 = pt_files_clean[:2]
        pt_files_clean_2 = pt_files_clean[2:4]

        compare_curvature_layerwise(
            pt_files_1=pt_files_clean_1,
            pt_files_2=pt_files_clean_2,
            model=f"{model}_c_vs_c",
            csv_file=CSV_OUTPUT,
            k=K,
            layers_to_check=layers_to_check,
            normalize=True,
            alpha=0.05,
            correction_method="fdr_bh",
            verbose=True
        )
    if model not in ['mistral_7B']:
        # c) Mixed vs Mixed (c + p vs c + p)
        # First half of each category as group1, second half as group2
        # Combine lists:
        pt_files_mixed_1 = pt_files_clean_1[:1] + pt_files_poisoned_1[:1]
        pt_files_mixed_2 = pt_files_clean_2[:2] + pt_files_poisoned_2[:2]

        compare_curvature_layerwise(
            pt_files_1=pt_files_mixed_1,
            pt_files_2=pt_files_mixed_2,
            model=f"{model}_mixed_vs_mixed",
            csv_file=CSV_OUTPUT,
            k=K,
            layers_to_check=layers_to_check,
            normalize=True,
            alpha=0.05,
            correction_method="fdr_bh",
            verbose=True
        )
