"""Active learning acquisition policies for ARCOS."""

import torch
import numpy as np
from typing import List, Tuple, Optional, Dict
from sklearn.metrics.pairwise import euclidean_distances
from tqdm import tqdm

from pathlib import Path
from typing import Dict, Any, Optional

from ..trace.ot import compute_w1_feature_space, normalize_features
from ..data.datasets import ImageDataset
from ..utils.config import get_config


class BaseAcquisitionPolicy:
    """Base class for acquisition policies."""
    
    def __init__(self, name: str):
        """Initialize base policy.
        
        Args:
            name: Policy name
        """
        self.name = name
    
    def select_samples(
        self,
        unlabeled_indices: List[int],
        unlabeled_features: torch.Tensor,
        labeled_features: torch.Tensor,
        source_features: torch.Tensor,
        batch_size: int,
        **kwargs
    ) -> Tuple[List[int], torch.Tensor]:
        """Select samples for labeling.
        
        Args:
            unlabeled_indices: Indices of unlabeled samples
            unlabeled_features: Features of unlabeled samples
            labeled_features: Features of labeled samples
            source_features: Features of source domain samples
            batch_size: Number of samples to select
            **kwargs: Additional arguments
            
        Returns:
            Tuple of (selected_indices, updated_labeled_features)
        """
        raise NotImplementedError
    
    def __str__(self) -> str:
        return f"{self.name}Policy"


class PolicyA_W1Min(BaseAcquisitionPolicy):
    """Policy A: W1-min - minimize Wasserstein distance."""
    
    def __init__(self, num_projections: int = 256, normalize: bool = True):
        """Initialize W1-min policy.
        
        Args:
            num_projections: Number of random projections for W1 estimation
            normalize: Whether to normalize features
        """
        super().__init__("W1Min")
        self.num_projections = num_projections
        self.normalize = normalize
    
    def select_samples(
        self,
        unlabeled_indices: List[int],
        unlabeled_features: torch.Tensor,
        labeled_features: torch.Tensor,
        source_features: torch.Tensor,
        batch_size: int,
        device: str = "cuda",
        **kwargs
    ) -> Tuple[List[int], torch.Tensor]:
        """Select samples to minimize W1 distance.
        
        Args:
            unlabeled_indices: Indices of unlabeled samples
            unlabeled_features: Features of unlabeled samples
            labeled_features: Features of labeled samples
            source_features: Features of source domain samples
            batch_size: Number of samples to select
            device: Device to use for computation
            **kwargs: Additional arguments
            
        Returns:
            Tuple of (selected_indices, updated_labeled_features)
        """
        print(f"W1-min policy selecting {batch_size} samples...")
        
        # Ensure tensors are on the correct device
        if not isinstance(unlabeled_features, torch.Tensor):
            unlabeled_features = torch.tensor(unlabeled_features, device=device, dtype=torch.float32)
        elif unlabeled_features.device != torch.device(device):
            unlabeled_features = unlabeled_features.to(device)
            
        if not isinstance(labeled_features, torch.Tensor):
            labeled_features = torch.tensor(labeled_features, device=device, dtype=torch.float32)
        elif labeled_features.device != torch.device(device):
            labeled_features = labeled_features.to(device)
            
        if not isinstance(source_features, torch.Tensor):
            source_features = torch.tensor(source_features, device=device, dtype=torch.float32)
        elif source_features.device != torch.device(device):
            source_features = source_features.to(device)
        
        # Normalize features if requested
        if self.normalize:
            unlabeled_features = normalize_features(unlabeled_features)
            labeled_features = normalize_features(labeled_features)
            source_features = normalize_features(source_features)
        
        # Combine source and labeled target features
        combined_features = torch.cat([source_features, labeled_features], dim=0)
        
        # Compute scores for each unlabeled sample
        scores = []
        for i, features in enumerate(tqdm(unlabeled_features, desc="Computing W1 scores")):
            # Compute W1 distance reduction if this sample is added
            current_w1 = compute_w1_approximate_fast(combined_features, unlabeled_features, self.num_projections, device)
            
            # Add this sample to labeled set
            new_labeled = torch.cat([combined_features, features.unsqueeze(0)], dim=0)
            new_w1 = compute_w1_approximate_fast(new_labeled, unlabeled_features, self.num_projections, device)
            
            # Score is the reduction in W1 distance
            score = current_w1 - new_w1
            scores.append(score)
        
        # Select top samples by score
        scores_tensor = torch.tensor(scores, device=device)
        selected_indices = torch.argsort(scores_tensor, descending=True)[:batch_size]
        selected_unlabeled_indices = [unlabeled_indices[i] for i in selected_indices.cpu().numpy()]
        
        # Update labeled features
        selected_features = unlabeled_features[selected_indices]
        updated_labeled_features = torch.cat([labeled_features, selected_features], dim=0)
        
        print(f"Selected {len(selected_unlabeled_indices)} samples with W1-min policy")
        
        return selected_unlabeled_indices, updated_labeled_features


