import re, sys, numpy as np, pandas as pd
import matplotlib.pyplot as plt, seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings("error", category=RuntimeWarning)

np.seterr(invalid="raise", divide="raise", over="raise")

# ===== ZENODO PARSING =====
log_path = Path("fullzenodo.txt")
txt = log_path.read_text()

# regexes 
re_session_hdr = re.compile(r"--- Session (\d+)/\d+ \(([^)]+)\) ---")
re_online_step = re.compile(r"Step\s+(\d+).*?Val\s+Avg\s+Corr:\s*([-+]?\d*\.?\d+)")
re_bptt_epoch = re.compile(r"Epoch\s*\[(\d+)/(\d+)\].*?Val\s+Avg\s+Corr:\s*([-+]?\d*\.?\d+)")
re_train_size = re.compile(r"Datasets created: Train=(\d+)")

def parse_session(session_text):
    session_data = {}
    
    # Check if it's an online session
    online_match = re.search(r"--- Running SNN Online Training ---", session_text)
    if online_match:
        start_idx = online_match.start()
        # bound the online block to avoid bleeding into BPTT
        m_next_train = re.search(r"Datasets\s+created:\s*Train=", session_text[start_idx:])
        end_idx = start_idx + m_next_train.start() if m_next_train else len(session_text)
        online_block = session_text[start_idx:end_idx]

        steps = re_online_step.findall(online_block)
        if steps:
            df = pd.DataFrame([(int(s), float(c)) for s,c in steps], columns=["samples","corr"])
            df["progress"] = df["samples"]/df["samples"].iloc[-1]
            # epoch count: use the maximum observed epoch numerator within the online block
            ep_hits = re.findall(r"Epoch\s*\[(\d+)/(\d+)\]", online_block)
            if ep_hits:
                try:
                    df["epoch_max"] = max(int(e) for e,_ in ep_hits)
                except Exception:
                    pass
            session_data["type"] = "online"
            session_data["data"] = df
        return session_data

    # Check if it's a BPTT session
    train_size_match = re_train_size.search(session_text)
    if train_size_match:
        Ntrain = int(train_size_match.group(1))
        epochs = re_bptt_epoch.findall(session_text)
        if epochs:
            df_rows = []
            for e, _, c in epochs:
                epoch_num = int(e)
                cumulative_samples = epoch_num * Ntrain
                df_rows.append((cumulative_samples, float(c), epoch_num))
            
            df = pd.DataFrame(df_rows, columns=["samples","corr","epoch"])
            df["progress"] = df["samples"]/df["samples"].iloc[-1]
            session_data["type"] = "bptt"
            session_data["data"] = df
            session_data["train_size"] = Ntrain
            
            print(f"BPTT session found: {len(df)} epochs, {Ntrain} samples/epoch, max samples: {df['samples'].max()}")
        return session_data

    return None

# parse all runs 
online_sessions_list = []
bptt_sessions_list = []

# Find all session headers and iterate through them
session_matches = list(re_session_hdr.finditer(txt))
for i, m in enumerate(session_matches):
    start = m.start()
    end = session_matches[i+1].start() if i + 1 < len(session_matches) else len(txt)
    session_text = txt[start:end]

    parsed_data = parse_session(session_text)
    if parsed_data:
        if parsed_data["type"] == "online":
            online_sessions_list.append(parsed_data["data"])
        elif parsed_data["type"] == "bptt":
            bptt_sessions_list.append(parsed_data["data"])

print(f"Found {len(online_sessions_list)} online sessions and {len(bptt_sessions_list)} BPTT sessions")

if not online_sessions_list or not bptt_sessions_list:
    print("couldn't find both types of sessions - check regex markers")
    sys.exit(1)

