import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import logging
import time
import csv
import os
from tqdm import tqdm
import copy

# Import model and default configs
from neuromamba.models.mixer_seq_simple import NeuroMambaLMHeadModel
from config import training_config as default_training_config
from config import dataset_config as default_dataset_config
from config import neuromamba_config as default_neuromamba_config
from data_generator import generate_dataset

# Setup logger
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run_experiment(model, optimizer, criterion, validation_sets, device, 
                   training_csv_file, validation_csv_file, t_config, d_config, nm_config):
    logger.info("--- Starting Experiment Run ---")
    start_time = time.time()
    
    train_fieldnames = ['step', 'loss', 'training_accuracy']
    with open(training_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=train_fieldnames)
        writer.writeheader()
        
    val_fieldnames = ['step', 'sequence_length', 'loss', 'validation_accuracy']
    with open(validation_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
        writer.writeheader()

    progress_bar = tqdm(range(1, t_config["num_steps"] + 1), desc="Training")

    for step in progress_bar:
        model.train()
        inputs, targets = generate_dataset(d_config, t_config)
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs).logits
        loss = criterion(outputs.view(-1, nm_config.vocab_size), targets.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            preds = outputs.argmax(dim=-1)
            mask = targets != -100
            correct = (preds[mask] == targets[mask]).sum().item()
            total_targets = mask.sum().item()
            train_accuracy = 100 * correct / total_targets if total_targets > 0 else 0.0

        progress_bar.set_postfix(loss=f"{loss.item():.4f}", train_acc=f"{train_accuracy:.2f}%")

        with open(training_csv_file, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=train_fieldnames)
            writer.writerow({'step': step, 'loss': loss.item(), 'training_accuracy': train_accuracy})

        if step % t_config['eval_interval'] == 0 or step == t_config["num_steps"]:
            model.eval()
            logger.info(f"\n--- Validation at Step {step} ---")
            with torch.no_grad():
                for sl, (eval_inputs, eval_targets) in validation_sets.items():
                    eval_outputs = model(eval_inputs).logits
                    val_loss = criterion(eval_outputs.view(-1, nm_config.vocab_size), eval_targets.view(-1))
                    preds = torch.argmax(eval_outputs, dim=-1)
                    mask_val = eval_targets != -100
                    correct_val = (preds[mask_val] == eval_targets[mask_val]).sum().item()
                    total_val = mask_val.sum().item()
                    val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
                    logger.info(f"  SeqLen {sl:<7} -> Loss: {val_loss.item():.4f}, Accuracy: {val_accuracy:.2f}%")
                    with open(validation_csv_file, 'a', newline='') as f:
                        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
                        writer.writerow({'step': step, 'sequence_length': sl, 'loss': val_loss.item(), 'validation_accuracy': val_accuracy})
            logger.info("---------------------------------")
    
    end_time = time.time()
    logger.info(f"--- Experiment run completed in {(end_time - start_time) / 60:.2f} minutes ---")

def get_validation_sets(device, d_config, t_config):
    logger.info("Generating fixed validation sets...")
    validation_sets = {}
    for sl in d_config['eval_seq_lens']:
        if sl > 32768: eval_batch_size = 1
        elif sl > 8192: eval_batch_size = max(1, t_config['batch_size'] // 2)
        else: eval_batch_size = t_config['batch_size'] * 2
        
        temp_t_config = {"batch_size": eval_batch_size}
        try:
            inputs, targets = generate_dataset(d_config, temp_t_config, seq_len=sl)
            validation_sets[sl] = (inputs.to(device), targets.to(device))
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"OOM creating validation set for SeqLen {sl}. Skipping.")
            torch.cuda.empty_cache()
            continue
        except Exception as e:
            logger.error(f"Error generating validation set for SeqLen {sl}: {e}. Skipping.")
            continue
            
    logger.info(f"Validation sets created for sequence lengths: {list(validation_sets.keys())}")
    return validation_sets
    
def execute_neuromamba_training_run(t_config, d_config, nm_config):
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Executing run on device: {device}")
    
    nm_config.vocab_size = d_config['vocab_size']
    
    logger.info("Initializing NeuroMamba Model and Optimizer...")
    model = NeuroMambaLMHeadModel(nm_config).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=t_config["learning_rate"], weight_decay=t_config["weight_decay"])
    criterion = nn.CrossEntropyLoss()

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_in_k = round(total_params / 1000)
    param_str = f"{params_in_k}K"
    logger.info(f"Model created. Arch: {nm_config.n_layer}L_{nm_config.d_model}D, expand_gc: {nm_config.expand_gc}. Params: {total_params:,} (~{param_str})")

    # --- Smart Filename Generation ---
    level = d_config.get('difficulty_level', 0)
    level_str = f"lv{level}"
    if level == 4:
        noise_type = d_config.get('level_4_noise_type', 'none')
        sublevel_map = {'none': 0, 'between': 1, 'conflict': 2}
        sublevel = sublevel_map.get(noise_type, 0)
        level_str += f"_{sublevel}"
    
    # <<< MODIFIED: Architecture string now includes expand_gc for clarity >>>
    gc_val = str(nm_config.expand_gc).replace('.', 'p') # replace 2.5 with 2p5
    arch_str = f"{nm_config.n_layer}_{nm_config.d_model}_gc{gc_val}"
    
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    
    os.makedirs("logs_neuromamba", exist_ok=True)
    base_filename = os.path.join("logs_neuromamba", f"ih_{level_str}_{param_str}_{arch_str}_{timestamp}")
    training_csv_filename = f"{base_filename}_training.csv"
    validation_csv_filename = f"{base_filename}_validation.csv"
    
    logger.info(f"Logging training metrics to: {training_csv_filename}")
    logger.info(f"Logging validation results to: {validation_csv_filename}")
    
    fixed_validation_sets = get_validation_sets(device, d_config, t_config)
    
    if not fixed_validation_sets:
        logger.warning("No validation sets were created (likely due to OOM). Skipping run.")
        return

    run_experiment(
        model, optimizer, criterion, 
        fixed_validation_sets, device,
        training_csv_filename, validation_csv_filename,
        t_config, d_config, nm_config
    )

if __name__ == '__main__':
    logger.info("Running a single NeuroMamba experiment with default config from config.py...")
    execute_neuromamba_training_run(default_training_config, default_dataset_config, default_neuromamba_config)