"""
Training loop for BSNP.
Enhanced with comprehensive checkpoint saving.
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, Optional, Callable
import time
from pathlib import Path
from tqdm import tqdm
from datetime import datetime


class Trainer:
    """
    Trainer for BSNP models.
    """
    
    def __init__(
        self,
        model: nn.Module,
        loss_fn: nn.Module,
        optimizer: torch.optim.Optimizer,
        device: str = 'cpu',
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        gradient_clip: Optional[float] = None,
        log_interval: int = 10,
        checkpoint_dir: Optional[str] = None,
        use_tqdm: bool = True
    ):
        """
        Args:
            model: Model to train
            loss_fn: Loss function
            optimizer: Optimizer
            device: Device to train on
            scheduler: Learning rate scheduler
            gradient_clip: Gradient clipping value (None to disable)
            log_interval: How often to log (in steps)
            checkpoint_dir: Directory to save checkpoints
            use_tqdm: Whether to use tqdm progress bar
        """
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler
        self.gradient_clip = gradient_clip
        self.log_interval = log_interval
        self.use_tqdm = use_tqdm
        
        # Checkpoint directory
        if checkpoint_dir is not None:
            self.checkpoint_dir = Path(checkpoint_dir)
            self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        else:
            self.checkpoint_dir = None
        
        # Training state
        self.epoch = 0
        self.step = 0
        self.best_val_loss = float('inf')
        
        # History
        self.train_history = []
        self.val_history = []
    
    def train_epoch(
        self,
        train_loader: DataLoader,
        lambda_params: Optional[torch.Tensor] = None
    ) -> Dict[str, float]:
        """
        Train for one epoch.
        
        Args:
            train_loader: Training data loader
            lambda_params: PDE parameters (if not in batch)
        
        Returns:
            Dictionary of average losses
        """
        self.model.train()
        
        epoch_losses = {
            'total': 0.0,
            'data': 0.0,
            'physics': 0.0,
            'reg': 0.0,
            'physics_weight': 0.0  # Track current physics weight
        }
        num_batches = 0
        
        # Create iterator
        if self.use_tqdm:
            iterator = tqdm(train_loader, desc=f"Epoch {self.epoch + 1}")
        else:
            iterator = train_loader
        
        for batch_idx, batch in enumerate(iterator):
            # Move batch to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            # Extract batch components
            x_context = batch['x_context']
            y_context = batch['y_context']
            x_target = batch['x_target']
            y_target = batch['y_target']
            
            # Get lambda_params
            batch_lambda = batch.get('lambda_params', lambda_params)
            if isinstance(batch_lambda, dict):
                if 'xi' in batch_lambda and 'w' in batch_lambda:
                    batch_lambda = torch.cat([batch_lambda['xi'], batch_lambda['w']], dim=-1)
                else:
                    raise ValueError(f"Unexpected lambda_params dict format: {batch_lambda.keys()}")
            
            # **CRITICAL: Update loss function step counter for warmup**
            if hasattr(self.loss_fn, 'current_step'):
                self.loss_fn.current_step = self.step
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Compute loss
            loss, loss_dict = self.loss_fn(
                self.model,
                x_context,
                y_context,
                x_target,
                y_target,
                lambda_params=batch_lambda
            )
            
            # Check for NaN/Inf in loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"\n⚠️  WARNING: NaN/Inf detected in loss at step {self.step}")
                print(f"   Loss dict: {loss_dict}")
                print(f"   Skipping this batch...")
                continue
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping
            if self.gradient_clip is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.gradient_clip
                )
                
                # Log gradient norm if it's unusually large
                if grad_norm > self.gradient_clip * 10:
                    print(f"\n⚠️  Large gradient norm: {grad_norm:.2f} at step {self.step}")
            
            # Optimizer step
            self.optimizer.step()
            
            # Accumulate losses
            for key in epoch_losses.keys():
                if key in loss_dict:
                    epoch_losses[key] += loss_dict[key]
            num_batches += 1
            
            # Update progress bar
            if self.use_tqdm:
                postfix = {
                    'loss': f"{loss_dict['total']:.4f}",
                    'data': f"{loss_dict['data']:.4f}"
                }
                if 'physics' in loss_dict and loss_dict['physics'] > 0:
                    postfix['phys'] = f"{loss_dict['physics']:.4f}"
                if 'physics_weight' in loss_dict:
                    postfix['λ_p'] = f"{loss_dict['physics_weight']:.2e}"
                iterator.set_postfix(postfix)
            elif batch_idx % self.log_interval == 0:
                print(f"  Batch {batch_idx}/{len(train_loader)}: "
                      f"Loss = {loss_dict['total']:.6f} "
                      f"(Data: {loss_dict['data']:.6f}, "
                      f"Physics: {loss_dict.get('physics', 0.0):.6f})")
            
            self.step += 1
        
        # Average losses
        for key in epoch_losses.keys():
            if num_batches > 0:
                epoch_losses[key] /= num_batches
            else:
                epoch_losses[key] = 0.0
        
        return epoch_losses

    def validate(
        self,
        val_loader: DataLoader,
        lambda_params: Optional[torch.Tensor] = None
    ) -> Dict[str, float]:
        """
        Validate on validation set.
        
        Note: If physics loss is used, gradients will be temporarily enabled
        during validation to compute PDE residuals.
        
        Args:
            val_loader: Validation data loader
            lambda_params: PDE parameters
        
        Returns:
            Dictionary of average losses
        """
        self.model.eval()
        
        val_losses = {
            'total': 0.0,
            'data': 0.0,
            'physics': 0.0,
            'reg': 0.0,
            'physics_weight': 0.0
        }
        num_batches = 0
        
        # Check if physics loss is used
        use_physics = (hasattr(self.loss_fn, 'use_physics_loss') and 
                       self.loss_fn.use_physics_loss and 
                       self.loss_fn.lambda_physics_initial > 0)
        
        # Create iterator
        if self.use_tqdm:
            iterator = tqdm(val_loader, desc="Validation")
        else:
            iterator = val_loader
        
        for batch in iterator:
            # Move to device
            batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            x_context = batch['x_context']
            y_context = batch['y_context']
            x_target = batch['x_target']
            y_target = batch['y_target']
            
            # Get lambda_params
            batch_lambda = batch.get('lambda_params', lambda_params)
            if isinstance(batch_lambda, dict):
                if 'xi' in batch_lambda and 'w' in batch_lambda:
                    batch_lambda = torch.cat([batch_lambda['xi'], batch_lambda['w']], dim=-1)
                else:
                    raise ValueError(f"Unexpected lambda_params dict format: {batch_lambda.keys()}")
            
            # Update step counter for validation too (to get current physics weight)
            if hasattr(self.loss_fn, 'current_step'):
                self.loss_fn.current_step = self.step
            
            # Compute loss with appropriate gradient setting
            if use_physics:
                # Enable gradients for physics loss computation
                # Model stays in eval mode (no dropout, batch norm in eval mode)
                with torch.set_grad_enabled(True):
                    _, loss_dict = self.loss_fn(
                        self.model,
                        x_context,
                        y_context,
                        x_target,
                        y_target,
                        lambda_params=batch_lambda
                    )
            else:
                # No gradients needed
                with torch.no_grad():
                    _, loss_dict = self.loss_fn(
                        self.model,
                        x_context,
                        y_context,
                        x_target,
                        y_target,
                        lambda_params=batch_lambda
                    )
            
            # Accumulate
            for key in val_losses.keys():
                if key in loss_dict:
                    val_losses[key] += loss_dict[key]
            num_batches += 1
        
        # Average
        for key in val_losses.keys():
            if num_batches > 0:
                val_losses[key] = val_losses[key] / num_batches
            else:
                val_losses[key] = 0.0
        
        return val_losses
    
    def train(
        self,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None,
        num_epochs: int = 100,
        lambda_params: Optional[torch.Tensor] = None,
        early_stopping_patience: Optional[int] = None
    ):
        """
        Full training loop.
        
        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            num_epochs: Number of epochs to train
            lambda_params: PDE parameters
            early_stopping_patience: Patience for early stopping
        """
        print(f"\nStarting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        # Print physics warmup info if applicable
        if hasattr(self.loss_fn, 'warmup_steps'):
            print(f"Physics loss warmup: {self.loss_fn.lambda_physics_initial:.2e} → "
                  f"{self.loss_fn.lambda_physics_final:.2e} over {self.loss_fn.warmup_steps} steps")
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            self.epoch = epoch
            epoch_start = time.time()
            
            if not self.use_tqdm:
                print(f"\nEpoch {epoch + 1}/{num_epochs}")
                print("-" * 50)
            
            # Training
            train_losses = self.train_epoch(train_loader, lambda_params)
            self.train_history.append(train_losses)
            
            # Validation
            if val_loader is not None:
                val_losses = self.validate(val_loader, lambda_params)
                self.val_history.append(val_losses)
                
                # Check for improvement
                if val_losses['total'] < self.best_val_loss:
                    self.best_val_loss = val_losses['total']
                    patience_counter = 0
                    
                    # Save best model with enhanced info
                    if self.checkpoint_dir is not None:
                        self.save_checkpoint('best_model.pt', is_best=True)
                        if not self.use_tqdm:
                            print(f"  ✓ New best model saved!")
                else:
                    patience_counter += 1
            
            # Learning rate scheduling
            if self.scheduler is not None:
                if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_losses['total'] if val_loader else train_losses['total'])
                else:
                    self.scheduler.step()
            
            # Print epoch summary
            epoch_time = time.time() - epoch_start
            
            if not self.use_tqdm:
                print(f"\nEpoch {epoch + 1} Summary:")
                print(f"  Time: {epoch_time:.2f}s")
                print(f"  Train Loss: {train_losses['total']:.6f} "
                      f"(Data: {train_losses['data']:.6f}, "
                      f"Physics: {train_losses['physics']:.6f})")
                if val_loader is not None:
                    print(f"  Val Loss: {val_losses['total']:.6f} "
                          f"(Data: {val_losses['data']:.6f}, "
                          f"Physics: {val_losses['physics']:.6f})")
                    print(f"  Best Val Loss: {self.best_val_loss:.6f}")
                print(f"  LR: {self.optimizer.param_groups[0]['lr']:.2e}")
                if 'physics_weight' in train_losses:
                    print(f"  Physics Weight: {train_losses['physics_weight']:.2e}")
            else:
                # Print compact summary even with tqdm
                summary = (f"Epoch {epoch+1}/{num_epochs} | "
                          f"Train: {train_losses['total']:.4f} | ")
                if val_loader is not None:
                    summary += f"Val: {val_losses['total']:.4f} | "
                summary += f"Best: {self.best_val_loss:.4f} | "
                summary += f"LR: {self.optimizer.param_groups[0]['lr']:.2e}"
                if 'physics_weight' in train_losses and train_losses['physics_weight'] > 0:
                    summary += f" | λ_p: {train_losses['physics_weight']:.2e}"
                print(summary)
            
            # Save periodic checkpoint
            if self.checkpoint_dir is not None and (epoch + 1) % 10 == 0:
                self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt')
            
            # Early stopping
            if early_stopping_patience is not None and patience_counter >= early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                break
        
        print("\n" + "="*50)
        print("Training completed!")
        print(f"Best validation loss: {self.best_val_loss:.6f}")
    
    def save_checkpoint(self, filename: str, is_best: bool = False):
        """
        Save model checkpoint with comprehensive information.
        
        Args:
            filename: Name of checkpoint file
            is_best: Whether this is the best model so far
        """
        if self.checkpoint_dir is None:
            return
        
        checkpoint = {
            # Training state
            'epoch': self.epoch,
            'step': self.step,
            'best_val_loss': self.best_val_loss,
            
            # Model and optimizer
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            
            # History
            'train_history': self.train_history,
            'val_history': self.val_history,
            
            # Metadata
            'timestamp': datetime.now().isoformat(),
            'is_best': is_best,
            
            # Model architecture info
            'model_info': {
                'type': self.model.__class__.__name__,
                'parameter_count': sum(p.numel() for p in self.model.parameters()),
                'trainable_params': sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            }
        }
        
        # Add scheduler state if available
        if self.scheduler is not None:
            checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
        
        # Add model configuration if available
        if hasattr(self.model, 'spatial_dim'):
            checkpoint['model_config'] = {
                'spatial_dim': self.model.spatial_dim,
                'observation_dim': self.model.observation_dim,
                'output_dim': self.model.output_dim,
                'latent_dim': self.model.latent_dim,
                'use_parameter_conditioning': self.model.use_parameter_conditioning,
                'parameter_dim': self.model.parameter_dim
            }
        
        # Save
        filepath = self.checkpoint_dir / filename
        torch.save(checkpoint, filepath)
        
        if is_best:
            print(f"🏆 Best model saved: {filepath} (val_loss: {self.best_val_loss:.6f})")
    
    def load_checkpoint(self, filename: str):
        """Load model checkpoint."""
        if self.checkpoint_dir is None:
            raise ValueError("No checkpoint directory specified")
        
        filepath = self.checkpoint_dir / filename
        checkpoint = torch.load(filepath, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.step = checkpoint['step']
        self.best_val_loss = checkpoint['best_val_loss']
        self.train_history = checkpoint['train_history']
        self.val_history = checkpoint['val_history']
        
        if self.scheduler is not None and 'scheduler_state_dict' in checkpoint:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"Checkpoint loaded from {filepath}")
        print(f"  Epoch: {self.epoch}, Step: {self.step}")
        print(f"  Best val loss: {self.best_val_loss:.6f}")