# run_experiment_ablation.py
import os
import time
import traceback
from datetime import datetime
import math
import numpy as np
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from einops import rearrange, repeat
import argparse

from neuromamba.ops.hm_interface import hm_fn
from causal_conv1d import causal_conv1d_fn


from config import simulation_config, training_config, dataset_config, model_config, paths_config
from data_generator import generate_trial_sequence


class NeuroMamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        d_conv_gc=4,
        expand=2,
        expand_gc=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        layer_idx=None,
        device=None,
        dtype=None,
        ablate_mf: bool = False,
        ablate_y2: bool = False, 
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.d_conv_gc = d_conv_gc
        self.expand = expand
        self.expand_gc = expand_gc
        self.d_inner = int(self.expand * self.d_model)
        self.d_gc = int(self.expand_gc * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.layer_idx = layer_idx


        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2 + self.d_gc, bias=bias, **factory_kwargs)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias,
            kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs,
        )
        self.conv1d_gc = nn.Conv1d(
            in_channels=self.d_gc, out_channels=self.d_gc, bias=conv_bias,
            kernel_size=d_conv_gc, groups=self.d_gc, padding=d_conv_gc - 1, **factory_kwargs,
        )
        self.activation = "silu"
        self.act = nn.SiLU()
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
        self.mf_proj = nn.Linear(self.d_gc, self.d_inner, bias=bias, **factory_kwargs)

        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else: nn.init.constant_(self.dt_proj.weight, dt_init_std)
        
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True

        A = repeat(torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner).contiguous()
        self.A_log = nn.Parameter(torch.log(A))
        self.A_log._no_weight_decay = True
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))
        self.D._no_weight_decay = True
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.out_cathree_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)


        if ablate_mf:
            print(f"Layer {self.layer_idx}: Ablating 'mf' branch. Setting mf_proj parameters to zero and freezing.")
            with torch.no_grad():
                self.mf_proj.weight.zero_()
                if self.mf_proj.bias is not None:
                    self.mf_proj.bias.zero_()
            self.mf_proj.weight.requires_grad = False
            if self.mf_proj.bias is not None:
                self.mf_proj.bias.requires_grad = False
        
        if ablate_y2:
            print(f"Layer {self.layer_idx}: Ablating 'y2' branch. Setting out_cathree_proj parameters to zero and freezing.")
            with torch.no_grad():
                self.out_cathree_proj.weight.zero_()
                if self.out_cathree_proj.bias is not None:
                    self.out_cathree_proj.bias.zero_()
            self.out_cathree_proj.weight.requires_grad = False
            if self.out_cathree_proj.bias is not None:
                self.out_cathree_proj.bias.requires_grad = False

    def forward(self, hidden_states, inference_params=None):
        batch, seqlen, dim = hidden_states.shape
        xz_gc = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l", l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz_gc = xz_gc + rearrange(self.in_proj.bias.to(dtype=xz_gc.dtype), "d -> d 1")
        A = -torch.exp(self.A_log.float())
        x, z, gc = torch.split(xz_gc, [self.d_inner, self.d_inner, self.d_gc], dim=1)
        x = causal_conv1d_fn(
            x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
            bias=self.conv1d.bias, activation=self.activation,
        )
        gc = causal_conv1d_fn(
            gc, weight=rearrange(self.conv1d_gc.weight, "d 1 w -> d w"),
            bias=self.conv1d_gc.bias, activation=self.activation,
        )
        mf = self.mf_proj(rearrange(gc, "b d l -> (b l) d"))
        mf = rearrange(mf, "(b l) d -> b d l", l=seqlen)
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = self.dt_proj(dt)
        dt = rearrange(dt, "(b l) d -> b d l", l=seqlen)
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        y = hm_fn(x, dt, mf, A, B, C, self.D.float(), z=z,
                  delta_bias=self.dt_proj.bias.float(), delta_softplus=True)
        y1, y2 = y
        y1 = rearrange(y1, "b d l -> b l d")
        y2 = rearrange(y2, "b d l -> b l d")
        out = self.out_proj(y1) + self.out_cathree_proj(y2)
        return out, y1

class NeuroMambaModelForAnalysis(nn.Module):
    def __init__(self, vocab_size, d_model, d_state, d_conv, d_conv_gc, expand, expand_gc, 
                 ablate_mf: bool = False, ablate_y2: bool = False, **kwargs):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.neuromamba_block = NeuroMamba(
            d_model=d_model, d_state=d_state, d_conv=d_conv, d_conv_gc=d_conv_gc,
            expand=expand, expand_gc=expand_gc,
            ablate_mf=ablate_mf,
            ablate_y2=ablate_y2
        )
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        final_output, hidden_states = self.neuromamba_block(x) 
        logits = self.lm_head(final_output)
        return logits, hidden_states


