import os 
import csv
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
from typing import List, Tuple
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from statsmodels.stats.multitest import multipletests
from scipy.stats import ttest_ind
from config import MODELS, MODEL_NAME_MAP

def load_files(model, directory_path) -> Tuple[List[str], List[str]]:
    clean = []
    poisoned = []
    for filename in os.listdir(directory_path):
        full_path = os.path.join(directory_path, filename)
        if filename.startswith('clean') and filename.endswith('.pt'):
            clean.append(full_path)
        elif filename.startswith('poisoned') and filename.endswith('.pt'):
            poisoned.append(full_path)
        else:
            print(f'File type {full_path} not recognised. Skipping.')
    print(f'{model}: Processed {len(clean)} clean files and {len(poisoned)} poisoned files.')
    return clean, poisoned


def get_layer_diff_data(pt_files: List[str]) -> np.ndarray:
    """
    Given a list of .pt files (each shape (2, N, L, D)),
    returns a combined numpy array of shape (TotalN, L, D)
    with the difference (pooled[1] - pooled[0]) for each file,
    concatenated along the batch dimension.
    """
    loaded_tensors = [torch.load(p) for p in pt_files]  # each is shape (2, N, L, D)
    # Concat along the second dimension => shape (2, sumN, L, D)
    cat_torch = torch.cat(loaded_tensors, dim=1)
    # Now difference = shape (sumN, L, D)
    diffs = cat_torch[1] - cat_torch[0]
    
    return diffs.cpu().float().numpy()
      
# Load activations files for each model
for model, data in MODELS.items():
    clean, poisoned = load_files(model, directory_path=MODELS[model]['disk_path'])
    data['clean'] = clean
    data['poisoned'] = poisoned
    

def plot_curvature_distribution(c1, c2, k, bins=30, title="Curvature Distribution"):
    """
    c1, c2: arrays of curvature values for two groups (e.g., clean vs. poisoned).
    bins : int or sequence, number of bins for histogram
    title: str, plot title
    """
    plt.figure(figsize=(8, 4))
    
    # --- Histogram side-by-side ---
    plt.hist(c1, bins=bins, alpha=0.5, color='blue', label='Group 1')
    plt.hist(c2, bins=bins, alpha=0.5, color='orange', label='Group 2')
    plt.xlabel("Curvature Value")
    plt.ylabel("Count")
    plt.title(title + " (Histogram)")
    plt.legend()
    plt.show()
    
    # --- Kernel Density Estimate (KDE) in a single figure ---
    plt.figure(figsize=(8, 4))
    sns.kdeplot(c1, color='blue', label='Group 1', shade=True)
    sns.kdeplot(c2, color='orange', label='Group 2', shade=True)
    plt.xlabel("Curvature Value")
    plt.ylabel("Density")
    plt.title(title + " (KDE)")
    plt.legend()
    file = f"cp_dispersion_test_results_parametric_{k}.pdf"
    plt.savefig(file)
    plt.show()


##############################################################################
# Local Curvature Estimation
##############################################################################

def estimate_local_curvatures(activations: np.ndarray, k: int = 10, verbose: bool = False) -> np.ndarray:
    """
    Estimate a local curvature measure for each point in `activations` using
    a PCA-based approach on its k-nearest neighbors.

    For each sample i:
      1) Identify the k neighbors of i.
      2) Perform PCA on those k points to obtain eigenvalues (variance explained).
      3) Let lambda_1 = largest eigenvalue, sum_rest = sum of the remaining eigenvalues.
      4) 'Curvature' = sum_rest / (lambda_1 + epsilon).

    If verbose=True, logs progress every 100 samples.

    Parameters
    ----------
    activations : np.ndarray
        Activation data of shape (N, D).
    k : int
        Number of neighbors to use for local PCA.
    verbose : bool
        If True, prints progress messages.

    Returns
    -------
    curvatures : np.ndarray of shape (N,)
        Curvature estimate for each of the N data points.
    """
    N, D = activations.shape
    if verbose:
        print(f"[Curvature] Starting local curvature computation for N={N}, D={D}, k={k}")
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(activations)
    _, indices = nbrs.kneighbors(activations)

    curvatures = np.zeros(N, dtype=np.float32)
    epsilon = 1e-12

    for i in range(N):
        if verbose and i % 100 == 0:
            print(f"[Curvature] Processing sample {i}/{N}")

        neighbor_idx = indices[i]  # shape (k,)
        neighbor_points = activations[neighbor_idx]  # shape (k, D)

        pca = PCA(n_components=None)
        pca.fit(neighbor_points)
        eigenvalues = pca.explained_variance_  # shape (<= D,)

        if eigenvalues.size == 0:
            # Edge case: if k=1 or degenerate data
            curvatures[i] = 0.0
            continue

        eigenvalues_sorted = np.sort(eigenvalues)[::-1]
        lambda_1 = eigenvalues_sorted[0]
        lambda_rest = np.sum(eigenvalues_sorted[1:])

        curvature = lambda_rest / (lambda_1 + epsilon)
        curvatures[i] = curvature

    return curvatures


##############################################################################
# Cohen's d and Confidence Intervals
##############################################################################

