""" This module is used to generate a verification dataset (the FoVer dataset) for FLDx2 """

import json
import random
from typing import Literal
from copy import deepcopy

from tap import Tap

from src.path import get_error_labels_path, get_fover_dataset_path
from src.config import splits_list
from src.typing import BASE_MODEL, PrmDatasetInstance
from src.llm.utils import save_md5_hash
from src.dataset_creation.prompts import \
    get_user_message, get_assistant_message, get_verification_prompt_for_single_turn_data
from src.dataset_creation.template import get_verification_reference_for_single_turn_data, get_verification_reference_for_multi_turn_data
from src.dataset_creation.utils import get_prm_dataset_stats


class GenerateVerificationDatasetTap(Tap):
    model_name: str
    dataset_name: str = "fldx2_symbol"
    generation_seed: str = "selected"
    suffix: str = "step_merged"
    instance_correct_ratio: float = 0.10


sampled_dataset_size_list: list[tuple[str, int]] \
    = [("5k", 5000), ("10k", 10000), ("20k", 20000)]


def get_balanced_error_label_dataset(
        llm_generated_data: list[dict], instance_correct_ratio: float=0.20,
        target_data_size: int | None = None,
        seed: int=68,
    ) -> list[dict]:
    """ Get a balanced dataset for the verification model.
    
    Args:
        llm_generated_data (list[dict]): The error labels dataset generated by
            the LLM model.
        instance_correct_ratio (float): The ratio of correct instances in the
            balanced dataset.
        seed (int): The seed for shuffling the dataset.
    
    Returns:
        output (list[dict]): The balanced dataset.
    """
    
    correct_instances = [
        d for d in llm_generated_data if d["all_process_correct"]]
    wrong_instances = [
        d for d in llm_generated_data if not d["all_process_correct"]]
    
    if target_data_size is not None:
        target_wrong_num = round(target_data_size * (1 - instance_correct_ratio))
        if len(wrong_instances) < target_wrong_num:
            print(
                f"Not enough wrong instances: {len(wrong_instances)} < " \
                f"{target_wrong_num}"
            )
        else:
            wrong_instances = random.Random(seed).sample(
                wrong_instances, target_wrong_num
            )

    # number of correct instances
    correct_num = round(
        len(wrong_instances) *
        instance_correct_ratio / (1 - instance_correct_ratio)
    )
    if len(correct_instances) < correct_num:
        raise ValueError(
            f"Not enough correct instances: {len(correct_instances)} < " \
            f"{correct_num}"
        )

    selected_correct_instances = random.Random(seed).sample(
        correct_instances, correct_num
    )
    
    # merge
    output = selected_correct_instances + wrong_instances
    # shuffle
    output = random.Random(seed).sample(output, len(output))
    
    return output


def organize_error_label_dataset_by_categories_fldx2(
        error_labels_data: list[dict], base_dataset_name: str,
        seed: int=68
    ) -> dict[str, list[dict]]:
    """ Organize the error labels dataset by y_true and correctness.
    
    Args:
        error_labels_data (list[dict]): The error labels dataset.
        base_dataset_name (str): The base dataset name.
        seed (int): The seed for shuffling the dataset.
    
    Returns:
        output (dict[str, list[dict]]): The organized dataset. The key is
            "y_true={y_true},correct={correct}".
    """
    
    if base_dataset_name in ["fldx2_symbol", "fldx2_text"]:
        from src.dataset_creation.base_dataset_specific.fol.typing \
            import FOL_PROOF_LABEL
        base_label_options = FOL_PROOF_LABEL.__args__
    else:
        raise ValueError(f"Unknown base dataset name: {base_dataset_name}")
    
    # initialize the output dict
    output: dict[str, list[dict]] = {}
    for correct in [True, False]:
        for base_dataset_label in base_label_options:
            output[f"y_true={base_dataset_label},correct={correct}"] = []
    
    # organize the dataset
    for d in error_labels_data:
        y_true = d["y_true"]
        
        # all steps are correct
        y_correct = d["y_correct"]
        all_step_correct = all(d["proof_step_correctness"])
        correct = y_correct and all_step_correct
        
        output[f"y_true={y_true},correct={correct}"].append(d)
    
    # shuffle
    for key in sorted(list(output.keys())):
        output[key] = random.Random(seed).sample(output[key], len(output[key]))
    
    return output


