"""
Trainer Module

This module contains the dataset, model, and trainer for training the KV correction MLP.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import logging
from typing import List, Optional, Tuple

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class KVCorrectionDataset(Dataset):
    """
    Dataset for KV correction training.
    
    Loads preprocessed data containing:
    - Trigger token hidden states (at error detection point)
    - Error KV states (incorrect anchor KV)
    - Correct anchor KV states (ground truth targets)
    """
    
    def __init__(
        self, 
        data_dir: str, 
        dataset_names: List[str], 
        layer_idx: int, 
        train_indices: Optional[List[int]] = None
    ):
        """
        Initialize the dataset.
        
        Args:
            data_dir: Directory containing preprocessed data
            dataset_names: List of dataset names to load
            layer_idx: Which layer's data to use
            train_indices: Indices of samples to use. If None, use all samples
        """
        self.data_dir = Path(data_dir)
        self.layer_idx = layer_idx
        self.layer_dir = self.data_dir / f"layer_{layer_idx}"
        
        # Storage for loaded data
        self.trigger_hiddens = []
        self.error_kvs = []
        self.anchor_kvs = []
        self.sample_infos = []
        
        # Load all datasets
        for dataset_name in dataset_names:
            self._load_dataset(dataset_name, train_indices)
        
        logger.info(f"Dataset loaded: {len(self.trigger_hiddens)} total samples")
        
    def _load_dataset(self, dataset_name: str, train_indices: Optional[List[int]] = None):
        """
        Load a single dataset.
        
        Args:
            dataset_name: Name of the dataset
            train_indices: Indices to load. If None, load all
        """
        # Load sample information
        sample_info_file = self.data_dir / f"{dataset_name}_sample_info.json"
        with open(sample_info_file, 'r', encoding='utf-8') as f:
            sample_info = json.load(f)
        
        # Load tensor data
        trigger_hidden = torch.load(
            self.layer_dir / f"{dataset_name}_trigger_hidden.pt", 
            map_location='cpu'
        )
        error_kv = torch.load(
            self.layer_dir / f"{dataset_name}_error_kv.pt", 
            map_location='cpu'
        )
        anchor_kv = torch.load(
            self.layer_dir / f"{dataset_name}_anchor_kv.pt", 
            map_location='cpu'
        )
        
        # Filter by indices if specified
        if train_indices is not None:
            sample_info = [sample_info[i] for i in train_indices]
            trigger_hidden = trigger_hidden[train_indices]
            error_kv = {k: v[train_indices] for k, v in error_kv.items()}
            anchor_kv = {k: v[train_indices] for k, v in anchor_kv.items()}
        
        # Append to storage
        self.sample_infos.extend(sample_info)
        self.trigger_hiddens.append(trigger_hidden)
        self.error_kvs.append(error_kv)
        self.anchor_kvs.append(anchor_kv)
        
        logger.info(f"Loaded dataset '{dataset_name}': {len(sample_info)} samples")
    
    def __len__(self) -> int:
        """Return total number of samples."""
        return sum(len(th) for th in self.trigger_hiddens)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get a single training sample.
        
        Args:
            idx: Sample index
            
        Returns:
            input_tensor: Concatenated [trigger_hidden, error_kv_flat]
            target_tensor: Correct anchor_kv_flat
        """
        # Find which dataset and sample index
        dataset_idx = 0
        sample_idx = idx
        
        for i, th in enumerate(self.trigger_hiddens):
            if sample_idx < len(th):
                dataset_idx = i
                break
            sample_idx -= len(th)
        
        # Get trigger hidden state
        trigger_hidden = self.trigger_hiddens[dataset_idx][sample_idx]  # [hidden_dim]
        
        # Get error KV and flatten
        error_kv = self.error_kvs[dataset_idx]
        error_key = error_kv['keys'][sample_idx]  # [1, num_heads, head_dim]
        error_value = error_kv['values'][sample_idx]  # [1, num_heads, head_dim]
        
        # Flatten KV: [1, num_heads, head_dim] -> [num_heads * head_dim]
        # Then concatenate key and value
        error_kv_flat = torch.cat([
            error_key.flatten(),  # [num_heads * head_dim]
            error_value.flatten()  # [num_heads * head_dim]
        ])  # [2 * num_heads * head_dim]
        
        # Get correct anchor KV and flatten
        anchor_kv = self.anchor_kvs[dataset_idx]
        anchor_key = anchor_kv['keys'][sample_idx]  # [1, num_heads, head_dim]
        anchor_value = anchor_kv['values'][sample_idx]  # [1, num_heads, head_dim]
        
        anchor_kv_flat = torch.cat([
            anchor_key.flatten(),  # [num_heads * head_dim]
            anchor_value.flatten()  # [num_heads * head_dim]
        ])  # [2 * num_heads * head_dim]
        
        # Concatenate input: trigger hidden + error KV
        input_tensor = torch.cat([
            trigger_hidden.float(),  # [hidden_dim]
            error_kv_flat.float()    # [2 * num_heads * head_dim]
        ])  # [hidden_dim + 2 * num_heads * head_dim]
        
        target_tensor = anchor_kv_flat.float()  # [2 * num_heads * head_dim]
        
        return input_tensor, target_tensor