def corr_finder(hidden_all, test_trials, tr_len, hidden_size):
    hidden_all = hidden_all.cpu().detach().numpy()
    test0, test1 = np.where(test_trials == 0)[0], np.where(test_trials == 1)[0]
    if test0.size == 0 or test1.size == 0: return np.eye(tr_len*2)
    test0_act = np.zeros((tr_len, hidden_size, len(test0)))
    for i, trial_idx in enumerate(test0): test0_act[:, :, i] = hidden_all[trial_idx * tr_len:(trial_idx + 1) * tr_len, :]
    test1_act = np.zeros((tr_len, hidden_size, len(test1)))
    for i, trial_idx in enumerate(test1): test1_act[:, :, i] = hidden_all[trial_idx * tr_len:(trial_idx + 1) * tr_len, :]
    mean_test0, mean_test1 = np.mean(test0_act, axis=2), np.mean(test1_act, axis=2)
    corrplot = np.corrcoef(mean_test0, mean_test1)
    return corrplot

def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"total_params: {total_params / 1e6:.2f} M")
    print(f"trainable_params: {trainable_params / 1e6:.2f} M (Note: reflects trainable status and tied weights)")

def run_simulation(gpu_id, seeds, shared_dict, ablation_setting, verbose=True):
    try:
        torch.cuda.set_device(gpu_id)
        device = torch.device(f"cuda:{gpu_id}")
        if verbose: print(f"[GPU {gpu_id}] starting. Device: {device}. Seeds: {seeds}. Ablation: '{ablation_setting}'")

        epochs, lr, save_interval = training_config['epochs'], training_config['learning_rate'], training_config['save_interval']
        tr_len, num_train_trials, vocab_size = dataset_config['trial_length'], dataset_config['num_train_trials'], dataset_config['vocab_size']
        hidden_size_for_corr = model_config['hidden_size_for_corr']

        num_saves = epochs // save_interval
        loss_all, corr_curve, accuracy_curve = np.zeros((len(seeds), num_saves)), np.zeros((len(seeds), num_saves, tr_len * 2, tr_len * 2)), np.zeros((len(seeds), num_saves))

        for idx, seed in enumerate(seeds):
            if verbose: print(f"[GPU {gpu_id}] Sim {idx+1}/{len(seeds)}, Seed: {seed}")
            torch.manual_seed(int(seed)); np.random.seed(int(seed))
            
            ablate_mf = (ablation_setting == 'mf' or ablation_setting == 'both')
            ablate_y2 = (ablation_setting == 'y2' or ablation_setting == 'both')
            
            model = NeuroMambaModelForAnalysis(**model_config, ablate_mf=ablate_mf, ablate_y2=ablate_y2).to(device)
            
            if idx == 0 and verbose:
                print("-" * 40); print(f"MODEL CONFIG FOR ABLATION: '{ablation_setting.upper()}'"); print_model_parameters(model); print(f"Hidden State size for correlation (d_inner): {hidden_size_for_corr}"); print("-" * 40)
            
            loss_func = nn.CrossEntropyLoss()
            optimizer = optim.AdamW(model.parameters(), lr=lr)
            
            save_idx = 0
            for epoch in range(epochs):
                x_int, trials = generate_trial_sequence(dataset_config)
                train_data = torch.tensor(x_int[:num_train_trials * tr_len], dtype=torch.long, device=device).unsqueeze(0)
                test_data = torch.tensor(x_int[num_train_trials * tr_len:], dtype=torch.long, device=device).unsqueeze(0)
                
                model.train()
                optimizer.zero_grad()
                prediction_logits, _ = model(train_data[:, :-1])
                loss = loss_func(prediction_logits.view(-1, vocab_size), train_data[:, 1:].view(-1))
                loss.backward(); optimizer.step()

                if epoch % save_interval == 0:
                    model.eval()
                    with torch.no_grad():
                        pred_test_logits, hidden_states_y1 = model(test_data)
                        predicted_tokens = pred_test_logits.argmax(dim=-1).squeeze(0)
                        actual_tokens = test_data.squeeze(0)
                        actual_reward_indices = torch.where(actual_tokens[1:] == 6)[0]
                        predicted_reward_indices = torch.where(predicted_tokens[:-1] == 6)[0]

                        accuracy = len(np.intersect1d(actual_reward_indices.cpu().numpy(), predicted_reward_indices.cpu().numpy())) / len(actual_reward_indices) if len(actual_reward_indices) > 0 else (1.0 if len(predicted_reward_indices) == 0 else 0.0)
                        
                        accuracy_curve[idx, save_idx] = accuracy
                        corr_curve[idx, save_idx] = corr_finder(hidden_states_y1.squeeze(0), trials[num_train_trials:], tr_len, hidden_size_for_corr)
                        loss_all[idx, save_idx] = loss.item()
                        save_idx += 1
                
                if epoch % 50 == 0 and verbose:
                    current_acc = accuracy if 'accuracy' in locals() else -1.0
                    print(f"[GPU {gpu_id}] Sim {idx+1}, Epoch {epoch}, Loss: {loss.item():.4f}, Acc: {current_acc:.4f}")
        
        shared_dict[gpu_id] = {'loss_all': loss_all, 'corr_curve': corr_curve, 'accuracy_curve_all_test': accuracy_curve}
        if verbose: print(f"[GPU {gpu_id}] finished and stored results.")
    except Exception as e:
        tb = traceback.format_exc(); print(f"[GPU {gpu_id}] ERROR: {e}\n{tb}"); shared_dict[gpu_id] = {'error': True, 'error_msg': str(e), 'traceback': tb}

