# This files defines several useful tool functions, such as data loading, filtering, and preprocessing.

import os
import pickle
import zlib
import base64
import glob
import json
import logging
from tqdm import tqdm
from datasets import load_dataset, Dataset
from omegaconf import OmegaConf
from typing import List, Dict, Tuple, Any
import concurrent.futures

LOGGER = logging.getLogger(__name__)


def load_local_rationale(
    ROOTDIR: str, start_index: int, end_index: int, do_decode=True, verbose=False
) -> Dataset:
    # The original rational testcase dataset is too large, so we download a partial of its jsonl files and first and then load from local disk with this function.
    # Usage: load_local_rationale(ROOTDIR, start_index, end_index, do_decode=True)

    data_files = [
        os.path.join(ROOTDIR, "test_cases_{}.jsonl".format(i))
        for i in range(start_index, end_index)
    ]

    all_lines = []
    for data_file in tqdm(data_files, disable=not verbose):
        all_lines.extend(read_jsonl(data_file, do_decode=do_decode))

    return Dataset.from_list(all_lines)


def encode_testcases(testcases: List[Dict[str, str]]) -> str:
    """
    According to LiveCodeBench, private test cases should be encoded.
    """
    json_str = json.dumps(testcases)
    pickled_data = pickle.dumps(json_str)
    compressed_data = zlib.compress(pickled_data)
    encoded_testcases = base64.b64encode(compressed_data).decode("utf-8")
    return encoded_testcases


def decode_testcases(encoded_testcases: str) -> List[Dict[str, str]]:
    return json.loads(
        pickle.loads(
            zlib.decompress(base64.b64decode(encoded_testcases.encode("utf-8")))
        )
    )


def decode_sample(sample: Dict[str, Any]) -> Dict[str, Any]:
    # This function decodes the test_cases for a sample in the rationale_code_test_case dataset
    if sample["test_cases"] is not None:
        sample["test_cases"] = decode_testcases(sample["test_cases"])
    return sample


def read_jsonl(path: str, do_decode=True) -> List[Dict]:
    decode_fn = decode_sample if do_decode else lambda x: x
    if not os.path.exists(path):
        LOGGER.warning(f"File {path} does not exist.")
        return []
    else:
        with open(path, "r") as f:
            return [decode_fn(json.loads(line)) for line in f]


def read_jsonl_parallel(path, do_decode=True, workers=4):
    decode_fn = decode_sample if do_decode else lambda x: x
    if not os.path.exists(path):
        LOGGER.warning(f"File {path} does not exist.")
        return []
    else:
        with open(path, "r", encoding="utf-8") as f:
            lines = f.readlines()

        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
            return list(executor.map(lambda x: decode_fn(json.loads(x)), lines))
