"""Training and fine-tuning utilities for ARCOS models."""

import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from typing import Dict, List, Tuple, Optional, Callable
import numpy as np
from tqdm import tqdm

from ..utils.checkpoint import CheckpointManager


class Trainer:
    """Trainer class for model training and fine-tuning."""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None,
        device: str = "cuda",
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-4,
        mixed_precision: bool = True,
        checkpoint_dir: Optional[str] = None
    ):
        """Initialize trainer.
        
        Args:
            model: Model to train
            train_loader: Training data loader
            val_loader: Validation data loader
            device: Device to train on
            learning_rate: Learning rate
            weight_decay: Weight decay
            mixed_precision: Whether to use mixed precision
            checkpoint_dir: Directory for checkpoints
        """
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.mixed_precision = mixed_precision
        
        # Optimizer and scheduler
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=len(train_loader)
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Mixed precision
        self.scaler = GradScaler() if mixed_precision else None
        
        # Checkpoint manager
        self.checkpoint_manager = None
        if checkpoint_dir:
            self.checkpoint_manager = CheckpointManager(checkpoint_dir, "trainer")
        
        # Training history
        self.train_history = []
        self.val_history = []
    
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch.
        
        Returns:
            Dictionary of training metrics
        """
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc="Training")
        
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.mixed_precision and self.scaler:
                with autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                output = self.model(data)
                loss = self.criterion(output, target)
                
                loss.backward()
                self.optimizer.step()
            
            self.scheduler.step()
            
            # Update metrics
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Acc': f"{100. * correct / total:.2f}%"
            })
        
        avg_loss = total_loss / len(self.train_loader)
        accuracy = 100. * correct / total
        
        return {
            'loss': avg_loss,
            'accuracy': accuracy
        }
    
    def validate(self) -> Dict[str, float]:
        """Validate the model.
        
        Returns:
            Dictionary of validation metrics
        """
        if self.val_loader is None:
            return {}
        
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in tqdm(self.val_loader, desc="Validation"):
                data, target = data.to(self.device), target.to(self.device)
                
                output = self.model(data)
                loss = self.criterion(output, target)
                
                total_loss += loss.item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        
        avg_loss = total_loss / len(self.val_loader)
        accuracy = 100. * correct / total
        
        return {
            'loss': avg_loss,
            'accuracy': accuracy
        }
    
    def train(
        self,
        epochs: int,
        save_best: bool = True,
        early_stopping_patience: Optional[int] = None
    ) -> Dict[str, List[float]]:
        """Train the model for specified number of epochs.
        
        Args:
            epochs: Number of epochs to train
            save_best: Whether to save best checkpoint
            early_stopping_patience: Early stopping patience
            
        Returns:
            Dictionary containing training history
        """
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            
            # Training
            train_metrics = self.train_epoch()
            self.train_history.append(train_metrics)
            
            # Validation
            val_metrics = self.validate()
            self.val_history.append(val_metrics)
            
            # Print metrics
            print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.2f}%")
            if val_metrics:
                print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.2f}%")
            
            # Save checkpoint
            if self.checkpoint_manager and save_best:
                is_best = val_metrics.get('loss', float('inf')) < best_val_loss
                if is_best:
                    best_val_loss = val_metrics['loss']
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                self.checkpoint_manager.save_checkpoint(
                    epoch=epoch,
                    model=self.model,
                    optimizer=self.optimizer,
                    scheduler=self.scheduler,
                    metrics={**train_metrics, **val_metrics},
                    is_best=is_best
                )
            
            # Early stopping
            if early_stopping_patience and patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break
        
        return {
            'train': self.train_history,
            'val': self.val_history
        }
    
    def load_checkpoint(self, checkpoint_path: str) -> int:
        """Load checkpoint.
        
        Args:
            checkpoint_path: Path to checkpoint file
            
        Returns:
            Epoch number
        """
        if self.checkpoint_manager:
            epoch, _ = self.checkpoint_manager.load_checkpoint(
                checkpoint_path,
                self.model,
                self.optimizer,
                self.scheduler,
                self.device
            )
            return epoch
        else:
            print("No checkpoint manager available")
            return 0


class FineTuner:
    """Fine-tuning utility for Q_tilde model."""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        device: str = "cuda",
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-4,
        mixed_precision: bool = True
    ):
        """Initialize fine-tuner.
        
        Args:
            model: Model to fine-tune
            train_loader: Training data loader
            device: Device to train on
            learning_rate: Learning rate for fine-tuning
            weight_decay: Weight decay
            mixed_precision: Whether to use mixed precision
        """
        self.model = model.to(device)
        self.train_loader = train_loader
        self.device = device
        self.mixed_precision = mixed_precision
        
        # Optimizer for fine-tuning (lower learning rate)
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Mixed precision
        self.scaler = GradScaler() if mixed_precision else None
    
    def fine_tune_epoch(self) -> Dict[str, float]:
        """Fine-tune for one epoch.
        
        Returns:
            Dictionary of training metrics
        """
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc="Fine-tuning")
        
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.mixed_precision and self.scaler:
                with autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                output = self.model(data)
                loss = self.criterion(output, target)
                
                loss.backward()
                self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Acc': f"{100. * correct / total:.2f}%"
            })
        
        avg_loss = total_loss / len(self.train_loader)
        accuracy = 100. * correct / total
        
        return {
            'loss': avg_loss,
            'accuracy': accuracy
        }
    
    def fine_tune(self, epochs: int) -> List[Dict[str, float]]:
        """Fine-tune the model for specified number of epochs.
        
        Args:
            epochs: Number of epochs to fine-tune
            
        Returns:
            List of training metrics per epoch
        """
        history = []
        
        for epoch in range(epochs):
            print(f"Fine-tuning epoch {epoch + 1}/{epochs}")
            
            metrics = self.fine_tune_epoch()
            history.append(metrics)
            
            print(f"Loss: {metrics['loss']:.4f}, Acc: {metrics['accuracy']:.2f}%")
        
        return history


def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    device: str = "cuda",
    return_predictions: bool = False
) -> Dict[str, float]:
    """Evaluate model on dataset.
    
    Args:
        model: Model to evaluate
        data_loader: Data loader
        device: Device to use
        return_predictions: Whether to return predictions
        
    Returns:
        Dictionary of evaluation metrics
    """
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for data, target in tqdm(data_loader, desc="Evaluation"):
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            if return_predictions:
                all_predictions.extend(output.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
    
    avg_loss = total_loss / len(data_loader)
    accuracy = 100. * correct / total
    
    metrics = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'error_rate': 100. - accuracy
    }
    
    if return_predictions:
        metrics['predictions'] = all_predictions
        metrics['targets'] = all_targets
    
    return metrics


def compute_risk(model: nn.Module, data_loader: DataLoader, device: str = "cuda") -> float:
    """Compute risk (error rate) of model.
    
    Args:
        model: Model to evaluate
        data_loader: Data loader
        device: Device to use
        
    Returns:
        Risk (error rate) as percentage
    """
    metrics = evaluate_model(model, data_loader, device)
    return metrics['error_rate']
