"""
================================================================================
                        Smart Log Filename Convention
================================================================================

This script automatically generates descriptive filenames for the training and
validation log files, ensuring that each experiment is self-documenting.
The filename structure provides a quick overview of the key experimental settings.

The convention is as follows:

    ih_{type}_{level_str}_{param_str}_{arch_str}_{timestamp}.csv

Where each component represents:

1.  ih:
    - Fixed prefix, standing for "Induction Heads", our core task.

2.  {type}:
    - "training" or "validation", indicating the content of the log file.

3.  {level_str}:
    - Describes the difficulty level of the dataset, controlled by `difficulty_level`
      in `config.py`.
    - Examples:
        - "lv0": Baseline task (Level 0)
        - "lv1": Memory Robustness task (Level 1)
        - "lv2": Abstract Pattern Recognition task (Level 2)
        - "lv3": Combined Stress Test (Level 3)
    - For Level 4 (Autonomous Learning), it includes a sub-level:
        - "lv4_0": Sanity Check (level_4_noise_type: 'none')
        - "lv4_1": Robust Discovery (level_4_noise_type: 'between')
        - "lv4_2": Dynamic World Modeling (level_4_noise_type: 'conflict')

4.  {param_str}:
    - Represents the total number of trainable parameters in the model,
      rounded to the nearest thousand and expressed in "K".
    - Example: "27K", "138K", "700K"

5.  {arch_str}:
    - A compact representation of the model's core architecture.
    - Format: "{n_layer}_{d_model}"
    - Example:
        - "2_64": 2 layers, 64 model dimension.
        - "4_96": 4 layers, 96 model dimension.

6.  {timestamp}:
    - The date and time the experiment was started.
    - Format: "YYYYMMDD-HHMMSS"
    - Example: "20250716-181410"

---
Example Filename:
    ih_validation_lv4_2_55K_2_96_20250716-193000.csv

This filename instantly tells us:
- It's a validation log for the Induction Heads task.
- The difficulty was Level 4.2 (Dynamic World Modeling).
- The model had approximately 55,000 parameters.
- The model architecture was 2 layers with a dimension of 96.
- The experiment was run on July 16, 2025, at 7:30 PM.
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import logging
import time
import csv
import os
from tqdm import tqdm

from neuromamba.models.mixer_seq_simple import NeuroMambaLMHeadModel
from config import training_config, dataset_config, neuromamba_config
from data_generator import generate_dataset

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

def run_experiment(model, optimizer, criterion, validation_sets, device, logger, 
                   training_csv_file, validation_csv_file):
    logger.info("--- Starting Experiment ---")
    start_time = time.time()
    
    train_fieldnames = ['step', 'loss', 'training_accuracy']
    with open(training_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=train_fieldnames)
        writer.writeheader()
        
    val_fieldnames = ['step', 'sequence_length', 'loss', 'validation_accuracy']
    with open(validation_csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
        writer.writeheader()

    progress_bar = tqdm(range(1, training_config["num_steps"] + 1), desc="Training")

    for step in progress_bar:
        model.train()
        inputs, targets = generate_dataset(dataset_config, training_config)
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs).logits
        loss = criterion(outputs.view(-1, neuromamba_config.vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            preds = outputs.argmax(dim=-1)
            mask = targets != -100
            correct = (preds[mask] == targets[mask]).sum().item()
            total_targets = mask.sum().item()
            train_accuracy = 100 * correct / total_targets if total_targets > 0 else 0.0

        progress_bar.set_postfix(loss=f"{loss.item():.4f}", train_acc=f"{train_accuracy:.2f}%")

        with open(training_csv_file, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=train_fieldnames)
            writer.writerow({'step': step, 'loss': loss.item(), 'training_accuracy': train_accuracy})

        if step % training_config['eval_interval'] == 0 or step == training_config["num_steps"]:
            model.eval()
            logger.info(f"\n--- Validation at Step {step} ---")
            with torch.no_grad():
                for sl, (eval_inputs, eval_targets) in validation_sets.items():
                    eval_outputs = model(eval_inputs).logits
                    val_loss = criterion(eval_outputs.view(-1, neuromamba_config.vocab_size), eval_targets.view(-1))
                    preds = torch.argmax(eval_outputs, dim=-1)
                    mask_val = eval_targets != -100
                    correct_val = (preds[mask_val] == eval_targets[mask_val]).sum().item()
                    total_val = mask_val.sum().item()
                    val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
                    logger.info(f"  SeqLen {sl:<7} -> Loss: {val_loss.item():.4f}, Accuracy: {val_accuracy:.2f}%")
                    with open(validation_csv_file, 'a', newline='') as f:
                        writer = csv.DictWriter(f, fieldnames=val_fieldnames)
                        writer.writerow({'step': step, 'sequence_length': sl, 'loss': val_loss.item(), 'validation_accuracy': val_accuracy})
            logger.info("---------------------------------")
    
    end_time = time.time()
    logger.info(f"--- Experiment completed in {(end_time - start_time) / 60:.2f} minutes ---")

def get_validation_sets(device):
    # This function remains unchanged
    logger.info("Generating fixed validation sets...")
    validation_sets = {}
    for sl in dataset_config['eval_seq_lens']:
        if sl > 32768: eval_batch_size = 1
        elif sl > 8192: eval_batch_size = 4
        else: eval_batch_size = training_config['batch_size'] * 16
        try:
            inputs, targets = generate_dataset(dataset_config, {"batch_size": eval_batch_size}, seq_len=sl)
            validation_sets[sl] = (inputs.to(device), targets.to(device))
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"OOM creating validation set for SeqLen {sl}. Skipping.")
            torch.cuda.empty_cache()
            continue
    logger.info(f"Validation sets created for sequence lengths: {list(validation_sets.keys())}")
    return validation_sets
    
if __name__ == '__main__':
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    logger.info("Initializing Model and Optimizer...")
    model = NeuroMambaLMHeadModel(neuromamba_config).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=training_config["learning_rate"], weight_decay=training_config["weight_decay"])
    criterion = nn.CrossEntropyLoss()

    # --- Parameter Count Calculation ---
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_in_k = round(total_params / 1000)

    param_str = f"{params_in_k}K"
    logger.info(f"Model created. Total trainable parameters: {total_params:,} (~{param_str})")

    # --- Smart Filename Generation ---
    # 1. Get difficulty level string
    level = dataset_config.get('difficulty_level', 0)
    level_str = f"lv{level}"
    if level == 4:
        noise_type = dataset_config.get('level_4_noise_type', 'none')
        sublevel_map = {'none': 0, 'between': 1, 'conflict': 2}
        sublevel = sublevel_map.get(noise_type, 0)
        level_str += f"_{sublevel}"

    # <<< MODIFIED: Get architecture details for filename
    n_layer = neuromamba_config.n_layer
    d_model = neuromamba_config.d_model
    arch_str = f"{n_layer}_{d_model}"
    
    # 3. Get timestamp
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    
    # 4. Assemble the final filename
    # New format: ih_type_level_params_arch_timestamp.csv
    training_csv_filename = f"ih_training_{level_str}_{param_str}_{arch_str}_{timestamp}.csv"
    validation_csv_filename = f"ih_validation_{level_str}_{param_str}_{arch_str}_{timestamp}.csv"
    
    logger.info(f"Logging training metrics to: {training_csv_filename}")
    logger.info(f"Logging validation results to: {validation_csv_filename}")
    # --- End of Filename Logic ---

    fixed_validation_sets = get_validation_sets(device)
    
    try:
        run_experiment(
            model, optimizer, criterion, 
            fixed_validation_sets, device, logger, 
            training_csv_filename, validation_csv_filename
        )
    except Exception as e:
        logger.error(f"An unhandled error occurred: {e}", exc_info=True)