import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Optional
from tqdm import tqdm
from pathlib import Path

from cpro import CPRO
from model import RichLazyControlledRNN

def train_epoch(model: RichLazyControlledRNN, 
                env: CPRO,
                optimizer: torch.optim.Optimizer,
                batch_size: int = 128,
                device: torch.device = None) -> Dict[str, float]:
    """
    Train the model for one epoch.
    
    Args:
        model: The RichLazyControlledRNN model
        env: CPRO environment instance
        optimizer: PyTorch optimizer
        batch_size: Size of training batches
    
    Returns:
        Dict containing training metrics for the epoch
    """
    if device is None:
        device = next(model.parameters()).device
    
    model.train()
    criterion = nn.CrossEntropyLoss()

    total_examples = len(env.training_tasks) * 256
    n_batches = total_examples // batch_size
    
    epoch_metrics = {
        'loss': 0,
        'accuracy': 0,
        'grad_norms': {'after_backward': {}, 'after_clip': {}},
        'h2h_weights_dim': 0,
        'weight_norms': {}
    }
    
    for batch_idx in range(n_batches):
        inputs, labels = env.generate_batch(batch_size=128, training=True, batch_idx=batch_idx)
        
        # Move data to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs, _, _ = model(inputs)
        loss = criterion(outputs[-1], labels[-1])
        loss.backward()
        
        # # Store gradients
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         if name not in epoch_metrics['grad_norms']['after_backward']:
        #             epoch_metrics['grad_norms']['after_backward'][name] = []
        #         epoch_metrics['grad_norms']['after_backward'][name].append(param.grad.norm().item())
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # # Store clipped gradients
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         if name not in epoch_metrics['grad_norms']['after_clip']:
        #             epoch_metrics['grad_norms']['after_clip'][name] = []
        #         epoch_metrics['grad_norms']['after_clip'][name].append(param.grad.norm().item())
        
        optimizer.step()
        
        with torch.no_grad():
            decisions = outputs[-1].argmax(dim=-1)
            acc = (decisions == labels[-1]).float().mean()
            epoch_metrics['loss'] += loss.item()
            epoch_metrics['accuracy'] += acc.item()

    # # Average gradient norms across batches for this epoch
    # for name in epoch_metrics['grad_norms']['after_backward']:
    #     epoch_metrics['grad_norms']['after_backward'][name] = np.mean(epoch_metrics['grad_norms']['after_backward'][name])
    #     epoch_metrics['grad_norms']['after_clip'][name] = np.mean(epoch_metrics['grad_norms']['after_clip'][name])
    
    # Average metrics across batches
    epoch_metrics['loss'] /= n_batches
    epoch_metrics['accuracy'] /= n_batches
    
    # Rank and weight norms at the end of each epoch
    epoch_metrics['h2h_weights_dim'] = model.get_h2h_weights_dim()
    for name, param in model.named_parameters():
        epoch_metrics['weight_norms'][name] = param.norm().item()
    
    return epoch_metrics

