# train_ablation.py

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import logging
import time
import csv
import os
from tqdm import tqdm
from functools import partial
import copy

from config_ablation import training_config, dataset_config, neuromamba_config
from data_generator import generate_dataset


from neuromamba.models.mixer_seq_simple import NeuroMambaLMHeadModel, Block
from neuromamba.modules.mlp import GatedMLP
try:
    from neuromamba.ops.triton.layer_norm import RMSNorm
except ImportError:
    RMSNorm = None

from neuromamba_ab import NeuroMamba_ab


def create_block_ablation(
    d_model,
    expand_gc,        
    d_intermediate,    
    ssm_cfg=None,
    attn_layer_idx=None, 
    attn_cfg=None,       
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):

    ablate_gc = getattr(neuromamba_config, 'ablate_gc', False)
    ablate_y2 = getattr(neuromamba_config, 'ablate_y2', False)

    factory_kwargs = {"device": device, "dtype": dtype}
    ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}

    mixer_cls = partial(
        NeuroMamba_ab,  
        expand_gc=expand_gc, 
        layer_idx=layer_idx,
        ablate_gc=ablate_gc,
        ablate_y2=ablate_y2, 
        **ssm_cfg,
        **factory_kwargs
    )
    
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )

    if d_intermediate == 0:
         mlp_cls = nn.Identity
    else:
        mlp_cls = partial(
            GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
        )
        
    block = Block(
        d_model,
        mixer_cls,
        mlp_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block



import neuromamba.models.mixer_seq_simple as mixer_module
original_create_block = mixer_module.create_block 
mixer_module.create_block = create_block_ablation

class CustomNeuroMambaLMHeadModel(NeuroMambaLMHeadModel):
    def __init__(self, config, **kwargs):
        original_ssm_cfg = config.ssm_cfg
        config.ssm_cfg['ablate_gc'] = config.ablate_gc
        config.ssm_cfg['ablate_y2'] = config.ablate_y2
        
        super().__init__(config, **kwargs)
        
        config.ssm_cfg = original_ssm_cfg

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

def run_experiment(model, optimizer, criterion, validation_sets, device, logger, 
                   training_csv_file, validation_csv_file):
    logger.info("--- Starting Experiment ---")
    start_time = time.time()
    
    train_fieldnames = ['step', 'loss', 'training_accuracy']
    with open(training_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=train_fieldnames)
        writer.writeheader()
        
    val_fieldnames = ['step', 'sequence_length', 'loss', 'validation_accuracy']
    with open(validation_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
        writer.writeheader()

    progress_bar = tqdm(range(1, training_config["num_steps"] + 1), desc="Training")

    for step in progress_bar:
        model.train()
        inputs, targets = generate_dataset(dataset_config, training_config)
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs).logits
        loss = criterion(outputs.view(-1, neuromamba_config.vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            preds = outputs.argmax(dim=-1)
            mask = targets != -100
            total_targets = mask.sum().item()
            if total_targets > 0:
                correct = (preds[mask] == targets[mask]).sum().item()
                train_accuracy = 100 * correct / total_targets
            else:
                train_accuracy = 0.0

        progress_bar.set_postfix(loss=f"{loss.item():.4f}", train_acc=f"{train_accuracy:.2f}%")

        with open(training_csv_file, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=train_fieldnames)
            writer.writerow({'step': step, 'loss': loss.item(), 'training_accuracy': train_accuracy})

        if step % training_config['eval_interval'] == 0 or step == training_config["num_steps"]:
            model.eval()
            logger.info(f"\n--- Validation at Step {step} ---")
            with torch.no_grad():
                for sl, (eval_inputs, eval_targets) in validation_sets.items():
                    eval_outputs = model(eval_inputs).logits
                    val_loss = criterion(eval_outputs.view(-1, neuromamba_config.vocab_size), eval_targets.view(-1))
                    preds = torch.argmax(eval_outputs, dim=-1)
                    mask_val = eval_targets != -100
                    total_val = mask_val.sum().item()
                    if total_val > 0:
                        correct_val = (preds[mask_val] == eval_targets[mask_val]).sum().item()
                        val_accuracy = 100 * correct_val / total_val
                    else:
                        val_accuracy = 0.0
                    logger.info(f"  SeqLen {sl:<7} -> Loss: {val_loss.item():.4f}, Accuracy: {val_accuracy:.2f}%")
                    with open(validation_csv_file, 'a', newline='') as f:
                        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
                        writer.writerow({'step': step, 'sequence_length': sl, 'loss': val_loss.item(), 'validation_accuracy': val_accuracy})
            logger.info("---------------------------------")
    
    end_time = time.time()
    logger.info(f"--- Experiment completed in {(end_time - start_time) / 60:.2f} minutes ---")

def get_validation_sets(device):
    logger.info("Generating fixed validation sets...")
    validation_sets = {}
    for sl in dataset_config['eval_seq_lens']:
        if sl > 32768: eval_batch_size = 1
        elif sl > 8192: eval_batch_size = 4
        else: eval_batch_size = training_config['batch_size']
        try:
            inputs, targets = generate_dataset(dataset_config, {"batch_size": eval_batch_size}, seq_len=sl)
            validation_sets[sl] = (inputs.to(device), targets.to(device))
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"OOM creating validation set for SeqLen {sl}. Skipping.")
            torch.cuda.empty_cache()
            continue
    logger.info(f"Validation sets created for sequence lengths: {list(validation_sets.keys())}")
    return validation_sets
    
if __name__ == '__main__':
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    logger.info("Initializing Model and Optimizer...")
    
    model = NeuroMambaLMHeadModel(neuromamba_config).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=training_config["learning_rate"], weight_decay=training_config["weight_decay"])
    criterion = nn.CrossEntropyLoss()

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_in_k = round(total_params / 1000)
    param_str = f"{params_in_k}K"
    logger.info(f"Model created. Total trainable parameters: {total_params:,} (~{param_str})")

    level = dataset_config.get('difficulty_level', 0)
    level_str = f"lv{level}"
    if level == 4:
        noise_type = dataset_config.get('level_4_noise_type', 'none')
        sublevel_map = {'none': 0, 'between': 1, 'conflict': 2}
        sublevel = sublevel_map.get(noise_type, 0)
        level_str += f"_{sublevel}"

    if neuromamba_config.ablate_gc and neuromamba_config.ablate_y2:
        ablation_str = "ab_gc_y2"  
    elif neuromamba_config.ablate_gc:
        ablation_str = "ab_gc"
    elif neuromamba_config.ablate_y2:
        ablation_str = "ab_y2"
    else:
        ablation_str = "base"

    n_layer = neuromamba_config.n_layer
    d_model = neuromamba_config.d_model
    arch_str = f"{n_layer}_{d_model}"
    
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    
    base_filename = f"ih_{level_str}_{ablation_str}_{param_str}_{arch_str}_{timestamp}"
    training_csv_filename = f"{base_filename}_training.csv"
    validation_csv_filename = f"{base_filename}_validation.csv"
    
    logger.info(f"Logging training metrics to: {training_csv_filename}")
    logger.info(f"Logging validation results to: {validation_csv_filename}")

    fixed_validation_sets = get_validation_sets(device)
    
    try:
        run_experiment(
            model, optimizer, criterion, 
            fixed_validation_sets, device, logger, 
            training_csv_filename, validation_csv_filename
        )
    except Exception as e:
        logger.error(f"An unhandled error occurred: {e}", exc_info=True)
    finally:
        mixer_module.create_block = original_create_block