"""
Active learning framework for GLEAM-AI.

This module contains the main ActiveLearner class that orchestrates the
active learning process for iterative model improvement.
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from typing import Dict, Any, Optional, List, Type, Union
from pathlib import Path
import logging

from .dataset import ActiveLearningData
from .acquisition.base import BaseAcquisition
from ..config.settings import ActiveLearningConfig, TrainingConfig, DataConfig
from ..data.datasets import FeatureDataset, get_z_score_stats, collate_fn

logger = logging.getLogger(__name__)


class ActiveLearner:
    """
    Main active learning framework for iterative model improvement.
    
    This class orchestrates the active learning process, including sample
    selection, model training, and data management.
    """
    
    def __init__(
        self,
        model: nn.Module,
        train_dataset: torch.utils.data.Dataset,
        val_dataset: Optional[torch.utils.data.Dataset],
        pool_dataset: torch.utils.data.Dataset,
        acquisition_function: BaseAcquisition,
        config: ActiveLearningConfig,
        training_config: Optional[TrainingConfig] = None,
        data_config: Optional[DataConfig] = None,
        **kwargs
    ):
        """
        Initialize the active learner.
        
        Args:
            model: The model to train
            train_dataset: Training dataset
            val_dataset: Validation dataset (optional)
            pool_dataset: Pool dataset for active learning
            acquisition_function: Acquisition function for sample selection
            config: Active learning configuration
            training_config: Training configuration (optional)
            data_config: Data configuration (optional)
            **kwargs: Additional configuration parameters
        """
        # Store core components
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.pool_dataset = pool_dataset
        self.acquisition_function = acquisition_function
        
        # Store configurations
        self.config = config
        self.training_config = training_config
        self.data_config = data_config
        
        # Store additional parameters
        self.device = kwargs.get("device", "cpu")
        self.gradient_clip_val = kwargs.get("gradient_clip_val", 0)
        self.patience = kwargs.get("patience", 30)
        self.NUM_COMP = kwargs.get("NUM_COMP", 4)
        
        # Initialize active learning data manager
        self.active_data = ActiveLearningData(pool_dataset)
        
        # Initialize data loaders
        self._setup_data_loaders()
        
        # Initialize tracking variables
        self.current_iteration = 0
        self.training_history = []
        self.acquisition_history = []
        
        # Initialize metadata (if available)
        self.meta_df = None
        self.populations = None
        self._load_metadata(kwargs)
        
        logger.info(f"ActiveLearner initialized with {len(pool_dataset)} pool samples")
    
    def _setup_data_loaders(self) -> None:
        """Setup data loaders for training and validation."""
        # Training data loader
        self.train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.training_config.batch_size if self.training_config else 32,
            shuffle=True,
            num_workers=self.training_config.num_workers if self.training_config else 4,
            pin_memory=True
        )
        
        # Validation data loader (if validation dataset provided)
        if self.val_dataset:
            self.val_dataloader = DataLoader(
                self.val_dataset,
                batch_size=self.training_config.batch_size if self.training_config else 32,
                shuffle=False,
                num_workers=self.training_config.num_workers if self.training_config else 4,
                pin_memory=True
            )
        else:
            self.val_dataloader = None
    
    def _load_metadata(self, kwargs: Dict[str, Any]) -> None:
        """Load metadata if available."""
        try:
            meta_path = kwargs.get("meta_path", "./meta_data/x_df.csv")
            if Path(meta_path).exists():
                import pandas as pd
                self.meta_df = pd.read_csv(meta_path)
                logger.info(f"Loaded metadata from {meta_path}")
        except Exception as e:
            logger.warning(f"Could not load metadata: {e}")
        
        try:
            population_path = kwargs.get("population_csv_path", "./meta_data/populations.csv")
            if Path(population_path).exists():
                import pandas as pd
                self.populations = pd.read_csv(population_path)
                logger.info(f"Loaded population data from {population_path}")
        except Exception as e:
            logger.warning(f"Could not load population data: {e}")
    
    def initialize_data(self) -> None:
        """
        Initialize the training data with random samples from the pool.
        
        This method selects initial training samples and sets up the
        data statistics for the model.
        """
        logger.info("Initializing training data...")
        
        # Get random initial samples
        initial_indices = self.active_data.acquire_random(self.config.initial_samples)
        
        # Update training dataset
        self._update_train_dataset()
        
        # Compute and update data statistics
        data_stats = self._compute_data_stats()
        if hasattr(self.model, 'update_y_stats'):
            self.model.update_y_stats(data_stats["y_mean"], data_stats["y_std"])
        
        # Save initial logs
        self._save_active_logs(initial_indices)
        
        logger.info(f"Initialized with {len(initial_indices)} training samples")
    
    def learn(self, max_iterations: Optional[int] = None) -> None:
        """
        Run the active learning process.
        
        Args:
            max_iterations: Maximum number of iterations (uses config default if None)
        """
        if max_iterations is None:
            max_iterations = self.config.max_iterations
        
        logger.info(f"Starting active learning for {max_iterations} iterations")
        
        while self.current_iteration < max_iterations:
            self.current_iteration += 1
            logger.info(f"Starting iteration {self.current_iteration}")
            
            # Train the model
            self._train_model()
            
            # Select new samples
            candidate_batch = self._search_candidates()
            
            # Acquire selected samples
            self._acquire_samples(candidate_batch)
            
            # Update data statistics
            data_stats = self._compute_data_stats()
            if hasattr(self.model, 'update_y_stats'):
                self.model.update_y_stats(data_stats["y_mean"], data_stats["y_std"])
            
            # Update training dataset
            self._update_train_dataset()
            
            # Save logs
            self._save_active_logs(candidate_batch)
            
            # Check stopping criteria
            if self._should_stop():
                logger.info("Stopping criteria met, ending active learning")
                break
        
        logger.info("Active learning completed")
    
    def _train_model(self) -> None:
        """Train the model on the current training set."""
        logger.info("Training model...")
        
        # This is a simplified training loop
        # In practice, you would use PyTorch Lightning or a custom training loop
        self.model.train()
        
        # Set up optimizer
        if self.training_config:
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=self.training_config.lr,
                weight_decay=self.training_config.weight_decay
            )
        else:
            optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        # Training loop (simplified)
        for epoch in range(self.training_config.max_epochs if self.training_config else 10):
            total_loss = 0
            for batch in self.train_dataloader:
                optimizer.zero_grad()
                
                # Forward pass (this would need to be adapted based on your model)
                if hasattr(self.model, 'training_step'):
                    loss = self.model.training_step(batch, None)
                else:
                    # Fallback: assume batch contains (x, xt, y0, y)
                    x, xt, y0, y = batch
                    mu, phi = self.model(x, xt, y0)
                    # Compute loss (this would need to be adapted)
                    loss = torch.nn.functional.mse_loss(mu, y)
                
                loss.backward()
                
                # Gradient clipping
                if self.gradient_clip_val > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)
                
                optimizer.step()
                total_loss += loss.item()
            
            if epoch % 10 == 0:
                logger.info(f"Epoch {epoch}, Loss: {total_loss / len(self.train_dataloader):.4f}")
    
    def _search_candidates(self) -> np.ndarray:
        """
        Search for candidate samples using the acquisition function.
        
        Returns:
            Array of indices of selected samples
        """
        logger.info("Searching for candidate samples...")
        
        # Use acquisition function to select candidates
        candidate_indices = self.acquisition_function.get_candidate_batch(
            self.model, self.active_data
        )
        
        logger.info(f"Selected {len(candidate_indices)} candidate samples")
        return candidate_indices
    
    def _acquire_samples(self, candidate_indices: np.ndarray) -> None:
        """
        Acquire selected samples into the training set.
        
        Args:
            candidate_indices: Indices of samples to acquire
        """
        # Acquire samples
        self.active_data.acquire(candidate_indices)
        
        # Record acquisition
        self.acquisition_history.append({
            "iteration": self.current_iteration,
            "indices": candidate_indices.tolist(),
            "num_samples": len(candidate_indices)
        })
        
        logger.info(f"Acquired {len(candidate_indices)} samples")
    
    def _update_train_dataset(self) -> None:
        """Update the training dataset with newly acquired samples."""
        # Get current training indices
        training_indices = self.active_data.get_training_indices()
        
        # Create new training dataset subset
        self.train_dataset = torch.utils.data.Subset(self.pool_dataset, training_indices)
        
        # Update data loader
        self._setup_data_loaders()
    
    def _compute_data_stats(self) -> Dict[str, torch.Tensor]:
        """
        Compute data statistics for normalization.
        
        Returns:
            Dictionary containing mean and standard deviation statistics
        """
        # Get configuration parameters
        if self.data_config:
            data_path = self.data_config.data_path
            batch_size = self.training_config.batch_size if self.training_config else 64
        else:
            # Try to get from kwargs or use defaults
            data_path = getattr(self, 'data_path', './data')
            batch_size = 64
        
        # Get model configuration
        if hasattr(self.model, 'config') and self.model.config:
            x_dim = self.model.config.x_dim
            xt_dim = self.model.config.xt_dim
            y_dim = self.model.config.y_dim
            seq_len = self.model.config.seq_len
        else:
            # Fallback to hparams or defaults
            x_dim = getattr(self.model, 'x_dim', 6)
            xt_dim = getattr(self.model, 'xt_dim', 1)
            y_dim = getattr(self.model, 'y_dim', 1)
            seq_len = getattr(self.model, 'seq_len', 7)
        
        # Get populations data if available
        populations = None
        if hasattr(self, 'populations') and self.populations is not None:
            if hasattr(self.populations, 'values'):
                populations = self.populations['population'].values
            elif isinstance(self.populations, np.ndarray):
                populations = self.populations
        
        # Compute statistics using the actual training data
        try:
            x_mean, x_std, y_mean, y_std = get_z_score_stats(
                data_path=data_path,
                x_dim=x_dim,
                xt_dim=xt_dim,
                y_dim=y_dim,
                seq_len=seq_len,
                NUM_COMP=self.NUM_COMP,
                populations=populations,
                category="train",
                batch_size=batch_size
            )
            
            # Convert to tensors
            y_mean = torch.from_numpy(y_mean).float()
            y_std = torch.from_numpy(y_std).float()
            x_mean = torch.from_numpy(x_mean).float() if x_mean is not None else None
            x_std = torch.from_numpy(x_std).float() if x_std is not None else None
            
            logger.info("Computed data statistics from training data")
            
            return {
                "y_mean": y_mean,
                "y_std": y_std,
                "x_mean": x_mean,
                "x_std": x_std
            }
        except Exception as e:
            logger.warning(f"Could not compute data statistics: {e}")
            logger.warning("Falling back to default statistics")
            
            # Fallback to default statistics if computation fails
            seq_len = getattr(self.config, 'seq_len', 7) if self.config else 7
            y_dim = getattr(self.config, 'y_dim', 1) if self.config else 1
            
            y_mean = torch.zeros(seq_len, y_dim * self.NUM_COMP)
            y_std = torch.ones(seq_len, y_dim * self.NUM_COMP)
            
            return {"y_mean": y_mean, "y_std": y_std}
    
    def _save_active_logs(self, indices: Union[List[int], np.ndarray]) -> None:
        """
        Save active learning logs.
        
        Args:
            indices: Indices of acquired samples
        """
        if isinstance(indices, np.ndarray):
            indices = indices.tolist()
        
        # Get file IDs if metadata is available
        file_ids = []
        if self.meta_df is not None:
            try:
                file_ids = self.meta_df.iloc[indices]["file_id"].tolist()
            except Exception as e:
                logger.warning(f"Could not get file IDs: {e}")
        
        # Record in history
        log_entry = {
            "iteration": self.current_iteration,
            "indices": indices,
            "file_ids": file_ids,
            "train_size": self.active_data.train_size,
            "pool_size": self.active_data.pool_size
        }
        
        self.training_history.append(log_entry)
        
        logger.info(f"Saved logs for iteration {self.current_iteration}")
    
    def _should_stop(self) -> bool:
        """
        Check if stopping criteria are met.
        
        Returns:
            True if should stop, False otherwise
        """
        # Check if pool is empty
        if self.active_data.pool_size == 0:
            logger.info("Pool is empty, stopping")
            return True
        
        # Check if minimum pool size reached
        if self.active_data.pool_size < self.config.min_pool_size:
            logger.info("Minimum pool size reached, stopping")
            return True
        
        return False
    
    def get_stats(self) -> Dict[str, Any]:
        """Get current statistics about the active learning process."""
        return {
            "current_iteration": self.current_iteration,
            "train_size": self.active_data.train_size,
            "pool_size": self.active_data.pool_size,
            "total_samples": self.active_data.total_size,
            "acquisition_function": self.acquisition_function.__class__.__name__,
            "config": self.config.__dict__
        }
    
    def get_training_history(self) -> List[Dict[str, Any]]:
        """Get the training history."""
        return self.training_history.copy()
    
    def get_acquisition_history(self) -> List[Dict[str, Any]]:
        """Get the acquisition history."""
        return self.acquisition_history.copy()
    
    def reset(self) -> None:
        """Reset the active learning process."""
        self.active_data.reset()
        self.current_iteration = 0
        self.training_history = []
        self.acquisition_history = []
        logger.info("Active learning process reset")
    
    def __repr__(self) -> str:
        """String representation of the active learner."""
        stats = self.get_stats()
        return (f"ActiveLearner(iteration={stats['current_iteration']}, "
                f"train={stats['train_size']}, pool={stats['pool_size']}, "
                f"acquisition={stats['acquisition_function']})")