def evaluate(model: RichLazyControlledRNN, 
            env: CPRO, 
            batch_size: int = 256,
            collect_hidden: bool = False,
            save_all_stimuli: bool = False, 
            device: torch.device = None) -> Dict[str, Any]:
    """
    Evaluate model performance on training and test tasks.
    
    Args:
        model: The RichLazyControlledRNN model
        env: CPRO environment instance
        batch_size: Size of evaluation batches
        collect_hidden: Whether to collect hidden activations
        save_all_stimuli: Whether to save individual stimulus activations (without averaging)
        device: Device to run evaluation on
    
    Returns:
        Dict containing evaluation metrics and error analysis
    """
    
    if device is None:
        device = next(model.parameters()).device
    
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    total_examples = len(env.test_tasks) * 256
    n_batches = total_examples // batch_size
    
    metrics = {
       'training': {'loss': [], 'acc': []},
       'test': {'loss': [], 'acc': []},  # For overall generalization in this case
       'overlap_analysis': {  # For overlap-specific analysis
           1: {'loss': [], 'acc': []},
           2: {'loss': [], 'acc': []},
           3: {'loss': [], 'acc': []}
       }
    }

    hidden_activities = {
        'timepoints': [],
        'task_indices': [],
        'hidden_states': [],
        'task_info': [],
        'stimulus_hidden_states': [] if save_all_stimuli else None,
        'stimulus_info': [] if save_all_stimuli else None
    }
    
    confusion_matrices = {
        'training': torch.zeros(env.config.output_size, env.config.output_size),
        'test': torch.zeros(env.config.output_size, env.config.output_size)
    }
    
    probabilities = {
        'training': {'probs': [], 'labels': []},
        'test': {'probs': [], 'labels': []}
    }
    
    with torch.no_grad():
        
        for batch_idx in range(n_batches):
            
            inputs, labels = env.generate_batch(batch_size=batch_size, training=False, batch_idx=batch_idx)
            
            # Move data to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs, probs, hidden_state = model(inputs)

            if collect_hidden:
                # For this batch, determine which task each example belongs to
                batch_start = batch_idx * batch_size
                for i in range(inputs.size(1)):  # Loop through batch dimension
                    example_idx = batch_start + i
                    task_idx = example_idx // 256  # Integer division to get task index
                    stim_idx = example_idx % 256   # Modulo to get stimulus index within task
                    
                    # Store only if this is a new combination of task and timepoint
                    if task_idx < len(env.test_tasks):
                        task = env.test_tasks[task_idx]
                        task_info = {
                            'logical_ctx': task['logical_ctx'],
                            'sensory_ctx': task['sensory_ctx'],
                            'motor_ctx': task['motor_ctx']
                        }
                        
                        # Get current stimulus combination (index within the task)
                        # Convert to actual stimulus values
                        stim_info = None
                        if save_all_stimuli:
                            stim_info = {
                                'stim_idx': stim_idx,
                                'stimulus': env.all_stim_combinations[stim_idx]
                            }
                        
                        # Now hidden_state is [seq_len, batch_size, hidden_size]
                        # Initialize storage for timepoints if needed
                        if len(hidden_activities['hidden_states']) == 0:
                            for t in range(hidden_state.size(0)):
                                hidden_activities['hidden_states'].append([])
                                hidden_activities['task_indices'].append([])
                                hidden_activities['task_info'].append([])
                                hidden_activities['timepoints'].append(t)
                                
                                # Initialize stimulus-level structures if needed
                                if save_all_stimuli:
                                    hidden_activities['stimulus_hidden_states'].append([])
                                    hidden_activities['stimulus_info'].append([])
                        
                        # Process each timepoint separately
                        for t in range(hidden_state.size(0)):
                            found = False
                            for j, stored_task in enumerate(hidden_activities['task_info'][t]):
                                if stored_task == task_info:
                                    # Add to existing collection for this task
                                    hidden_activities['hidden_states'][t][j] += hidden_state[t, i].detach().cpu()
                                    hidden_activities['task_indices'][t][j] += 1
                                    found = True
                                    break
                            
                            if not found:
                                # Start new collection for this task
                                hidden_activities['hidden_states'][t].append(hidden_state[t, i].detach().cpu())
                                hidden_activities['task_indices'][t].append(1)
                                hidden_activities['task_info'][t].append(task_info)
                                
                            # Store stimulus-level data if requested
                            if save_all_stimuli:
                                # Save individual stimulus data
                                hidden_activities['stimulus_hidden_states'][t].append(hidden_state[t, i].detach().cpu())
                                # Add both task and stimulus info
                                stimulus_full_info = {
                                    'task': task_info,
                                    'stimulus': stim_info
                                }
                                hidden_activities['stimulus_info'][t].append(stimulus_full_info)
            
            # Process batch outputs as before but with batch indices
            batch_start = batch_idx * batch_size
            
            num_combinations = 256  # Number of stimulus combinations per task

            # Calculate indices for this batch's training and test tasks
            training_task_indices = []
            for task in env.training_tasks:
                task_indices = [i for i, test_task in enumerate(env.test_tasks) if test_task == task]
                for task_idx in task_indices:
                    start_idx = task_idx * num_combinations
                    if start_idx >= batch_start and start_idx < batch_start + batch_size:
                        batch_relative_start = start_idx - batch_start
                        batch_relative_end = min(batch_relative_start + num_combinations, batch_size)
                        training_task_indices.extend(range(batch_relative_start, batch_relative_end))
            
            # Process training tasks in this batch
            if training_task_indices:
                train_outputs = outputs[-1, training_task_indices]
                train_probs = probs[-1, training_task_indices]
                train_labels = labels[-1, training_task_indices]
                train_decisions = train_outputs.argmax(dim=-1)
                
                train_loss = criterion(train_outputs, train_labels)
                train_acc = (train_decisions == train_labels).float().mean()
                
                metrics['training']['loss'].append(train_loss.item())
                metrics['training']['acc'].append(train_acc.item())
                
                for pred, label in zip(train_decisions, train_labels):
                    confusion_matrices['training'][pred, label] += 1
                probabilities['training']['probs'].append(train_probs)
                probabilities['training']['labels'].append(train_labels)

            # Process generalization tasks separately from overlap analysis
            test_tasks = [task for task in env.test_tasks if task not in env.training_tasks]
            if test_tasks:
                test_indices = []
                for task in test_tasks:
                    task_idx = env.test_tasks.index(task)
                    start_idx = task_idx * num_combinations
                    if start_idx >= batch_start and start_idx < batch_start + batch_size:
                        batch_relative_start = start_idx - batch_start
                        batch_relative_end = min(batch_relative_start + num_combinations, batch_size)
                        test_indices.extend(range(batch_relative_start, batch_relative_end))
            
                if test_indices:
                    test_outputs = outputs[-1, test_indices]
                    test_probs = probs[-1, test_indices]
                    test_labels = labels[-1, test_indices]
                    test_decisions = test_outputs.argmax(dim=-1)
                    
                    test_loss = criterion(test_outputs, test_labels)
                    test_acc = (test_decisions == test_labels).float().mean()
                    metrics['test']['loss'].append(test_loss.item())
                    metrics['test']['acc'].append(test_acc.item())
            
                    # Update confusion matrix and store probabilities for generalization tasks
                    for pred, label in zip(test_decisions, test_labels):
                        confusion_matrices['test'][pred, label] += 1
                    probabilities['test']['probs'].append(test_probs)
                    probabilities['test']['labels'].append(test_labels)

            # Process generalization tasks by overlap
            categorized_tasks = env.categorize_test_tasks()
            for overlap in [1, 2]:
                test_tasks = [task for task in categorized_tasks[overlap] if task not in env.training_tasks]
                if test_tasks:
                    test_indices = []
                    for task in test_tasks:
                        task_idx = env.test_tasks.index(task)
                        start_idx = task_idx * num_combinations
                        if start_idx >= batch_start and start_idx < batch_start + batch_size:
                            batch_relative_start = start_idx - batch_start
                            batch_relative_end = min(batch_relative_start + num_combinations, batch_size)
                            test_indices.extend(range(batch_relative_start, batch_relative_end))
                    
                    if test_indices:
                        test_outputs = outputs[-1, test_indices]
                        test_labels = labels[-1, test_indices]
                        test_decisions = test_outputs.argmax(dim=-1)
                        
                        test_loss = criterion(test_outputs, test_labels)
                        test_acc = (test_decisions == test_labels).float().mean()
                        
                        metrics['overlap_analysis'][overlap]['loss'].append(test_loss.item())
                        metrics['overlap_analysis'][overlap]['acc'].append(test_acc.item())
            
    # Average metrics and prepare final output
    for overlap in [1, 2]:
        if metrics['overlap_analysis'][overlap]['loss']:
            metrics['overlap_analysis'][overlap] = {
                'loss': np.mean(metrics['overlap_analysis'][overlap]['loss']),
                'acc': np.mean(metrics['overlap_analysis'][overlap]['acc']),
                'n_tasks': len([task for task in categorized_tasks[overlap] 
                              if task not in env.training_tasks])
            }
    # Add training metrics as overlap 3
    metrics['overlap_analysis'][3] = {
        'loss': metrics['training']['loss'],
        'acc': metrics['training']['acc'],
        'n_tasks': len(env.training_tasks)
    }

    # Convert probability lists to tensors
    for key in probabilities:
        probabilities[key]['probs'] = torch.cat(probabilities[key]['probs'])
        probabilities[key]['labels'] = torch.cat(probabilities[key]['labels'])

    if collect_hidden:
        # Average hidden states across examples for each task
        for t in range(len(hidden_activities['hidden_states'])):
            for j in range(len(hidden_activities['hidden_states'][t])):
                # Divide by count to get average
                count = hidden_activities['task_indices'][t][j]
                if count > 0:
                    hidden_activities['hidden_states'][t][j] /= count
                    
        # If we're not saving all stimuli, remove the empty fields
        if not save_all_stimuli:
            hidden_activities.pop('stimulus_hidden_states', None)
            hidden_activities.pop('stimulus_info', None)
    
    return_dict =  {
        'training': {
            'loss': np.mean(metrics['training']['loss']),
            'acc': np.mean(metrics['training']['acc'])
        },
        'test': {
            'loss': np.mean(metrics['test']['loss']),
            'acc': np.mean(metrics['test']['acc'])
        },
        'overlap_analysis': metrics['overlap_analysis'],
        'error_analysis': {
            'confusion_matrices': confusion_matrices,
            'probabilities': probabilities
        },
        'hidden_activities': hidden_activities
    }
    return return_dict

