import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
import wandb
import time
import os
from typing import Dict, Any, Optional, List
from pathlib import Path
import numpy as np

from .losses import DRAMNetLoss
from .optimizers import get_optimizer, get_scheduler
from .validation import ValidationManager
from ..models import DRAMNet
from ..metrics import MetricsManager
from ..utils.checkpoint import CheckpointManager
from ..utils.logger import TrainingLogger

class DRAMNetTrainer:
    def __init__(
        self,
        model: DRAMNet,
        train_loader: DataLoader,
        val_loader: DataLoader,
        config: Dict[str, Any],
        device: torch.device,
        experiment_dir: str
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        self.experiment_dir = Path(experiment_dir)
        
        self.current_epoch = 0
        self.global_step = 0
        self.best_psnr = 0.0
        self.best_ssim = 0.0
        
        self._setup_training()
        self._setup_logging()
        
    def _setup_training(self):
        """Setup training components."""
        self.criterion = DRAMNetLoss(
            reconstruction_weight=self.config.get('reconstruction_weight', 1.0),
            perceptual_weight=self.config.get('perceptual_weight', 0.1),
            early_exit_weight=self.config.get('early_exit_weight', 0.05),
            blur_map_weight=self.config.get('blur_map_weight', 0.1)
        )
        
        self.optimizer = get_optimizer(
            self.model,
            optimizer_type=self.config.get('optimizer', 'adamw'),
            lr=self.config.get('learning_rate', 1e-4),
            weight_decay=self.config.get('weight_decay', 1e-2)
        )
        
        self.scheduler = get_scheduler(
            self.optimizer,
            scheduler_type=self.config.get('scheduler', 'cosine'),
            total_epochs=self.config.get('epochs', 300),
            warmup_epochs=self.config.get('warmup_epochs', 10)
        )
        
        self.use_amp = self.config.get('use_amp', True)
        if self.use_amp:
            self.scaler = GradScaler()
        
        self.val_manager = ValidationManager(
            model=self.model,
            val_loader=self.val_loader,
            device=self.device
        )
        
        self.metrics_manager = MetricsManager()
        
        self.checkpoint_manager = CheckpointManager(
            experiment_dir=self.experiment_dir,
            save_top_k=self.config.get('save_top_k', 3)
        )
        
        self.early_exit_warmup_epochs = self.config.get('early_exit_warmup_epochs', 50)
        self.progressive_blur_epochs = self.config.get('progressive_blur_epochs', 100)
        
    def _setup_logging(self):
        """Setup logging and monitoring."""
        # Initialize wandb if configured
        if self.config.get('use_wandb', False):
            wandb.init(
                project=self.config.get('wandb_project', 'dramnet'),
                name=self.config.get('experiment_name', 'dramnet_experiment'),
                config=self.config
            )
        
        self.logger = TrainingLogger(
            log_dir=self.experiment_dir / 'logs',
            log_level=self.config.get('log_level', 'INFO')
        )
        
        self.logger.info("Training setup completed")
        self.logger.info(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch."""
        self.model.train()
        epoch_metrics = {
            'loss': 0.0,
            'reconstruction_loss': 0.0,
            'perceptual_loss': 0.0,
            'early_exit_loss': 0.0,
            'blur_map_loss': 0.0,
            'exit_accuracy': 0.0
        }
        
        num_batches = len(self.train_loader)
        
        for batch_idx, batch in enumerate(self.train_loader):
            # Move data to device
            blur_images = batch['blur'].to(self.device)
            sharp_images = batch['sharp'].to(self.device)
            blur_severity = batch.get('blur_severity', None)
            
            current_strategy = self._get_training_strategy()
            loss_dict = self._forward_pass(
                blur_images, sharp_images, blur_severity, current_strategy
            )
            self._backward_pass(loss_dict['total_loss'])
            self._update_batch_metrics(epoch_metrics, loss_dict, num_batches)
            if batch_idx % self.config.get('log_interval', 100) == 0:
                self._log_batch_progress(batch_idx, num_batches, loss_dict)
            
            self.global_step += 1
        
        # Scheduler step
        self.scheduler.step()
        
        return epoch_metrics
    
    def _get_training_strategy(self) -> Dict[str, Any]:
        """Get current training strategy based on epoch."""
        strategy = {
            'enable_early_exit': self.current_epoch >= self.early_exit_warmup_epochs,
            'blur_augmentation_prob': min(1.0, self.current_epoch / self.progressive_blur_epochs),
            'multi_scale_training': self.current_epoch >= 20,
            'curriculum_learning': self.current_epoch < 100
        }
        
        return strategy
    
    def _forward_pass(
        self, 
        blur_images: torch.Tensor, 
        sharp_images: torch.Tensor,
        blur_severity: Optional[torch.Tensor],
        strategy: Dict[str, Any]
    ) -> Dict[str, torch.Tensor]:
        """Forward pass with adaptive training."""
        
        if self.use_amp:
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = self.model(
                    blur_images,
                    return_exit_info=strategy['enable_early_exit']
                )
                
                loss_dict = self.criterion(
                    outputs=outputs,
                    targets=sharp_images,
                    blur_severity=blur_severity,
                    enable_early_exit=strategy['enable_early_exit']
                )
        else:
            outputs = self.model(
                blur_images,
                return_exit_info=strategy['enable_early_exit']
            )
            
            loss_dict = self.criterion(
                outputs=outputs,
                targets=sharp_images,
                blur_severity=blur_severity,
                enable_early_exit=strategy['enable_early_exit']
            )
        
        return loss_dict
    
    def _backward_pass(self, loss: torch.Tensor):
        """Backward pass with gradient scaling."""
        self.optimizer.zero_grad()
        
        if self.use_amp:
            self.scaler.scale(loss).backward()
            
            if self.config.get('gradient_clip_norm', None):
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.config['gradient_clip_norm']
                )
            
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            
            if self.config.get('gradient_clip_norm', None):
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.config['gradient_clip_norm']
                )
            
            self.optimizer.step()
    
    def _update_batch_metrics(
        self, 
        epoch_metrics: Dict[str, float], 
        loss_dict: Dict[str, torch.Tensor],
        num_batches: int
    ):
        """Update epoch metrics with batch results."""
        for key, value in loss_dict.items():
            if key in epoch_metrics:
                epoch_metrics[key] += value.item() / num_batches
    
    def _log_batch_progress(
        self, 
        batch_idx: int, 
        num_batches: int, 
        loss_dict: Dict[str, torch.Tensor]
    ):
        """Log batch training progress."""
        progress = 100.0 * batch_idx / num_batches
        self.logger.info(
            f"Epoch {self.current_epoch} [{batch_idx:4d}/{num_batches:4d}] "
            f"({progress:5.1f}%) | Loss: {loss_dict['total_loss'].item():.4f} | "
            f"LR: {self.scheduler.get_last_lr()[0]:.2e}"
        )
    
    def validate(self) -> Dict[str, float]:
        """Run validation."""
        self.logger.info("Running validation...")
        
        val_metrics = self.val_manager.validate()
        
        # Log validation metrics
        self.logger.info(
            f"Validation - PSNR: {val_metrics['psnr']:.2f}, "
            f"SSIM: {val_metrics['ssim']:.4f}, "
            f"LPIPS: {val_metrics['lpips']:.4f}"
        )
        
        return val_metrics
    
    def save_checkpoint(self, val_metrics: Dict[str, float], is_best: bool = False):
        """Save model checkpoint."""
        checkpoint_data = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_psnr': self.best_psnr,
            'best_ssim': self.best_ssim,
            'config': self.config,
            'val_metrics': val_metrics
        }
        
        if self.use_amp:
            checkpoint_data['scaler_state_dict'] = self.scaler.state_dict()
        
        self.checkpoint_manager.save(
            checkpoint_data,
            epoch=self.current_epoch,
            metrics=val_metrics,
            is_best=is_best
        )
    
    def train(self):
        """Main training loop."""
        self.logger.info(f"Starting training for {self.config['epochs']} epochs")
        
        start_time = time.time()
        
        for epoch in range(self.config['epochs']):
            self.current_epoch = epoch
            train_metrics = self.train_epoch()
            if epoch % self.config.get('val_interval', 5) == 0:
                val_metrics = self.validate()
                is_best = val_metrics['psnr'] > self.best_psnr
                if is_best:
                    self.best_psnr = val_metrics['psnr']
                    self.best_ssim = val_metrics['ssim']
                if epoch % self.config.get('save_interval', 10) == 0:
                    self.save_checkpoint(val_metrics, is_best)
                if self.config.get('use_wandb', False):
                    wandb.log({
                        'epoch': epoch,
                        'train/loss': train_metrics['loss'],
                        'train/reconstruction_loss': train_metrics['reconstruction_loss'],
                        'train/perceptual_loss': train_metrics['perceptual_loss'],
                        'val/psnr': val_metrics['psnr'],
                        'val/ssim': val_metrics['ssim'],
                        'val/lpips': val_metrics['lpips'],
                        'learning_rate': self.scheduler.get_last_lr()[0]
                    })
            
            self.logger.info(
                f"Epoch {epoch:3d} completed | "
                f"Train Loss: {train_metrics['loss']:.4f} | "
                f"Time: {time.time() - start_time:.1f}s"
            )
        
        final_val_metrics = self.validate()
        self.save_checkpoint(final_val_metrics, is_best=False)
        
        self.logger.info("Training completed!")
        self.logger.info(f"Best PSNR: {self.best_psnr:.2f}")
        self.logger.info(f"Best SSIM: {self.best_ssim:.4f}")
        
        if self.config.get('use_wandb', False):
            wandb.finish()

class ProgressiveTrainer(DRAMNetTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.resolution_schedule = self.config.get('resolution_schedule', {
            0: 128,
            20: 256,
            50: 384, 
            100: 512 
        })
        
        self.blur_curriculum = self.config.get('blur_curriculum', {
            'start_severity': 0.2,
            'end_severity': 1.0,
            'progression_epochs': 100
        })
    
    def _get_current_resolution(self) -> int:
        """Get current training resolution based on epoch."""
        current_res = 128
        for epoch_threshold, resolution in self.resolution_schedule.items():
            if self.current_epoch >= epoch_threshold:
                current_res = resolution
        return current_res
    
    def _get_blur_severity_range(self) -> tuple:
        """Get current blur severity range for curriculum learning."""
        progress = min(1.0, self.current_epoch / self.blur_curriculum['progression_epochs'])
        
        start_sev = self.blur_curriculum['start_severity']
        end_sev = self.blur_curriculum['end_severity']
        
        current_max = start_sev + (end_sev - start_sev) * progress
        
        return (0.1, current_max)

class MultiGPUTrainer(DRAMNetTrainer):
    """
    Multi-GPU trainer with DistributedDataParallel support.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        if torch.cuda.device_count() > 1:
            self.logger.info(f"Using {torch.cuda.device_count()} GPUs")
            self.model = nn.DataParallel(self.model)
    
    def save_checkpoint(self, val_metrics: Dict[str, float], is_best: bool = False):
        """Save checkpoint with multi-GPU consideration."""
        model_state = self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict()
        
        checkpoint_data = {
            'epoch': self.current_epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_psnr': self.best_psnr,
            'best_ssim': self.best_ssim,
            'config': self.config,
            'val_metrics': val_metrics
        }
        
        if self.use_amp:
            checkpoint_data['scaler_state_dict'] = self.scaler.state_dict()
        
        self.checkpoint_manager.save(
            checkpoint_data,
            epoch=self.current_epoch,
            metrics=val_metrics,
            is_best=is_best
        ) 