def get_balanced_error_label_dataset_fldx2(
        llm_generated_data: list[dict], ground_truth_data: list[dict],
        base_dataset_name: str,
        target_data_size: int, instance_correct_ratio: float=0.20,
        seed: int=68
    ) -> list[dict]:
    """ Get a balanced dataset for the verification model.
    
    Args:
        llm_generated_data (list[dict]): The error labels dataset generated by
            the LLM model.
        ground_truth_data (list[dict]): The error labels dataset of ground
            truth data.
        target_data_size (int): The target data size.
        instance_correct_ratio (float): The ratio of correct instances in the
            balanced dataset.
        seed (int): The seed for shuffling the dataset.
    
    Returns:
        output (list[dict]): The balanced dataset.
    """
    
    if base_dataset_name in ["fldx2_symbol", "fldx2_text"]:
        from src.dataset_creation.base_dataset_specific.fol.typing \
            import FOL_PROOF_LABEL
        base_label_options = FOL_PROOF_LABEL.__args__
    else:
        raise ValueError(f"Unknown base dataset name: {base_dataset_name}")
    
    # check if target_data_size * instance_correct_ratio is integer
    if target_data_size * instance_correct_ratio % 1 != 0:
        raise ValueError(
            "target_data_size * instance_correct_ratio must be integer."
        )
    # check if target_data_size * instance_correct_ratio is divisible by 3
    if round(target_data_size * instance_correct_ratio) \
            % len(base_label_options) != 0:
        raise ValueError(
            "target_data_size * instance_correct_ratio must be divisible by "\
            "len(base_label_options)."
        )
    if round(target_data_size * (1 - instance_correct_ratio)) % 3 != 0:
        raise ValueError(
            "target_data_size * (1 - instance_correct_ratio) must be " \
            "divisible by len(base_label_options): "\
            f"{target_data_size * (1 - instance_correct_ratio)} % " \
            f"{len(base_label_options)} != 0"
        )
    
    ###
    # generate the balanced dataset
    
    # organize the dataset
    llm_generated_data_organized = organize_error_label_dataset_by_categories_fldx2(
        llm_generated_data, base_dataset_name=base_dataset_name
    )
    ground_truth_data_organized = organize_error_label_dataset_by_categories_fldx2(
        ground_truth_data, base_dataset_name=base_dataset_name
    )
    
    # get the target data size per category
    print(f"Target data size: {target_data_size}")
    output: list[dict] = []
    for correct in [True, False]:
        ratio = instance_correct_ratio if correct \
            else 1 - instance_correct_ratio
        target_data_size_per_category = \
            round(target_data_size * ratio) // 3
        print(
            f"Correct={correct}, " \
            f"target data size per category: {target_data_size_per_category}"
        )
        
        for base_dataset_label in base_label_options:
            category_data = (
                llm_generated_data_organized[
                    f"y_true={base_dataset_label},correct={correct}"]
                + ground_truth_data_organized[
                    f"y_true={base_dataset_label},correct={correct}"]
            )

            if len(category_data) < target_data_size_per_category:
                raise ValueError(
                    f"Category {base_dataset_label}, correct={correct} has " \
                    f"less data than the target size: {len(category_data)} < " \
                    f"{target_data_size_per_category}"
                )

            output.extend(category_data[:target_data_size_per_category])
            print(
                f"Category {base_dataset_label}, correct={correct} has " \
                f"{len(category_data)} data, {target_data_size_per_category} "\
                "data is selected."
            )
    
    assert len(output) == target_data_size
    return random.Random(seed).sample(output, target_data_size)