def _safe_concatenate(list_of_arrays, axis=0):
    return np.concatenate(list_of_arrays, axis=axis) if list_of_arrays else None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run NeuroMamba Ablation Experiments on OSM task.")
    parser.add_argument('--ablation', type=str, default='none', choices=['none', 'mf', 'y2', 'both'], help="Specify the ablation to run.")
    args = parser.parse_args()
    
    print(f"\n{'='*50}\n--- STARTING EXPERIMENT RUN: ABLATION = {args.ablation.upper()} ---\n{'='*50}\n")
    
    mp.set_start_method("spawn", force=True)
    if not torch.cuda.is_available(): raise RuntimeError("CUDA not available.")
    
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs found: {num_gpus}")

    total_sims = simulation_config['total_simulations']
    seeds_per_gpu = total_sims // num_gpus
    remaining = total_sims % num_gpus
    all_seeds = np.arange(simulation_config['base_seed'], simulation_config['base_seed'] + total_sims)

    manager = mp.Manager(); shared_dict = manager.dict(); processes = []
    start_idx = 0
    for i in range(num_gpus):
        end_idx = start_idx + seeds_per_gpu + (1 if i < remaining else 0)
        gpu_seeds = all_seeds[start_idx:end_idx]
        if len(gpu_seeds) == 0: continue
        p = mp.Process(target=run_simulation, args=(i, gpu_seeds, shared_dict, args.ablation))
        p.start(); processes.append(p)
        start_idx = end_idx

    for p in processes: p.join()

    loss_list, corr_list, acc_list = [], [], []
    for k in sorted(shared_dict.keys()):
        res = shared_dict[k]
        if isinstance(res, dict) and res.get('error'): continue
        loss_list.append(res['loss_all']); corr_list.append(res['corr_curve']); acc_list.append(res['accuracy_curve_all_test'])
    if not loss_list: raise RuntimeError("No simulation results produced.")

    combined_loss, combined_corr, combined_accuracy = map(lambda x: _safe_concatenate(x, axis=0), [loss_list, corr_list, acc_list])

    now = datetime.now()
  
    dir_prefix = now.strftime("%m%d")
    
    dir_name = f"{dir_prefix}_ablation-{args.ablation}"
    os.makedirs(dir_name, exist_ok=True)
    

    timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")

    np.save(os.path.join(dir_name, paths_config['npy_filename_pattern'].format(metric='corr_curve', timestamp=timestamp)), combined_corr)
    np.save(os.path.join(dir_name, paths_config['npy_filename_pattern'].format(metric='accuracy_curve_all_test', timestamp=timestamp)), combined_accuracy)
    np.save(os.path.join(dir_name, paths_config['npy_filename_pattern'].format(metric='loss_all', timestamp=timestamp)), combined_loss)

    print(f"\n--- Training Complete ---\nResults for '{args.ablation.upper()}' saved to '{dir_name}'")
    print(f"Final Average Loss: {np.mean(combined_loss[:, -1]):.4f}, Final Average Accuracy: {np.mean(combined_accuracy[:, -1]):.4f}")
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1); plt.plot(np.mean(combined_loss, axis=0)); plt.title(f"Average Loss Curve (Ablation: {args.ablation.upper()})"); plt.xlabel(f"Epochs (x{training_config['save_interval']})"); plt.ylabel("Loss")
    plt.subplot(1, 2, 2); plt.plot(np.mean(combined_accuracy, axis=0)); plt.title(f"Average Accuracy Curve (Ablation: {args.ablation.upper()})"); plt.xlabel(f"Epochs (x{training_config['save_interval']})"); plt.ylabel("Accuracy")
    
    plot_path = paths_config['plot_filename_pattern'].format(timestamp=timestamp)
    plt.tight_layout(); plt.savefig(os.path.join(dir_name, plot_path)); plt.show()
    print(f"Training visualization saved to {os.path.join(dir_name, plot_path)}")