from datasets import concatenate_datasets, load_dataset
import os
import sys
import json

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
sys.path.insert(0, project_root)

os.chdir(project_root)

from rllm.data.dataset import DatasetRegistry


def select_data_type(data_type: str, do_pilot: bool):
    if do_pilot:
        if "A" in data_type:
            target_dict = {
                "sft": ["train_0", "train_1", "train_2", "train_3"],
                "rl": ["train_4"],
            }
        elif "B" in data_type:
            target_dict = {
                "sft": ["train_0", "train_1", "train_2"],
                "rl": ["train_3", "train_4"],
            }
        elif "C" in data_type:
            target_dict = {
                "sft": ["train_0", "train_1"],
                "rl": ["train_2", "train_3", "train_4"],
            }
        elif "D" in data_type:
            target_dict = {
                "sft": ["train_0", "train_1", "train_2", "train_3"],
                "rl": ["train_4"],
            }
        elif "E" in data_type:
            target_dict = {
                "sft": ["train_0", "train_1", "train_2", "train_3", "train_4"],
                "rl": [],
            }
        else:
            raise ValueError(f"Invalid data type: {data_type}")
        
        if "easy" in data_type:
            possible_data_prefix = ["4x4"]
        elif "normal" in data_type:
            possible_data_prefix = ["4x4", "9x9"]
        elif "all" in data_type:
            possible_data_prefix = ["4x4", "9x9", "ctc"]
        else:
            raise ValueError(f"Invalid data type: {data_type}")
    else:
        target_dict = {
            "sft": [],
            "rl": [],
        }
        if "A" in data_type:
            if "short" in data_type:
                target_dict["sft"].extend(["sft_1_gpt"])
                # target_dict["sft"].extend(["sft_1_filtered"])
            elif "concise" in data_type:
                target_dict["sft"].extend(["sft_1_concise"])
            else:
                target_dict["sft"].extend(["sft_1"])
            target_dict["rl"].extend(["rl_1"])
        elif "B" in data_type:
            if "short" in data_type:
                target_dict["sft"].extend(["sft_1_gpt", "sft_2_gpt"])
                # target_dict["sft"].extend(["sft_1_filtered", "sft_2_filtered"])
            elif "concise" in data_type:
                target_dict["sft"].extend(["sft_1_concise", "sft_2_concise"])
            else:
                target_dict["sft"].extend(["sft_1", "sft_2"])
            target_dict["rl"].extend(["rl_1", "rl_2"])
        elif "C" in data_type:
            if "short" in data_type:
                target_dict["sft"].extend(["sft_1_gpt", "sft_2_gpt", "sft_3_gpt"])
                # target_dict["sft"].extend(["sft_1_filtered", "sft_2_filtered", "sft_3_filtered"])
            elif "concise" in data_type:
                target_dict["sft"].extend(["sft_1_concise", "sft_2_concise", "sft_3_concise"])
            else:
                target_dict["sft"].extend(["sft_1", "sft_2", "sft_3"])
            target_dict["rl"].extend(["rl_1", "rl_2", "rl_3"])
        else:
            raise ValueError(f"Invalid data type: {data_type}")
        
        if "easy" in data_type:
            possible_data_prefix = ["4x4", "9x9_easy", "ctc_easy"]
        elif "medium" in data_type:
            possible_data_prefix = ["9x9_medium", "ctc_medium"]
        elif "hard" in data_type:
            # tmp: no sft data
            target_dict["sft"] = []
            possible_data_prefix = ["9x9_hard", "ctc_hard"]
        elif "normal" in data_type:
            possible_data_prefix = ["4x4", "9x9_easy", "ctc_easy", "9x9_medium", "ctc_medium"]
        elif "medium_hard" in data_type:
            possible_data_prefix = ["9x9_medium", "ctc_medium", "9x9_hard", "ctc_hard"]
        elif "all" in data_type:
            possible_data_prefix = ["4x4", "9x9_easy", "ctc_easy", "9x9_medium", "ctc_medium", "9x9_hard", "ctc_hard"]
        else:
            raise ValueError(f"Invalid data type: {data_type}")
    
    return target_dict, possible_data_prefix