def get_last_step_balanced_dataset(
            llm_generate_data: list[dict], last_step_correct_ratio=0.5
        ) -> list[dict]:
    """ Get the data whose last steps are balanced.
    During training, only the last step is used.
    We don't use this dataset for evaluation.
    
    Args:
        llm_generate_data (list[dict]): The error labels dataset generated by
            the LLM model.
    
    Returns:
        output (list[dict]): The last step balanced dataset.
    """
    
    last_step_labels_dict: dict[str, list[dict]] = {
        "correct": [], "incorrect": [],
        "last_step_correct": [], "last_step_incorrect": []
    }
    for d in llm_generate_data:
        for step_id in range(1, len(d["proof_steps"]) + 1):
            # only use the steps up to step_id
            new_d = deepcopy(d)
            new_d["proof_steps"] = d["proof_steps"][:step_id]
            new_d["proof_step_correctness"] = \
                d["proof_step_correctness"][:step_id]
            
            if "cot_steps" in d.keys():
                new_d["cot_steps"] = d["cot_steps"][:step_id]

            # make y labels null because we only use the last step
            new_d["y_true"] = None
            new_d["y_correct"] = None

            # last step correctness
            last_step_correct = new_d["proof_step_correctness"][-1]
            
            # the last steps are the final answers, which are different from
            # intermediate steps
            # we will store the last steps in a separate list
            if step_id == len(d["proof_steps"]) - 1:
                if last_step_correct:
                    last_step_labels_dict["last_step_correct"].append(new_d)
                else:
                    last_step_labels_dict["last_step_incorrect"].append(new_d)
            else:
                if last_step_correct:
                    last_step_labels_dict["correct"].append(new_d)
                else:
                    last_step_labels_dict["incorrect"].append(new_d)
    
    # last step data
    # make the last step data also balanced
    last_step_correct_num = len(last_step_labels_dict["last_step_correct"])
    last_step_incorrect_num = len(last_step_labels_dict["last_step_incorrect"])

    if round((last_step_correct_num / last_step_correct_ratio) * 
            (1 - last_step_correct_ratio)) <= last_step_incorrect_num:
        # incorrect_num is large
        last_step_target_data_size = round(
            last_step_correct_num / last_step_correct_ratio)
    else:
        # correct_num is large
        last_step_target_data_size = round(
            last_step_incorrect_num / (1 - last_step_correct_ratio))

    last_step_correct_data = random.Random(68).sample(
        last_step_labels_dict["last_step_correct"],
        round(last_step_target_data_size * last_step_correct_ratio)
    )
    last_step_incorrect_data = random.Random(68).sample(
        last_step_labels_dict["last_step_incorrect"],
        round(last_step_target_data_size * (1 - last_step_correct_ratio))
    )
    
    # intermediate data
    correct_num = len(last_step_labels_dict["correct"])
    incorrect_num = len(last_step_labels_dict["incorrect"])

    if round((correct_num / last_step_correct_ratio) * 
            (1 - last_step_correct_ratio)) <= incorrect_num:
        # incorrect_num is large
        target_data_size = round(correct_num / last_step_correct_ratio)
    else:
        # correct_num is large
        target_data_size = round(incorrect_num / (1 - last_step_correct_ratio))
    
    correct_data = random.Random(68).sample(
        last_step_labels_dict["correct"],
        round(target_data_size * last_step_correct_ratio)
    )
    incorrect_data = random.Random(68).sample(
        last_step_labels_dict["incorrect"],
        round(target_data_size * (1 - last_step_correct_ratio))
    )

    # merge
    all_correct_data = last_step_correct_data + correct_data
    all_incorrect_data = last_step_incorrect_data + incorrect_data

    output = all_correct_data + all_incorrect_data
    output = random.Random(68).sample(output, len(output))
    
    return output


# https://github.com/EleutherAI/lm-evaluation-harness/
# blob/main/lm_eval/tasks/fld/fld_default.yaml
fld_text_problem_definition = "Based on the provided facts ($context$), " \
    "either prove or disprove the hypothesis or state that it is unknown."