def cohen_d(x: np.ndarray, y: np.ndarray) -> float:
    """
    Compute Cohen's d as a measure of effect size for two independent samples.

    d = (mean_x - mean_y) / pooled_std

    Returns
    -------
    d_value : float
        The effect size in standard deviation units.
    """
    nx = len(x)
    ny = len(y)
    dof = nx + ny - 2

    var_x = np.var(x, ddof=1)
    var_y = np.var(y, ddof=1)

    pooled_std = math.sqrt(((nx - 1)*var_x + (ny - 1)*var_y) / (dof + 1e-12))
    return (np.mean(x) - np.mean(y)) / (pooled_std + 1e-12)


def normal_95_ci(data: np.ndarray) -> (float, float):
    """
    Compute a 95% CI for the mean of `data` under a normal (large-sample) assumption.

    Returns
    -------
    (ci_lower, ci_upper) : tuple of floats
        The lower and upper bound of the 95% confidence interval.
    """
    if len(data) < 2:
        mean_val = np.mean(data) if len(data) > 0 else 0.0
        return (mean_val, mean_val)

    z_value = 1.96  # ~ 95% confidence for large sample normal
    mean_val = np.mean(data)
    std_dev = np.std(data, ddof=1)
    sem = std_dev / math.sqrt(len(data))

    halfwidth = z_value * sem
    return (mean_val - halfwidth, mean_val + halfwidth)


##############################################################################
# Main Function: Compare Curvature Layerwise + T-Test
##############################################################################

def compare_curvature_layerwise(
    pt_files_1: List[str],
    pt_files_2: List[str],
    model: str,
    csv_file: str = "curvature_results.csv",
    k: int = 10,
    normalize: bool = True,
    alpha: float = 0.05,
    correction_method: str = "fdr_bh",
    verbose: bool = False
):
    """
    Loads two lists of .pt files, each of shape (2, N, L, D),
    computes local curvature for each specified layer, performs
    Welch's t-tests with multiple comparisons correction, and
    writes results (including 95% normal CIs and Cohen's d) to a CSV.

    Parameters
    ----------
    pt_files_1 : List[str]
        Paths to the .pt files for condition 1 (e.g. 'clean' or 'before').
        Each file has shape (2, N, L, D).
    pt_files_2 : List[str]
        Paths to the .pt files for condition 2 (e.g. 'poisoned' or 'after').
        Same shape as pt_files_1.
    model : str
        Name of the model (for CSV logging).
    csv_file : str
        Output CSV file path. Appended if it exists, else created.
    k : int
        #neighbors for local PCA curvature estimation.
    normalize : bool
        Whether to z-score each layer's features prior to curvature analysis.
    alpha : float
        Significance level for the t-test (default 0.05).
    correction_method : str
        Method for multiple comparisons correction, e.g. 'bonferroni', 'fdr_bh'.
    verbose : bool
        If True, prints progress messages for data loading, layer iteration, etc.

    Returns
    -------
    None
        (Writes results with 95% CIs to the CSV.)
    """
    if verbose:
        print(f"[compare_curvature_layerwise] Loading pt_files_1={pt_files_1}")
        print(f"[compare_curvature_layerwise] Loading pt_files_2={pt_files_2}")

    # 1) Load raw data from .pt files: shape (2, N, L, D)
    data_1_torch = torch.cat([torch.load(fn) for fn in pt_files_1], dim=1)
    data_2_torch = torch.cat([torch.load(fn) for fn in pt_files_2], dim=1)

    data_1_diff = data_1_torch[1] - data_1_torch[0]  # shape (N, L, D)
    data_2_diff = data_2_torch[1] - data_2_torch[0]  # shape (N, L, D)

    data_1 = data_1_diff.float().cpu().numpy()
    data_2 = data_2_diff.float().cpu().numpy()

    if verbose:
        print(f"[compare_curvature_layerwise] data_1 shape={data_1.shape}, data_2 shape={data_2.shape}")

    # Check shapes
    if data_1.ndim != 3 or data_2.ndim != 3:
        raise ValueError(f"Expected data shape (N, L, D). Got {data_1.shape}, {data_2.shape}")

    N1, L1, D1 = data_1.shape
    N2, L2, D2 = data_2.shape
    if L1 != L2 or D1 != D2:
        raise ValueError(f"Mismatched shapes: data1 {data_1.shape}, data2 {data_2.shape}")

    # 2) Choose layers to check  
    layers_to_check = [0, 15, 23, 31]

    if verbose:
        print(f"[compare_curvature_layerwise] Will test layers={layers_to_check}")

    # 3) For each layer: compute curvature, do t-test, store p-values
    pvals = []
    results_layerwise = []

    for layer_idx in layers_to_check:
        if layer_idx >= L1:
            if verbose:
                print(f"[compare_curvature_layerwise] Skipping layer {layer_idx}, out of range.")
            continue

        if verbose:
            print(f"\n[compare_curvature_layerwise] Processing layer {layer_idx}")

        layer_data_1 = data_1[:, layer_idx, :]  # shape (N1, D)
        layer_data_2 = data_2[:, layer_idx, :]  # shape (N2, D)

        # (Optional) Z-score
        if normalize:
            m1, s1 = layer_data_1.mean(axis=0), layer_data_1.std(axis=0) + 1e-12
            layer_data_1 = (layer_data_1 - m1) / s1

            m2, s2 = layer_data_2.mean(axis=0), layer_data_2.std(axis=0) + 1e-12
            layer_data_2 = (layer_data_2 - m2) / s2

        # Compute local curvature (with logging)
        c1 = estimate_local_curvatures(layer_data_1, k=k, verbose=verbose)
        c2 = estimate_local_curvatures(layer_data_2, k=k, verbose=verbose)
        plot_curvature_distribution(c1, c2, k=k, bins=30, title=f"Clean vs Poisoned Curvature Layer {layer_idx}")


        # Confidence intervals (normal approx) for group 1 & group 2
        c1_lower, c1_upper = normal_95_ci(c1)
        c2_lower, c2_upper = normal_95_ci(c2)

        # Welch’s t-test
        ttest_res = ttest_ind(c1, c2, equal_var=False)
        p_val = ttest_res.pvalue
        pvals.append(p_val)

        # Cohen's d
        d_val = cohen_d(c1, c2)

        if verbose:
            print(f"[compare_curvature_layerwise] Layer {layer_idx}: p_val_raw={p_val:.3e}, cohen_d={d_val:.3f}")

        results_layerwise.append({
            "layer": layer_idx,
            "mean_c1": float(np.mean(c1)),
            "ci_c1_lower": float(c1_lower),
            "ci_c1_upper": float(c1_upper),
            "mean_c2": float(np.mean(c2)),
            "ci_c2_lower": float(c2_lower),
            "ci_c2_upper": float(c2_upper),
            "p_val_raw": float(p_val),
            "cohen_d": float(d_val)
        })

    # 4) Multiple Comparisons Correction
    if len(pvals) > 0:
        if verbose:
            print(f"[compare_curvature_layerwise] Applying multiple comparisons correction: {correction_method}")
        rej, pvals_corr, _, _ = multipletests(pvals, alpha=alpha, method=correction_method)
    else:
        rej, pvals_corr = [], []

    # 5) Write results to CSV
    file_exists = os.path.isfile(csv_file)
    with open(csv_file, "a", newline="") as out_f:
        writer = csv.writer(out_f)
        if not file_exists:
            writer.writerow([
                "model",
                "layer",
                "mean_c1", "ci_c1_lower", "ci_c1_upper",
                "mean_c2", "ci_c2_lower", "ci_c2_upper",
                "raw_p", "corrected_p",
                "reject_H0",
                "cohen_d"
            ])

        for i, res in enumerate(results_layerwise):
            corr_p = pvals_corr[i] if i < len(pvals_corr) else 1.0
            rj = "Yes" if i < len(rej) and rej[i] else "No"
            writer.writerow([
                model,
                res["layer"],
                f"{res['mean_c1']:.6f}",
                f"{res['ci_c1_lower']:.6f}",
                f"{res['ci_c1_upper']:.6f}",
                f"{res['mean_c2']:.6f}",
                f"{res['ci_c2_lower']:.6f}",
                f"{res['ci_c2_upper']:.6f}",
                f"{res['p_val_raw']:.3e}",
                f"{corr_p:.3e}",
                rj,
                f"{res['cohen_d']:.3f}"
            ])

    print(f"[compare_curvature_layerwise] Analysis complete for model {model}, results appended to {csv_file}.")

