import os
import gc
import psutil
import numpy as np
import torch
import torch.nn as nn
from ripser import ripser
from umap import UMAP
from sklearn.decomposition import PCA
from persim import plot_diagrams
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


import matplotlib.pyplot as plt
import seaborn as sns

def detect_anomalies(betti_series, dim=1, thresh=2.5):
    betti_vals = [b[dim] for _, b in betti_series]
    diffs = np.abs(np.diff(betti_vals))
    std = np.std(diffs)
    anomalies = [i for i, d in enumerate(diffs) if d > thresh * std]
    return [betti_series[i][0] for i in anomalies]


def save_grad_heatmap(matrix, name, output_dir):
    plt.figure(figsize=(6, 5))
    sns.heatmap(matrix, cmap="coolwarm", annot=False, cbar=True,
            xticklabels=[f"r{j}" for j in range(matrix.shape[1])],
            yticklabels=[f"h{i}" for i in range(matrix.shape[0])])

    plt.title(name)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{name}_heatmap.png"))
    plt.close()

# --- Plot Betti Evolution Time Series ---
def plot_betti_timeseries(betti_series, sid, output_dir):
    times = [t for t, _ in betti_series]
    betti_dims = list(zip(*[b for _, b in betti_series]))

    plt.figure(figsize=(10, 5))
    for i, dim in enumerate(betti_dims):
        plt.plot(times, dim, label=f"Betti-{i}", linewidth=2 if i == 1 else 1)

    # Highlight Betti-1 spikes > median + 2 std
    if len(betti_dims[1]) > 0:
        betti1 = np.array(betti_dims[1])
        anomalies = detect_anomalies(betti_series, dim=1)
        for t in anomalies:
            plt.axvline(x=t, color='red', linestyle='--', alpha=0.6)

    plt.title(f"Session {sid} – Betti Evolution with Anomalies")
    plt.xlabel("Time Index")
    plt.ylabel("Topological Feature Count")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"session_{sid}_betti_plot.png"))
    plt.close()

# --- Downsampling Utility ---
def downsample(data, step=5):
    return data[::step]

def add_jitter(data, scale=1e-4, seed=42):
    rng = np.random.default_rng(seed)
    return data + rng.normal(0, scale, data.shape)



# --- Low-Rank RNN ---
class LowRankRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, rank):
        super().__init__()
        self.U = nn.Parameter(torch.randn(hidden_dim, rank, device=device))
        self.V = nn.Parameter(torch.randn(hidden_dim, rank, device=device))
        self.W_input = nn.Linear(input_dim, hidden_dim).to(device)
        self.activation = nn.Tanh()

    def forward(self, x, h):
        W_r = self.U @ self.V.T
        h_next = self.activation(W_r @ h + self.W_input(x))
        return h_next

# --- Curvature Proxy ---
def curvature_index(data, n_components=3):
    pca = PCA(n_components=n_components)
    pca.fit(data)
    ratios = pca.explained_variance_ratio_
    curvature_score = 1.0 - ratios[0]
    return np.clip(curvature_score, 0, 1)

# --- Betti Number Estimation with Betti-2 Support ---