def possible_name(target_dict: dict, possible_data_prefix: str):
    sft_list = []
    rl_list = []
    for k in ["sft", "rl"]:
        for data in target_dict[k]:
            for prefix in possible_data_prefix:
                if k == "sft":
                    sft_list.append(f"sudoku_{prefix}_{data}")
                else:
                    rl_list.append(f"sudoku_{prefix}_{data}")
    return sft_list, rl_list
    

def collect_datasets(data_type: str, do_pilot: bool, exp_type: str, valid_type: str, blending_ratio: float):
    if exp_type == "benchmark":
        sft_datasets = None
        rl_datasets = None
        if data_type == "sudoku":
            rl_valid_dataset = load_dataset(YOUR_PATH)
        elif data_type == "standard_sudoku":
            rl_valid_dataset = load_dataset(YOUR_PATH)
            rl_valid_dataset = rl_valid_dataset.select(range(450))
        elif data_type == "standard_easy_sudoku":
            rl_valid_dataset = load_dataset(YOUR_PATH)
            rl_valid_dataset = rl_valid_dataset.select(range(150))
        elif data_type == "sudoku_validation":
            rl_valid_dataset = load_dataset(YOUR_PATH)
        else:
            raise ValueError(f"Invalid exp type (Benchmark): {exp_type}")
    else:
        target_dict, possible_data_prefix = select_data_type(data_type, do_pilot)
        sft_list, rl_list = possible_name(target_dict, possible_data_prefix)
        sft_datasets = []
        rl_datasets = []
        if exp_type == "sft":
            print(f"\n\nLoading sft datasets: {sft_list}")
            for data_name in sft_list:
                dataset_name = YOUR_PATH
                print(f"Loading dataset: {dataset_name}")
                try:
                    sft_dataset = load_dataset(dataset_name, split="train")
                    sft_datasets.append(sft_dataset)
                except Exception as e:
                    print(f"Error loading dataset: {e}")
            sft_datasets = concatenate_datasets(sft_datasets)
            rl_datasets = None
            rl_valid_dataset = None
        elif exp_type == "rl":
            print(f"\n\nLoading rl datasets: {rl_list}")
            rl_difficulty_dict = {
                "easy": [],
                "medium": [],
                "hard": [],
            }
            for data_name in rl_list:
                dataset_name = YOUR_PATH
                print(f"Loading dataset: {dataset_name}")
                if "medium" in data_name:
                    rl_difficulty_dict["medium"].append(load_dataset(dataset_name, split="train"))
                elif "hard" in data_name:
                    rl_difficulty_dict["hard"].append(load_dataset(dataset_name, split="train"))
                else:
                    rl_difficulty_dict["easy"].append(load_dataset(dataset_name, split="train"))
            
            if blending_ratio > 0.0:
                if "normal" in data_type:
                    main_dataset = concatenate_datasets(rl_difficulty_dict["easy"])
                    blending_dataset = concatenate_datasets(rl_difficulty_dict["medium"])
                    # suffle
                    main_dataset = main_dataset.shuffle(seed=42)
                    blending_dataset = blending_dataset.shuffle(seed=42)
                elif "medium_hard" in data_type:
                    main_dataset = concatenate_datasets(rl_difficulty_dict["medium"])
                    blending_dataset = concatenate_datasets(rl_difficulty_dict["hard"])
                main_num = int(main_dataset.num_rows * blending_ratio)
                blending_num = main_dataset.num_rows - main_num
                rl_main_dataset = main_dataset.select(range(main_num))
                rl_blending_dataset = blending_dataset.select(range(blending_num))
                rl_datasets.extend([rl_main_dataset, rl_blending_dataset])
            else:
                for difficulty in rl_difficulty_dict.keys():
                    rl_datasets.extend(rl_difficulty_dict[difficulty])
            rl_datasets = concatenate_datasets(rl_datasets)

            rl_valid_dataset = load_dataset(YOUR_PATH)
            if valid_type == "A":
                rl_valid_dataset = rl_valid_dataset.select(range(50, 150))
            elif valid_type == "B":
                easy_valid_dataset = rl_valid_dataset.select(range(50, 100))
                medium_valid_dataset = rl_valid_dataset.select(range(200, 250))
                rl_valid_dataset = concatenate_datasets([easy_valid_dataset, medium_valid_dataset])
            elif valid_type == "C":
                easy_valid_dataset = rl_valid_dataset.select(range(50))
                medium_valid_dataset = rl_valid_dataset.select(range(150, 200))
                hard_valid_dataset = rl_valid_dataset.select(range(300, 350))
                rl_valid_dataset = concatenate_datasets([easy_valid_dataset, medium_valid_dataset, hard_valid_dataset])
            sft_datasets = None
        else:
            raise ValueError(f"Invalid exp type: {exp_type}")

    return sft_datasets, rl_datasets, rl_valid_dataset