def run_experiment(env: CPRO,
                   hidden_size: int,
                   scale: float,
                   weight_decay: float = 0.01,
                   learning_rate: float = 0.01,  # Default LR for SGD
                   optimizer_type: str = 'sgd',  # Default to SGD
                   train_batch_size: int = 128,
                   eval_batch_size: int = 256,
                   n_epochs: int = 1000,
                   eval_every: int = 10,
                   use_early_stopping: bool = False,
                   early_stop_acc: float = 0.95,
                   patience: int = 10,
                   results_dir: Optional[Path] = None,
                   seed: Optional[int] = None,
                   use_gpu: bool = False,
                   save_all_stimuli: bool = False,
                   save_initial_states: bool = False) -> Dict:
    """
    Run complete training experiment with specified rank and weight decay. Tracks three types of performance:
    1. Practice performance: How well network performs during training
    2. Test performance on training tasks: How well it recalls trained tasks
    3. Generalization performance: How well it performs on unseen tasks
    
    Args:
        env: CPRO environment instance
        hidden_size: Size of RNN hidden layer
        rank: Rank constraint for recurrent weights
        weight_decay: Weight decay parameter for AdamW optimizer
        train_batch_size: Batch size during training
        eval_batch_size: Batch size during evaluation
        n_epochs: Maximum number of training epochs
        eval_every: Evaluate every N epochs
        use_early_stopping: Whether to use early stopping
        early_stop_acc: Accuracy threshold for early stopping
        patience: Number of epochs to wait before early stopping
        save_all_stimuli: Whether to save individual stimulus activations
        save_initial_states: Whether to save the initial hidden states before training
    
    Returns:
        Dict containing complete training history and results, including weight decay setting
    """
    # Initialize patience variables
    patience_counter = 0
    best_acc = 0
    
    # Diagnosis 
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"GPU count: {torch.cuda.device_count() if torch.cuda.is_available() else 0}")
    print(f"use_gpu flag: {use_gpu}")
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")
    print(f"Using device: {device}")
    
    model = RichLazyControlledRNN(
        input_size=env.config.input_size,
        hidden_size=hidden_size,
        output_size=env.config.output_size,
        scale=scale
    )
    
    # Move model to device
    model = model.to(device)
    
    if save_initial_states:
        initial_eval = evaluate(model, env, batch_size=eval_batch_size, 
                               collect_hidden=True, save_all_stimuli=save_all_stimuli)
        
        # Save initial hidden activities if results_dir is provided
        if results_dir is not None:
            initial_hidden_file = f"initial_hidden_scale{scale}_{optimizer_type}_seed{seed}.pt"
            torch.save(initial_eval['hidden_activities'], results_dir / initial_hidden_file)
            print(f"Initial hidden states saved to {initial_hidden_file}")
    
    if optimizer_type.lower() == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    else:  # default to SGD without weight decay
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    
    history = {
        # Performance during practice
        'practice_loss': [],      # Loss while learning training tasks
        'practice_acc': [],       # Accuracy while learning training tasks
        # 'grad_norms': {
        #     'after_backward': [],
        #     'after_clip': []
        #     },
        
        # Performance during evaluation (no updates)
        'trained_task_loss': [],  # Loss on training tasks during testing
        'trained_task_acc': [],   # Accuracy on training tasks during testing
        'test_task_loss': [],           # Loss on generalization tasks
        'test_task_acc': [],            # Accuracy on generalization tasks
        
        # Detailed overlap analysis
        'overlap_loss': {1: [], 2: [], 3: []},
        'overlap_acc': {1: [], 2: [], 3: []},
        
        # Model properties
        'h2h_weights_dim': [],

        # Weight decay
        'weight_decay': weight_decay,
        'weight_norms': [],

        # Error analysis
        'error_analysis': None
    }

    # Add a variable to track when we hit threshold
    stopping_epoch = n_epochs  # Default to max epochs

    # Training loop with integrated periodic evaluation
    with tqdm(total=n_epochs) as pbar:
        for epoch in range(n_epochs):
            # Training step
            train_metrics = train_epoch(model, env, optimizer, batch_size=train_batch_size)
            
            # Store practice performance
            history['practice_loss'].append(train_metrics['loss'])
            history['practice_acc'].append(train_metrics['accuracy'])
            history['h2h_weights_dim'].append(train_metrics['h2h_weights_dim'])
            # history['grad_norms']['after_backward'].append(train_metrics['grad_norms']['after_backward'])
            # history['grad_norms']['after_clip'].append(train_metrics['grad_norms']['after_clip'])
            history['weight_norms'].append(train_metrics['weight_norms'])
            
            # Periodic evaluation during training
            if epoch % eval_every == 0:
                eval_metrics = evaluate(model, env, batch_size=eval_batch_size)
                
                # Store evaluation metrics
                history['trained_task_loss'].append(eval_metrics['training']['loss'])
                history['trained_task_acc'].append(eval_metrics['training']['acc'])
                history['test_task_loss'].append(eval_metrics['test']['loss'])
                history['test_task_acc'].append(eval_metrics['test']['acc'])
                
                # Store detailed overlap analysis
                for overlap in [1, 2, 3]:
                    history['overlap_loss'][overlap].append(
                        eval_metrics['overlap_analysis'][overlap]['loss']
                    )
                    history['overlap_acc'][overlap].append(
                        eval_metrics['overlap_analysis'][overlap]['acc']
                    )
            
            # Early stopping check
            if use_early_stopping:
                if train_metrics['accuracy'] >= early_stop_acc:
                    if train_metrics['accuracy'] > best_acc:
                        best_acc = train_metrics['accuracy']
                        patience_counter = 0
                    else:
                        patience_counter += 1
                        
                    if patience_counter >= patience:
                        stopping_epoch = epoch
                        print(f"\nSustained accuracy threshold for {patience} epochs at epoch {epoch}")
                        break
            
            pbar.update(1)
        
    # Store evaluation metrics
    history['trained_task_loss'].append(eval_metrics['training']['loss'])
    history['trained_task_acc'].append(eval_metrics['training']['acc'])
    history['test_task_loss'].append(eval_metrics['test']['loss'])
    history['test_task_acc'].append(eval_metrics['test']['acc'])
    
    # Store detailed overlap analysis
    for overlap in [1, 2, 3]:
        history['overlap_loss'][overlap].append(
            eval_metrics['overlap_analysis'][overlap]['loss']
        )
        history['overlap_acc'][overlap].append(
            eval_metrics['overlap_analysis'][overlap]['acc']
        )

    # Store final evaluation's hidden activity and error analysis
    final_eval = evaluate(model, env, batch_size=eval_batch_size, collect_hidden=True, save_all_stimuli=save_all_stimuli)
    history['hidden_activities'] = final_eval['hidden_activities']
    history['error_analysis'] = final_eval['error_analysis']

    # Save the final model state as a separate file if results_dir is provided
    if results_dir is not None:
        # Create a filename that includes all relevant hyperparameters
        model_filename = f"model_scale{scale}_{optimizer_type}_seed{seed}.pt"
        
        # Save the model state dictionary
        torch.save(model.state_dict(), results_dir / model_filename)

        # Save hidden activities separately (they're large)
        hidden_file = f"hidden_scale{scale}_{optimizer_type}_seed{seed}.pt"
        torch.save(history['hidden_activities'], results_dir / hidden_file)
        history['hidden_file'] = hidden_file
        
        # Remove from main history to keep JSON size manageable
        del history['hidden_activities']
        
        # Add the filename to the history for reference
        history['model_file'] = model_filename

    return history

__all__ = ['train_epoch', 'evaluate', 'run_experiment']