"""
Latent Information Gain acquisition functions for active learning.

This module contains acquisition functions that select samples based on
the information gain in the latent space of neural processes.
"""

import torch
import torch.nn as nn
import torch.distributions as dist
from torch.utils.data import DataLoader, Subset
import numpy as np
from typing import List, Dict, Any, Optional

from .base import BaseAcquisition


class LatentInfoGain(BaseAcquisition):
    """
    Latent Information Gain acquisition function.
    
    This acquisition function selects samples based on the information gain
    in the latent space, preferring samples that would most improve the
    latent representation.
    """
    
    def __init__(
        self,
        acquisition_size: int,
        pool_loader_batch_size: int,
        acquisition_pool_fraction: float,
        num_workers: int = 4,
        device: str = "cpu"
    ):
        """
        Initialize the LatentInfoGain acquisition function.
        
        Args:
            acquisition_size: Number of samples to acquire
            pool_loader_batch_size: Batch size for processing pool
            acquisition_pool_fraction: Fraction of pool to consider
            num_workers: Number of data loading workers
            device: Computation device
        """
        super().__init__(
            acquisition_size=acquisition_size,
            pool_loader_batch_size=pool_loader_batch_size,
            acquisition_pool_fraction=acquisition_pool_fraction,
            num_workers=num_workers,
            device=device
        )
        
        # Validate batch size constraint
        if pool_loader_batch_size % acquisition_size != 0:
            raise ValueError("pool_loader_batch_size must be divisible by acquisition_size")
    
    @torch.no_grad()
    def get_candidate_batch(
        self, 
        model: nn.Module, 
        active_data: 'ActiveLearningData', 
        **kwargs
    ) -> np.ndarray:
        """
        Get candidate samples using latent information gain.
        
        Args:
            model: The trained model
            active_data: Active learning data manager
            **kwargs: Additional arguments
            
        Returns:
            Array of indices of selected samples
        """
        model.eval().to(self.device)
        
        # Calculate pool size to consider
        pool_size = int(self.acquisition_pool_fraction * len(active_data.pool_dataset))
        pool_indices = torch.randperm(len(active_data.pool_dataset))[:pool_size]
        
        # Create data loader
        try:
            from ..dataset import pool_collate_fn
            pool_loader = DataLoader(
                Subset(active_data.pool_dataset, pool_indices),
                shuffle=True,
                batch_size=self.pool_loader_batch_size,
                num_workers=self.num_workers,
                collate_fn=pool_collate_fn
            )
        except ImportError:
            # Fallback implementation
            pool_loader = DataLoader(
                Subset(active_data.pool_dataset, pool_indices),
                shuffle=True,
                batch_size=self.pool_loader_batch_size,
                num_workers=self.num_workers
            )
        
        # Get current latent representation from training data
        mu_z_train, var_z_train = model.get_latent_tensors()
        mu_z_train = mu_z_train.to(self.device)
        sigma_z_train = var_z_train.sqrt().to(self.device)
        
        batch_scores = []
        num_train_data = active_data.train_size
        
        # Process batches
        for i, batch in enumerate(pool_loader):
            if len(batch) == 4:
                x, xt, y0_latent_prev, indices = batch
            else:
                # Fallback for different batch formats
                x, xt, y0_latent_prev = batch[:3]
                indices = torch.arange(i * self.pool_loader_batch_size, 
                                     min((i + 1) * self.pool_loader_batch_size, pool_size))
            
            # Move to device
            x = x.to(self.device)
            xt = xt.to(self.device)
            y0_latent_prev = y0_latent_prev.to(self.device)
            
            # Get embeddings and predictions
            embed_out = model.get_input_embedding(x, xt)
            y_post = model(x, xt, y0_latent_prev)
            
            # Extract latent prevalence predictions (assuming 4 compartments)
            if hasattr(model, 'NUM_COMP') and model.NUM_COMP == 4:
                _, _, _, y_latent_prev_post = torch.chunk(y_post, model.NUM_COMP, dim=-1)
            else:
                # Fallback: use the last quarter of predictions
                y_latent_prev_post = y_post[..., -y_post.shape[-1]//4:]
            
            # Get latent representation for query samples
            mu_z_q, var_z_q = model.get_latent_representation(
                embed_out, y_latent_prev_post, y0_latent_prev
            )
            sigma_z_q = var_z_q.sqrt()
            
            # Compute information gain score
            score = self.acquisition_fn(
                mu_z_train, sigma_z_train, mu_z_q, sigma_z_q, 
                x.size(0), num_train_data
            )
            
            batch_scores.append((score, indices))
        
        # Select batch with highest score
        _, candidate_indices = max(batch_scores, key=lambda x: x[0])
        return candidate_indices.cpu().numpy()
    
    def acquisition_fn(
        self, 
        mu_z_train: torch.Tensor, 
        sigma_z_train: torch.Tensor, 
        mu_z_q: torch.Tensor, 
        sigma_z_q: torch.Tensor, 
        query_size: int, 
        num_train_data: int
    ) -> float:
        """
        Compute information gain score.
        
        Args:
            mu_z_train: Training latent means
            sigma_z_train: Training latent standard deviations
            mu_z_q: Query latent means
            sigma_z_q: Query latent standard deviations
            query_size: Size of query batch
            num_train_data: Number of training samples
            
        Returns:
            Information gain score
        """
        # Compute posterior parameters
        denum = query_size + num_train_data
        mu_z_post = (mu_z_q * query_size + mu_z_train * num_train_data) / denum
        sigma_z_post = torch.sqrt(
            (sigma_z_q**2.0 * query_size + sigma_z_train**2.0 * num_train_data) / denum
        )
        
        # Move to device
        mu_z_post = mu_z_post.to(self.device)
        sigma_z_post = sigma_z_post.to(self.device)
        
        # Compute KL divergence
        normal_q = dist.Normal(mu_z_post, sigma_z_post)
        normal = dist.Normal(mu_z_train, sigma_z_train)
        score = dist.kl_divergence(normal_q, normal).sum()
        
        return score.item()


class LatentInfoGainStream(BaseAcquisition):
    """
    Streaming version of Latent Information Gain acquisition function.
    
    This version processes samples in a streaming fashion, which can be
    more memory efficient for large datasets.
    """
    
    def __init__(
        self,
        acquisition_size: int,
        pool_loader_batch_size: int,
        acquisition_pool_fraction: float,
        num_workers: int = 4,
        device: str = "cpu"
    ):
        """
        Initialize the LatentInfoGainStream acquisition function.
        
        Args:
            acquisition_size: Number of samples to acquire
            pool_loader_batch_size: Batch size for processing pool
            acquisition_pool_fraction: Fraction of pool to consider
            num_workers: Number of data loading workers
            device: Computation device
        """
        super().__init__(
            acquisition_size=acquisition_size,
            pool_loader_batch_size=pool_loader_batch_size,
            acquisition_pool_fraction=acquisition_pool_fraction,
            num_workers=num_workers,
            device=device
        )
        
        # Validate batch size constraint
        if pool_loader_batch_size % acquisition_size != 0:
            raise ValueError("pool_loader_batch_size must be divisible by acquisition_size")
    
    @torch.no_grad()
    def get_candidate_batch(
        self, 
        model: nn.Module, 
        active_data: 'ActiveLearningData', 
        **kwargs
    ) -> np.ndarray:
        """
        Get candidate samples using streaming latent information gain.
        
        Args:
            model: The trained model
            active_data: Active learning data manager
            **kwargs: Additional arguments
            
        Returns:
            Array of indices of selected samples
        """
        model.eval().to(self.device)
        
        # Calculate pool size to consider
        pool_size = int(self.acquisition_pool_fraction * len(active_data.pool_dataset))
        pool_indices = torch.randperm(len(active_data.pool_dataset))[:pool_size]
        
        # Create data loader
        try:
            from ..dataset import pool_collate_fn
            pool_loader = DataLoader(
                Subset(active_data.pool_dataset, pool_indices),
                shuffle=True,
                batch_size=self.pool_loader_batch_size,
                num_workers=self.num_workers,
                collate_fn=pool_collate_fn
            )
        except ImportError:
            # Fallback implementation
            pool_loader = DataLoader(
                Subset(active_data.pool_dataset, pool_indices),
                shuffle=True,
                batch_size=self.pool_loader_batch_size,
                num_workers=self.num_workers
            )
        
        # Get current latent representation
        mu_z_train, var_z_train = model.get_latent_tensors()
        mu_z_train = mu_z_train.to(self.device)
        sigma_z_train = var_z_train.sqrt().to(self.device)
        
        all_scores = []
        all_indices = []
        num_train_data = active_data.train_size
        
        # Process all batches and collect scores
        for i, batch in enumerate(pool_loader):
            if len(batch) == 4:
                x, xt, y0_latent_prev, indices = batch
            else:
                # Fallback for different batch formats
                x, xt, y0_latent_prev = batch[:3]
                indices = torch.arange(i * self.pool_loader_batch_size, 
                                     min((i + 1) * self.pool_loader_batch_size, pool_size))
            
            # Move to device
            x = x.to(self.device)
            xt = xt.to(self.device)
            y0_latent_prev = y0_latent_prev.to(self.device)
            
            # Get embeddings and predictions
            embed_out = model.get_input_embedding(x, xt)
            y_post = model(x, xt, y0_latent_prev)
            
            # Extract latent prevalence predictions
            if hasattr(model, 'NUM_COMP') and model.NUM_COMP == 4:
                _, _, _, y_latent_prev_post = torch.chunk(y_post, model.NUM_COMP, dim=-1)
            else:
                # Fallback: use the last quarter of predictions
                y_latent_prev_post = y_post[..., -y_post.shape[-1]//4:]
            
            # Get latent representation for query samples
            mu_z_q, var_z_q = model.get_latent_representation(
                embed_out, y_latent_prev_post, y0_latent_prev
            )
            sigma_z_q = var_z_q.sqrt()
            
            # Compute information gain scores for each sample in batch
            batch_size = x.size(0)
            for j in range(batch_size):
                score = self.acquisition_fn(
                    mu_z_train, sigma_z_train, 
                    mu_z_q[j:j+1], sigma_z_q[j:j+1], 
                    1, num_train_data
                )
                all_scores.append(score)
                all_indices.append(indices[j].item())
        
        # Select top samples
        all_scores = np.array(all_scores)
        all_indices = np.array(all_indices)
        
        top_indices = np.argsort(all_scores)[-self.acquisition_size:]
        return all_indices[top_indices]
    
    def acquisition_fn(
        self, 
        mu_z_train: torch.Tensor, 
        sigma_z_train: torch.Tensor, 
        mu_z_q: torch.Tensor, 
        sigma_z_q: torch.Tensor, 
        query_size: int, 
        num_train_data: int
    ) -> float:
        """
        Compute information gain score for streaming version.
        
        Args:
            mu_z_train: Training latent means
            sigma_z_train: Training latent standard deviations
            mu_z_q: Query latent means
            sigma_z_q: Query latent standard deviations
            query_size: Size of query batch
            num_train_data: Number of training samples
            
        Returns:
            Information gain score
        """
        # Compute posterior parameters
        denum = query_size + num_train_data
        mu_z_post = (mu_z_q * query_size + mu_z_train * num_train_data) / denum
        sigma_z_post = torch.sqrt(
            (sigma_z_q**2.0 * query_size + sigma_z_train**2.0 * num_train_data) / denum
        )
        
        # Move to device
        mu_z_post = mu_z_post.to(self.device)
        sigma_z_post = sigma_z_post.to(self.device)
        
        # Compute KL divergence
        normal_q = dist.Normal(mu_z_post, sigma_z_post)
        normal = dist.Normal(mu_z_train, sigma_z_train)
        score = dist.kl_divergence(normal_q, normal).sum()
        
        return score.item()
