from pathlib import Path
import os
from multiprocessing import Pool
import threading
from tqdm import tqdm
import yaml
import time
import concurrent.futures
import re
import logging
import json
from importlib import import_module

from omegaconf import OmegaConf
import hydra

from inference_rlhf.code.code_contest_utils import execution_server_client
from inference_rlhf.code.helpers.utils import set_seeds, rget_json_files_from_dir
from inference_rlhf.code.tasks.code_contests import get_test_cases

MAX_CONCURRENT_REQUESTS = 512
semaphore = threading.Semaphore(value=MAX_CONCURRENT_REQUESTS)
NUM_RETRIES = 3
RETRY_BACKOFF = 3

log = logging.getLogger(__name__)

def get_timeout(item):
    timeout_seconds = 0
    if item["time_limit"] is not None:
        timeout_seconds += item["time_limit"]["seconds"]
        timeout_seconds += item["time_limit"]["nanos"] / 1_000_000_000

    if timeout_seconds == 0:
        timeout_seconds = None
    return timeout_seconds

def is_valid_python(snippet):
    try:
        compile(snippet, "<string>", "exec")
        return True
    except SyntaxError:
        return False
           
def extract_first_code(output_string: str):
    trimmed = output_string.strip()

    # Extracting the first occurrence of content between backticks
    code_match = re.search(r"```(.*?)```", trimmed, re.DOTALL)

    if code_match:
        # Strip leading and trailing whitespace from the extracted code
        code = code_match.group(1).strip()

        # sometimes the block of code is ```python ... ``` instead of ``` ... ```
        # in this case strip the python out

        if code.startswith("python"):
            code = code[len("python") :].strip()

        return code

    if is_valid_python(trimmed):
        return trimmed

    return None

def solution_is_correct(
    code: str | None,
    test_cases: dict,
    timeout: int,
    client: execution_server_client.ExecutionServerClient,
):
    if code is None:
        return False

    if extract_first_code(code) is None:
        return False

    assert len(test_cases["input"]) == len(test_cases["output"])

    input_expected_output_pairs = list(
        zip(test_cases["input"], test_cases["output"])
    )

    with semaphore:
        for i in range(NUM_RETRIES):
            try:
                is_correct = client.execute_code(
                    extract_first_code(code),
                    input_expected_output_pairs,
                    timeout=timeout + 10,  # buffer for 10
                    memory_limit_bytes=2_000_000_000_000,  # double max limit
                )
                break
            except:
                if i == NUM_RETRIES - 1:
                    # raise
                    return False
                # print(f"Retrying in {RETRY_BACKOFF**i} seconds")
                time.sleep(RETRY_BACKOFF**i)

    return is_correct


def grade_problems(
    solutions_data: dict,
    test_cases: dict,
    timeouts: dict,
    output_file: str,
    client: execution_server_client.ExecutionServerClient,
):
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=MAX_CONCURRENT_REQUESTS // 2
    ) as executor:
        is_corrects_futures = [
            executor.submit(
                solution_is_correct,
                code=sample["solution"],
                test_cases=test_cases[sample["prompt_idx"]],
                timeout=timeouts[sample["prompt_idx"]],
                client=client,
            )
            for sample in solutions_data
        ]

        is_corrects = []
        for i, future in enumerate(is_corrects_futures):
            print(i)
            if i % 100 == 0:
                print(f"Progress being made... {i}")
            is_corrects.append(future.result())

    for is_correct, sample in zip(is_corrects, solutions_data):
        sample["is_correct"] = is_correct

    for sample in solutions_data:
        del sample["solution"]

    with open(output_file, 'w') as f:
        json.dump(solutions_data, f)


def load_solutions_from_json_file(input_path):
    """Loads results data from a json file."""
    with open(input_path, 'r') as f:
        data = json.load(f)
    for sample in data:
        sample["solution"] = extract_first_code(sample["response"])
    return data, input_path.replace('.json', '--CHECKED.json')


@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):
    # print out the config nicely
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    # Get load path
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    task = cfg.task.name
    load_path = os.path.join(load_path, 'data', task, policy, 'generations')
    log.info(f"Loading generations from root @ {load_path}")
    
    # Get all generation files
    generation_files = rget_json_files_from_dir(load_path)
    to_eval_files = [f for f in generation_files if '--CHECKED' not in f and f.replace('.json', '--CHECKED.json') not in generation_files][:1]
    already_evaled = [f for f in generation_files if "--CHECKED" in f]
    
    print(
        f"Num to eval: {len(to_eval_files)}",
        f"Num already evaled: {len(already_evaled)}",
    )

    with Pool(cfg.evaluation.code_contests.num_workers) as process_pool:
        solutions_with_output_files = process_pool.map(
            load_solutions_from_json_file,
            [file for file in tqdm(to_eval_files, desc="loading json files")],
        )
        solutions_data = [solution_data for solution_data, _ in solutions_with_output_files]
        output_files = [output_file for _, output_file in solutions_with_output_files]

    print("Done loading json files.")

    print(f'\nLoading data for {cfg.task.name} with {cfg.shots} shots')
    data_module = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}",  package='inference_rlhf.code')
    dl = data_module.DataLoader(cfg)

    # add test cases to solutions_data
    test_cases = dict()
    for solution_data in solutions_data:
        test_cases[solution_data[0]["prompt_idx"]] = get_test_cases(dl.non_image_data[solution_data[0]["prompt_idx"]])

    # retrieve timeouts
    timeouts = dict()
    for solution_data in solutions_data:
        timeouts[solution_data[0]["prompt_idx"]] = get_timeout(dl.non_image_data[solution_data[0]["prompt_idx"]])

    # multiprocessing pool is used to load data
    with execution_server_client.ExecutionServerClient() as client:
        # threads are used to run code in parallel
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=cfg.evaluation.code_contests.num_workers
        ) as executor:
            futures = [
                executor.submit(
                    grade_problems,
                    solutions_data=solution_data,
                    test_cases=test_cases,
                    timeouts=timeouts,
                    output_file=output_file,
                    client=client,
                )
                for solution_data, output_file in zip(solutions_data, output_files)
            ]

            for future in tqdm(futures, desc="Running tests on problem"):
                future.result()


if __name__ == "__main__":
    main()
