"""
Test information about verifiers:
pytest test_constants.py -k test_verifier_descriptions_present 
pytest test_constants.py -k test_consistent_verifiers_across_datasets
pytest test_constants.py -k test_reward_models_after_renaming 
pytest test_constants.py -s -k test_print_verifiers_by_size_for_each_dataset
"""
try:
    import weaver
except:
    import sys
    import os
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from weaver.constants import VERIFIER_DESCRIPTIONS, DATASET_TO_HF, DATASET_TO_VERIFIERS, VERIFIER_NAME_MAP
import pandas as pd
import pytest


DATA_IGNORE_LIST = ["CodeContests_gonly"]

ALL_DATASETS = list(set(DATASET_TO_HF.keys()) - set(DATA_IGNORE_LIST))
# filter to check only v2 datasets
ALL_DATASETS = [d for d in ALL_DATASETS if "v2" in d]


@pytest.mark.parametrize("data", ALL_DATASETS)
def test_verifier_descriptions_present(data):
    """Ensure all verifiers are present in VERIFIER_DESCRIPTIONS."""
    for model_size in DATASET_TO_HF[data].keys():
        dataset_name = DATASET_TO_HF[data][model_size]

        # Check all reward models are in the VERIFIER_DESCRIPTIONS dictionary
        all_verifier_names = DATASET_TO_VERIFIERS[dataset_name]

        # check uniqueness of verifier names before renaming
        verifier_names = set(all_verifier_names)
        assert len(verifier_names) == len(all_verifier_names), \
            f"Duplicate verifier names found in {dataset_name}: {verifier_names}"

        for verifier_name in all_verifier_names:
            assert verifier_name in VERIFIER_NAME_MAP, f"{verifier_name} is not in VERIFIER_NAME_MAP"

            assert VERIFIER_NAME_MAP[verifier_name] in VERIFIER_DESCRIPTIONS.keys(), (
                f"{verifier_name} {VERIFIER_NAME_MAP[verifier_name]} for {dataset_name} is not in VERIFIER_DESCRIPTIONS"
                )

        # check uniqueness of verifier names after renaming
        renamed_verifier_names = set()
        for verifier_name in all_verifier_names:
            renamed_verifier_name = VERIFIER_NAME_MAP[verifier_name]

            assert renamed_verifier_name not in renamed_verifier_names, \
                f"Duplicate renamed verifier name found for {dataset_name}: original: {verifier_name}, renamed: {renamed_verifier_name}"
            renamed_verifier_names.add(renamed_verifier_name)

    print(f"All verifier names are unique for dataset {data}.")


def test_consistent_verifiers_across_datasets():
    # Create and display a DataFrame for each model size
    for model_size in ["8B"]: #sorted(set(ms for data in ALL_DATASETS for ms in DATASET_TO_HF[data].keys())):
        # Collect reward models for the current model size
        model_size_reward_models = set()
        for data in ALL_DATASETS:
            model_sizes = DATASET_TO_HF[data]
            if model_size in model_sizes:
                dataset_name = model_sizes[model_size]
                reward_models = DATASET_TO_VERIFIERS[dataset_name]
                model_size_reward_models.update(reward_models)

        # Create a DataFrame for the current model size
        df = pd.DataFrame(index=sorted(model_size_reward_models))
        for data in ALL_DATASETS:
            model_sizes = DATASET_TO_HF[data]
            if model_size in model_sizes:
                dataset_name = model_sizes[model_size]
                reward_models = DATASET_TO_VERIFIERS[dataset_name]
                df[data] = df.index.isin(reward_models)

        # Display the DataFrame
        print(f"\nReward Models Presence Table for Model Size: {model_size}")
        print(df.astype(int))  # Convert boolean to int for clearer display
        # assert unique reward models,
        duplicates = df[df.sum(axis=1) != df.sum(axis=1).iloc[0]]
        # Check which reward models don't match
        
        #assert df.sum(axis=1).nunique() == 1, f"Duplicate reward models found: {duplicates.to_dict()}"

@pytest.mark.parametrize("dataset", DATASET_TO_VERIFIERS.keys())
def test_reward_models_after_renaming(dataset):
    """Ensure all reward models after being renamed using the REWARD_MODELS_NAME_MAP are unique."""
    renamed_reward_models = set()
    used_verifier_names = set()
    # Get all reward models from the dataset
    reward_models = DATASET_TO_VERIFIERS[dataset]
    for model in reward_models:
        # Get each reward model from the dataset
        if model in VERIFIER_NAME_MAP:
            renamed_model = VERIFIER_NAME_MAP[model]
        else:
            raise ValueError(f"Model {model} not found in VERIFIER_NAME_MAP")
        assert renamed_model not in renamed_reward_models, f"Duplicate renamed reward model found in {dataset}: {model}:{renamed_model} used: {used_verifier_names}"
        renamed_reward_models.add(renamed_model)
        used_verifier_names.add(model)
    print(f"All renamed reward models are unique for dataset {dataset}.")


@pytest.mark.parametrize("data", ALL_DATASETS)
def test_print_verifiers_by_size_for_each_dataset(data):
    # For each dataset, get reward models and judges
    for model_size in DATASET_TO_HF[data].keys():
        dataset_name = DATASET_TO_HF[data][model_size]

        # Check all reward models are in the VERIFIER_DESCRIPTIONS dictionary
        all_verifier_names = DATASET_TO_VERIFIERS[dataset_name]
        less_than_8B = set()
        less_than_32B = set()
        less_than_75B = set()
        less_than_400B = set()
        greater_than_400B = set()
        for verifier in all_verifier_names:
            details = VERIFIER_DESCRIPTIONS[VERIFIER_NAME_MAP[verifier]]
            num_parameters = details.get('num_parameters')
            if num_parameters is not None:
                if num_parameters <= 8:
                    less_than_8B.add(verifier)
                elif num_parameters <= 32:
                    less_than_32B.add(verifier)
                elif num_parameters <= 75:
                    less_than_75B.add(verifier)
                elif num_parameters <= 400:
                    less_than_400B.add(verifier)
                else:
                    greater_than_400B.add(verifier)

        # Calculate additional verifiers for each group
        additional_32B = less_than_32B - less_than_8B
        additional_75B = less_than_75B - less_than_32B - less_than_8B
        additional_400B = less_than_400B - less_than_75B - less_than_32B - less_than_8B
        additional_greater_than_400B = greater_than_400B - less_than_400B

        print(f"\n\nDataset: {data} base model size {model_size} in {dataset_name}")
        print(f"Verifiers w <=8B params: len={len(less_than_8B)}:\t", less_than_8B)
        print(f"Additional verifiers w <=32B params: len={len(additional_32B)}:\t", additional_32B)
        print(f"Additional verifiers w <=75B params: len={len(additional_75B)}:\t", additional_75B)
        print(f"Additional verifiers with <=400B params: len={len(additional_400B)}:\t", additional_400B)
        print(f"Verifiers with >=400B params: len={len(additional_greater_than_400B)}:\t", additional_greater_than_400B)

