"""
Batch-aware UHead scorer that uses pre-extracted features
Following the pattern from luh/calculator_apply_uq_head.py
"""

import torch
import numpy as np
import logging
from typing import Dict, List, Optional
from luh import AutoUncertaintyHead
from luh.utils import recursive_to
from synthetic_dataset_generation.utils.steps_extractor import StepsExtractor
from lm_polygraph import WhiteboxModel

log = logging.getLogger(__name__)


class BatchUHeadScorer:
    """
    Scorer that computes uncertainty using pre-extracted features.
    No forward passes - just applies UHead to existing features.
    """
    
    def __init__(
        self, 
        uncertainty_head=None,
        uhead_path: str = None,
        model=None,
        device: str = "cuda"
    ):
        """
        Initialize scorer with uncertainty head.
        
        Args:
            uncertainty_head: Pre-loaded uncertainty head
            uhead_path: Path to load uncertainty head from
            model: WhiteboxModel (needed for steps extraction and loading uhead)
            device: Device to run on
        """
        self.device = device
        self.model = model  # Store WhiteboxModel for StepsExtractor
        
        if uncertainty_head is not None:
            self.uncertainty_head = uncertainty_head
        elif uhead_path is not None and model is not None:
            self.uncertainty_head = AutoUncertaintyHead.from_pretrained(
                uhead_path, 
                model.model
            ).to(device)
            self.uncertainty_head.eval()
        else:
            raise ValueError("Either uncertainty_head or (uhead_path and model) must be provided")
        
        # Initialize steps extractor
        self.steps_extractor = StepsExtractor(progress_bar=False)
    
    def compute_uncertainties(
        self, 
        features_dict: Dict[str, any],
        input_texts: Optional[List[str]] = None
    ) -> List[float]:
        """
        Compute uncertainty scores from pre-extracted features.
        
        Args:
            features_dict: Dict with greedy_texts, greedy_tokens, uhead_features, full_attention_mask, llm_inputs (as lists)
            input_texts: Optional input texts for claim extraction context
            
        Returns:
            List of uncertainty scores (one per candidate)
        """
        # Validate required fields
        greedy_texts = features_dict.get('greedy_texts', None)
        greedy_tokens = features_dict.get('greedy_tokens', None)
        
        if greedy_texts is None:
            log.warning("greedy_texts is None in features_dict, using empty list")
            greedy_texts = []
        
        if greedy_tokens is None:
            log.warning("greedy_tokens is None in features_dict, using empty list")
            greedy_tokens = []
        
        # Create stats dict for StepsExtractor
        stats_dict = {
            'greedy_texts': greedy_texts,
            'greedy_tokens': greedy_tokens
        }
        # import pdb; pdb.set_trace()
        # Use empty input texts if not provided (for step candidates)
        if input_texts is None:
            log.warning(f"input_texts is None, using {len(greedy_texts)} empty strings")
            input_texts = [''] * len(greedy_texts)
        
        # Extract claims using StepsExtractor
        extraction_result = self.steps_extractor(
            stats_dict,
            input_texts,
            self.model
        )
        claims = extraction_result.get('claims', [])
        # import pdb; pdb.set_trace()
        # Debug logging
        log.debug(f"Extracted {len(claims)} claims for {len(greedy_texts)} texts")
        # Check if inputs are in list format (batch processing)
        is_list_format = isinstance(features_dict["uhead_features"], list)
        
        if not is_list_format:
            # Single batch format - convert to list for consistency
            features_dict = {
                "uhead_features": [features_dict["uhead_features"]],
                "full_attention_mask": [features_dict["full_attention_mask"]],
                "llm_inputs": [features_dict["llm_inputs"]]
            }
        
        all_uncertainty_scores = []
        
        # Track sample index for claims
        sample_idx = 0
        
        # Process each batch
        for batch_idx, (batch_features, batch_mask, batch_inputs) in enumerate(zip(
            features_dict["uhead_features"],
            features_dict["full_attention_mask"], 
            features_dict["llm_inputs"]
        )):
            # import pdb; pdb.set_trace()
            batch_size = len(batch_inputs["input_ids"])
            
            # Get claims for this batch
            batch_claims = claims[sample_idx:sample_idx + batch_size] if claims else [[]] * batch_size
            
            # Track which samples have empty claims
            empty_claim_indices = []
            for i, claims_list in enumerate(batch_claims):
                if not claims_list:  # Empty claims
                    empty_claim_indices.append(i)
            
            # Prepare claims following luh/calculator_apply_uq_head.py pattern
            batch_inputs["claims"] = self.prepare_claims(
                batch_inputs, 
                batch_claims,
                batch_mask.shape[1]
            )
            
            sample_idx += batch_size
            
            with torch.no_grad():
                # Apply uncertainty head following luh/calculator_apply_uq_head.py
                uncertainty_logits = self.uncertainty_head._compute_tensors(
                    recursive_to(batch_inputs, self.device),
                    batch_features.to(self.device),
                    batch_mask[:, :-1].to(self.device)  # Ignoring last token
                )
            
            # Convert logits to scores, handling empty claims
            # The uncertainty head skips samples with empty claims, so we need to insert neutral scores
            logit_idx = 0
            for i in range(batch_size):
                if i in empty_claim_indices:
                    # This sample had empty claims - use neutral score
                    score = 0.5
                    log.debug(f"Sample {i} in batch {batch_idx} had empty claims, using neutral score")
                else:
                    # Extract non-padding values from the corresponding logit
                    if logit_idx < uncertainty_logits.shape[0]:
                        logits = uncertainty_logits[logit_idx].cpu().numpy()
                        valid_logits = logits[logits != -100]
                        
                        if len(valid_logits) > 0:
                            # Convert logits to probabilities and average
                            probs = np.exp(valid_logits)
                            score = float(np.mean(probs))
                        else:
                            # No valid logits - neutral score
                            score = 0.5
                        logit_idx += 1
                    else:
                        # Shouldn't happen, but safeguard
                        log.warning(f"Logit index {logit_idx} out of bounds for uncertainty_logits shape {uncertainty_logits.shape}")
                        score = 0.5
                
                all_uncertainty_scores.append(score)
        
        return all_uncertainty_scores
    
    def prepare_claims(self, batch, claims, full_len):
        batch_size = len(batch["input_ids"])
        context_lenghts = batch["context_lenghts"]
        all_claim_tensors = []
        for i in range(batch_size):
            instance_claims = []
            for claim in claims[i]:
                mask = torch.zeros(full_len, dtype=int)
                mask[(context_lenghts[i] + torch.as_tensor(claim.aligned_token_ids)).int()] = 1
                instance_claims.append(mask[1:]) # ignoring <s>

            all_claim_tensors.append(torch.stack(instance_claims) if len(instance_claims) > 0 else torch.zeros(0, full_len - 1, dtype=int))

        return all_claim_tensors