"""
Training framework for Multi-Scale Attention U-Net
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import time
import os
from typing import Dict, List, Tuple, Optional, Callable
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import seaborn as sns

from model import MSAUNet, BaselineUNet
from losses import CombinedLoss, create_loss_function
from metrics import ModelEvaluator, compute_model_efficiency
from dataset import DatasetConfig

class EarlyStopping:
    """Early stopping utility"""
    
    def __init__(self, patience: int = 20, min_delta: float = 0.001, restore_best_weights: bool = True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_score = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_score: float, model: nn.Module) -> bool:
        """
        Args:
            val_score: Current validation score
            model: Model to potentially restore
        Returns:
            should_stop: Whether to stop training
        """
        if self.best_score is None:
            self.best_score = val_score
            self.save_checkpoint(model)
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                if self.restore_best_weights:
                    model.load_state_dict(self.best_weights)
                return True
        else:
            self.best_score = val_score
            self.counter = 0
            self.save_checkpoint(model)
        
        return False
    
    def save_checkpoint(self, model: nn.Module):
        """Save model checkpoint"""
        self.best_weights = model.state_dict().copy()

class LearningRateScheduler:
    """Learning rate scheduler"""
    
    def __init__(self, optimizer: optim.Optimizer, mode: str = 'step', 
                 step_size: int = 50, gamma: float = 0.1, patience: int = 10):
        self.optimizer = optimizer
        self.mode = mode
        self.step_size = step_size
        self.gamma = gamma
        self.patience = patience
        self.best_score = None
        self.counter = 0
        
        if mode == 'step':
            self.scheduler = optim.lr_scheduler.StepLR(optimizer, step_size, gamma)
        elif mode == 'plateau':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                                patience=patience, factor=gamma)
        else:
            raise ValueError(f"Unknown scheduler mode: {mode}")
    
    def step(self, val_score: Optional[float] = None):
        """Step the scheduler"""
        if self.mode == 'step':
            self.scheduler.step()
        elif self.mode == 'plateau' and val_score is not None:
            self.scheduler.step(val_score)
    
    def get_lr(self) -> float:
        """Get current learning rate"""
        return self.optimizer.param_groups[0]['lr']

class ModelTrainer:
    """Main training class for MSA-UNet"""
    
    def __init__(self, 
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 test_loader: DataLoader,
                 device: torch.device,
                 config: Dict):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.config = config
        
        # Move model to device
        self.model = self.model.to(device)
        
        # Initialize loss function
        self.criterion = create_loss_function(
            config.get('loss_type', 'combined'),
            **config.get('loss_params', {})
        )
        
        # Initialize optimizer
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=config.get('learning_rate', 0.001),
            weight_decay=config.get('weight_decay', 1e-4)
        )
        
        # Initialize scheduler
        self.scheduler = LearningRateScheduler(
            self.optimizer,
            mode=config.get('scheduler_mode', 'step'),
            step_size=config.get('step_size', 50),
            gamma=config.get('gamma', 0.1),
            patience=config.get('patience', 10)
        )
        
        # Initialize early stopping
        self.early_stopping = EarlyStopping(
            patience=config.get('early_stopping_patience', 20),
            min_delta=config.get('early_stopping_min_delta', 0.001)
        )
        
        # Initialize evaluator
        self.evaluator = ModelEvaluator(self.model, device, config.get('num_classes', 5))
        
        # Training history
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_dice': [],
            'val_iou': [],
            'val_hausdorff': [],
            'val_boundary_f1': [],
            'learning_rate': []
        }
        
        # Best model tracking
        self.best_val_dice = 0.0
        self.best_epoch = 0
        
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()
        
        epoch_loss = 0.0
        num_batches = len(self.train_loader)
        
        progress_bar = tqdm(self.train_loader, desc="Training", leave=False)
        
        for batch_idx, (images, targets) in enumerate(progress_bar):
            # Move to device
            images = images.to(self.device)
            targets = targets.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            predictions = self.model(images)
            
            # Compute loss
            if isinstance(self.criterion, CombinedLoss):
                loss, loss_dict = self.criterion(predictions, targets)
            else:
                loss = self.criterion(predictions, targets)
                loss_dict = {'total_loss': loss.item()}
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            if self.config.get('grad_clip', 0) > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip'])
            
            # Update weights
            self.optimizer.step()
            
            # Update loss
            epoch_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'LR': f'{self.scheduler.get_lr():.6f}'
            })
        
        return {
            'train_loss': epoch_loss / num_batches,
            'learning_rate': self.scheduler.get_lr()
        }
    
    def validate_epoch(self) -> Dict[str, float]:
        """Validate for one epoch"""
        self.model.eval()
        
        val_loss = 0.0
        num_batches = len(self.val_loader)
        
        with torch.no_grad():
            for images, targets in tqdm(self.val_loader, desc="Validation", leave=False):
                # Move to device
                images = images.to(self.device)
                targets = targets.to(self.device)
                
                # Forward pass
                predictions = self.model(images)
                
                # Compute loss
                if isinstance(self.criterion, CombinedLoss):
                    loss, _ = self.criterion(predictions, targets)
                else:
                    loss = self.criterion(predictions, targets)
                
                val_loss += loss.item()
        
        # Compute validation metrics
        val_metrics = self.evaluator.evaluate(self.val_loader)
        
        return {
            'val_loss': val_loss / num_batches,
            'val_dice': val_metrics['mean_dice'],
            'val_iou': val_metrics['mean_iou'],
            'val_hausdorff': val_metrics['mean_hausdorff'],
            'val_boundary_f1': val_metrics['mean_boundary_f1']
        }
    
    def train(self, num_epochs: int, save_dir: str = 'checkpoints') -> Dict[str, List[float]]:
        """Train the model"""
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        start_time = time.time()
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print("-" * 50)
            
            # Train
            train_metrics = self.train_epoch()
            
            # Validate
            val_metrics = self.validate_epoch()
            
            # Update learning rate
            self.scheduler.step(val_metrics['val_dice'])
            
            # Update history
            self.history['train_loss'].append(train_metrics['train_loss'])
            self.history['val_loss'].append(val_metrics['val_loss'])
            self.history['val_dice'].append(val_metrics['val_dice'])
            self.history['val_iou'].append(val_metrics['val_iou'])
            self.history['val_hausdorff'].append(val_metrics['val_hausdorff'])
            self.history['val_boundary_f1'].append(val_metrics['val_boundary_f1'])
            self.history['learning_rate'].append(train_metrics['learning_rate'])
            
            # Print metrics
            print(f"Train Loss: {train_metrics['train_loss']:.4f}")
            print(f"Val Loss: {val_metrics['val_loss']:.4f}")
            print(f"Val Dice: {val_metrics['val_dice']:.4f}")
            print(f"Val IoU: {val_metrics['val_iou']:.4f}")
            print(f"Val Hausdorff: {val_metrics['val_hausdorff']:.4f}")
            print(f"Val Boundary F1: {val_metrics['val_boundary_f1']:.4f}")
            print(f"Learning Rate: {train_metrics['learning_rate']:.6f}")
            
            # Save best model
            if val_metrics['val_dice'] > self.best_val_dice:
                self.best_val_dice = val_metrics['val_dice']
                self.best_epoch = epoch
                self.save_model(os.path.join(save_dir, 'best_model.pth'))
                print(f"New best model saved! Dice: {self.best_val_dice:.4f}")
            
            # Early stopping
            if self.early_stopping(val_metrics['val_dice'], self.model):
                print(f"Early stopping at epoch {epoch+1}")
                break
            
            # Save checkpoint
            if (epoch + 1) % self.config.get('save_interval', 10) == 0:
                self.save_checkpoint(os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'), epoch)
        
        training_time = time.time() - start_time
        print(f"\nTraining completed in {training_time:.2f} seconds")
        print(f"Best validation Dice: {self.best_val_dice:.4f} at epoch {self.best_epoch+1}")
        
        return self.history
    
    def save_model(self, filepath: str):
        """Save model state"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'history': self.history
        }, filepath)
    
    def load_model(self, filepath: str):
        """Load model state"""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.config = checkpoint['config']
        self.history = checkpoint['history']
    
    def save_checkpoint(self, filepath: str, epoch: int):
        """Save training checkpoint"""
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.scheduler.state_dict(),
            'config': self.config,
            'history': self.history,
            'best_val_dice': self.best_val_dice
        }, filepath)
    
    def plot_training_history(self, save_path: str = None):
        """Plot training history"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Loss curves
        axes[0, 0].plot(self.history['train_loss'], label='Train Loss')
        axes[0, 0].plot(self.history['val_loss'], label='Val Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Dice score
        axes[0, 1].plot(self.history['val_dice'], label='Val Dice', color='green')
        axes[0, 1].set_title('Validation Dice Score')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Dice Score')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # IoU score
        axes[0, 2].plot(self.history['val_iou'], label='Val IoU', color='blue')
        axes[0, 2].set_title('Validation IoU Score')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('IoU Score')
        axes[0, 2].legend()
        axes[0, 2].grid(True)
        
        # Hausdorff distance
        axes[1, 0].plot(self.history['val_hausdorff'], label='Val Hausdorff', color='red')
        axes[1, 0].set_title('Validation Hausdorff Distance')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Hausdorff Distance')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Boundary F1 score
        axes[1, 1].plot(self.history['val_boundary_f1'], label='Val Boundary F1', color='purple')
        axes[1, 1].set_title('Validation Boundary F1 Score')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Boundary F1 Score')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        # Learning rate
        axes[1, 2].plot(self.history['learning_rate'], label='Learning Rate', color='orange')
        axes[1, 2].set_title('Learning Rate Schedule')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Learning Rate')
        axes[1, 2].legend()
        axes[1, 2].grid(True)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def evaluate_model(self) -> Dict[str, float]:
        """Evaluate model on test set"""
        print("Evaluating model on test set...")
        
        # Load best model
        self.load_model('checkpoints/best_model.pth')
        
        # Evaluate
        test_metrics = self.evaluator.evaluate(self.test_loader)
        
        print("Test Results:")
        print(f"Dice Score: {test_metrics['mean_dice']:.4f}")
        print(f"IoU Score: {test_metrics['mean_iou']:.4f}")
        print(f"Hausdorff Distance: {test_metrics['mean_hausdorff']:.4f}")
        print(f"Boundary F1 Score: {test_metrics['mean_boundary_f1']:.4f}")
        print(f"Pixel Accuracy: {test_metrics['pixel_accuracy']:.4f}")
        
        return test_metrics

def create_trainer_config() -> Dict:
    """Create default training configuration"""
    return {
        'learning_rate': 0.001,
        'weight_decay': 1e-4,
        'num_epochs': 200,
        'batch_size': 16,
        'loss_type': 'combined',
        'loss_params': {
            'dice_weight': 0.7,
            'boundary_weight': 0.3,
            'dice_smooth': 1e-7
        },
        'scheduler_mode': 'step',
        'step_size': 50,
        'gamma': 0.1,
        'early_stopping_patience': 20,
        'early_stopping_min_delta': 0.001,
        'grad_clip': 1.0,
        'save_interval': 10,
        'num_classes': 5
    }

if __name__ == "__main__":
    # Test trainer
    from dataset import create_dataloaders, DatasetConfig
    
    # Create config
    config = create_trainer_config()
    dataset_config = DatasetConfig()
    
    # Create dataloaders
    train_loader, val_loader, test_loader = create_dataloaders(dataset_config, batch_size=16)
    
    # Create model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MSAUNet(in_channels=3, num_classes=5, num_heads=4)
    
    # Create trainer
    trainer = ModelTrainer(model, train_loader, val_loader, test_loader, device, config)
    
    # Train for a few epochs
    history = trainer.train(num_epochs=5)
    
    # Plot history
    trainer.plot_training_history()
    
    # Evaluate
    test_metrics = trainer.evaluate_model()
    print(f"Test metrics: {test_metrics}")

