import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
import gc, os, tracemalloc

# Attempt to import psutil, but proceed without it if not available
try:
    import psutil
except ImportError:
    psutil = None

# --- Matplotlib Configuration ---
mpl.rcParams.update({
    "font.family": "serif", "font.serif": ["Times New Roman", "Times"], "font.size": 12,
    "axes.linewidth": 0.6, "axes.prop_cycle": cycler(color=["#D55E00", "#0072B2", "#009E73"]),
    "xtick.major.width": 0.6, "ytick.major.width": 0.6, "xtick.direction": "in", "ytick.direction": "in",
    "legend.frameon": False, "pdf.fonttype": 42, "figure.autolayout": False,
})

# --- Constants ---
OUTPUT_SIZE = 2
BYTES_PER_FLOAT = 4

# --- SNN Model Definition ---
class SNNRegression(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SNNRegression, self).__init__()
        spike_grad = surrogate.fast_sigmoid()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc_rec = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.lif2 = snn.Leaky(beta=0.7, spike_grad=spike_grad, init_hidden=False)
        self.fc3 = nn.Linear(hidden_size // 2, output_size)
        self.lif3 = snn.Leaky(beta=0.5, spike_grad=spike_grad, init_hidden=False, threshold=1.0, reset_mechanism="none")
        
    def forward(self, x):
        batch_size = x.size(0)
        spk1_rec = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=x.device)
        outputs = []
        for t in range(x.size(1)):
            inp = x[:, t, :]
            cur1 = self.fc1(inp) + self.fc_rec(spk1_rec)
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_rec = spk1
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            out, mem3 = self.lif3(cur3, mem3)
            outputs.append(mem3)
        return torch.stack(outputs, dim=1)

# --- Memory Calculation and Profiling Functions ---

