import random

from pathlib import Path
from typing import (
    Callable,
    List,
    Optional,
)

import fire

from loguru import logger

from src.dataset import dataset_factory
from src.schema.qa import QA
from src.schema.result import (
    InferenceResult,
    PredictedAnswerSet,
)
from src.utils.json import write_json_file
from src.utils.log import set_log_level
from src.utils.set import union_of_sets


def infer_element(qa: QA, pick_strategy_fn: Callable[[List[str]], List[str]]) -> InferenceResult:
    answer_element_universe = qa.answer_element_universe or list(
        union_of_sets([answer_set.elements for answer_set in qa.answer_sets])
    )
    pred_answer_sets = [
        PredictedAnswerSet(
            answer_set_id=answer_set.answer_set_id,
            user_id=answer_set.user_id,
            elements=pick_strategy_fn(answer_element_universe),
            ref_memory_ids=[],
            ref_memory_contents=[],
            ref_memory_scores=[],
        )
        for answer_set in qa.answer_sets
    ]
    return InferenceResult(qa=qa, pred_answer_sets=pred_answer_sets)


def pick_strategy_factory(pick_strategy_name: str) -> Callable[[List[str]], List[str]]:
    if pick_strategy_name == "pick_one_uniform_random":
        return pick_one_uniform_random
    if pick_strategy_name == "pick_set_uniform_random":
        return pick_set_uniform_random
    if pick_strategy_name == "pick_all":
        return pick_all

    raise ValueError(f"Unknown strategy: {pick_strategy_name}")


def pick_one_uniform_random(answer_element_universe: List[str]) -> List[str]:
    return [random.choice(answer_element_universe)]


def pick_set_uniform_random(answer_element_universe: List[str]) -> List[str]:
    mask = [random.randint(0, 1) for _ in answer_element_universe]
    return [element for i, element in enumerate(answer_element_universe) if mask[i]]


def pick_all(answer_element_universe: List[str]) -> List[str]:
    return answer_element_universe[:]


def infer(
    inference_results_dir: str,
    dataset_name: str,
    qa_data_path: str,
    batch_size: Optional[int] = None,
    shuffle: bool = False,
    num_workers: int = 8,
    seed: int = 42,
) -> None:
    set_log_level()

    random.seed(seed)
    inference_results_dir = Path(inference_results_dir)
    inference_results_dir.mkdir(parents=True, exist_ok=True)

    dataset = dataset_factory(dataset_name=dataset_name, data_path=qa_data_path)
    data_loader = dataset.get_data_loader(
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )

    for pick_strategy_name in ["pick_one_uniform_random", "pick_set_uniform_random", "pick_all"]:
        pick_strategy_fn = pick_strategy_factory(pick_strategy_name)
        inference_results = [infer_element(qa, pick_strategy_fn) for qa in data_loader]
        inference_results_dict = [res.model_dump(mode="json") for res in inference_results]
        logger.info(f"[Inference Result Sample]\n{inference_results_dict[0]}")
        inference_results_path = inference_results_dir / f"{pick_strategy_name}.json"
        write_json_file(file_path=inference_results_path, data=inference_results_dict)


if __name__ == "__main__":
    fire.Fire(infer)
