from lm_polygraph.stat_calculators.stat_calculator import StatCalculator
from lm_polygraph.utils.model import Model

from .utils import recursive_to

from typing import Dict, Tuple, List
import torch
import numpy as np
import logging

log = logging.getLogger()


class CalculatorApplyUQHead(StatCalculator):
    def __init__(self, uncertainty_head):
        super().__init__()
        self.uncertainty_head = uncertainty_head

    @staticmethod
    def meta_info() -> Tuple[List[str], List[str]]:
        """
        Returns the statistics and dependencies for the calculator.
        """

        return [
            "uncertainty_claim_logits",
        ], ["uhead_features", "claims"]

    # def __call__(
    #     self,
    #     dependencies: Dict[str, np.array],
    #     texts: List[str],
    #     model: Model,
    #     max_new_tokens: int = 100,  # TODO: move to args_generate
    #     **kwargs,
    # ) -> Dict[str, np.ndarray]:
    #     # if type(batch)
    #     batch = dependencies["llm_inputs"]

    #     batch["claims"] = self.prepare_claims(batch, dependencies["claims"], dependencies["full_attention_mask"].shape[1])

    #     with torch.no_grad():
    #         uncertainty_logits = self.uncertainty_head._compute_tensors(
    #             recursive_to(batch, model.device()),
    #             dependencies["uhead_features"].to(model.device()),
    #             dependencies["full_attention_mask"][:, :-1].to(model.device()), # Ignoring last token
    #         )

    #     final_uncertainty_claims = [np.asarray([e.item() for e in claim if e != -100]) for claim in uncertainty_logits.cpu().numpy()]
    #     results = {"uncertainty_claim_logits": final_uncertainty_claims}
    #     return results

    def __call__(
        self,
        dependencies: Dict[str, np.array],
        texts: List[str],
        model: Model,
        max_new_tokens: int = 100,  # TODO: move to args_generate
        **kwargs,
    ) -> Dict[str, np.ndarray]:
        # Check if inputs are lists (new format) or single tensors (old format)
        is_list_format = isinstance(dependencies["uhead_features"], list)
        
        if is_list_format:
            log.info("Using list format for uncertainty head")
            # Handle list format - process each batch separately
            all_final_uncertainty_claims = []
            
            # Get total number of samples processed so far
            sample_idx = 0
            
            for batch_idx, (batch_features, batch_mask, batch_inputs) in enumerate(zip(
                dependencies["uhead_features"], 
                dependencies["full_attention_mask"],
                dependencies["llm_inputs"]
            )):
                # Get claims for this batch
                # import pdb; pdb.set_trace()
                batch_size = len(batch_inputs["input_ids"])
                batch_claims = dependencies["claims"][sample_idx:sample_idx + batch_size]
                
                # Prepare claims for this batch
                batch_inputs["claims"] = self.prepare_claims(
                    batch_inputs, 
                    batch_claims, 
                    batch_mask.shape[1]
                )
                # import pdb; pdb.set_trace()
                
                with torch.no_grad():
                    uncertainty_logits = self.uncertainty_head._compute_tensors(
                        recursive_to(batch_inputs, model.device()),
                        batch_features.to(model.device()),
                        batch_mask[:, :-1].to(model.device()), # Ignoring last token
                    )
                
                batch_final_claims = [
                    np.asarray([e.item() for e in claim if e != -100]) 
                    for claim in uncertainty_logits.cpu().numpy()
                ]
                all_final_uncertainty_claims.extend(batch_final_claims)
                
                sample_idx += batch_size
            
            results = {"uncertainty_claim_logits": all_final_uncertainty_claims}
            
        else:
            # Handle old format - single tensors (backward compatibility)
            log.info("Using old format for uncertainty head")
            batch = dependencies["llm_inputs"]
            batch["claims"] = self.prepare_claims(
                batch, 
                dependencies["claims"], 
                dependencies["full_attention_mask"].shape[1]
            )

            with torch.no_grad():
                uncertainty_logits = self.uncertainty_head._compute_tensors(
                    recursive_to(batch, model.device()),
                    dependencies["uhead_features"].to(model.device()),
                    dependencies["full_attention_mask"][:, :-1].to(model.device()), # Ignoring last token
                )

            final_uncertainty_claims = [
                np.asarray([e.item() for e in claim if e != -100]) 
                for claim in uncertainty_logits.cpu().numpy()
            ]
            results = {"uncertainty_claim_logits": final_uncertainty_claims}
            # import pdb; pdb.set_trace()
        
        return results

    def prepare_claims(self, batch, claims, full_len):
        batch_size = len(batch["input_ids"])
        context_lenghts = batch["context_lenghts"]
        all_claim_tensors = []
        for i in range(batch_size):
            instance_claims = []
            for claim in claims[i]:
                mask = torch.zeros(full_len, dtype=int)
                mask[(context_lenghts[i] + torch.as_tensor(claim.aligned_token_ids)).int()] = 1
                instance_claims.append(mask[1:]) # ignoring <s>

            all_claim_tensors.append(torch.stack(instance_claims) if len(instance_claims) > 0 else torch.zeros(0, full_len - 1, dtype=int))

        return all_claim_tensors
