from datasets import concatenate_datasets, load_dataset, Dataset
import json

from data.utils.deepcoder import parse_description as parse_description_deepcoder
from data.utils.livecodebench import LCB_TRAIN_CUTOFF, load_livecodebench
from data.utils.taco import load_taco
from data.utils.primeintellect import load_primeintellect
from data.utils.codeforces import load_codeforces
from data.utils.code_contests import load_code_contests
from data.utils.leetcode import load_leetcode

MIN_TESTS_DEFAULT = 5
MIN_TESTS = {
    "deepcoder": MIN_TESTS_DEFAULT,
    "livecodebench": MIN_TESTS_DEFAULT,
    "taco": MIN_TESTS_DEFAULT,
    "primeintellect": MIN_TESTS_DEFAULT,
    "codeforces": MIN_TESTS_DEFAULT,
    "code_contests": MIN_TESTS_DEFAULT,
    "kodcode": MIN_TESTS_DEFAULT,
    "leetcode": MIN_TESTS_DEFAULT,
}


def load_deepcoder() -> Dataset:
    assert LCB_TRAIN_CUTOFF == "2024-08-01T00:00:00"

    subs = ["taco", "primeintellect", "lcbv5"]
    deep_list = []
    for cfg in subs:
        ds = load_dataset(
            "agentica-org/DeepCoder-Preview-Dataset",
            cfg,
            split="train"
        )
        def cast_deepcoder(ex):
            out = dict(ex)
            out["description"] = parse_description_deepcoder(out["problem"], subset=cfg)
            out["answer"] = ""
            out["kind"] = "code"
            out["dataset"] = cfg if cfg != "lcbv5" else "livecodebench"

            def parse_testtype(t):
                if "type" in t:
                    if t["type"] == "stdin_stdout":
                        return "stdin"
                    elif t["type"] == "function_call":
                        return "functional"
                elif "testtype" in t:
                    return t["testtype"]
                return "stdin"

            # Standardize test cases
            tests = json.loads(out["tests"])
            if isinstance(tests, list):
                out_tests = [{"input": t["input"], "output": t["output"], "testtype": parse_testtype(t), "fn_name": t.get("fn_name", "")} for t in tests]
            elif isinstance(tests, dict):
                out_tests = [{"input": tests["inputs"][i], "output": tests["outputs"][i], "testtype": "functional" if "fn_name" in tests else "stdin", "fn_name": tests.get("fn_name", "")} for i in range(len(tests["inputs"]))]
            else:
                raise ValueError(f"Unexpected type for tests: {type(tests)}")
            out["tests"] = json.dumps(out_tests, ensure_ascii=False)

            return out
        ds = ds.map(cast_deepcoder, desc=f"Preprocess DeepCoder/{cfg}")
        deep_list.append(ds)
    return concatenate_datasets(deep_list)


def load_code_dataset(dataset_name: str) -> Dataset:
    # if dataset_name == "deepcoder":
    #     ds = load_deepcoder()
    if dataset_name == "livecodebench":
        ds = load_livecodebench(dataset_split="train")
    elif dataset_name == "taco":  # contains some of codeforces, leetcode
        ds = load_taco()
    elif dataset_name == "primeintellect":  # contains some of apps, code_contests
        ds = load_primeintellect()
    elif dataset_name == "codeforces":
        ds = load_codeforces(dataset_split="train")
    elif dataset_name == "code_contests":
        ds = load_code_contests()
    # elif dataset_name == "kodcode":
    #     ds = load_kodcode()
    # elif dataset_name == "apps":
    #     ds = load_apps()
    elif dataset_name == "leetcode":
        ds = load_leetcode()
    else:
        raise ValueError(f"Invalid dataset name: {dataset_name}")

    def _filter(ex):
        if len(ex["tests"]) < MIN_TESTS[dataset_name]:
            return False
        if ex["description"] is None or ex["description"].strip() == "":
            return False
        if len(ex["description"]) == len(ex["problem"]):
            return False
        return True

    filtered_ds = ds.filter(_filter)
    print(f"Loaded {len(filtered_ds)} rows from {dataset_name} out of {len(ds)}")
    return filtered_ds
