from weaver.verification import VerificationMethod, FilteringStrategyType
import numpy as np
from collections import Counter
from typing import List
from loguru import logger
from tqdm import tqdm
from copy import deepcopy 

logger.remove()
logger.add(lambda msg: print(msg, end=""), format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {message}", level="INFO")

class FirstSample(VerificationMethod):
    """Always selects the first sample from each row."""
    
    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        """
        Returns indices for first sample in each row.
        
        Args:
            dataset: HuggingFace dataset with samples
            
        Returns:
            List of lists containing selected indices (always [0] for this method)
        """
        # Always select first sample (index 0)
        filtered_idxs = [[0]] * len(dataset)         
        return filtered_idxs

class MajorityVote(VerificationMethod):
    """Selects samples that match the K most common answers in each row."""
    
    def filter(self, dataset, cumulative_mask, k: int = 1, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples that match any of the K most common answers.
        
        Args:
            dataset: HuggingFace dataset with samples
            k: Number of most common answers to consider (default: 1)
            
        Returns:
            List of lists containing selected indices
        """
        filtered_idxs = []        
        for i in tqdm(range(len(dataset))):
            answers = dataset["extracted_answers"][i]
            # Count frequency of each answer over the valid indices, excluding NO_ANSWER
            if cumulative_mask is not None:
                mask_idxs = np.where(np.array(cumulative_mask)[i] == 1)[0] # find what the valid indices are, and only select over those 
            else:
                mask_idxs = np.arange(len(answers))
            valid_answers = [ans for i, ans in enumerate(answers) if ans != 'NO_ANSWER' and i in mask_idxs]
            if len(valid_answers) == 0:
                filtered_idxs.append([])
                continue
                
            answer_counts = Counter(valid_answers)
            if self.filter_strategy.name == FilteringStrategyType.MAJORITY_VOTE.value:
                top_k_answers = [answer for answer, count in answer_counts.most_common(self.filter_strategy.params)]
                # Select indices where answer matches any of the top k answers
                best_idxs = [
                    j for j, ans in enumerate(answers)
                    if ans in top_k_answers and j in mask_idxs
                ]
            elif self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
                # Sort answers by frequency (highest first), then by first appearance (stable sorting)
                sorted_answers = sorted(answer_counts.items(), key=lambda x: -x[1])

                best_idxs = []
                for answer, _ in sorted_answers:
                    indices = [j for j, ans in enumerate(answers) if ans == answer and j in mask_idxs]

                    # Add as many as possible, but don't exceed k
                    best_idxs.extend(indices)
                    if len(best_idxs) >= self.filter_strategy.params:
                        break

                # Trim exactly to k entries
                best_idxs = best_idxs[:self.filter_strategy.params]
            else:
                ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")
            filtered_idxs.append(best_idxs)

        return filtered_idxs

class HighestScoringRM(VerificationMethod):
    """Selects the sample with highest score from specified reward model."""
    
    def filter(self, dataset, cumulative_mask, rm_column: str, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples with highest RM scores.
        
        Args:
            dataset: HuggingFace dataset with samples
            rm_column: Name of reward model column to use
            
        Returns:
            List of lists containing selected indices
        """
        if rm_column not in dataset.column_names:
            raise ValueError(f"Reward model column {rm_column} not found in dataset")
            
        filtered_idxs = []
        for i, sample in enumerate(dataset):
            scores = sample[rm_column] # n_generations vector
            reward_threshold = sample[f"{rm_column}_threshold"]
            if all(s is None for s in scores):
                filtered_idxs.append([])
                continue
            
            scores = np.array(scores)
            if cumulative_mask is not None:
                scores[np.array(cumulative_mask)[i] == 0] = -10e8

            if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
                best_idxs = np.argsort(-scores, kind='stable')[:self.filter_strategy.params]
            elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
                best_idxs = np.where(scores >= reward_threshold)[0]
            else:
                ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

            filtered_idxs.append(best_idxs)
            
        return filtered_idxs

class HighestScoringLM(VerificationMethod):
    """Selects samples that a specified LM judge marks as correct."""
    
    def filter(self, dataset, cumulative_mask, judge_column: str, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples approved by LM judge.
        
        Args:
            dataset: HuggingFace dataset with samples
            judge_column: Name of LM judge verdicts column to use
            
        Returns:
            List of lists containing selected indices
        """
        if judge_column not in dataset.column_names:
            raise ValueError(f"LM judge column {judge_column} not found in dataset")
            
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            logger.warning("top-k for a judge has arbitrary tie-breaking policy.")

        filtered_idxs = []
        for i in tqdm(range(len(dataset))):
            verdicts = dataset[judge_column][i]
            
            if not verdicts:
                filtered_idxs.append([])
                continue
            
            scores = np.array(verdicts)
            if cumulative_mask is not None:
                scores[np.array(cumulative_mask)[i] == 0] = -10e8

            if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
                best_idxs = np.argsort(-scores, kind='stable')[:self.filter_strategy.params]
            elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
                best_idxs = np.where(scores == 1)[0]
            else:
                ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")
            filtered_idxs.append(best_idxs.tolist())

                    
        return filtered_idxs
    
class NaiveRMEnsemble(VerificationMethod):
    """Selects samples that pass a naive ensemble of reward models."""
    
    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples approved by ensemble of reward models.
        
        Args:
            dataset: HuggingFace dataset with samples
            reward_threshold: Threshold for individual RM scores
            naive_ensemble_threshold: Fraction of RMs that must pass for sample to be positive
            
        Returns:
            List of lists containing selected indices
        """
        if len(self.reward_models) == 0:
            return None

        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 
        
        ensembled_scores = masked_scores[:, :, self.reward_model_idxs].mean(axis=-1) # average across judges 
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")


        weights = np.zeros(self.scores.shape[-1])
        weights[self.reward_model_idxs] = 1

        return filtered_idxs, weights

        
class NaiveLMEnsemble(VerificationMethod):
    """Selects samples that pass a naive ensemble of LM judges."""
    
    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples approved by ensemble of LM judges.
        
        Args:
            dataset: HuggingFace dataset with samples
            naive_ensemble_threshold: Fraction of LM judges that must approve for sample to be positive
            
        Returns:
            List of lists containing selected indices
        """
        if len(self.lm_judges) == 0:
            return None

        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            logger.warning("top-k for a judge has arbitrary tie-breaking policy.")

        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 
        
        ensembled_scores = masked_scores[:, :, self.lm_judges_idxs].mean(axis=-1) # average across judges 
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.BEST_TIED.value:
            max_pass_rate = ensembled_scores.max(axis=1) # per problem 
            filtered_idxs = [np.where(row == max_pass_rate[i])[0] for i, row in enumerate(ensembled_scores)]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")
        
        weights = np.zeros(self.scores.shape[-1])
        weights[self.lm_judges_idxs] = 1
        
        return filtered_idxs, weights


class NaiveEnsemble(VerificationMethod):
    
    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples approved by ensemble of LM judges.
        
        Args:
            dataset: HuggingFace dataset with samples
            naive_ensemble_threshold: Fraction of LM judges that must approve for sample to be positive
            
        Returns:
            List of lists containing selected indices
        """
        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.scores, -10e8)
        else:
            masked_scores = self.scores 

        ensembled_scores = masked_scores.mean(axis=-1) # average across verifiers to get (n_problems, n_generations)
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        weights = np.ones(self.scores.shape[-1])
        return filtered_idxs, weights


class NaiveBinaryEnsemble(VerificationMethod):
    
    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        """
        Returns indices for samples approved by ensemble of LM judges.
        
        Args:
            dataset: HuggingFace dataset with samples
            naive_ensemble_threshold: Fraction of LM judges that must approve for sample to be positive
            
        Returns:
            List of lists containing selected indices
        """
        if cumulative_mask is not None:
            mask = np.array(cumulative_mask)[..., np.newaxis]
            masked_scores = np.where(mask, self.binary_scores, -10e8)
        else:
            masked_scores = self.binary_scores 

        ensembled_scores = masked_scores.mean(axis=-1) # average across verifiers to get (n_problems, n_generations)
        if self.filter_strategy.name == FilteringStrategyType.TOP_K.value:
            filtered_idxs = np.argsort(-ensembled_scores, axis=1, kind='stable')[:, :self.filter_strategy.params]
        elif self.filter_strategy.name == FilteringStrategyType.BEST_TIED.value:
            max_pass_rate = ensembled_scores.max(axis=1) # per problem 
            filtered_idxs = [np.where(row == max_pass_rate[i])[0] for i, row in enumerate(ensembled_scores)]
        elif self.filter_strategy.name == FilteringStrategyType.THRESHOLD.value:
            filtered_idxs = [np.where(row >= self.filter_strategy.params)[0] for row in ensembled_scores]
        else:
            ValueError(f"{self.filter_strategy.name} is not a valid filtering strategy for {self.__class__.__name__}.")

        weights = np.ones(self.scores.shape[-1])
        return filtered_idxs, weights