isabelle_problem_definition = "Generate a proof for the following theorem " \
    "in the Isabelle proof assistant format."


def get_verification_data(
        error_labels_instance: dict,
        base_dataset_name: str, model_name: BASE_MODEL,
        conversation_type: Literal["single_turn", "multi_turn"] = "single_turn"
    ) -> PrmDatasetInstance:
    """ Get the verification data for the given error labels instance. """
    
    if base_dataset_name == "fldx2_symbol":
        from src.dataset_creation.base_dataset_specific.fol.prompts.\
            initial_answers_prompt import fol_symbol_problem_definition
        problem_definition = fol_symbol_problem_definition
    elif base_dataset_name == "fldx2_text":
        problem_definition = fld_text_problem_definition
    elif base_dataset_name == "isabelle_all":
        problem_definition = isabelle_problem_definition
    elif base_dataset_name == "prm800k":
        problem_definition = ""
    else:
        raise ValueError(f"Unknown base dataset name: {base_dataset_name}")
    
    problem = ""
    if len(problem_definition) > 0:
        problem += problem_definition + "\n\n"
    problem += error_labels_instance["problem"]
    
    # add user message
    if conversation_type == "single_turn":
        # include everything in single response
        # this is an old version
        prompt = get_verification_prompt_for_single_turn_data(
            data_id=error_labels_instance["id"],
            problem=problem,
            solution_steps=error_labels_instance["proof_steps"]
        )
        
        verification_reference = \
            get_verification_reference_for_single_turn_data(
                explanations_list=error_labels_instance["cot_steps"],
                error_labels=error_labels_instance["proof_step_correctness"]
            )
        
        # this part is common to all base datasets
        conversation = [
            get_user_message(prompt),
            get_assistant_message(verification_reference, model_name=model_name)
        ]
    elif conversation_type == "multi_turn":
        # this is a new version
        model_role_name = "model" if "gemma" in model_name else "assistant"
        
        # for reference
        conversation = get_verification_reference_for_multi_turn_data(
            problem=problem,
            solution_steps=error_labels_instance["proof_steps"],
            reference_error_labels=error_labels_instance["proof_step_correctness"],
            model_role_name=model_role_name,
        )
        
        # for prediction
        conversation_for_prediction = get_verification_reference_for_multi_turn_data(
            problem=problem,
            solution_steps=error_labels_instance["proof_steps"],
            reference_error_labels=None,  # for prediction
            model_role_name=model_role_name,
        )
    else:
        raise ValueError(f"Unknown conversation type: {conversation_type}")
    
    instance = {
        "id": error_labels_instance["id"],
        "problem": problem,
        "solution_steps": error_labels_instance["proof_steps"],
        "error_labels": error_labels_instance["proof_step_correctness"],
        "problem_witout_definition": error_labels_instance["problem"],
        "messages": conversation,
        "base_dataset": base_dataset_name,
    }
    
    if conversation_type == "multi_turn":
        instance["messages_for_prediction"] = conversation_for_prediction
    
    # add dataset specific items (e.g., metadata)
    if base_dataset_name in ["fldx2_symbol", "fldx2_text"]:
        from src.dataset_creation.base_dataset_specific.fol.utils.\
            fld_final_dataset import get_fld_specific_items_in_final_dataset
        get_dataset_specific_items = get_fld_specific_items_in_final_dataset
    elif base_dataset_name in ["isabelle_all", "prm800k"]:
        get_dataset_specific_items = lambda x: {}
    else:
        raise ValueError(f"Unknown base dataset name: {base_dataset_name}")
    
    base_dataset_specific_items = get_dataset_specific_items(
        error_labels_instance
    )
    instance.update(base_dataset_specific_items)
    
    return instance


