from typing import Union, Optional, List
from enum import Enum
from dataclasses import dataclass
import numpy as np
from copy import deepcopy

from weaver.constants import VERIFIER_DESCRIPTIONS, DATASET_TO_REWARD_MODELS, DATASET_TO_LM_JUDGES


class FilteringStrategyType(Enum):
    TOP_K = "top_k"  # requires int param
    BEST_TIED = "best_tied"  # no params
    THRESHOLD = "threshold"  # requires float param
    UNIQUE_TOP_K = "unique_top_k"  # requires int param
    MAJORITY_VOTE = "majority_vote_k"  # requires int param : select generations that have one of the top k answers 

@dataclass
class FilterStrategy:
    """Simple container and validator for filtering strategy configuration"""
    name: str
    params: Optional[Union[int, float]] = None

    def __post_init__(self):
        # Convert name to lowercase and validate it's a known strategy
        if self.name not in {strategy.value for strategy in FilteringStrategyType}:
            raise ValueError(f"Unknown filtering strategy: {self.name}. "
                           f"Must be one of {[strategy.value for strategy in FilteringStrategyType]}")

        # Validate parameters based on strategy type
        if self.name == FilteringStrategyType.BEST_TIED.value:
            if self.params is not None:
                raise ValueError(f"Strategy '{self.name}' does not require any parameters (threshold is set by global.reward_threshold).")
        elif self.name == FilteringStrategyType.TOP_K.value:
            if not isinstance(self.params, int) or self.params <= 0:
                raise ValueError(f"Strategy '{self.name}' requires a positive integer parameter, k.")
        elif self.name == FilteringStrategyType.UNIQUE_TOP_K.value:
            if not isinstance(self.params, int) or self.params <= 0:
                raise ValueError(f"Strategy '{self.name}' requires a positive integer parameter, k.")
        elif self.name == FilteringStrategyType.MAJORITY_VOTE.value:
            if not isinstance(self.params, int) or self.params <= 0:
                raise ValueError(f"Strategy '{self.name}' requires a positive integer parameter")
        elif self.name == FilteringStrategyType.THRESHOLD.value:
            if self.params is None:
                raise ValueError(f"Strategy '{self.name}' requires a parameter for thresholding")
            if not isinstance(self.params, (int, float)):
                raise ValueError(
                    f"Strategy '{self.name}' requires a numerical threshold parameter")



class VerificationMethod():
    def __init__(self, dataset_path, dataset, verifier_metrics, 
                 filter_strategy, filter_strategy_param, 
                 verifier_subset, mv_as_voter
                 ):
        self.dataset_path = dataset_path 
        self.dataset = dataset 

        if isinstance(verifier_subset, int):
            # select all reward models and judges that are under this parameter size
            verifier_subset = [v for v, desc in VERIFIER_DESCRIPTIONS.items() if desc['num_parameters'] <= verifier_subset]

        self.reward_models = sorted(list(set(DATASET_TO_REWARD_MODELS[self.dataset_path]).intersection(set(verifier_subset))) if verifier_subset is not None else DATASET_TO_REWARD_MODELS[self.dataset_path])
        self.reward_models = [rm for rm in self.reward_models if "_step" not in rm]
        self.lm_judges = sorted(list(set(DATASET_TO_LM_JUDGES[self.dataset_path]).intersection(set(verifier_subset))) if verifier_subset is not None else DATASET_TO_LM_JUDGES[self.dataset_path])

        if mv_as_voter:
            self.reward_models += ['mv_verifier']

        self.scores = np.column_stack([self.dataset[col] for col in self.reward_models + self.lm_judges])
        self.scores = self.scores.reshape(len(self.dataset), len(self.reward_models + self.lm_judges), -1)
        self.scores = self.scores.transpose(0, 2, 1) # X_data

        self.reward_model_idxs = np.arange(len(self.reward_models))
        self.lm_judges_idxs = np.arange(len(self.reward_models), len(self.reward_models + self.lm_judges))

        self.verifier_metrics = verifier_metrics # weight by different verifier accuracies 
        if filter_strategy is not None:
            self.filter_strategy = FilterStrategy(filter_strategy, filter_strategy_param) # top-k or something else 

        self.binary_scores = deepcopy(self.scores)
        thresholds = np.column_stack([self.dataset[f"{rm}_threshold"] for rm in self.reward_models])  # Shape (n_problems, n_reward_models)
        self.binary_scores[:, :, :len(self.reward_models)] = (self.binary_scores[:, :, :len(self.reward_models)] > thresholds[:, None, :]).astype(int)


    def filter(self, dataset, cumulative_mask, **kwargs) -> List[List[int]]:
        pass 
