import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import logging
import time
import csv
import os
from neuromamba.models.mixer_seq_simple import NeuroMambaLMHeadModel
from config import training_config, dataset_config, neuma_config
from data_generator import generate_dataset

def set_seed(seed=42):
    """Sets the random seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run_experiment():
    """
    Runs the complete training and evaluation pipeline for the NeuroMamba model.
    The model is trained and evaluated at specified intervals, with evaluation 
    results saved to a CSV file in the 'results' directory.
    """
    set_seed(42)

    # --- 1. Setup Logging and Device ---
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device: {device}')

    # --- 2. Initialize Model, Loss Function, and Optimizer ---
    model = NeuroMambaLMHeadModel(neuma_config, device=device)
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=training_config["learning_rate"])
    optimizer_name = optimizer.__class__.__name__.lower()
    
    # --- 3. Configure Logging and Evaluation Parameters ---
    eval_interval = 1000  # Evaluate every 1000 steps
    eval_steps = 100      # Use 100 batches for a stable evaluation result

    # Create results directory if it doesn't exist
    output_dir = "results"
    os.makedirs(output_dir, exist_ok=True)

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    csv_filename = (f"log_{optimizer_name}_{neuma_config.expand_gc}_"
                    f"lr{training_config['learning_rate']:.0e}_{timestamp}.csv")
    csv_filepath = os.path.join(output_dir, csv_filename)
    
    # Use the user-specified fieldnames for the CSV file
    csv_fieldnames = ['step', 'loss', 'accuracy']
    
    # Initialize the CSV log file
    with open(csv_filepath, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=csv_fieldnames)
        writer.writeheader()
    logger.info(f"Evaluation results will be saved to: {csv_filepath}")

    # --- 4. Training and Evaluation Loop ---
    start_time = time.time()
    logger.info("--- Starting Training ---")

    for step in range(1, training_config["num_steps"] + 1):
        # --- Training Step ---
        model.train() # Set model to training mode
        
        inputs, targets = generate_dataset(dataset_config, training_config)
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs, num_last_tokens=dataset_config['l_memorize']).logits
        loss = criterion(outputs.view(-1, neuma_config.vocab_size), targets.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log training loss periodically (optional)
        if step % 200 == 0:
            logger.info(f'Step [{step}/{training_config["num_steps"]}], Training Loss: {loss.item():.4f}')

        # --- Evaluation Step ---
        if step % eval_interval == 0 or step == training_config["num_steps"]:
            model.eval() # Set model to evaluation mode
            total_eval_loss = 0
            total_eval_correct = 0
            total_eval_tokens = 0
            
            logger.info(f"--- Step {step}: Starting evaluation... ---")
            with torch.no_grad(): # Disable gradient calculation
                for _ in range(eval_steps):
                    val_inputs, val_targets = generate_dataset(dataset_config, training_config)
                    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

                    val_outputs = model(val_inputs, num_last_tokens=dataset_config['l_memorize']).logits
                    
                    val_loss = criterion(val_outputs.view(-1, neuma_config.vocab_size), val_targets.view(-1))
                    total_eval_loss += val_loss.item()
                    
                    preds = val_outputs.argmax(dim=-1)
                    total_eval_correct += (preds == val_targets).sum().item()
                    total_eval_tokens += val_targets.numel()
            
            avg_eval_loss = total_eval_loss / eval_steps
            avg_eval_accuracy = 100 * total_eval_correct / total_eval_tokens
            
            logger.info(f'>>> Evaluation at step [{step}]: '
                        f'Average Validation Loss: {avg_eval_loss:.4f}, '
                        f'Validation Accuracy: {avg_eval_accuracy:.2f}% <<<')

            # Log metrics to CSV file using the specified headers
            with open(csv_filepath, 'a', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=csv_fieldnames)
                writer.writerow({
                    'step': step,
                    'loss': f"{avg_eval_loss:.4f}",
                    'accuracy': f"{avg_eval_accuracy:.2f}"
                })

    end_time = time.time()
    total_duration = (end_time - start_time) / 60
    logger.info(f"--- Training and evaluation finished. Total duration: {total_duration:.2f} minutes ---")


if __name__ == '__main__':
    run_experiment()