# train_ablation.py

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import logging
import time
import csv
import os
import copy
from functools import partial

from config_ablation import training_config, dataset_config, neuma_config
from data_generator import generate_dataset

from neuromamba.models.mixer_seq_simple import NeuroMambaLMHeadModel, Block
from neuromamba.modules.mlp import GatedMLP
try:
    from neuromamba.ops.triton.layer_norm import RMSNorm
except ImportError:
    RMSNorm = None

from neuromamba_ab import NeuroMamba_ab

def create_block_ablation(
    d_model,
    expand_gc,
    d_intermediate,
    ssm_cfg=None,
    attn_layer_idx=None,
    attn_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):

    ablate_gc = getattr(neuma_config, 'ablate_gc', False)
    ablate_y2 = getattr(neuma_config, 'ablate_y2', False)

    factory_kwargs = {"device": device, "dtype": dtype}
    ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}

    mixer_cls = partial(
        NeuroMamba_ab,
        expand_gc=expand_gc,
        layer_idx=layer_idx,
        ablate_gc=ablate_gc,
        ablate_y2=ablate_y2,
        **ssm_cfg,
        **factory_kwargs
    )
    
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )

    if d_intermediate == 0:
         mlp_cls = nn.Identity
    else:
        mlp_cls = partial(
            GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
        )
        
    block = Block(
        d_model, mixer_cls, mlp_cls, norm_cls=norm_cls,
        fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block

# -----------------------------------------------------------------------------
# --- "Monkey Patching"  ---
# -----------------------------------------------------------------------------
import neuromamba.models.mixer_seq_simple as mixer_module
original_create_block = mixer_module.create_block 
mixer_module.create_block = create_block_ablation


def set_seed(seed=42):
    """Sets the random seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run_experiment():
    """
    Runs the complete training and evaluation pipeline for the NeuroMamba model.
    """
    set_seed(42)

    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device: {device}')


    model = NeuroMambaLMHeadModel(neuma_config, device=device)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=training_config["learning_rate"])
    optimizer_name = optimizer.__class__.__name__.lower()
    
    eval_interval = 1000
    eval_steps = 100

    output_dir = "results"
    os.makedirs(output_dir, exist_ok=True)

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    

    if neuma_config.ablate_gc and neuma_config.ablate_y2:
        ablation_str = "ab_gc_y2"  
    elif neuma_config.ablate_gc:
        ablation_str = "ab_gc"
    elif neuma_config.ablate_y2:
        ablation_str = "ab_y2"
    else:
        ablation_str = "base"

    csv_filename = (f"log_{optimizer_name}_{ablation_str}_gc{neuma_config.expand_gc}_"
                    f"lr{training_config['learning_rate']:.0e}_{timestamp}.csv")
    csv_filepath = os.path.join(output_dir, csv_filename)
    
    csv_fieldnames = ['step', 'loss', 'accuracy']
    
    with open(csv_filepath, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=csv_fieldnames)
        writer.writeheader()
    logger.info(f"Evaluation results will be saved to: {csv_filepath}")


    start_time = time.time()
    logger.info("--- Starting Training ---")

    for step in range(1, training_config["num_steps"] + 1):
        model.train()
        inputs, targets = generate_dataset(dataset_config, training_config)
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs, num_last_tokens=dataset_config['l_memorize']).logits
        loss = criterion(outputs.view(-1, neuma_config.vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 200 == 0:
            logger.info(f'Step [{step}/{training_config["num_steps"]}], Training Loss: {loss.item():.4f}')

        if step % eval_interval == 0 or step == training_config["num_steps"]:
            model.eval()
            total_eval_loss = 0
            total_eval_correct = 0
            total_eval_tokens = 0
            
            logger.info(f"--- Step {step}: Starting evaluation... ---")
            with torch.no_grad():
                for _ in range(eval_steps):
                    val_inputs, val_targets = generate_dataset(dataset_config, training_config)
                    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                    val_outputs = model(val_inputs, num_last_tokens=dataset_config['l_memorize']).logits
                    val_loss = criterion(val_outputs.view(-1, neuma_config.vocab_size), val_targets.view(-1))
                    total_eval_loss += val_loss.item()
                    preds = val_outputs.argmax(dim=-1)
                    total_eval_correct += (preds == val_targets).sum().item()
                    total_eval_tokens += val_targets.numel()
            
            avg_eval_loss = total_eval_loss / eval_steps
            avg_eval_accuracy = 100 * total_eval_correct / total_eval_tokens
            
            logger.info(f'>>> Evaluation at step [{step}]: '
                        f'Average Validation Loss: {avg_eval_loss:.4f}, '
                        f'Validation Accuracy: {avg_eval_accuracy:.2f}% <<<')

            with open(csv_filepath, 'a', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=csv_fieldnames)
                writer.writerow({
                    'step': step, 'loss': f"{avg_eval_loss:.4f}", 'accuracy': f"{avg_eval_accuracy:.2f}"
                })

    end_time = time.time()
    total_duration = (end_time - start_time) / 60
    logger.info(f"--- Training and evaluation finished. Total duration: {total_duration:.2f} minutes ---")

if __name__ == '__main__':
    try:
        run_experiment()
    finally:
        mixer_module.create_block = original_create_block
        print("Restored original create_block function.")