# ===== FIXED: Use reasonable comparison range that shows the intersection =====
def get_fair_comparison_range(online_list, bptt_list, dataset_name=""):
    """
    Find a comparison range that allows fair comparison while showing intersection.
    Strategy: Use the maximum range where online learning has data, 
    but ensure BPTT has enough data points for meaningful comparison.
    """
    
    if not online_list or not bptt_list:
        return 0, 0
    
    # Get online learning range (this is typically the limiting factor)
    online_maxs = [d["samples"].max() for d in online_list]
    online_max = max(online_maxs)  # Take the maximum online range
    
    # Get BPTT range
    bptt_maxs = [d["samples"].max() for d in bptt_list]
    bptt_max = max(bptt_maxs)
    
    # For fair comparison, use the online max as upper bound
    # This ensures we compare both methods over the same training duration
    comparison_max = online_max
    
    # But make sure BPTT has at least a few data points in this range
    bptt_points_in_range = []
    for d in bptt_list:
        points_in_range = len(d[d["samples"] <= comparison_max])
        bptt_points_in_range.append(points_in_range)
    
    min_bptt_points = min(bptt_points_in_range) if bptt_points_in_range else 0
    
    print(f"{dataset_name} ranges:")
    print(f"  Online max: {online_max:,} samples")
    print(f"  BPTT max: {bptt_max:,} samples") 
    print(f"  Comparison max: {comparison_max:,} samples")
    print(f"  Min BPTT points in range: {min_bptt_points}")
    
    # If BPTT has very few points, we might need to extend slightly
    if min_bptt_points < 3:  # Need at least 3 points for meaningful interpolation
        # Find a range where BPTT has at least 3-5 points
        for test_max in np.linspace(online_max, bptt_max, 20):
            test_points = min(len(d[d["samples"] <= test_max]) for d in bptt_list)
            if test_points >= 5:
                comparison_max = test_max
                print(f"  Extended to {comparison_max:,} samples for better BPTT coverage")
                break
    
    return 0, comparison_max

def to_grid_with_extrapolation(df, grid_samples, max_sample):
    """
    Interpolate data to grid, with forward-fill extrapolation for BPTT
    """
    df = df.sort_values("samples", kind="mergesort").copy()
    
    # Use all data up to max_sample, but allow extrapolation for BPTT
    df_in_range = df[df["samples"] <= max_sample].copy()
    
    if len(df_in_range) == 0:
        # If no data in range, try to use the closest point
        if len(df) > 0:
            closest_idx = 0  # Use first point
            df_in_range = df.iloc[[closest_idx]].copy()
            df_in_range["samples"] = 0  # Start from 0
        else:
            return np.full_like(grid_samples, np.nan)
    
    # Use running max for cumulative best performance
    df_in_range["best"] = df_in_range["corr"].cummax()
    
    xp = df_in_range["samples"].to_numpy()
    yp = df_in_range["best"].to_numpy()
    
    # Ensure we start from 0
    if len(xp) > 0 and xp[0] > 0:
        xp = np.insert(xp, 0, 0.0)
        yp = np.insert(yp, 0, yp[0])
    
    # Interpolate with forward-fill extrapolation
    if len(xp) > 0:
        result = np.interp(grid_samples, xp, yp, left=yp[0] if len(yp) > 0 else np.nan, 
                          right=yp[-1] if len(yp) > 0 else np.nan)
    else:
        result = np.full_like(grid_samples, np.nan)
    
    return result

# Get fair comparison range for Zenodo
common_start_z, common_end_z = get_fair_comparison_range(online_sessions_list, bptt_sessions_list, "Zenodo")

if common_end_z <= common_start_z:
    print("Error: Invalid comparison range for Zenodo")
    sys.exit(1)

grid_samples_zenodo = np.linspace(common_start_z, common_end_z, 101)

# Build matrices with fair comparison
online_mat_zenodo = []
for d in online_sessions_list:
    interpolated = to_grid_with_extrapolation(d, grid_samples_zenodo, common_end_z)
    if not np.all(np.isnan(interpolated)):
        online_mat_zenodo.append(interpolated)

