import numpy as np
import torch
import torch.nn as nn
import h5py
from umap import UMAP
from ripser import ripser
from persim import plot_diagrams
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class LowRankRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, rank):
        super().__init__()
        self.U = nn.Parameter(torch.randn(hidden_dim, rank, device=device))
        self.V = nn.Parameter(torch.randn(hidden_dim, rank, device=device))
        self.W_input = nn.Linear(input_dim, hidden_dim).to(device)
        self.activation = nn.Tanh()

    def forward(self, x, h):
        W_r = self.U @ self.V.T
        h_next = self.activation(W_r @ h + self.W_input(x))
        return h_next

def generate_trajectory(T=1000):
    np.random.seed(42)
    steps = np.random.randn(T, 2) * 0.1
    pos = np.cumsum(steps, axis=0)
    vel = np.diff(pos, axis=0, prepend=pos[:1])
    return vel.astype(np.float32)

# --- Run RNN and Collect Hidden States ---
def run_rnn(model, velocity_seq):
    h = torch.zeros(model.U.shape[0], device=device)
    hidden_states = []
    for v in velocity_seq:
        v_tensor = torch.tensor(v, device=device)
        h = model(v_tensor, h)
        hidden_states.append(h.cpu().detach().numpy())  # Move to CPU for numpy operations
    return np.stack(hidden_states)

# --- Curvature Proxy ---
def curvature_index(data, n_components=3):
    pca = PCA(n_components=n_components)
    pca.fit(data)
    ratios = pca.explained_variance_ratio_
    curvature_score = 1.0 - ratios[0]
    return np.clip(curvature_score, 0, 1)

# --- Betti Number Estimation via Ripser ---
def compute_betti_numbers(points):
    result = ripser(points, maxdim=2)
    diagrams = result['dgms']
    bettis = [len(dgm) for dgm in diagrams[:3]]
    return bettis

# --- CTIS ---
def geometry_aware_ctis(beta_base, beta_pert, gamma_g, weights=(1, 1, 1)):
    return gamma_g * sum(w * abs(b1 - b2) for b1, b2, w in zip(beta_base, beta_pert, weights))

# --- Alzheimer Simulation Curve ---
def lesion_curve(model, velocity_seq, levels=[0.0, 0.1, 0.2, 0.3, 0.4]):
    results = []
    original_U = model.U.clone().detach()
    for level in levels:
        model.U.data = original_U.clone()
        with torch.no_grad():
            mask = torch.rand_like(model.U) < level
            model.U[mask] = 0.0

        hidden = run_rnn(model, velocity_seq)
        embedding = UMAP(n_components=3).fit_transform(hidden)
        beta = compute_betti_numbers(embedding)
        gamma_g = curvature_index(embedding)
        ctis = geometry_aware_ctis(base_betti, beta, gamma_g)
        results.append((level, ctis, beta))
    return results

# --- Main Execution ---
T = 1000
vel = generate_trajectory(T)
model = LowRankRNN(input_dim=2, hidden_dim=64, rank=4).to(device)
hidden_base = run_rnn(model, vel)
embedding_base = UMAP(n_components=3).fit_transform(hidden_base)
base_betti = compute_betti_numbers(embedding_base)

# Alzheimer Lesion Curve
results = lesion_curve(model, vel, levels=np.linspace(0.0, 0.5, 6))

# --- Plot Results ---
levels = [r[0] for r in results]
ctis_scores = [r[1] for r in results]

plt.figure(figsize=(7, 4))
plt.plot(levels, ctis_scores, marker='o')
plt.xlabel('Lesion Level (Fraction of Low-Rank Matrix Zeroed)')
plt.ylabel('Geometry-Aware CTIS')
plt.title('Manifold Degradation Curve')
plt.grid(True)
plt.savefig("Manifold Degradation Curve syntetic.png")


print("Base Betti Numbers:", base_betti)
for level, ctis, betti in results:
    print(f"Lesion {level:.2f}: CTIS = {ctis:.3f}, Betti = {betti}")

