"""
Training framework for GLEAM-AI.

This module provides the main training framework using PyTorch Lightning
for the GLEAM-AI epidemiological forecasting system.
"""

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from typing import Dict, Any, Optional, Union
from pathlib import Path
import logging

from ..models import STNP
from ..data import FeatureDataset, PoolDataset, get_datasets, load_graph_data
from ..active_learning import ActiveLearner, MeanStd, LatentInfoGain
from ..config.settings import TrainingConfig, ModelConfig, DataConfig, ActiveLearningConfig
from ..utils import set_seed, get_device, save_model_checkpoint

logger = logging.getLogger(__name__)


class GLEAMTrainer:
    """
    Main training class for GLEAM-AI models.
    
    This class provides a comprehensive training framework that supports
    both standard training and active learning scenarios.
    """
    
    def __init__(
        self,
        model_config: ModelConfig,
        training_config: TrainingConfig,
        data_config: Optional[DataConfig] = None,
        active_learning_config: Optional[ActiveLearningConfig] = None,
        device: Optional[str] = None,
        seed: Optional[int] = None
    ):
        """
        Initialize the GLEAM trainer.
        
        Args:
            model_config: Model configuration
            training_config: Training configuration
            data_config: Data configuration (optional)
            active_learning_config: Active learning configuration (optional)
            device: Device to use for training
            seed: Random seed for reproducibility
        """
        # Set seed if provided
        if seed is not None:
            set_seed(seed)
        
        # Set device
        self.device = device or get_device()
        
        # Store configurations
        self.model_config = model_config
        self.training_config = training_config
        self.data_config = data_config
        self.active_learning_config = active_learning_config
        
        # Initialize components
        self.model = None
        self.trainer = None
        self.datasets = {}
        self.active_learner = None
        
        logger.info(f"GLEAMTrainer initialized with device: {self.device}")
    
    def setup_data(
        self,
        meta_path: Union[str, Path],
        data_path: Union[str, Path],
        src_path: Union[str, Path],
        population_csv_path: Union[str, Path]
    ) -> None:
        """
        Setup datasets for training.
        
        Args:
            meta_path: Path to metadata
            data_path: Path to data directory
            src_path: Path to source directory
            population_csv_path: Path to population CSV file
        """
        logger.info("Setting up datasets...")
        
        # Get datasets
        train_dataset, val_dataset, pool_dataset = get_datasets(
            meta_path=meta_path,
            data_path=data_path,
            src_path=src_path,
            x_col_names=self.data_config.x_col_names if self.data_config else [],
            frac_pops_names=self.data_config.frac_pops_names if self.data_config else [],
            initial_col_names=self.data_config.initial_col_names if self.data_config else [],
            seq_len=self.model_config.seq_len,
            num_nodes=self.model_config.num_nodes,
            population_csv_path=population_csv_path,
            population_scaler=self.model_config.population_scaler
        )
        
        self.datasets = {
            "train": train_dataset,
            "val": val_dataset,
            "pool": pool_dataset
        }
        
        logger.info(f"Datasets loaded: train={len(train_dataset)}, val={len(val_dataset)}, pool={len(pool_dataset)}")
    
    def setup_model(
        self,
        meta_path: Union[str, Path]
    ) -> None:
        """
        Setup the STNP model.
        
        Args:
            meta_path: Path to metadata directory
        """
        logger.info("Setting up model...")
        
        # Load graph data
        edge_index, edge_weight = load_graph_data(meta_path)
        
        # Create model
        self.model = STNP(
            x_dim=self.model_config.x_dim,
            xt_dim=self.model_config.xt_dim,
            y_dim=self.model_config.y_dim,
            z_dim=self.model_config.z_dim,
            r_dim=self.model_config.r_dim,
            seq_len=self.model_config.seq_len,
            num_nodes=self.model_config.num_nodes,
            in_channels=self.model_config.in_channels,
            out_channels=self.model_config.out_channels,
            embed_out_dim=self.model_config.embed_out_dim,
            max_diffusion_step=self.model_config.max_diffusion_step,
            encoder_num_rnn=self.model_config.encoder_num_rnn,
            decoder_num_rnn=self.model_config.decoder_num_rnn,
            decoder_hidden_dims=self.model_config.decoder_hidden_dims,
            num_comp=self.model_config.num_comp,
            context_percentage=self.model_config.context_percentage,
            lr=self.training_config.lr,
            lr_encoder=self.training_config.lr_encoder,
            lr_decoder=self.training_config.lr_decoder,
            lr_milestones=self.training_config.lr_milestones,
            lr_gamma=self.training_config.lr_gamma,
            edge_index=edge_index,
            edge_weight=edge_weight
        )
        
        logger.info("Model created successfully")
    
    def setup_trainer(
        self,
        output_dir: Union[str, Path],
        experiment_name: str = "gleam_experiment"
    ) -> None:
        """
        Setup PyTorch Lightning trainer.
        
        Args:
            output_dir: Output directory for logs and checkpoints
            experiment_name: Name of the experiment
        """
        logger.info("Setting up PyTorch Lightning trainer...")
        
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Setup callbacks
        callbacks = []
        
        # Model checkpoint callback
        checkpoint_callback = ModelCheckpoint(
            dirpath=output_dir / "checkpoints",
            filename=f"{experiment_name}_{{epoch:02d}}_{{val_loss:.2f}}",
            monitor="val_loss",
            mode="min",
            save_top_k=3,
            save_last=True
        )
        callbacks.append(checkpoint_callback)
        
        # Early stopping callback
        if self.training_config.patience > 0:
            early_stopping = EarlyStopping(
                monitor="val_loss",
                mode="min",
                patience=self.training_config.patience,
                min_delta=self.training_config.min_delta
            )
            callbacks.append(early_stopping)
        
        # Learning rate monitor
        lr_monitor = LearningRateMonitor(logging_interval="epoch")
        callbacks.append(lr_monitor)
        
        # Setup logger
        logger_tb = TensorBoardLogger(
            save_dir=output_dir / "logs",
            name=experiment_name
        )
        
        # Create trainer
        self.trainer = pl.Trainer(
            max_epochs=self.training_config.max_epochs,
            devices=1 if self.device == "cpu" else "auto",
            accelerator="cpu" if self.device == "cpu" else "gpu",
            callbacks=callbacks,
            logger=logger_tb,
            gradient_clip_val=self.training_config.gradient_clip_val,
            precision=self.training_config.precision,
            deterministic=True,
            enable_progress_bar=True,
            enable_model_summary=True
        )
        
        logger.info("PyTorch Lightning trainer setup complete")
    
    def train(
        self,
        output_dir: Union[str, Path],
        experiment_name: str = "gleam_experiment"
    ) -> Dict[str, Any]:
        """
        Train the model.
        
        Args:
            output_dir: Output directory for logs and checkpoints
            experiment_name: Name of the experiment
            
        Returns:
            Dictionary containing training results
        """
        if self.model is None:
            raise ValueError("Model not setup. Call setup_model() first.")
        
        if self.trainer is None:
            self.setup_trainer(output_dir, experiment_name)
        
        logger.info("Starting training...")
        
        # Train the model
        self.trainer.fit(self.model)
        
        # Get best model path
        best_model_path = self.trainer.checkpoint_callback.best_model_path
        
        # Load best model
        if best_model_path:
            self.model = STNP.load_from_checkpoint(best_model_path)
            logger.info(f"Best model loaded from: {best_model_path}")
        
        # Return training results
        results = {
            "best_model_path": best_model_path,
            "trainer": self.trainer,
            "model": self.model
        }
        
        logger.info("Training completed successfully")
        return results
    
    def setup_active_learning(
        self,
        acquisition_type: str = "mean_std",
        data_retriever=None
    ) -> None:
        """
        Setup active learning framework.
        
        Args:
            acquisition_type: Type of acquisition function ("mean_std" or "latent_info_gain")
            data_retriever: Data retriever for active learning
        """
        if self.active_learning_config is None:
            raise ValueError("Active learning config not provided")
        
        logger.info("Setting up active learning...")
        
        # Create acquisition function
        if acquisition_type == "mean_std":
            acquisition = MeanStd(
                output_dim=self.model_config.y_dim,
                acquisition_size=self.active_learning_config.samples_per_iteration,
                pool_loader_batch_size=self.active_learning_config.pool_loader_batch_size,
                acquisition_pool_fraction=self.active_learning_config.acquisition_pool_fraction,
                num_workers=self.active_learning_config.num_workers,
                device=self.device
            )
        elif acquisition_type == "latent_info_gain":
            acquisition = LatentInfoGain(
                acquisition_size=self.active_learning_config.samples_per_iteration,
                pool_loader_batch_size=self.active_learning_config.pool_loader_batch_size,
                acquisition_pool_fraction=self.active_learning_config.acquisition_pool_fraction,
                num_workers=self.active_learning_config.num_workers,
                device=self.device
            )
        else:
            raise ValueError(f"Unknown acquisition type: {acquisition_type}")
        
        # Create active learner
        self.active_learner = ActiveLearner(
            model=self.model,
            train_dataset=self.datasets["train"],
            val_dataset=self.datasets["val"],
            pool_dataset=self.datasets["pool"],
            acquisition_function=acquisition,
            config=self.active_learning_config,
            training_config=self.training_config,
            data_config=self.data_config,
            device=self.device
        )
        
        logger.info("Active learning setup complete")
    
    def train_with_active_learning(
        self,
        output_dir: Union[str, Path],
        experiment_name: str = "gleam_active_learning"
    ) -> Dict[str, Any]:
        """
        Train the model using active learning.
        
        Args:
            output_dir: Output directory for logs and checkpoints
            experiment_name: Name of the experiment
            
        Returns:
            Dictionary containing training results
        """
        if self.active_learner is None:
            raise ValueError("Active learning not setup. Call setup_active_learning() first.")
        
        logger.info("Starting active learning training...")
        
        # Initialize data
        self.active_learner.initialize_data()
        
        # Run active learning
        self.active_learner.learn(self.active_learning_config.max_iterations)
        
        # Get results
        results = {
            "active_learner": self.active_learner,
            "training_history": self.active_learner.get_training_history(),
            "acquisition_history": self.active_learner.get_acquisition_history(),
            "model": self.model
        }
        
        logger.info("Active learning training completed successfully")
        return results
    
    def save_model(
        self,
        filepath: Union[str, Path],
        **kwargs
    ) -> None:
        """
        Save the trained model.
        
        Args:
            filepath: Path to save the model
            **kwargs: Additional data to save
        """
        if self.model is None:
            raise ValueError("No model to save")
        
        save_model_checkpoint(
            model=self.model,
            optimizer=None,  # Will be handled by PyTorch Lightning
            epoch=0,  # Will be updated by PyTorch Lightning
            loss=0.0,  # Will be updated by PyTorch Lightning
            filepath=filepath,
            **kwargs
        )
        
        logger.info(f"Model saved to: {filepath}")
    
    def load_model(
        self,
        filepath: Union[str, Path]
    ) -> None:
        """
        Load a trained model.
        
        Args:
            filepath: Path to the model file
        """
        self.model = STNP.load_from_checkpoint(filepath)
        logger.info(f"Model loaded from: {filepath}")
    
    def get_model_info(self) -> Dict[str, Any]:
        """
        Get information about the current model.
        
        Returns:
            Dictionary containing model information
        """
        if self.model is None:
            return {"error": "No model loaded"}
        
        return {
            "model_type": type(self.model).__name__,
            "device": str(self.device),
            "model_config": self.model_config.__dict__,
            "training_config": self.training_config.__dict__,
            "datasets": {k: len(v) for k, v in self.datasets.items()}
        }
