import os
import csv
import random
import numpy as np
import torch
from typing import List, Tuple
from scipy.stats import ttest_ind
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors
from config import MODELS

###############################################################################
# LOAD .pt FILES
###############################################################################
def load_files(model, directory_path) -> Tuple[List[str], List[str]]:
    """Loads .pt files for a given model from the directory_path."""
    clean = []
    poison= []
    for filename in os.listdir(directory_path):
        full_path = os.path.join(directory_path, filename)
        if filename.startswith('clean') and filename.endswith('.pt'):
            clean.append(full_path)
        elif filename.startswith('poisoned') and filename.endswith('.pt'):
            poison.append(full_path)
    print(f"{model}: Found {len(clean)} clean, {len(poison)} poisoned files.")
    return clean, poison

###############################################################################
# DISTANCE MEASURES
###############################################################################
def mean_neighborhood_distance(X: np.ndarray, n_neighbors=10):
    nn = NearestNeighbors(n_neighbors=n_neighbors).fit(X)
    dist_mat, _ = nn.kneighbors(X)
    # skip the 0 self-dist => dist_mat[:,1:].mean()
    return dist_mat[:,1:].mean()

def compute_distance_scalar(X: np.ndarray, metric: str):
    if X.shape[0] < 2:
        return np.nan
    if metric == 'cosine':
        mat = pairwise_distances(X, metric='cosine')
        return mat.mean()
    elif metric == 'euclidean':
        mat = pairwise_distances(X, metric='euclidean')
        return mat.mean()
    elif metric.startswith('mean_neigh'):
        # e.g. "mean_neigh_10"
        parts = metric.split('_')
        nN = 10
        if parts[-1].isdigit():
            nN = int(parts[-1])
        return mean_neighborhood_distance(X, n_neighbors=nN)
    else:
        # fallback => treat as pairwise distance
        mat = pairwise_distances(X, metric=metric)
        return mat.mean()

###############################################################################
# BOOTSTRAP for AVERAGE DISTANCE
###############################################################################
def bootstrap_distance(
    dataA: np.ndarray,
    dataB: np.ndarray,
    subsample_size: int,
    num_subsamples: int,
    metric: str
) -> Tuple[List[float], List[float]]:
    """
    For each bootstrap iteration:
      - sample 'subsample_size' from dataA
      - sample 'subsample_size' from dataB
      - compute 'average distance' for each
    Return distributions distA_vals, distB_vals.
    """
    N1, D1 = dataA.shape
    N2, D2 = dataB.shape
    distA_vals = []
    distB_vals = []

    for _ in range(num_subsamples):
        idxA = np.random.choice(N1, subsample_size, replace=False)
        idxB = np.random.choice(N2, subsample_size, replace=False)
        subsetA = dataA[idxA]
        subsetB = dataB[idxB]
        dA = compute_distance_scalar(subsetA, metric)
        dB = compute_distance_scalar(subsetB, metric)
        distA_vals.append(dA)
        distB_vals.append(dB)

    return distA_vals, distB_vals

###############################################################################
# MAIN COMPARISONS
###############################################################################
def do_comparisons_for_layer(
    model_name: str,
    layer_idx: int,
    dataC: np.ndarray,
    dataP: np.ndarray,
    dist_metrics: List[str],
    SUBSAMPLE_SIZE: int,
    NUM_SUBSAMPLES: int,
    out_csv: str
):
    """
    dataC, dataP => shape (N, D) each. We'll do:
      (1) Clean vs Poison
      (2) Clean vs Clean (split in half)
      (3) Poison vs Poison (split in half)
      (4) Mixed
    for each metric in dist_metrics
    """
    # 1) Clean vs Poison
    comp_name = "clean_vs_poison"
    if dataC.shape[0]>=SUBSAMPLE_SIZE and dataP.shape[0]>=SUBSAMPLE_SIZE:
        do_bootstrap_and_log(model_name, layer_idx, comp_name,
                             dataC, dataP,
                             dist_metrics,
                             SUBSAMPLE_SIZE,
                             NUM_SUBSAMPLES,
                             out_csv)
    else:
        print(f" => Not enough data for {comp_name} (need {SUBSAMPLE_SIZE} each).")

    # 2) Clean vs Clean => split in half
    comp_name = "within_class_clean"
    if dataC.shape[0]>=2*SUBSAMPLE_SIZE:
        perm = np.random.permutation(dataC.shape[0])
        half= dataC.shape[0]//2
        c1 = dataC[perm[:half]]
        c2 = dataC[perm[half:]]
        do_bootstrap_and_log(model_name, layer_idx, comp_name,
                             c1, c2,
                             dist_metrics,
                             SUBSAMPLE_SIZE,
                             NUM_SUBSAMPLES,
                             out_csv)
    else:
        print(f" => Not enough data for {comp_name} (need 2*{SUBSAMPLE_SIZE}).")

    # 3) Poison vs Poison => split in half
    comp_name = "within_class_poison"
    if dataP.shape[0]>=2*SUBSAMPLE_SIZE:
        perm = np.random.permutation(dataP.shape[0])
        half= dataP.shape[0]//2
        p1 = dataP[perm[:half]]
        p2 = dataP[perm[half:]]
        do_bootstrap_and_log(model_name, layer_idx, comp_name,
                             p1, p2,
                             dist_metrics,
                             SUBSAMPLE_SIZE,
                             NUM_SUBSAMPLES,
                             out_csv)
    else:
        print(f" => Not enough data for {comp_name} (need 2*{SUBSAMPLE_SIZE}).")

    # 4) Mixed => each subset half from dataC + half from dataP
    comp_name = "mixed_sample_equal"
    if dataC.shape[0]>=SUBSAMPLE_SIZE and dataP.shape[0]>=SUBSAMPLE_SIZE:
        nC = SUBSAMPLE_SIZE//2
        nP = SUBSAMPLE_SIZE - nC
        distA_vals = {m:[] for m in dist_metrics}
        distB_vals = {m:[] for m in dist_metrics}

        for _ in range(NUM_SUBSAMPLES):
            idxC_A = np.random.choice(dataC.shape[0], nC, replace=False)
            idxP_A = np.random.choice(dataP.shape[0], nP, replace=False)
            subsetA = np.concatenate([dataC[idxC_A], dataP[idxP_A]], axis=0)

            idxC_B = np.random.choice(dataC.shape[0], nC, replace=False)
            idxP_B = np.random.choice(dataP.shape[0], nP, replace=False)
            subsetB = np.concatenate([dataC[idxC_B], dataP[idxP_B]], axis=0)

            for m in dist_metrics:
                dA = compute_distance_scalar(subsetA, m)
                dB = compute_distance_scalar(subsetB, m)
                distA_vals[m].append(dA)
                distB_vals[m].append(dB)

        # Log progress and save to disk
        for m in dist_metrics:
            arrA = np.array(distA_vals[m])
            arrB = np.array(distB_vals[m])
            if len(arrA)<2 or len(arrB)<2:
                p_val = np.nan
            else:
                _, p_val = ttest_ind(arrA, arrB, equal_var=False)

            with open(out_csv,"a",newline="") as f:
                w = csv.writer(f)
                w.writerow([
                    model_name, layer_idx, comp_name, m,
                    f"{p_val:.2e}",
                    "Yes" if (p_val<0.05) else "No",
                    f"{arrA.mean():.4f}", "na", "na",
                    f"{arrB.mean():.4f}", "na", "na"
                ])
    else:
        print(f" => Not enough data for {comp_name} (need {SUBSAMPLE_SIZE} each).")