bptt_mat_zenodo = []
for d in bptt_sessions_list:
    interpolated = to_grid_with_extrapolation(d, grid_samples_zenodo, common_end_z)
    if not np.all(np.isnan(interpolated)):
        bptt_mat_zenodo.append(interpolated)

if len(online_mat_zenodo) == 0 or len(bptt_mat_zenodo) == 0:
    print("Error: No valid data in Zenodo comparison range")
    sys.exit(1)

online_mat_zenodo = np.vstack(online_mat_zenodo)
bptt_mat_zenodo = np.vstack(bptt_mat_zenodo)

print(f"Zenodo matrices: Online {online_mat_zenodo.shape}, BPTT {bptt_mat_zenodo.shape}")

def mean_sem(mat):
    """Calculate mean and SEM handling NaN values"""
    if mat.size == 0:
        return np.array([]), np.array([])
    
    mean = np.nanmean(mat, axis=0)
    n_valid = np.sum(~np.isnan(mat), axis=0)
    sem = np.full_like(mean, 0.0)
    mask = n_valid >= 2
    if np.any(mask):
        sem[mask] = np.nanstd(mat[:, mask], axis=0, ddof=1) / np.sqrt(n_valid[mask])
    return mean, sem

on_mean_z, on_sem_z = mean_sem(online_mat_zenodo)
bt_mean_z, bt_sem_z = mean_sem(bptt_mat_zenodo)

# ===== MCMAZE PARSING =====
print("\n" + "="*50)
print("Processing MCMaze...")

txt_mc = Path("fullmcmaze.txt").read_text()

# MCMaze specific regexes
split_hdr = re.compile(r"---\s*Split\s+(\d+)/(\d+)\s*---")
train_sz_pat = re.compile(r"Train(?:\s+set)?:\s+(\d+)\s+samples", re.I)
epoch_pat = re.compile(
    r"Epoch\s+\[(\d+)/\d+\].*?Val\s+Corr\s+X:\s*([-+]?\d*\.?\d+).*?"
    r"Val\s+Corr\s+Y:\s*([-+]?\d*\.?\d+)", 
    re.S)
online_mark = "--- Running SNN Online Training ---"
step_avg_pat = re.compile(
    r"Step\s+(\d+).*?Val\s+Avg\s+Corr\s*:\s*([-+]?\d*\.?\d+)",
    re.S,
)
step_xy_pat = re.compile(
    r"Step\s+(\d+).*?Val\s+Corr\s+X\s*:\s*([-+]?\d*\.?\d+)" 
    r".*?Val\s+Corr\s+Y\s*:\s*([-+]?\d*\.?\d+)",
    re.S,
)
epoch_any_pat = re.compile(r"Epoch\s+\[(\d+)/(\d+)\]")
online_session_hdr = re.compile(r"===\s*PRE-SNN\s+VERIFICATION\s*\(Session\s*(\d+)\)\s*===")

def parse_bptt_mc(block: str):
    """Parse BPTT data for MCMaze with proper sample counting"""
    m = train_sz_pat.search(block)
    if not m: 
        return None
    N = int(m.group(1))
    hits = epoch_pat.findall(block)
    if not hits: 
        return None
    
    # Calculate cumulative samples correctly
    rows = []
    for e, x, y in hits:
        epoch_num = int(e)
        cumulative_samples = epoch_num * N
        avg_corr = (float(x) + float(y)) / 2
        rows.append((cumulative_samples, avg_corr, epoch_num))
    
    df = pd.DataFrame(rows, columns=["samples","corr","epoch"])
    df["progress"] = df["samples"] / df["samples"].iloc[-1]
    
    print(f"MCMaze BPTT: {len(df)} epochs, {N} samples/epoch, max samples: {df['samples'].max()}")
    return df

