import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict
import torch
import numpy as np
import os 
from typing import List, Tuple
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
    
MODELS = {
    'mistral_7B': {
        'disk_path': '/disk1/activations/mistral_7B/test',
        'clean' : [],
        'poisoned': [],
        'readable_name': 'Mistral 7B',
        'layers_to_check': [0, 7, 15, 23, 31]
    },
    'llama_8B': {
        'disk_path': '/disk1/activations/llama3_8b/test',
        'clean': [],
        'poisoned': [],
        'readable_name': 'LLaMA 8B',
        'layers_to_check': [0, 7, 15, 23, 31]
    },
    'mixtral_8_7B': {
        'disk_path': '/disk1/activations/mixtral_8x7B_fp16/test',
        'clean': [],
        'poisoned': [],
        'readable_name': 'Mixtral 8×7B',
        'layers_to_check': [0, 7, 15, 23, 31]
    },
    'phi3_3.8': {
        'disk_path': '/disk1/activations/phi3/test',
        'clean': [],
        'poisoned': [],
        'readable_name': 'Phi 3 (3.8B)',
        'layers_to_check': [0, 7, 15, 23, 31]
    },
    'llama3_70B': {
        'disk_path': '/disk3/activations/llama_3_70B_Instruct/test',
        'clean': [],
        'poisoned': [],
        'readable_name': 'LLaMA-3 70B',
        'layers_to_check': [0, 7, 15, 23, 31, 39, 47, 55, 59]
    },
    'phi3_medium_128k': {
        'disk_path': '/disk3/activations/phi3_medium_128k/test',
        'clean': [],
        'poisoned': [],
        'readable_name': 'Phi3 Medium (128k)',
        'layers_to_check': [0, 7, 15, 23, 31]
    }
}


def load_files(model, directory_path) -> Tuple[List[str], List[str]]:
    clean = []
    poisoned = []
    for filename in os.listdir(directory_path):
        full_path = os.path.join(directory_path, filename)
        if filename.startswith('clean') and filename.endswith('.pt'):
            clean.append(full_path)
        elif filename.startswith('poisoned') and filename.endswith('.pt'):
            poisoned.append(full_path)
        else:
            print(f'File type {full_path} not recognised. Skipping.')
    print(f'{model}: Processed {len(clean)} clean files and {len(poisoned)} poisoned files.')
    return clean, poisoned
            
    # Load activations files for each model
for model, data in MODELS.items():
    clean, poisoned = load_files(model, directory_path=MODELS[model]['disk_path'])
    data['clean'] = clean
    data['poisoned'] = poisoned
 

def get_layer_diff_data(pt_files: List[str]) -> np.ndarray:
    """
    Load multiple .pt files, each with shape (2, N, L, D).
    Concatenate along dimension=1 -> shape (2, totalN, L, D).
    Return the difference: (after - before), shape (totalN, L, D).
    """
    loaded_tensors = [torch.load(p) for p in pt_files]  # each: (2, N, L, D)
    cat_torch = torch.cat(loaded_tensors, dim=1)        # (2, sumN, L, D)
    diffs = cat_torch[1] - cat_torch[0]                 # (sumN, L, D)
    return diffs.float().cpu().numpy()

def estimate_local_curvatures(activations: np.ndarray, k: int) -> np.ndarray:
    """
    Compute a PCA-based local curvature measure for each row in 'activations' (N, D),
    using k-NN. Returns shape (N,) array of curvature values.
    """

    
    nbrs = NearestNeighbors(n_neighbors=k).fit(activations)
    _, indices = nbrs.kneighbors(activations)  # (N, k)

    N, D = activations.shape
    curvatures = np.zeros(N)
    
    for i in range(N):
        neighbor_points = activations[indices[i]]
        pca = PCA(n_components=None)
        pca.fit(neighbor_points)
        eigenvals = pca.explained_variance_

        if len(eigenvals) == 0:
            # Edge case if k=1 or degenerate data
            curvatures[i] = 0
            continue

        # Ratio-based local curvature = (sum of smaller eigenvals) / (largest eigenval)
        ev_sorted = np.sort(eigenvals)[::-1]
        lambda1 = ev_sorted[0]
        lambda_rest = np.sum(ev_sorted[1:])
        curvatures[i] = lambda_rest / (lambda1 + 1e-12)

    return curvatures

def standard_scale(data: np.ndarray) -> np.ndarray:
    """
    Standard (z-score) scale: (data - mean) / std for each dimension.
    """
    mean = data.mean(axis=0)
    std = data.std(axis=0) + 1e-12
    return (data - mean) / std

