"""
Plot the spindle PLV and TG-HMM figure.

1) We look at PLV differences near and far from spindles, which show dense differences.
2) Then we fit a TG-HMM and find a state associated with spindles.
3) This state shows sparse and interpretable differences from other states.
"""
__date__ = "September 2025"

import jax.numpy as jnp
import joblib
import numpy as np
import matplotlib.pyplot as plt

import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.plots import stats_to_colors


channels = [
    "MD_Thal_01",
    "MD_Thal_02",
    "Cg_Cx_L_01",
    "Cg_Cx_R_01",
    "IL_Cx_L_01",
    "PrL_Cx_L_01",
    "PrL_Cx_R_01",
    "S1_Cx_01",
    "dHipp_01",
    "vHipp_01"
]
F = 6
K = 7

tick_vals = [F * (i+0.5) for i in [0.5, 2.5, 4, 5.5, 7, 8.5]]
reduced_channels = ["Thal", "Cg Cx", "IL Cx", "PrL Cx", "S1 Cx", "Hipp"]

if __name__ == '__main__':
    d = jnp.load(os.path.join(ROOT, "data", "spindle_data", "spindles_refined_sorted.npz"))
    print(list(d.keys()))
    lfps = d["windows_raw"][:,:,d["ref_channel"]] # [N,T]
    del d
    N, T = lfps.shape
    print(lfps.shape)
    ts = np.linspace(-0.75, 0.75, T)

    fig, axarr = plt.subplots(ncols=4, figsize=(10,2.8))
    axarr = axarr.flatten()
    plt.sca(axarr[0])
    plt.title("Aligned Spindles")
    vmax = np.quantile(np.abs(lfps),0.99)
    plt.imshow(
        lfps,
        aspect=1,
        origin="lower",
        extent=[-750,750,1,N],
        cmap="PRGn",
        vmin=-vmax,
        vmax=vmax,
    )
    plt.ylabel("Spindle # (sorted)")
    plt.xlabel("Time (ms)")

    # Plot ΔPLV
    plt.sca(axarr[1])
    full_data = jnp.load("spindle_cwt.npy") # (N,T,C,F)
    print("Original data shape:", full_data.shape)
    ts = np.linspace(-0.75, 0.75, full_data.shape[1])[::2]
    full_data = full_data[:,::2]
    N, T, C, F = full_data.shape
    full_data = full_data.reshape(N, T, -1).astype(jnp.float32)
    print("Modified data shape:", full_data.shape)
    mask = (np.abs(ts) < 0.2).astype(jnp.bool)
    phases_1 = full_data[:,mask].reshape(-1,C*F) # (N1, CF)
    print(phases_1.shape)
    phases_2 = full_data[:,~mask].reshape(-1,C*F) # (N2, CF)
    print(phases_2.shape)
    plv_1 = jnp.abs(jnp.mean(jnp.exp(1j * (phases_1[:,None] - phases_1[:,:,None])), axis=0))
    plv_2 = jnp.abs(jnp.mean(jnp.exp(1j * (phases_2[:,None] - phases_2[:,:,None])), axis=0))
    diff = plv_1 - plv_2
    vmax = np.max(np.abs(diff))
    plt.title("$\Delta$ PLV")
    plt.imshow(diff, vmin=-vmax, vmax=vmax, cmap='bwr')
    ax = plt.gca()
    plt.yticks(tick_vals, reduced_channels)


    # Plot TG-HMM Timecourse
    plt.sca(axarr[2])
    zs = jnp.load(f"hmm_spindle_all_z_seq_{K}.npz")["all_seq"]
    unique_states = np.unique(zs)
    print(
        "unique_states", unique_states,
    )

    # Get colormap
    colors = ["gray" for i in unique_states]
    colors[3] = "mediumpurple"

    # Average across N (over trials, at each time)
    occupancy_T = np.array([
        np.mean(zs == s, axis=0) for s in unique_states
    ])  # shape: (num_states, T)
    # plt.axvline(x=0, c='k', alpha=0.6, lw=0.5, ls='--')
    for i, s in enumerate(unique_states):
        if i == 4:
            plt.plot(1e3 * ts, occupancy_T[i], color=colors[i], label="Other States")
        elif i == 3:
            plt.plot(1e3 * ts, occupancy_T[i], color=colors[i], label="Spindle State")
        plt.plot(1e3 * ts, occupancy_T[i], color=colors[i])
    plt.legend(loc="best")
    plt.title("TG-HMM State Occupancy")
    plt.xlabel("Time (ms)")
    plt.ylabel("Average State Occupancy")

    # Plot Δ phi
    plt.sca(axarr[3])
    d = joblib.load(f"hmm_spindle_info_{K}.joblib")
    avg_occ = np.mean(occupancy_T, axis=1)
    idx = np.argmax(avg_occ)
    print("idx:", idx)
    avg_occ /= np.sum([val for i, val in enumerate(avg_occ) if i != idx])

    phis = d["phis"]
    phi_spindle = phis[idx]
    phi_other = sum(w * phi for i, (w, phi) in enumerate(zip(avg_occ, phis)) if i != idx)

    phi_diff = phi_spindle - phi_other
    r_max = np.quantile(np.abs(phi_diff[...,0] + 1j * phi_diff[..., 1]), 1.0)
    print("r_max:", r_max)
    rgb = stats_to_colors(phi_diff, r_max=r_max, mode="expanded_complex")
    plt.imshow(rgb)
    plt.xticks([])
    plt.yticks([])
    plt.title("$\Delta \ \phi$")

    plt.tight_layout()
    plt.savefig("temp.pdf")