def parse_online_mc(block: str):
    """Parse online data for MCMaze"""
    start_idx = block.find("Starting SNN Online Training")
    end_idx = block.find("HEBBIAN Results")
    if start_idx != -1:
        block = block[start_idx: end_idx if end_idx != -1 else len(block)]

    rows: list[tuple[int, float]] = []
    for m in step_avg_pat.finditer(block):
        s, c = m.group(1), m.group(2)
        rows.append((int(s), float(c)))
    for m in step_xy_pat.finditer(block):
        s, x, y = m.group(1), m.group(2), m.group(3)
        rows.append((int(s), (float(x) + float(y)) / 2.0))
    
    if not rows:
        return None
    
    # Sort and dedupe by step
    rows.sort(key=lambda t: t[0])
    dedup: dict[int, float] = {}
    for step_num, corr in rows:
        if step_num not in dedup:
            dedup[step_num] = corr
    
    steps_sorted = sorted(dedup.items())
    df = pd.DataFrame([(step, corr) for step, corr in steps_sorted], columns=["samples", "corr"])
    
    # Add epoch info
    ep_hits = epoch_any_pat.findall(block)
    if ep_hits:
        try:
            max_ep = max(int(e) for e, _ in ep_hits)
            df["epoch_max"] = max_ep
        except Exception:
            pass
    
    df["progress"] = df["samples"] / df["samples"].iloc[-1]
    return df

def parse_online_sessions_mc(block: str):
    """Parse multiple online sessions in MCMaze"""
    sessions = []
    matches = list(online_session_hdr.finditer(block))
    if matches:
        for i, m in enumerate(matches):
            seg = block[m.start(): matches[i+1].start() if i+1 < len(matches) else len(block)]
            df = parse_online_mc(seg)
            if df is not None and len(df) > 0:
                sessions.append(df)
    else:
        df = parse_online_mc(block)
        if df is not None and len(df) > 0:
            sessions.append(df)
    return sessions

# Scan MCMaze file
online_runs_mc, bptt_runs_mc = [], []
splits = list(split_hdr.finditer(txt_mc))

for i, m in enumerate(splits):
    block = txt_mc[m.start(): splits[i+1].start() if i+1 < len(splits) else len(txt_mc)]

    # Parse BPTT
    bt = parse_bptt_mc(block)
    if bt is not None:
        bptt_runs_mc.append(bt)

    # Parse Online
    if online_mark in block:
        _, post = block.split(online_mark, 1)
        on_list = parse_online_sessions_mc(online_mark + post)
    else:
        on_list = parse_online_sessions_mc(block)

    if on_list:
        online_runs_mc.extend(on_list)

print(f"MCMaze found {len(online_runs_mc)} online runs and {len(bptt_runs_mc)} BPTT runs")

if not bptt_runs_mc:
    raise RuntimeError("BPTT runs not found - check log format.")
if not online_runs_mc:
    print("No online runs found - continuing with BPTT only.")

# Apply same fair comparison logic to MCMaze
common_start_mc, common_end_mc = get_fair_comparison_range(online_runs_mc, bptt_runs_mc, "MCMaze")

if common_end_mc <= common_start_mc:
    print("Error: Invalid comparison range for MCMaze")
    sys.exit(1)

grid_samples_mc = np.linspace(common_start_mc, common_end_mc, 101)

# Build MCMaze matrices
online_mat_mc = []
for d in online_runs_mc:
    interpolated = to_grid_with_extrapolation(d, grid_samples_mc, common_end_mc)
    if not np.all(np.isnan(interpolated)):
        online_mat_mc.append(interpolated)

bptt_mat_mc = []
for d in bptt_runs_mc:
    interpolated = to_grid_with_extrapolation(d, grid_samples_mc, common_end_mc)
    if not np.all(np.isnan(interpolated)):
        bptt_mat_mc.append(interpolated)

if len(online_mat_mc) == 0 or len(bptt_mat_mc) == 0:
    print("Error: No valid MCMaze data in comparison range")
    sys.exit(1)

online_mat_mc = np.vstack(online_mat_mc)
bptt_mat_mc = np.vstack(bptt_mat_mc)