def prepare_sudoku_data(data_type, do_pilot, exp_type, valid_type, blending_ratio, add_info):
    sft_dataset, rl_dataset, rl_valid_dataset = collect_datasets(data_type, do_pilot, exp_type, valid_type, blending_ratio)

    def preprocess_fn(example):
        return {
            **example,
            "task_type": "sudoku",
            "data_source": "sudoku_train",
        }

    def process_fn_rl(example):
        if isinstance(example, dict):
            row_dict = example
        else:
            row_dict = dict(example)
        # add task_type
        row_dict["task_type"] = "sudoku"
        if add_info is not None:
            row_dict["add_info"] = add_info

        return {
            "data_source": "sudoku_rl",
            "prompt": [
                {
                    "role": "user",
                    "content": "",  # placeholder since there is no real prompt is needed to environment based trajectory collection
                }
            ],
            "ability": "sudoku",
            "reward_model": {"style": "rule", "ground_truth": ""},
            "extra_info": row_dict,
        }
    
    if sft_dataset is not None:
        sft_dataset = sft_dataset.map(preprocess_fn)
        sft_dataset = DatasetRegistry.register_dataset("sudoku_train", sft_dataset, "sft")
    else:
        sft_dataset = []
    
    if rl_dataset is not None:
        rl_dataset = rl_dataset.map(process_fn_rl)
    else:
        rl_dataset = []
    if rl_valid_dataset is not None:
        rl_valid_dataset = rl_valid_dataset.map(process_fn_rl)
    else:
        rl_valid_dataset = []

    print(f"Current working directory: {os.getcwd()}")
    print(f"Registering datasets in: {os.path.abspath('rllm/data/datasets')}")
    
    rl_dataset = DatasetRegistry.register_dataset("sudoku_train", rl_dataset, "rl_train")
    rl_valid_dataset = DatasetRegistry.register_dataset("sudoku_train", rl_valid_dataset, "rl_valid")
    
    print(f"Datasets registered successfully!")
    print(f"Available datasets: {DatasetRegistry.get_dataset_names()}")
    print(f"Sudoku bench splits: {DatasetRegistry.get_dataset_splits('sudoku_train')}")
    
    return sft_dataset, rl_dataset, rl_valid_dataset


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_type", type=str, default="A_normal")
    parser.add_argument("--do_pilot", action="store_true")
    parser.add_argument("--exp_type", type=str, default="rl")
    parser.add_argument("--valid_type", type=str, default="B")
    parser.add_argument("--blending_ratio", type=float, default=0.0)
    parser.add_argument("--add_info", type=str, default=None)
    args = parser.parse_args()

    sft_dataset, rl_dataset, rl_valid_dataset = prepare_sudoku_data(args.data_type, args.do_pilot, args.exp_type, args.valid_type, args.blending_ratio, args.add_info)

    print(len(rl_dataset))