def pick_best_k_across_all_models_layers(
    models_dict: Dict[str, Dict],
    candidate_k = [10, 20, 30, 40],
    normalize: bool = True
) -> int:
    """
    1) Iterates over all models in models_dict.
    2) For each model:
       - Takes 1-2 files from 'clean' and 1-2 files from 'poisoned'.
       - Loads their activation differences.
       - For each layer in 'layers_to_check', computes curvature difference for each k in candidate_k.
    3) Aggregates the *absolute mean difference* across all (model, layer) combos.
    4) Plots a SINGLE graph with candidate_k on X-axis and average absolute difference on Y-axis.
    5) Returns the best_k (largest average absolute diff).

    Parameters
    ----------
    models_dict : a dictionary like MODELS
    candidate_k : list of integers
    normalize   : whether to standard-scale each group’s data per layer

    Returns
    -------
    best_k : int
    """
    # For each k, we'll collect a list of absolute differences across all (model,layer).
    k2absdiffs = {k: [] for k in candidate_k}

    for model_name, info in models_dict.items():
        print(f'Processing {model_name}')
        clean_files = info.get('clean')
        poisoned_files = info.get('poisoned')
        layers_to_check = info.get('layers_to_check')

        if not clean_files or not poisoned_files:
            print(f"Skipping model={model_name} because no clean/poisoned files.")
            continue

        pilot_clean = clean_files[:5]
        pilot_poisoned = poisoned_files[:5]

        data_clean = get_layer_diff_data(pilot_clean)         # shape (Nc, L, D)
        data_poison = get_layer_diff_data(pilot_poisoned)     # shape (Np, L, D)
        Nc, Lc, Dc = data_clean.shape
        Np, Lp, Dp = data_poison.shape

        if (Lc != Lp) or (Dc != Dp):
            print(f"Shape mismatch for model={model_name}: clean {data_clean.shape}, poison {data_poison.shape}")
            continue

        # Evaluate each layer
        for layer_idx in [15]:
            print(f'Processing layer {layer_idx}')
            if layer_idx >= Lc:  # out of range
                continue

            layer_c = data_clean[:, layer_idx, :]   # (Nc, D)
            layer_p = data_poison[:, layer_idx, :]  # (Np, D)

            if normalize:
                layer_c = standard_scale(layer_c)
                layer_p = standard_scale(layer_p)

            # For each candidate k
            for k in candidate_k:
                # skip k if not feasible for either group
                if k >= layer_c.shape[0] or k >= layer_p.shape[0]:
                    continue

                cvals_c = estimate_local_curvatures(layer_c, k=k)
                cvals_p = estimate_local_curvatures(layer_p, k=k)

                diff = cvals_c.mean() - cvals_p.mean()
                abs_diff = abs(diff)
                k2absdiffs[k].append(abs_diff)

    # Store abosolute differences
    k_values = []
    avg_abs_diffs = []
    for k in candidate_k:
        diffs_list = k2absdiffs[k]
        if len(diffs_list) == 0:
            # if no valid data for that k (maybe it was too large for some groups)
            k_values.append(k)
            avg_abs_diffs.append(0.0)
        else:
            avg_abs_diff = np.mean(diffs_list)
            k_values.append(k)
            avg_abs_diffs.append(avg_abs_diff)

    # Pick the k that yields the largest average absolute difference
    best_idx = np.argmax(avg_abs_diffs)
    best_k = k_values[best_idx]
    best_val = avg_abs_diffs[best_idx]

    # Produce a single plot
    plt.figure(figsize=(6, 4))
    plt.title("k Selection Across All Models & Layers")
    plt.plot(k_values, avg_abs_diffs, marker='o', linewidth=1.5)
    plt.axhline(y=0.0, color='gray', linestyle='--', alpha=0.5)

    plt.scatter(best_k, best_val, color='red', zorder=5,
                label=f"Best k={best_k} (avg abs diff={best_val:.4f})")
    plt.xlabel("k (neighbors)")
    plt.ylabel("Average Absolute Curvature Difference")
    plt.legend()
    plt.tight_layout()
    plt.show()

    print(f"** Final chosen k = {best_k}, with average abs difference = {best_val:.4f} **")
    return best_k


if __name__ == "__main__":
    chosen_k = pick_best_k_across_all_models_layers(
        models_dict=MODELS,
        candidate_k=[10, 30, 35, 40, 45],
        normalize=True
    )
    print("Final selected k:", chosen_k)
    out_file = os.path.join(os.getcwd(),"chosen_k.txt")
    with open(out_file, "w") as f:
        f.write(str(chosen_k) + "\n")

    print(f"Wrote chosen k={chosen_k} to {out_file}")