print(f"MCMaze matrices: Online {online_mat_mc.shape}, BPTT {bptt_mat_mc.shape}")

on_mean_mc, on_sem_mc = mean_sem(online_mat_mc)
bt_mean_mc, bt_sem_mc = mean_sem(bptt_mat_mc)

# ===== IMPROVED PLOTTING =====
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman"],
    "font.size": 18,
    "axes.labelsize": 16,
    "axes.titlesize": 18,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
    "lines.linewidth": 2.0,
    "axes.linewidth": 1.0,
    "axes.grid": True,
    "axes.grid.axis": "y",
    "grid.linestyle": "--",
    "grid.alpha": 0.35,
})

# Colors
color_online = "#1f77b4"  # Blue
color_bptt = "#ff7f0e"    # Orange

# === ZENODO PLOT ===
fig1, ax1 = plt.subplots(1, 1, figsize=(8, 5))

# Plot mean lines with confidence intervals
ax1.plot(grid_samples_zenodo, on_mean_z, color=color_online, linewidth=2.5, 
         label='Online SNN', alpha=0.9)
ax1.fill_between(grid_samples_zenodo, on_mean_z - on_sem_z, on_mean_z + on_sem_z, 
                 color=color_online, alpha=0.2)

ax1.plot(grid_samples_zenodo, bt_mean_z, color=color_bptt, linewidth=2.5, 
         label='BPTT', alpha=0.9)
ax1.fill_between(grid_samples_zenodo, bt_mean_z - bt_sem_z, bt_mean_z + bt_sem_z, 
                 color=color_bptt, alpha=0.2)

# Add raw data as light scatter
for d in online_sessions_list:
    dfp = d[d["samples"] <= common_end_z].sort_values("samples")
    if len(dfp) > 0:
        ax1.scatter(dfp["samples"], dfp["corr"], s=6, alpha=0.15, 
                   color=color_online, linewidths=0)

for d in bptt_sessions_list:
    dfp = d[d["samples"] <= common_end_z].sort_values("samples")
    if len(dfp) > 0:
        ax1.scatter(dfp["samples"], dfp["corr"], s=8, alpha=0.25, 
                   color=color_bptt, linewidths=0)

# Find intersection point
valid_mask = ~(np.isnan(on_mean_z) | np.isnan(bt_mean_z))
if np.any(valid_mask):
    diff = np.abs(on_mean_z - bt_mean_z)
    intersection_idx = np.argmin(diff[valid_mask])
    valid_indices = np.where(valid_mask)[0]
    actual_idx = valid_indices[intersection_idx]
    
    intersection_sample = grid_samples_zenodo[actual_idx]
    intersection_corr = (on_mean_z[actual_idx] + bt_mean_z[actual_idx]) / 2
    
    ax1.axvline(x=intersection_sample, color='red', linestyle=':', alpha=0.7, linewidth=1.5)
    ax1.text(intersection_sample, intersection_corr + 0.05, 
             f'Intersection\n{int(intersection_sample):,} samples', 
             ha='center', va='bottom', fontsize=12, 
             bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

ax1.set_xlabel("Samples Processed")
ax1.set_ylabel("Validation Correlation")
ax1.set_title("Zenodo Dataset")
ax1.set_xlim(common_start_z, common_end_z)
ax1.set_ylim(0, 1)
ax1.spines["top"].set_visible(False)
ax1.spines["right"].set_visible(False)
ax1.legend(loc='lower right')

plt.tight_layout()
plt.savefig("figures/zenodo_learning_curves.pdf", bbox_inches="tight", pad_inches=0.02)
plt.savefig("figures/zenodo_learning_curves.png", dpi=600, bbox_inches="tight", pad_inches=0.02)
plt.show()

# === MCMAZE PLOT ===
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 5))

ax2.plot(grid_samples_mc, on_mean_mc, color=color_online, linewidth=2.5, 
         label='Online SNN', alpha=0.9)