def calculate_static_memory(input_size, hidden_size, algorithm_type):
    """Calculates static memory for either Online SNN or BPTT SNN."""
    fc1 = input_size * hidden_size
    fc_rec = hidden_size * hidden_size
    fc2 = hidden_size * (hidden_size // 2)
    fc3 = (hidden_size // 2) * OUTPUT_SIZE
    biases = hidden_size + (hidden_size // 2) + OUTPUT_SIZE
    n_params = fc1 + fc_rec + fc2 + fc3 + biases
    params_mem = n_params * BYTES_PER_FLOAT

    if algorithm_type == 'online':
        # Online SNN stores parameters (1x) + fast/slow eligibility traces (2x) = 3x total
        traces_mem = 2 * n_params * BYTES_PER_FLOAT
        return params_mem + traces_mem
    elif algorithm_type == 'bptt':
        # BPTT SNN stores parameters (1x) + .grad buffers (1x) + Adam optimizer state (2x) = 4x total
        grad_mem = n_params * BYTES_PER_FLOAT
        optimizer_mem = 2 * n_params * BYTES_PER_FLOAT
        return params_mem + grad_mem + optimizer_mem
    return 0

_bptt_measurement_cache = {}
def measure_bptt_activation_memory(input_size, hidden_size, sequence_length, device='cpu'):
    """Performs a single, robust measurement of BPTT activation graph memory."""
    cache_key = (input_size, hidden_size, sequence_length)
    if cache_key in _bptt_measurement_cache:
        return _bptt_measurement_cache[cache_key]

    print(f"--- Measuring BPTT activation memory for {cache_key} ---")
    
    model = SNNRegression(input_size, hidden_size, OUTPUT_SIZE).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    x = torch.zeros(1, sequence_length, input_size, device=device)
    y = torch.zeros(1, sequence_length, OUTPUT_SIZE, device=device)
    criterion = torch.nn.MSELoss()

    # 1. Warm-up JIT/kernels
    for _ in range(2):
        optimizer.zero_grad(set_to_none=True)
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()

    # 2. Prime the optimizer and .grad buffers to allocate their state BEFORE measurement
    optimizer.zero_grad(set_to_none=True)
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)
    del loss
    gc.collect()

    # 3. Measure the incremental peak memory of a single step
    def one_update_step():
        optimizer.zero_grad(set_to_none=True)
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()

    gc.collect()
    tracemalloc.start()
    cur0, _ = tracemalloc.get_traced_memory()
    rss0 = psutil.Process(os.getpid()).memory_info().rss if psutil else 0
    one_update_step()
    cur1, peak = tracemalloc.get_traced_memory()
    rss1 = psutil.Process(os.getpid()).memory_info().rss if psutil else 0
    tracemalloc.stop()
    
    act_mem_bytes = max(peak - cur0, rss1 - rss0)
    _bptt_measurement_cache[cache_key] = act_mem_bytes
    
    del model, optimizer, x, y
    gc.collect()
    
    return act_mem_bytes

def get_bptt_dynamic_memory(input_size, hidden_size, sequence_length):
    """Calculates BPTT dynamic memory by scaling a single benchmark measurement."""
    bench_input, bench_hidden, bench_seq = 96, 256, 120
    measured_act_mem = measure_bptt_activation_memory(bench_input, bench_hidden, bench_seq)
    
    # Corrected scaling: scales with total hidden layer width, timesteps, and input size
    total_hidden_width = hidden_size + hidden_size // 2
    bench_total_hidden_width = bench_hidden + bench_hidden // 2
    
    width_scaling = total_hidden_width / bench_total_hidden_width
    length_scaling = sequence_length / bench_seq
    input_scaling = input_size / bench_input
    
    scaling_factor = width_scaling * length_scaling * input_scaling
    return measured_act_mem * scaling_factor

# --- Plotting Function ---
def generate_memory_plots(hidden_size, suffix):
    print(f"\n--- Generating plots for HIDDEN_SIZE = {hidden_size} ---")
    fs = 20

    duration_points = np.linspace(0.1, 5.0, 50)
    neuron_points = np.linspace(16, 1500, 50).astype(int)
    fixed_input_size, fixed_seq_len = 96, 120

    # Plot 1: vs. Duration
    online_mem_dur = [calculate_static_memory(fixed_input_size, hidden_size, 'online') for _ in duration_points]
    bptt_static_mem = calculate_static_memory(fixed_input_size, hidden_size, 'bptt')
    bptt_mem_dur = [bptt_static_mem + get_bptt_dynamic_memory(fixed_input_size, hidden_size, int(d*fs)) for d in duration_points]
    
    fig, ax = plt.subplots(figsize=(3.25, 2.3), constrained_layout=True)
    ax.plot(duration_points, np.array(online_mem_dur) / (1024**2), label='Online SNN', linewidth=1.3)
    ax.plot(duration_points, np.array(bptt_mem_dur) / (1024**2), label='BPTT SNN', linewidth=1.3, linestyle='--')
    ax.set(xlabel='Input Duration (s)', ylabel='Total Memory (MiB)', ylim=(0, None))
    ax.grid(True, linestyle=':', linewidth=0.5, alpha=0.7), ax.legend()
    fig.savefig(f'memory_vs_duration_{suffix}.pdf', bbox_inches='tight')
    plt.close(fig)

    # Plot 2: vs. Neurons
    online_mem_neu = [calculate_static_memory(n, hidden_size, 'online') for n in neuron_points]
    bptt_static_mem_neu = [calculate_static_memory(n, hidden_size, 'bptt') for n in neuron_points]
    bptt_mem_neu = [s + get_bptt_dynamic_memory(n, hidden_size, fixed_seq_len) for s, n in zip(bptt_static_mem_neu, neuron_points)]

    fig, ax = plt.subplots(figsize=(3.25, 2.3), constrained_layout=True)
    ax.plot(neuron_points, np.array(online_mem_neu) / (1024**2), label='Online SNN', linewidth=1.3)
    ax.plot(neuron_points, np.array(bptt_mem_neu) / (1024**2), label='BPTT SNN', linewidth=1.3, linestyle='--')
    ax.set(xlabel='Number of Input Neurons', ylabel='Total Memory (MiB)', ylim=(0, None))
    ax.grid(True, linestyle=':', linewidth=0.5, alpha=0.7), ax.legend()
    fig.savefig(f'memory_vs_neurons_{suffix}.pdf', bbox_inches='tight')
    plt.close(fig)
    
    # Plot 3: Region Map
    duration_range = np.linspace(0.1, 2.0, 50)
    online_grid = np.array([[calculate_static_memory(int(n), hidden_size, 'online') for n in neuron_points] for _ in duration_range])
    bptt_grid = np.array([[calculate_static_memory(int(n), hidden_size, 'bptt') + get_bptt_dynamic_memory(int(n), hidden_size, int(d*fs)) for n in neuron_points] for d in duration_range])
    
    fig, ax = plt.subplots(figsize=(3.25, 2.3), constrained_layout=True)
    ax.contourf(neuron_points, duration_range, (online_grid < bptt_grid), levels=[-0.1, 0.5, 1.1], colors=['#0072B2', '#D55E00'], alpha=0.7)
    ax.contour(neuron_points, duration_range, (online_grid - bptt_grid), levels=[0], colors='k', linewidths=1.2)
    ax.set(xlabel='Input Neurons', ylabel='Input Duration (s)')
    fig.savefig(f'mem_region_map_{suffix}.pdf', bbox_inches='tight')
    plt.close(fig)
    
    print(f"All plots saved for {suffix}")

# --- Detailed Breakdown Function ---
def print_detailed_breakdown(hidden_size, suffix):
    input_size, seq_len = 96, 120
    print(f"\n--- Detailed Memory Calculation ({suffix}) ---")
    print(f"Parameters: Input={input_size}, Hidden={hidden_size}, SeqLen={seq_len}")

    online_static = calculate_static_memory(input_size, hidden_size, 'online')
    bptt_static = calculate_static_memory(input_size, hidden_size, 'bptt')
    bptt_dynamic = get_bptt_dynamic_memory(input_size, hidden_size, seq_len)
    
    print("\nOnline SNN Total Memory:")
    print(f"  {'Static (Params + Traces)':<35}: {online_static/1024**2:7.2f} MB")
    
    print("\nBPTT SNN Breakdown:")
    print(f"  {'Static (Params + Grads + Optimizer)':<35}: {bptt_static/1024**2:7.2f} MB")
    print(f"  {'Dynamic (Activation Graph)':<35}: {bptt_dynamic/1024**2:7.2f} MB")
    print(f"  {'Total Peak Memory':<35}: {(bptt_static + bptt_dynamic)/1024**2:7.2f} MB")

if __name__ == "__main__":
    if psutil is None:
        print("Warning: psutil not found. Memory measurements may be less accurate.")
        exit()

    # --- Generate Artifacts for Zenodo & MC Maze Architectures ---
    print_detailed_breakdown(hidden_size=256, suffix="Zenodo")
    generate_memory_plots(hidden_size=256, suffix="zenodo_256")
    
    print_detailed_breakdown(hidden_size=1024, suffix="MC Maze")
    generate_memory_plots(hidden_size=1024, suffix="mcmaze_1024")