def sample_dataset(
        dataset: list[dict], target_data_size: int,
        instance_correct_ratio: float | None = None,
        last_step_correct_ratio: float | None = None,
        seed: int=68
    ) -> list[dict]:

    # check input
    if instance_correct_ratio is None and \
            last_step_correct_ratio is None:
        raise ValueError(
            "Either instance_correct_ratio or last_step_correct_ratio must be "
            "provided."
        )
    if instance_correct_ratio is not None and \
            last_step_correct_ratio is not None:
        raise ValueError(
            "Either instance_correct_ratio or last_step_correct_ratio must be "
            "provided, not both."
        )

    # categorize the dataset
    if instance_correct_ratio is not None:
        correct_data = [
            d for d in dataset if all(d["error_labels"])]
        wrong_data = [
            d for d in dataset if not all(d["error_labels"])]
        
        correct_num = round(target_data_size * instance_correct_ratio)
        wrong_num = round(target_data_size * (1 - instance_correct_ratio))
    elif last_step_correct_ratio is not None:
        correct_data = [
            d for d in dataset if d["error_labels"][-1]]
        wrong_data = [
            d for d in dataset if not d["error_labels"][-1]]
        
        correct_num = round(target_data_size * last_step_correct_ratio)
        wrong_num = round(target_data_size * (1 - last_step_correct_ratio))
    else:
        raise ValueError(
            "Either instance_correct_ratio or last_step_correct_ratio must be "
            "provided."
        )

    # select correct and wrong instances
    selected_correct_instances = random.Random(seed).sample(
        correct_data, round(correct_num)
    )
    selected_wrong_instances = random.Random(seed).sample(
        wrong_data, round(wrong_num)
    )

    # merge and shuffle
    output = selected_correct_instances + selected_wrong_instances
    output = random.Random(seed).sample(output, len(output))

    return output


def get_final_dataset_name(dataset_name: str, data_type: str) -> str:
    """ Get the final dataset name based on the dataset name and data type. """

    if data_type == "no_cot":
        return dataset_name
    elif data_type == "multi_turn":
        return f"{dataset_name}_multi_turn"
    elif data_type == "multi_turn_balanced_last_step":
        return f"{dataset_name}_multi_turn_balanced_last_step"
    else:
        assert data_type == "with_cot"
        return f"{dataset_name}_with_cot"


data_types_list = ["no_cot", "with_cot", "multi_turn",
                      "multi_turn_balanced_last_step"]

