"""
Query utilities for IFC.

Provides functions to:
1. Combine cached z_c solutions for subset queries
2. Compute influence scores for validation samples
3. Compute predicted loss changes

DDP Support:
    - Query operations run only on rank 0 (inference-only, no multi-GPU benefit).
    - In DDP mode, non-rank-0 processes should not call query functions.
"""

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Dict, List, Tuple, Union
import numpy as np
import os
from tqdm import tqdm

from .vit_full import ViTWithHooks, load_vit
from .imagenet_loader import ImageNetDataset
from .ifc_build import IFCBuilder, load_ifc
from .logging_utils import (
    get_logger, log_dict, log_tensor_stats,
    is_rank0, is_ddp, barrier, rank0_print,
)

logger = get_logger(__name__)


class IFCQuery:
    """
    Query interface for Influence Function Cache.
    
    Computes influence scores and loss predictions using cached z_c solutions.
    """
    
    def __init__(
        self,
        ifc: IFCBuilder,
        model: Optional[ViTWithHooks] = None,
        device: str = "cuda",
        n_train: Optional[int] = None,
        v_is_mean: bool = True,
    ):
        """
        Args:
            ifc: IFCBuilder instance with cached solutions
            model: Optional ViT model (needed for gradient computation)
            device: Computation device
            n_train: Total number of training samples (N). If None, will try to infer from cluster counts.
            v_is_mean: If True, z_c solutions are IHVP of mean gradients (default).
                       If False, z_c solutions are IHVP of sum gradients.
        """
        self.ifc = ifc
        self.model = model
        self.device = device
        self.v_is_mean = v_is_mean
        
        # Get cluster counts from RHS manager if available
        self._cluster_counts: Optional[np.ndarray] = None
        if hasattr(ifc, '_rhs_manager') and ifc._rhs_manager is not None:
            self._cluster_counts = ifc._rhs_manager.counts
        
        # Total training samples
        if n_train is not None:
            self.n_train = n_train
        elif self._cluster_counts is not None:
            self.n_train = int(self._cluster_counts.sum())
        else:
            self.n_train = None
            logger.warning("n_train not specified and cannot be inferred from cluster counts")
        
        # Cache for loaded solutions
        self._z_cache: Dict[int, torch.Tensor] = {}
        self._z_cache_device: Dict[int, torch.Tensor] = {}
        
        # Query statistics
        self._query_count = 0
        
        logger.info(f"IFCQuery initialized (n_train={self.n_train}, v_is_mean={v_is_mean})")
        
    def _get_z(self, cluster_id: int, to_device: bool = True) -> torch.Tensor:
        """Get cached solution for a cluster."""
        if cluster_id not in self._z_cache:
            self._z_cache[cluster_id] = self.ifc.load_solution(cluster_id)
        
        z = self._z_cache[cluster_id]
        
        if to_device:
            if cluster_id not in self._z_cache_device:
                self._z_cache_device[cluster_id] = z.to(self.device)
            return self._z_cache_device[cluster_id]
        
        return z
    
    def compute_delta_theta(
        self,
        subset_indices: np.ndarray,
        cluster_ids: np.ndarray,
        use_exact_den: bool = True,
    ) -> torch.Tensor:
        """
        Compute parameter change from removing a subset.
        
        Δθ = (1/denom) * Σ_c coeff_c * z_c
        
        where:
        - m_c is the number of subset samples in cluster c
        - If v_is_mean: coeff_c = m_c (z_c is IHVP of mean gradient)
        - If not v_is_mean: coeff_c = m_c / N_c (z_c is IHVP of sum gradient)
        - denom = (N - |S|) if use_exact_den else N
        
        Args:
            subset_indices: Indices of samples in subset
            cluster_ids: Cluster assignments for all samples
            use_exact_den: If True, use (N - |S|) as denominator; else use N
            
        Returns:
            delta_theta: Parameter change vector
        """
        # Validate and filter indices
        idx_S = np.asarray(subset_indices, dtype=int)
        idx_S = idx_S[(idx_S >= 0) & (idx_S < len(cluster_ids))]
        
        if len(idx_S) == 0:
            # No valid samples
            num_params = self.model.num_params if self.model else 1
            return torch.zeros(num_params, device=self.device)
        
        # Count cluster memberships
        subset_clusters = cluster_ids[idx_S]
        C = self.ifc.n_clusters
        m_c = np.bincount(subset_clusters, minlength=C)
        
        # Get cluster counts for coefficient computation
        if self._cluster_counts is not None:
            counts_all = self._cluster_counts
        else:
            # Fallback: compute from cluster_ids
            counts_all = np.bincount(cluster_ids, minlength=C)
        
        # Compute weighted sum of z_c
        delta_theta = None
        
        for c in range(C):
            if m_c[c] == 0:
                continue
            if not self.ifc.is_solved(c):
                continue
            
            z_c = self._get_z(c, to_device=True)
            
            # Compute coefficient
            if self.v_is_mean:
                # z_c ≈ H^{-1} * (mean grad in cluster c)
                coeff = float(m_c[c])
            else:
                # z_c ≈ H^{-1} * (sum grad in cluster c)
                N_c = max(counts_all[c], 1)
                coeff = float(m_c[c]) / N_c
            
            if delta_theta is None:
                delta_theta = coeff * z_c
            else:
                delta_theta = delta_theta + coeff * z_c
        
        if delta_theta is None:
            # No valid clusters
            num_params = self.model.num_params if self.model else 1
            return torch.zeros(num_params, device=self.device)
        
        # Compute denominator
        if self.n_train is not None:
            N = self.n_train
        else:
            N = len(cluster_ids)
        
        if use_exact_den:
            denom = N - int(m_c.sum())
        else:
            denom = N
        denom = max(1, denom)
        
        return delta_theta / denom
    
    
    def compute_val_gradient(
        self,
        image: torch.Tensor,
        label: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute gradient for a validation sample.
        
        Args:
            image: (1, 3, H, W) validation image
            label: (1,) label
            
        Returns:
            Gradient vector
        """
        if self.model is None:
            raise RuntimeError("Model required for gradient computation")
        
        self.model.model.eval()
        params = list(self.model.model.parameters())
        
        for p in params:
            p.requires_grad = True
        
        image = image.to(self.device)
        label = label.to(self.device)
        
        self.model.model.zero_grad()
        logits = self.model(image)
        loss = F.cross_entropy(logits, label)
        loss.backward()
        
        grad = self.model.get_grad_vector()
        return grad.detach()
    
    def compute_fold_loss_change(
        self,
        val_loader: DataLoader,
        subset_indices: np.ndarray,
        cluster_ids: np.ndarray,
        max_samples: Optional[int] = None,
    ) -> Tuple[float, List[float]]:
        """
        Compute predicted loss change for a validation fold.
        
        ΔL = (1/|V|) Σ_{v∈V} -g_v · Δθ
        
        Args:
            val_loader: DataLoader for validation samples
            subset_indices: Training subset to remove
            cluster_ids: Cluster assignments
            max_samples: Maximum validation samples
            
        Returns:
            mean_loss_change: Average predicted loss change
            individual_scores: List of per-sample scores
        """
        if self.model is None:
            raise RuntimeError("Model required for gradient computation")
        
        self._query_count += 1
        logger.info(f"Computing fold loss change (query #{self._query_count})")
        
        # Precompute delta_theta
        delta_theta = self.compute_delta_theta(subset_indices, cluster_ids)
        log_tensor_stats(logger, "delta_theta", delta_theta)
        
        scores = []
        count = 0
        
        for images, labels in tqdm(val_loader, desc="Computing val scores"):
            images = images.to(self.device)
            labels = labels.to(self.device)
            batch_size = images.shape[0]
            
            for i in range(batch_size):
                if max_samples and count >= max_samples:
                    break
                
                grad = self.compute_val_gradient(
                    images[i:i+1], labels[i:i+1]
                )
                
                score = -torch.dot(grad.flatten(), delta_theta.flatten()).item()
                scores.append(score)
                count += 1
            
            if max_samples and count >= max_samples:
                break
        
        mean_score = np.mean(scores) if scores else 0.0
        
        # Log influence statistics
        if scores:
            scores_arr = np.array(scores)
            log_dict(logger, f"Fold loss change (query #{self._query_count})", {
                'n_val_samples': len(scores),
                'n_subset_samples': len(subset_indices),
                'mean_score': mean_score,
                'score_std': float(scores_arr.std()),
                'score_min': float(scores_arr.min()),
                'score_max': float(scores_arr.max()),
                'positive_scores': int((scores_arr > 0).sum()),
                'negative_scores': int((scores_arr < 0).sum()),
            })
        
        return mean_score, scores
    
    def compute_per_cluster_influence(
        self,
        val_gradient: torch.Tensor,
    ) -> Dict[int, float]:
        """
        Compute influence of each cluster on a validation sample.
        
        Returns dict mapping cluster_id -> influence score.
        
        Args:
            val_gradient: Gradient of validation sample
            
        Returns:
            Dict of cluster influences
        """
        val_gradient = val_gradient.to(self.device)
        influences = {}
        
        for c in range(self.ifc.n_clusters):
            if not self.ifc.is_solved(c):
                continue
            
            z_c = self._get_z(c, to_device=True)
            score = -torch.dot(val_gradient.flatten(), z_c.flatten()).item()
            influences[c] = score
        
        return influences
    
    def get_most_influential_clusters(
        self,
        val_gradient: torch.Tensor,
        top_k: int = 10,
        most_helpful: bool = True,
    ) -> List[Tuple[int, float]]:
        """
        Get most influential clusters for a validation sample.
        
        Args:
            val_gradient: Validation sample gradient
            top_k: Number of clusters to return
            most_helpful: If True, return clusters that decrease loss
                         If False, return clusters that increase loss
                         
        Returns:
            List of (cluster_id, influence_score) tuples
        """
        influences = self.compute_per_cluster_influence(val_gradient)
        
        # Sort by influence
        sorted_clusters = sorted(
            influences.items(),
            key=lambda x: x[1],
            reverse=not most_helpful,  # Most helpful = most negative
        )
        
        return sorted_clusters[:top_k]
    
    def clear_cache(self):
        """Clear solution cache."""
        self._z_cache.clear()
        self._z_cache_device.clear()
        torch.cuda.empty_cache()


def compute_subset_influence(
    ifc_path: str,
    imagenet_root: str,
    subset_indices: np.ndarray,
    cluster_ids: np.ndarray,
    val_split: str = "train",
    max_val_samples: int = 1000,
    device: str = "cuda",
    use_tiny: bool = True,
) -> Tuple[float, List[float]]:
    """
    Compute influence of removing a training subset.
    
    DDP: Runs only on rank 0. Other ranks should not call this function.
    
    Args:
        ifc_path: Path to IFC directory
        imagenet_root: Path to ImageNet
        subset_indices: Training indices to remove
        cluster_ids: Cluster assignments
        val_split: Validation split name
        max_val_samples: Maximum validation samples
        device: Computation device
        
    Returns:
        mean_influence: Average influence score
        per_sample_scores: List of per-sample scores
    """
    # DDP: Only rank 0 runs query
    if is_ddp() and not is_rank0():
        logger.warning("compute_subset_influence called on non-rank-0 in DDP mode, returning empty")
        return 0.0, []
    
    # Load IFC
    ifc = load_ifc(ifc_path, device)
    
    # Load model
    model = load_vit(pretrained=True, device=device, use_tiny=use_tiny)
    
    # Create query interface
    query = IFCQuery(ifc, model, device)
    
    # Load validation data
    val_dataset = ImageNetDataset(imagenet_root, split=val_split)
    val_loader = val_dataset.get_loader(batch_size=1, shuffle=False)
    
    # Compute influence
    mean_score, scores = query.compute_fold_loss_change(
        val_loader,
        subset_indices,
        cluster_ids,
        max_samples=max_val_samples,
    )
    
    return mean_score, scores


def find_poisoned_samples(
    ifc_path: str,
    imagenet_root: str,
    cluster_ids: np.ndarray,
    val_loader: DataLoader,
    n_suspicious: int = 100,
    device: str = "cuda",
    use_tiny: bool = True,
) -> np.ndarray:
    """
    Find potentially poisoned training samples.
    
    Identifies samples that have high negative influence on
    validation performance.
    
    DDP: Runs only on rank 0. Other ranks should not call this function.
    
    Args:
        ifc_path: Path to IFC directory
        imagenet_root: Path to ImageNet
        cluster_ids: Cluster assignments
        val_loader: Validation DataLoader
        n_suspicious: Number of suspicious samples to return
        device: Computation device
        
    Returns:
        suspicious_indices: Indices of suspicious training samples
    """
    # DDP: Only rank 0 runs query
    if is_ddp() and not is_rank0():
        logger.warning("find_poisoned_samples called on non-rank-0 in DDP mode, returning empty")
        return np.array([])
    
    # Load IFC and model
    ifc = load_ifc(ifc_path, device)
    model = load_vit(pretrained=True, device=device, use_tiny=use_tiny)
    query = IFCQuery(ifc, model, device)
    
    # Aggregate cluster influences across validation set
    cluster_influences = {c: 0.0 for c in range(ifc.n_clusters)}
    n_val = 0
    
    for images, labels in tqdm(val_loader, desc="Computing cluster influences"):
        images = images.to(device)
        labels = labels.to(device)
        
        for i in range(len(images)):
            grad = query.compute_val_gradient(images[i:i+1], labels[i:i+1])
            influences = query.compute_per_cluster_influence(grad)
            
            for c, score in influences.items():
                cluster_influences[c] += score
            
            n_val += 1
    
    # Average
    for c in cluster_influences:
        cluster_influences[c] /= n_val
    
    # Find most harmful clusters
    sorted_clusters = sorted(
        cluster_influences.items(),
        key=lambda x: x[1],
        reverse=True,  # Highest = most harmful
    )
    
    # Get samples from harmful clusters
    suspicious = []
    
    for cluster_id, _ in sorted_clusters:
        cluster_samples = np.where(cluster_ids == cluster_id)[0]
        suspicious.extend(cluster_samples.tolist())
        
        if len(suspicious) >= n_suspicious:
            break
    
    return np.array(suspicious[:n_suspicious])


class InfluenceAnalyzer:
    """
    High-level interface for influence analysis.
    
    Provides methods for common influence computation tasks.
    """
    
    def __init__(
        self,
        ifc_path: str,
        imagenet_root: str,
        device: str = "cuda",
        use_tiny: bool = True,
    ):
        """
        Args:
            ifc_path: Path to IFC directory
            imagenet_root: Path to ImageNet
            device: Computation device
        """
        self.ifc_path = ifc_path
        self.imagenet_root = imagenet_root
        self.device = device
        
        # Load components
        self.ifc = load_ifc(ifc_path, device)
        self.model = load_vit(pretrained=True, device=device, use_tiny=use_tiny)
        self.query = IFCQuery(self.ifc, self.model, device)
        
        # Load cluster IDs
        cluster_ids_path = os.path.join(ifc_path, "cluster_id.npy")
        self.cluster_ids = np.load(cluster_ids_path)
        
    
    def find_helpful_training_samples(
        self,
        val_image: torch.Tensor,
        val_label: torch.Tensor,
        top_k: int = 10,
    ) -> List[Tuple[int, int, float]]:
        """
        Find training samples most helpful for a validation sample.
        
        Returns list of (cluster_id, cluster_size, influence) tuples.
        """
        val_grad = self.query.compute_val_gradient(val_image, val_label)
        
        return self.query.get_most_influential_clusters(
            val_grad, top_k, most_helpful=True
        )
    
    def find_harmful_training_samples(
        self,
        val_image: torch.Tensor,
        val_label: torch.Tensor,
        top_k: int = 10,
    ) -> List[Tuple[int, int, float]]:
        """
        Find training samples most harmful for a validation sample.
        """
        val_grad = self.query.compute_val_gradient(val_image, val_label)
        
        return self.query.get_most_influential_clusters(
            val_grad, top_k, most_helpful=False
        )


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python query.py /path/to/ifc /path/to/imagenet")
        print("\nExample queries will be shown here when IFC is available.")
        sys.exit(0)
    
    ifc_path = sys.argv[1]
    imagenet_root = sys.argv[2]
    
    # Load analyzer
    analyzer = InfluenceAnalyzer(ifc_path, imagenet_root)
    
    # Load a validation sample
    val_dataset = ImageNetDataset(imagenet_root, split="val")
    val_loader = val_dataset.get_loader(batch_size=1, shuffle=False)
    
    for images, labels in val_loader:
        print(f"\nValidation sample: label={labels[0].item()}")
        
        # Find helpful clusters
        helpful = analyzer.find_helpful_training_samples(images, labels, top_k=5)
        print("\nMost helpful clusters:")
        for cluster_id, score in helpful:
            print(f"  Cluster {cluster_id}: influence={score:.6f}")
        
        # Find harmful clusters
        harmful = analyzer.find_harmful_training_samples(images, labels, top_k=5)
        print("\nMost harmful clusters:")
        for cluster_id, score in harmful:
            print(f"  Cluster {cluster_id}: influence={score:.6f}")
        
        break