def compute_betti_numbers(points, maxdim=2, thresh=0.8, step=2, show=False, return_lifetime=False):
    points = points[::step]
    if len(points) > 500:
        points = points[::len(points) // 500]
    try:
        result = ripser(points, maxdim=maxdim, thresh=thresh)
        diagrams = result['dgms']
    except Exception as e:
        print(f"[Ripser Error] {e}")
        diagrams = [[] for _ in range(maxdim + 1)]

    bettis = [len(dgm) for dgm in diagrams]
    while len(bettis) < 3:
        bettis.append(0)

    if return_lifetime:
        lifetimes = [np.mean(dgm[:, 1] - dgm[:, 0]) if len(dgm) > 0 else 0.0 for dgm in diagrams]
        return bettis, lifetimes
    return bettis

# --- CTIS Metric ---
def geometry_aware_ctis(beta_base, beta_pert, gamma_g, weights=(1, 1, 1)):
    return gamma_g * sum(w * abs(b1 - b2) for b1, b2, w in zip(beta_base, beta_pert, weights))

# --- CRCNS Spike Loading ---
def load_crcns_spike_train(res_file, clu_file, bin_size=0.1, duration=1800, sampling_rate=20000):
    try:
        spikes = np.loadtxt(res_file, dtype=int)
        clusters = np.loadtxt(clu_file, dtype=int)
    except Exception as e:
        raise RuntimeError(f"Failed to load files:\n{res_file}\n{clu_file}\nError: {e}")

    # Handle .clu header line
    if len(clusters) == len(spikes) + 1:
        print("Detected header in .clu file, removing first line...")
        clusters = clusters[1:]

    # Check shape consistency
    if len(spikes) != len(clusters):
        raise ValueError(f"Mismatch: {len(spikes)} spikes vs {len(clusters)} clusters")

    num_bins = int(duration / bin_size)
    n_units = int(clusters.max())
    binned = np.zeros((num_bins, n_units))

    for spike_time, cluster_id in zip(spikes, clusters):
        t = spike_time / sampling_rate
        bin_idx = int(t / bin_size)
        if 0 <= bin_idx < num_bins:
            binned[bin_idx, cluster_id - 1] += 1

    binned = (binned - binned.mean(axis=0)) / (binned.std(axis=0) + 1e-5)
    return binned


# --- RNN Run ---
def run_rnn(model, input_seq):
    h = torch.zeros(model.U.shape[0], device=device)
    hidden_states = []
    with torch.no_grad():
        for v in input_seq:
            v_tensor = torch.tensor(v, dtype=torch.float32, device=device)
            h = model(v_tensor, h)
            hidden_states.append(h.cpu().numpy().astype(np.float32))
    return np.stack(hidden_states)

# --- Lesion Evaluation ---
def lesion_and_analyze(model, input_seq, base_betti, lesion_level=0.3, weights=(1, 1, 1)):
    original_U = model.U.clone().detach()
    with torch.no_grad():
        mask = torch.rand_like(model.U) < lesion_level
        model.U[mask] = 0.0
        hidden = run_rnn(model, input_seq)
        hidden = downsample(hidden, step=10)
        embedding = UMAP(n_components=3, n_neighbors=10, min_dist=0.3, low_memory=True).fit_transform(add_jitter(hidden))

        betti = compute_betti_numbers(embedding, maxdim=2, thresh=0.8, step=2)
        gamma = curvature_index(embedding)
        ctis = geometry_aware_ctis(base_betti, betti, gamma, weights)
        model.U.copy_(original_U)
    return embedding, betti, ctis


# --- Dynamic Topological Fingerprint ---
def track_dynamic_betti(hidden_seq, window_size=300, step_size=50, maxdim=2, thresh=0.3):
    betti_series = []
    for start in range(0, len(hidden_seq) - window_size, step_size):
        window = hidden_seq[start:start + window_size]
        try:
            # Add PCA preprocessing
            pca_proj = PCA(n_components=10).fit_transform(window)
            pca_proj = add_jitter(pca_proj, scale=1e-4)
            if len(pca_proj) > 500:
                pca_proj = pca_proj[::len(pca_proj) // 500]
            embedding = UMAP(n_components=3, n_neighbors=30, min_dist=0.1).fit_transform(pca_proj)

            betti = compute_betti_numbers(embedding, maxdim=maxdim, thresh=thresh)
            betti_series.append((start, betti))
        except Exception as e:
            print(f"[UMAP error @window {start}] {e}")
            continue
    return betti_series



# --- Gradient Attribution to Betti₂ ---
def compute_grad_betti2(model, input_seq, base_betti):
    model.zero_grad()

    h = torch.zeros(model.U.shape[0], device=device)
    hidden_states = []
    input_tensor = torch.tensor(input_seq, dtype=torch.float32, device=device)

    for v in input_tensor:
        h = model(v, h)
        hidden_states.append(h)

    hidden_stack = torch.stack(hidden_states).detach().cpu().numpy()
    base_embedding = UMAP(n_components=3).fit_transform(add_jitter(hidden_stack))
    delta_b2 = abs(compute_betti_numbers(base_embedding)[2] - base_betti[2])

    eps = 1e-2
    grad_U = torch.zeros_like(model.U)

    with torch.no_grad():
        for i in range(model.U.shape[0]):
            for j in range(model.U.shape[1]):
                original = model.U[i, j].item()
                model.U[i, j] += eps
                pert_hidden = run_rnn(model, input_seq)
                pert_embedding = UMAP(n_components=3).fit_transform(add_jitter(pert_hidden))
                pert_betti = compute_betti_numbers(pert_embedding)
                grad_U[i, j] = (abs(pert_betti[2] - base_betti[2]) - delta_b2) / eps
                model.U[i, j] = original

    return grad_U.cpu().numpy(), None, delta_b2  # skip grad_V if slow


def simulate_recovery(model, input_seq, base_betti, lesion_steps=5):
    results = []
    original_U = model.U.clone().detach()
    zeroed_mask = torch.rand_like(model.U) < 0.4  # strong lesion

    # Apply full lesion
    model.U[zeroed_mask] = 0.0

    for i in range(lesion_steps):
        # Gradually recover weights
        recovery_mask = (torch.rand_like(model.U) < (i + 1) / lesion_steps) & zeroed_mask
        model.U[recovery_mask] = original_U[recovery_mask]

        hidden = run_rnn(model, input_seq)
        embedding = UMAP(n_components=3).fit_transform(add_jitter(hidden))
        betti = compute_betti_numbers(embedding)
        gamma = curvature_index(embedding)
        ctis = geometry_aware_ctis(base_betti, betti, gamma)
        results.append((i, betti, ctis))

    model.U.copy_(original_U)
    return results

class CTISWrapper:
    def __init__(self, model, base_betti):
        self.model = model
        self.base_betti = base_betti

    def __call__(self, vel_batch):
        outputs = []
        for vel_seq in vel_batch:
            hidden = run_rnn(self.model, vel_seq)
            embedding = UMAP(n_components=3).fit_transform(add_jitter(hidden))
            betti = compute_betti_numbers(embedding)
            gamma = curvature_index(embedding)
            ctis = geometry_aware_ctis(self.base_betti, betti, gamma)
            outputs.append([ctis])
        return np.array(outputs)

    
def plot_summary_heatmap(metric_list, metric_name, output_dir):
    plt.figure(figsize=(6, 1.5))
    sns.heatmap(np.array([metric_list]), annot=True, fmt=".3f", cmap="viridis",
                xticklabels=[f"S{sid}" for sid in range(1, len(metric_list) + 1)],
                yticklabels=[metric_name])
    plt.title(f"{metric_name} across Sessions")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{metric_name.lower()}_summary_heatmap.png"))
    plt.close()

def plot_persistence_summary(persistence_list, output_dir):
    plt.figure(figsize=(6, 1.5))
    sns.heatmap(np.array([persistence_list]), annot=True, fmt=".3f", cmap="magma",
                xticklabels=[f"S{sid}" for sid in range(1, len(persistence_list) + 1)],
                yticklabels=["Betti2 Persistence"])
    plt.title("Betti-2 Lifetimes Across Sessions")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "betti2_persistence_summary.png"))
    plt.close()