ax2.fill_between(grid_samples_mc, on_mean_mc - on_sem_mc, on_mean_mc + on_sem_mc, 
                 color=color_online, alpha=0.2)

ax2.plot(grid_samples_mc, bt_mean_mc, color=color_bptt, linewidth=2.5, 
         label='BPTT', alpha=0.9)
ax2.fill_between(grid_samples_mc, bt_mean_mc - bt_sem_mc, bt_mean_mc + bt_sem_mc, 
                 color=color_bptt, alpha=0.2)

# Add raw data
for d in online_runs_mc:
    dfp = d[d["samples"] <= common_end_mc].sort_values("samples")
    if len(dfp) > 0:
        ax2.scatter(dfp["samples"], dfp["corr"], s=6, alpha=0.15, 
                   color=color_online, linewidths=0)

for d in bptt_runs_mc:
    dfp = d[d["samples"] <= common_end_mc].sort_values("samples")
    if len(dfp) > 0:
        ax2.scatter(dfp["samples"], dfp["corr"], s=8, alpha=0.25, 
                   color=color_bptt, linewidths=0)

# Find intersection point
valid_mask_mc = ~(np.isnan(on_mean_mc) | np.isnan(bt_mean_mc))
if np.any(valid_mask_mc):
    diff_mc = np.abs(on_mean_mc - bt_mean_mc)
    intersection_idx_mc = np.argmin(diff_mc[valid_mask_mc])
    valid_indices_mc = np.where(valid_mask_mc)[0]
    actual_idx_mc = valid_indices_mc[intersection_idx_mc]
    
    intersection_sample_mc = grid_samples_mc[actual_idx_mc]
    intersection_corr_mc = (on_mean_mc[actual_idx_mc] + bt_mean_mc[actual_idx_mc]) / 2
    
    ax2.axvline(x=intersection_sample_mc, color='red', linestyle=':', alpha=0.7, linewidth=1.5)
    ax2.text(intersection_sample_mc, intersection_corr_mc + 0.05, 
             f'Intersection\n{int(intersection_sample_mc):,} samples', 
             ha='center', va='bottom', fontsize=12,
             bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

ax2.set_xlabel("Samples Processed")
ax2.set_ylabel("Avg. Val. Correlation")
ax2.set_title("MCMaze Dataset")
ax2.set_xlim(common_start_mc, common_end_mc)
ax2.set_ylim(0, 1)
ax2.spines["top"].set_visible(False)
ax2.spines["right"].set_visible(False)
ax2.legend(loc='lower right')

plt.tight_layout()
plt.savefig("figures/mcmaze_learning_curves.pdf", bbox_inches="tight", pad_inches=0.02)
plt.savefig("figures/mcmaze_learning_curves.png", dpi=600, bbox_inches="tight", pad_inches=0.02)
plt.show()

# Print summary statistics
print("\n=== COMPARISON SUMMARY ===")
print(f"Zenodo Dataset:")
print(f"  - Comparison range: {common_start_z:,} to {common_end_z:,} samples")
print(f"  - Online final performance: {on_mean_z[-1]:.3f} ± {on_sem_z[-1]:.3f}")
print(f"  - BPTT final performance: {bt_mean_z[-1]:.3f} ± {bt_sem_z[-1]:.3f}")
if np.any(valid_mask):
    print(f"  - Intersection at ~{int(intersection_sample):,} samples")

print(f"\nMCMaze Dataset:")
print(f"  - Comparison range: {common_start_mc:,} to {common_end_mc:,} samples")
print(f"  - Online final performance: {on_mean_mc[-1]:.3f} ± {on_sem_mc[-1]:.3f}")
print(f"  - BPTT final performance: {bt_mean_mc[-1]:.3f} ± {bt_sem_mc[-1]:.3f}")
if np.any(valid_mask_mc):
    print(f"  - Intersection at ~{int(intersection_sample_mc):,} samples")