def do_bootstrap_and_log(
    model_name: str,
    layer_idx: int,
    comparison_name: str,
    dataA: np.ndarray,
    dataB: np.ndarray,
    dist_metrics: List[str],
    SUBSAMPLE_SIZE: int,
    NUM_SUBSAMPLES: int,
    out_csv: str
):
    """Helper to run bootstrap_distance for multiple metrics, store results in out_csv."""
    for m in dist_metrics:
        arrA, arrB = bootstrap_distance(dataA, dataB, SUBSAMPLE_SIZE, NUM_SUBSAMPLES, m)
        arrA = np.array(arrA); arrB = np.array(arrB)
        if len(arrA)<2 or len(arrB)<2:
            p_val = np.nan
        else:
            _, p_val = ttest_ind(arrA, arrB, equal_var=False)
        with open(out_csv,"a",newline="") as f:
            w = csv.writer(f)
            w.writerow([
                model_name, layer_idx, comparison_name, m,
                f"{p_val:.2e}",
                "Yes" if p_val<0.05 else "No",
                f"{arrA.mean():.4f}", "na","na",
                f"{arrB.mean():.4f}", "na","na"
            ])

###############################################################################
def main():
    np.random.seed(0)
    random.seed(0)
    print("[main] => Starting distance-based comparisons (no TDA).")

    out_csv = "distance_comparison_noTDA.csv"
    if os.path.isfile(out_csv):
        os.remove(out_csv)
    with open(out_csv,"w",newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "model","layer","comparison","distance_metric",
            "p_value_raw","reject_H0",
            "distA_mean","distA_ci_lower","distA_ci_upper",
            "distB_mean","distB_ci_lower","distB_ci_upper"
        ])

    dist_metrics = ["cosine","euclidean","mean_neigh_10"]
    # Subsample + #subsamples
    BOOTSTRAP_SIZE = 5000  
    NUM_SUBSAMPLES = 3

    # 1) For each model => load .pt => difference => do comparisons
    for model, info in MODELS.items():
        print(f"\n[main] => Loading data for {model}")
        c_files, p_files = load_files(model, info["disk_path"])
        info["clean"]    = c_files
        info["poisoned"] = p_files

    for model, info in MODELS.items():
        c_files = info["clean"][:5]
        p_files = info["poisoned"][:5]
        if len(c_files)==0 or len(p_files)==0:
            print(f" => Skipping {model}, no data.")
            continue

        # load
        c_data_torch = torch.cat([torch.load(x) for x in c_files], dim=1)
        p_data_torch = torch.cat([torch.load(x) for x in p_files], dim=1)
        # difference => shape (N, L, D)
        diffC = (c_data_torch[1] - c_data_torch[0]).cpu().float().numpy()
        diffP = (p_data_torch[1] - p_data_torch[0]).cpu().float().numpy()

        for ly in info["layers_to_check"]:
            if ly>=diffC.shape[1] or ly>=diffP.shape[1]:
                print(f" => layer {ly} out of range => skip.")
                continue

            # shape (Nc, D)
            clean_data  = diffC[:, ly, :]
            poison_data = diffP[:, ly, :]

            # run all comparisons
            do_comparisons_for_layer(
                model_name=model,
                layer_idx=ly,
                dataC=clean_data,
                dataP=poison_data,
                dist_metrics=dist_metrics,
                SUBSAMPLE_SIZE=BOOTSTRAP_SIZE,
                NUM_SUBSAMPLES=NUM_SUBSAMPLES,
                out_csv=out_csv
            )

    print(f"\n[main] => Done. Results => {out_csv}")
    print("[main] => Check the CSV for p-values, means, etc. No TDA used.\n")

if __name__=="__main__":
    main()