def plot_recovery_curve(results, output_dir):
    steps = [r[0] for r in results]
    ctis_vals = [r[2] for r in results]
    plt.figure(figsize=(6, 4))
    plt.plot(steps, ctis_vals, marker='o')
    plt.xlabel("Recovery Step")
    plt.ylabel("CTIS Score")
    plt.title("Manifold Recovery Curve")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "manifold_recovery_curve.png"))
    plt.close()
from concurrent.futures import ProcessPoolExecutor, as_completed
import time


def process_single_session(sid, output_dir):
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        start_time = time.time()
        res_file = f"/data/ec013.544.res.{sid}"
        clu_file = f"/data/ec013.544.clu.{sid}"

        if not os.path.exists(res_file) or not os.path.exists(clu_file):
            return sid, None, None, None, f"Skipping session {sid} (missing files)"

        session_out = os.path.join(output_dir, f"session_{sid}")
        os.makedirs(session_out, exist_ok=True)

        print(f"[S{sid}] Loading spike train...")
        binned_activity = load_crcns_spike_train(res_file, clu_file)

        model = LowRankRNN(input_dim=binned_activity.shape[1], hidden_dim=32, rank=4).to(device)
        hidden_base = run_rnn(model, binned_activity)
        hidden_base = downsample(hidden_base, step=10)

        umap_model = UMAP(n_components=3, n_neighbors=10, min_dist=0.3, low_memory=True)
        embedding_base = umap_model.fit_transform(hidden_base)

        subset = embedding_base[::max(1, len(embedding_base) // 500)] if len(embedding_base) > 500 else embedding_base

        base_betti, base_lifetime = compute_betti_numbers(subset, maxdim=2, thresh=0.8, step=2, return_lifetime=True)
        betti_evolution = track_dynamic_betti(hidden_base)

        with open(os.path.join(session_out, "betti_timeseries.txt"), "w") as f:
            for t, b in betti_evolution:
                f.write(f"{t}, {b}\n")

        plot_betti_timeseries(betti_evolution, sid, session_out)

        grad_U, _, delta_b2 = compute_grad_betti2(model, binned_activity, base_betti)
        save_grad_heatmap(grad_U, "grad_U", session_out)
        np.save(os.path.join(session_out, "grad_U.npy"), grad_U)

        embedding_post, betti_post, ctis_val = lesion_and_analyze(model, binned_activity, base_betti)
        betti_series_pre = track_dynamic_betti(hidden_base)
        betti_series_post = track_dynamic_betti(embedding_post)
        plot_comparison_betti_series(betti_series_pre, betti_series_post, sid, session_out)

        recovery_results = simulate_recovery(model, binned_activity, base_betti)
        plot_recovery_curve(recovery_results, session_out)

        with open(os.path.join(session_out, "result.txt"), "w") as f:
            f.write(
                f"Session {sid}\nBase Betti: {base_betti}\nPost-Lesion Betti: {betti_post}\nCTIS Score: {ctis_val:.5f}\n"
            )

        print(f"[S{sid}] done in {time.time() - start_time:.1f}s")
        return sid, delta_b2, ctis_val, base_lifetime[2], None

    except Exception as e:
        return sid, None, None, None, f"[S{sid}] Error: {e}"


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)  # required for CUDA compatibility
    session_ids = [1,2,3, 4, 5, 6,7, 8]
    output_dir = "pfc7_result"
    os.makedirs(output_dir, exist_ok=True)

    delta_b2_all = []
    ctis_all = []
    persistence_b2_all = []

    futures = []
    with ProcessPoolExecutor(max_workers=4) as executor:
        for sid in session_ids:
            futures.append(executor.submit(process_single_session, sid, output_dir))

        for fut in as_completed(futures):
            sid, delta_b2, ctis_val, persistence, err = fut.result()
            if err:
                print(err)
                continue
            delta_b2_all.append(delta_b2)
            ctis_all.append(ctis_val)
            persistence_b2_all.append(persistence)

    avg_grad_U = np.mean([
        np.load(os.path.join(output_dir, f"session_{sid}/grad_U.npy"))
        for sid in session_ids if os.path.exists(os.path.join(output_dir, f"session_{sid}/grad_U.npy"))
    ], axis=0)

    save_grad_heatmap(avg_grad_U, "avg_grad_U", output_dir)
    plot_summary_heatmap(delta_b2_all, "Delta_Betti2", output_dir)
    plot_summary_heatmap(ctis_all, "CTIS", output_dir)
    plot_persistence_summary(persistence_b2_all, output_dir)