class PolicyB_DiscMax(BaseAcquisitionPolicy):
    """Policy B: Disc-max - maximize output discrepancy."""
    
    def __init__(self):
        """Initialize Disc-max policy."""
        super().__init__("DiscMax")
    
    def select_samples(
        self,
        unlabeled_indices: List[int],
        unlabeled_features: torch.Tensor,
        labeled_features: torch.Tensor,
        source_features: torch.Tensor,
        batch_size: int,
        model_Q: Optional[torch.nn.Module] = None,
        model_Qt: Optional[torch.nn.Module] = None,
        unlabeled_loader: Optional[torch.utils.data.DataLoader] = None,
        device: str = "cuda",
        **kwargs
    ) -> Tuple[List[int], torch.Tensor]:
        """Select samples to maximize output discrepancy.
        
        Args:
            unlabeled_indices: Indices of unlabeled samples
            unlabeled_features: Features of unlabeled samples
            labeled_features: Features of labeled samples
            source_features: Features of source domain samples
            batch_size: Number of samples to select
            model_Q: Frozen model Q
            model_Qt: Fine-tuned model Q_tilde
            unlabeled_loader: DataLoader for unlabeled samples
            device: Device to use
            **kwargs: Additional arguments
            
        Returns:
            Tuple of (selected_indices, updated_labeled_features)
        """
        print(f"Disc-max policy selecting {batch_size} samples...")
        
        if model_Q is None or model_Qt is None or unlabeled_loader is None:
            raise ValueError("Disc-max policy requires model_Q, model_Qt, and unlabeled_loader")
        
        # Compute output discrepancies for all unlabeled samples
        discrepancies = []
        
        model_Q.eval()
        model_Qt.eval()
        
        with torch.no_grad():
            for data, _ in tqdm(unlabeled_loader, desc="Computing discrepancies"):
                data = data.to(device)
                
                # Get outputs from both models
                output_Q = model_Q(data)
                output_Qt = model_Qt(data)
                
                # Compute L2 norm of logits difference
                diff = output_Q - output_Qt
                l2_norm = torch.norm(diff, p=2, dim=1)
                discrepancies.extend(l2_norm.cpu().numpy())
        
        # Select top samples by discrepancy
        selected_indices = np.argsort(discrepancies)[-batch_size:][::-1]
        selected_unlabeled_indices = [unlabeled_indices[i] for i in selected_indices]
        
        # Update labeled features
        if isinstance(unlabeled_features, torch.Tensor):
            unlabeled_features = unlabeled_features.detach().cpu().numpy()
        if isinstance(labeled_features, torch.Tensor):
            labeled_features = labeled_features.detach().cpu().numpy()
        
        selected_features = unlabeled_features[selected_indices]
        updated_labeled_features = np.vstack([labeled_features, selected_features])
        
        print(f"Selected {len(selected_unlabeled_indices)} samples with Disc-max policy")
        
        return selected_unlabeled_indices, torch.from_numpy(updated_labeled_features)