class KVCorrectionMLP(nn.Module):
    """
    MLP model for KV correction.
    
    Takes concatenated [trigger_hidden, error_kv] as input and predicts corrected KV.
    """
    
    def __init__(
        self, 
        input_dim: int, 
        output_dim: int, 
        hidden_dim1: int = 2048, 
        hidden_dim2: int = 1024,
        dropout_rate: float = 0.1
    ):
        """
        Initialize the MLP model.
        
        Args:
            input_dim: Input dimension (hidden_dim + 2*kv_dim)
            output_dim: Output dimension (2*kv_dim for corrected K and V)
            hidden_dim1: First hidden layer dimension
            hidden_dim2: Second hidden layer dimension
            dropout_rate: Dropout rate for regularization
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        # Network architecture
        self.network = nn.Sequential(
            # Input -> Hidden1
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            # Hidden1 -> Hidden2
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            # Hidden2 -> Output
            nn.Linear(hidden_dim2, output_dim)
        )
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Input tensor [batch_size, input_dim]
            
        Returns:
            Output tensor [batch_size, output_dim]
        """
        return self.network(x)
    
    def predict_kv(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict and reshape output to KV format.
        
        Args:
            x: Input tensor [batch_size, input_dim]
            
        Returns:
            keys: Predicted keys [batch_size, 1, num_heads, head_dim]
            values: Predicted values [batch_size, 1, num_heads, head_dim]
        """
        output = self.forward(x)  # [batch_size, output_dim]
        
        # Reshape to KV format: [batch_size, output_dim] -> [batch_size, 2, 1, num_heads, head_dim]
        batch_size = output.shape[0]
        # Assuming output_dim = 2 * num_heads * head_dim
        # For example: output_dim=2048 -> 2 * 8 * 128
        kv_reshaped = output.view(batch_size, 2, 1, 8, 128)
        
        keys = kv_reshaped[:, 0]    # [batch_size, 1, 8, 128]
        values = kv_reshaped[:, 1]  # [batch_size, 1, 8, 128]
        
        return keys, values


class MLPTrainer:
    """
    Trainer for KV correction MLP.
    
    Handles training loop, validation, checkpointing, and visualization.
    """
    
    def __init__(self, model: nn.Module, device: str = 'cuda'):
        """
        Initialize the trainer.
        
        Args:
            model: Model to train
            device: Device to train on ('cuda' or 'cpu')
        """
        self.model = model.to(device)
        self.device = device
        self.train_losses = []
        self.val_losses = []
        
    def train_epoch(
        self, 
        train_loader: DataLoader, 
        optimizer: optim.Optimizer, 
        criterion: nn.Module
    ) -> float:
        """
        Train for one epoch.
        
        Args:
            train_loader: Training data loader
            optimizer: Optimizer
            criterion: Loss function
            
        Returns:
            Average training loss for the epoch
        """
        self.model.train()
        total_loss = 0.0
        
        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc="Training")):
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        return total_loss / len(train_loader)
    
    def validate(self, val_loader: DataLoader, criterion: nn.Module) -> float:
        """
        Validate the model.
        
        Args:
            val_loader: Validation data loader
            criterion: Loss function
            
        Returns:
            Average validation loss
        """
        self.model.eval()
        total_loss = 0.0
        
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc="Validating"):
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                total_loss += loss.item()
        
        return total_loss / len(val_loader)
    
    def train(
        self, 
        train_loader: DataLoader, 
        val_loader: DataLoader, 
        num_epochs: int, 
        learning_rate: float = 1e-4, 
        weight_decay: float = 1e-5, 
        save_dir: Optional[str] = None
    ) -> Tuple[List[float], List[float]]:
        """
        Complete training loop.
        
        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            num_epochs: Number of epochs to train
            learning_rate: Learning rate
            weight_decay: Weight decay for regularization
            save_dir: Directory to save checkpoints. If None, no checkpoints saved
            
        Returns:
            train_losses: List of training losses per epoch
            val_losses: List of validation losses per epoch
        """
        # Setup optimizer and scheduler
        optimizer = optim.AdamW(
            self.model.parameters(), 
            lr=learning_rate, 
            weight_decay=weight_decay
        )
        criterion = nn.MSELoss()
        
        # Learning rate scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5
        )
        
        # Early stopping
        best_val_loss = float('inf')
        patience_counter = 0
        patience = 10
        
        logger.info(f"Starting training for {num_epochs} epochs...")
        
        for epoch in range(num_epochs):
            logger.info(f"\nEpoch {epoch+1}/{num_epochs}")
            
            # Train
            train_loss = self.train_epoch(train_loader, optimizer, criterion)
            self.train_losses.append(train_loss)
            
            # Validate
            val_loss = self.validate(val_loader, criterion)
            self.val_losses.append(val_loss)
            
            # Learning rate scheduling
            old_lr = optimizer.param_groups[0]['lr']
            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]['lr']
            
            # Log learning rate changes
            if new_lr != old_lr:
                logger.info(f"Learning rate reduced from {old_lr:.2e} to {new_lr:.2e}")
            
            logger.info(f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {new_lr:.2e}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                
                if save_dir:
                    save_path = Path(save_dir)
                    save_path.mkdir(parents=True, exist_ok=True)
                    torch.save({
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch,
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        'model_config': {
                            'input_dim': self.model.input_dim,
                            'output_dim': self.model.output_dim
                        }
                    }, save_path / 'best_model.pt')
                    logger.info(f"Saved best model with val loss: {val_loss:.6f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch+1}")
                break
        
        logger.info("\nTraining completed!")
        return self.train_losses, self.val_losses
    
    def plot_losses(self, save_path: Optional[str] = None):
        """
        Plot training and validation losses.
        
        Args:
            save_path: Path to save the plot. If None, only display
        """
        plt.figure(figsize=(10, 6))
        plt.plot(self.train_losses, label='Train Loss', color='blue', linewidth=2)
        plt.plot(self.val_losses, label='Validation Loss', color='red', linewidth=2)
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Loss', fontsize=12)
        plt.title('Training and Validation Loss', fontsize=14)
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            logger.info(f"Loss curve saved to: {save_path}")
        
        plt.close()


def main():
    """
    Main training function with example configuration.
    
    NOTE: This is an example. Adjust paths and parameters for your setup.
    """
    # Configuration parameters
    DATA_DIR = "./data/train_mlp"  # CHANGE THIS to your data directory
    DATASET_NAMES = ["dataset1"]  # CHANGE THIS to your dataset names
    LAYER_IDX = 40  # Layer to use for training
    
    # Training parameters
    BATCH_SIZE = 16
    NUM_EPOCHS = 100
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    TRAIN_RATIO = 0.8
    
    # Model parameters (adjust based on your data dimensions)
    INPUT_DIM = 5120 + 2048  # hidden_dim + 2*kv_dim (trigger hidden + error KV)
    OUTPUT_DIM = 2048        # 2*kv_dim (corrected KV)
    HIDDEN_DIM1 = 2048
    HIDDEN_DIM2 = 1024
    
    SAVE_DIR = f"./models/layer_{LAYER_IDX}"
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Using device: {device}")
    
    # Create dataset
    logger.info("Creating dataset...")
    
    # First load full dataset to get sample count
    full_dataset = KVCorrectionDataset(DATA_DIR, DATASET_NAMES, LAYER_IDX)
    total_samples = len(full_dataset)
    
    # Split into train and validation
    indices = list(range(total_samples))
    train_indices, val_indices = train_test_split(
        indices, train_size=TRAIN_RATIO, random_state=42
    )
    
    logger.info(f"Total samples: {total_samples}")
    logger.info(f"Training samples: {len(train_indices)}")
    logger.info(f"Validation samples: {len(val_indices)}")
    
    # Create train and validation datasets
    train_dataset = KVCorrectionDataset(DATA_DIR, DATASET_NAMES, LAYER_IDX, train_indices)
    val_dataset = KVCorrectionDataset(DATA_DIR, DATASET_NAMES, LAYER_IDX, val_indices)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Create model
    logger.info("Creating model...")
    model = KVCorrectionMLP(
        input_dim=INPUT_DIM,
        output_dim=OUTPUT_DIM,
        hidden_dim1=HIDDEN_DIM1,
        hidden_dim2=HIDDEN_DIM2
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Model parameters: {total_params:,}")
    
    # Create trainer
    trainer = MLPTrainer(model, device)
    
    # Train
    logger.info("Starting training...")
    train_losses, val_losses = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        save_dir=SAVE_DIR
    )
    
    # Plot losses
    trainer.plot_losses(f"{SAVE_DIR}/loss_curve.png")
    
    logger.info("Training completed successfully!")


if __name__ == "__main__":
    main()