def main():
    args = GenerateVerificationDatasetTap().parse_args()
    
    if args.dataset_name in ["fldx2_symbol", "fldx2_text"]:
        target_data_size_dict = {"train": 60000, "validation": 360, "test": 360}
        model_names_list = ["ground_truth", args.model_name]
    elif args.dataset_name in ["isabelle_all"]:
        target_data_size_dict = {"train": None, "validation": 360, "test": 360}
        model_names_list = [args.model_name]
    elif args.dataset_name in ["prm800k"]:
        target_data_size_dict = {}
        model_names_list = ["ground_truth"]
    else:
        raise ValueError(f"Unknown base dataset name: {args.dataset_name}")
    
    
    for data_type in data_types_list:
        cot_type = "with_cot" if data_type == "with_cot" else "no_cot"
        conversation_type = "multi_turn" if "multi_turn" in data_type \
            else "single_turn"

        if data_type == "multi_turn_balanced_last_step":
            # this type is only for training
            selected_splits = ["train"]
        else:
            selected_splits = splits_list
        
        for split in selected_splits:
            print(f"Gathering data for {split} split")
            
            # ground_truth is used to add correct cases
            error_labels_dict: dict[str, list[dict]] = {}
            for model_name in model_names_list:
                generation_seed = 1 if model_name == "ground_truth" \
                    else args.generation_seed
                
                # error labels
                error_labels_path = get_error_labels_path(
                    dataset_name=args.dataset_name, model_name=model_name,
                    split=split, seed=generation_seed
                )
                if len(args.suffix) > 0:
                    error_labels_path = error_labels_path.with_suffix(
                        f".{cot_type}.{args.suffix}.jsonl"
                    )
                with open(error_labels_path, "r") as f:
                    error_labels = [json.loads(line) for line in f]

                error_labels_dict[model_name] = error_labels
            
            # make balanced dataset
            if args.dataset_name in ["fldx2_symbol", "fldx2_text"]:
                balanced_dataset = get_balanced_error_label_dataset_fldx2(
                    llm_generated_data=error_labels_dict[args.model_name],
                    ground_truth_data=error_labels_dict["ground_truth"],
                    base_dataset_name=args.dataset_name,
                    target_data_size=target_data_size_dict[split],
                    instance_correct_ratio=args.instance_correct_ratio
                )
            elif args.dataset_name in ["isabelle_all"]:
                if split in target_data_size_dict.keys():
                    target_data_size = target_data_size_dict[split]
                else:
                    target_data_size = None
                
                print("target_data_size", target_data_size)

                balanced_dataset = get_balanced_error_label_dataset(
                    error_labels_dict[args.model_name],
                    instance_correct_ratio=args.instance_correct_ratio,
                    target_data_size=target_data_size,
                )
            elif args.dataset_name in ["prm800k"]:
                balanced_dataset = error_labels_dict["ground_truth"]
            else:
                raise ValueError(f"Unknown base dataset name: {args.dataset_name}")

            # last step dataset
            # this type includes datas with balanced last step
            # during training, only the last step is used
            # e.g., in llama-factory, set mask_history=True
            if data_type == "multi_turn_balanced_last_step":
                balanced_dataset = get_last_step_balanced_dataset(
                    balanced_dataset
                )

            # generate final dataset
            output: list[dict] = []
            for d in balanced_dataset:
                output.append(
                    get_verification_data(
                        d, base_dataset_name=args.dataset_name,
                        model_name=args.model_name,
                        conversation_type=conversation_type,
                    )
                )
            
            # save final dataset
            final_dataset_name = get_final_dataset_name(
                dataset_name=args.dataset_name, data_type=data_type
            )
            
            verification_dataset_path = get_fover_dataset_path(
                dataset_name=final_dataset_name, model_name=args.model_name,
                split=split
            )
            verification_dataset_path.parent.mkdir(parents=True, exist_ok=True)
            with open(verification_dataset_path, "w") as f:
                for line in output:
                    f.write(json.dumps(line) + "\n")
            save_md5_hash(verification_dataset_path)
            print(f"Saved {split} split to {verification_dataset_path}")

            # get statistics
            stats = get_prm_dataset_stats(output)
            stats_path = verification_dataset_path.with_suffix(".stats.json")
            with open(stats_path, "w") as f:
                json.dump(stats, f, indent=4)
            
            # multiple size versions
            for size_name, size_num in sampled_dataset_size_list:
                if len(output) > size_num:
                    if data_type == "multi_turn_balanced_last_step":
                        sampled_dataset = sample_dataset(
                            output, target_data_size=size_num,
                            last_step_correct_ratio=0.5,
                        )
                    else:
                        sampled_dataset = sample_dataset(
                            output, target_data_size=size_num,
                            instance_correct_ratio=args.instance_correct_ratio
                        )
                else:
                    sampled_dataset = output

                sampled_dataset_path = get_fover_dataset_path(
                    dataset_name=f"{final_dataset_name}_{size_name}",
                    model_name=args.model_name,
                    split=split
                )
                sampled_dataset_path.parent.mkdir(
                    parents=True, exist_ok=True)

                with open(sampled_dataset_path, "w") as f:
                    for line in sampled_dataset:
                        f.write(json.dumps(line) + "\n")
                save_md5_hash(sampled_dataset_path)
                print(
                    f"Saved {split} split to {sampled_dataset_path}"
                )
                
                # get statistics
                stats = get_prm_dataset_stats(sampled_dataset)
                stats_path = sampled_dataset_path.with_suffix(
                    ".stats.json")
                with open(stats_path, "w") as f:
                    json.dump(stats, f, indent=4)


if __name__ == "__main__":
    main()