def compute_w1_approximate_fast(
    X: torch.Tensor,
    Y: torch.Tensor,
    num_projections: int = 256,
    device: str = "cuda"
) -> float:
    """Approximate W1 distance using PyTorch (GPU-accelerated).
    
    Args:
        X: First set of features
        Y: Second set of features
        num_projections: Number of random projections
        device: Device to use for computation
        
    Returns:
        Approximate W1 distance
    """
    # Ensure tensors are on the correct device
    if not isinstance(X, torch.Tensor):
        X = torch.tensor(X, device=device, dtype=torch.float32)
    elif X.device != torch.device(device):
        X = X.to(device)
    
    if not isinstance(Y, torch.Tensor):
        Y = torch.tensor(Y, device=device, dtype=torch.float32)
    elif Y.device != torch.device(device):
        Y = Y.to(device)
    
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections on GPU
    torch.manual_seed(42)  # For reproducibility
    projections = torch.randn(n_features, num_projections, device=device, dtype=torch.float32)
    projections = projections / torch.norm(projections, dim=0, keepdim=True)
    
    # Project data (matrix multiplication on GPU)
    X_proj = X @ projections  # (N, num_projections)
    Y_proj = Y @ projections  # (M, num_projections)
    
    # Compute 1D Wasserstein distances for all projections at once
    # Sort projections
    X_sorted, _ = torch.sort(X_proj, dim=0)  # (N, num_projections)
    Y_sorted, _ = torch.sort(Y_proj, dim=0)  # (M, num_projections)
    
    # Use quantile-based approach for fair comparison
    # Sample quantiles from both distributions
    quantiles = torch.linspace(0.01, 0.99, 100, device=device)
    
    # Compute quantiles for X
    x_quantiles = torch.quantile(X_sorted, quantiles, dim=0)  # (100, num_projections)
    
    # Compute quantiles for Y
    y_quantiles = torch.quantile(Y_sorted, quantiles, dim=0)  # (100, num_projections)
    
    # Compute W1 distance using quantiles
    distances = torch.mean(torch.abs(x_quantiles - y_quantiles), dim=0)  # (num_projections,)
    
    return torch.mean(distances).item()


def compute_w1_approximate(
    X: np.ndarray,
    Y: np.ndarray,
    num_projections: int = 256
) -> float:
    """Approximate W1 distance using random projections (CPU version).
    
    Args:
        X: First set of features
        Y: Second set of features
        num_projections: Number of random projections
        
    Returns:
        Approximate W1 distance
    """
    n_samples, n_features = X.shape
    m_samples, _ = Y.shape
    
    # Generate random projections
    np.random.seed(42)  # For reproducibility
    projections = np.random.randn(n_features, num_projections)
    projections = projections / np.linalg.norm(projections, axis=0, keepdims=True)
    
    # Project data
    X_proj = X @ projections
    Y_proj = Y @ projections
    
    # Compute 1D Wasserstein distances
    distances = []
    for i in range(num_projections):
        x_proj = X_proj[:, i]
        y_proj = Y_proj[:, i]
        
        # Sort projections
        x_sorted = np.sort(x_proj)
        y_sorted = np.sort(y_proj)
        
        # Compute 1D Wasserstein distance (W1) via quantile interpolation (handles unequal sizes)
        n = x_sorted.shape[0]
        m = y_sorted.shape[0]
        grid_size = max(n, m)
        u = (np.arange(grid_size) + 0.5) / grid_size
        qx = np.interp(u, (np.arange(n) + 0.5) / n, x_sorted)
        qy = np.interp(u, (np.arange(m) + 0.5) / m, y_sorted)
        dist = np.mean(np.abs(qx - qy))
        distances.append(dist)
    
    return np.mean(distances)


def normalize_features_numpy(features: np.ndarray, method: str = "l2") -> np.ndarray:
    """Normalize features using numpy.
    
    Args:
        features: Feature array
        method: Normalization method (l2, minmax, zscore)
        
    Returns:
        Normalized features
    """
    if method == "l2":
        # L2 normalization
        norm = np.linalg.norm(features, axis=1, keepdims=True)
        norm = np.clip(norm, 1e-8, None)
        return features / norm
    elif method == "minmax":
        # Min-max normalization
        min_val = np.min(features, axis=1, keepdims=True)
        max_val = np.max(features, axis=1, keepdims=True)
        range_val = max_val - min_val
        range_val = np.clip(range_val, 1e-8, None)
        return (features - min_val) / range_val
    elif method == "zscore":
        # Z-score normalization
        mean = np.mean(features, axis=1, keepdims=True)
        std = np.std(features, axis=1, keepdims=True)
        std = np.clip(std, 1e-8, None)
        return (features - mean) / std
    else:
        raise ValueError(f"Unknown normalization method: {method}")


def get_acquisition_policy(policy_name: str, **kwargs) -> BaseAcquisitionPolicy:
    """Get acquisition policy by name.
    
    Args:
        policy_name: Name of policy (w1min, discmax)
        **kwargs: Additional arguments for policy
        
    Returns:
        Acquisition policy instance
    """
    if policy_name.lower() == "w1min":
        return PolicyA_W1Min(**kwargs)
    elif policy_name.lower() == "discmax":
        return PolicyB_DiscMax(**kwargs)
    else:
        raise ValueError(f"Unknown acquisition policy: {policy_name}")