for model, model_info in MODELS.items():
    print(f'Processing modle {model}')
    FILES = 2
    pt_files_clean = model_info["clean"][:FILES]
    pt_files_poisoned = model_info["poisoned"][:FILES]
    
    K = 30
    
    CSV_OUTPUT = f"cp_dispersion_test_results_parametric_{K}.csv"

    compare_curvature_layerwise(
        pt_files_1=pt_files_clean,
        pt_files_2=pt_files_poisoned,
        model=model,
        csv_file=CSV_OUTPUT,
        k=K,
        normalize=True,
        alpha=0.05,
        correction_method="fdr_bh",
        verbose = True
    )
    
    compare_curvature_layerwise(
        pt_files_1=pt_files_clean,
        pt_files_2=pt_files_poisoned,
        model=model,
        csv_file=CSV_OUTPUT,
        k=K,
        normalize=True,
        alpha=0.05,
        correction_method="fdr_bh",
        verbose = True
    )
    
    compare_curvature_layerwise(
        pt_files_1=pt_files_clean,
        pt_files_2=pt_files_poisoned,
        model=model,
        csv_file=CSV_OUTPUT,
        k=K,
        normalize=True,
        alpha=0.05,
        correction_method="fdr_bh",
        verbose = True
    )
    
    
    compare_curvature_layerwise(
        pt_files_1=pt_files_clean,
        pt_files_2=pt_files_poisoned,
        model=model,
        csv_file=CSV_OUTPUT,
        k=K,
        normalize=True,
        alpha=0.05,
        correction_method="fdr_bh",
        verbose = True
    )
    
    
    