
import argparse
from json import load
import logging
import os

from src.postprocess import PostProcessor
from src.execution import evaluate_with_test_code, evaluate_with_test_cases
from src.io_utils import Tools
from src.agreement import DataManager, DualAgreement
from src.evaluation import pass_at_K, get_result_of_sorted_solutions

from src.evaluation import get_SolScorePassed_pairs


import tempfile
import os
from cruise.utilities.hdfs_io import hcopy
def load_hdfs_path(ckpt_path):
    if ckpt_path.startswith("hdfs"):
        tmp_dir = os.path.join(
            tempfile.gettempdir(), os.path.basename(ckpt_path)
        )
        local_dir = tmp_dir
        hcopy(ckpt_path, local_dir)
    else:
        local_dir = ckpt_path
    return local_dir


logging.basicConfig(
    format="SystemLog: [%(asctime)s][%(name)s][%(levelname)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_path_for_solution", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_solution", type=str, help="model output file in .jsonl format")
    parser.add_argument("--source_path_for_test", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_test", type=str, help="model output file in .jsonl format")
    parser.add_argument("--cache_dir", type=str, help="the directory to store the cache files")
    parser.add_argument("--timeout", type=float, default=0.1, help="how many seconds to wait during execution for each test case")
    parser.add_argument("--test_case_limit", type=int, default=5, help="first n test cases per sample")

    args = parser.parse_args()

    args.source_path_for_solution = load_hdfs_path(args.source_path_for_solution)
    args.predict_path_for_solution = load_hdfs_path(args.predict_path_for_solution)
    args.source_path_for_test = load_hdfs_path(args.source_path_for_test)
    args.predict_path_for_test = load_hdfs_path(args.predict_path_for_test)

    handled_solutions, task_count = PostProcessor.map_task_id_for_solution(args.predict_path_for_solution, args.source_path_for_solution)
    handled_test_cases = PostProcessor.map_task_id_for_test_case(args.predict_path_for_test, args.source_path_for_test)

    ground_truth_exec_result = evaluate_with_test_code(handled_solutions, timeout=args.timeout)
    dual_exec_result = evaluate_with_test_cases(handled_solutions, handled_test_cases, timeout=args.timeout, limit=args.test_case_limit)

    Tools.dump_pickle(os.path.join(args.cache_dir, 'ground_truth_exec_result.pkl'), ground_truth_exec_result)
    Tools.dump_pickle(os.path.join(args.cache_dir, 'dual_exec_result.pkl'), dual_exec_result)

    data_manager = DataManager(dual_exec_result, handled_solutions, handled_test_cases, args.test_case_limit)
    set_consistency = DualAgreement(data_manager)
    ranked_result = set_consistency.get_sorted_solutions_without_iter()
    logger.info('pass rates of ranked solutions')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)
    logger.info('pass rates of random solutions')
    pass_at_K(ground_truth_exec_result)


def debug_from_pkl():
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_path_for_solution", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_solution", type=str, help="model output file in .jsonl format")
    parser.add_argument("--source_path_for_test", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_test", type=str, help="model output file in .jsonl format")
    parser.add_argument("--cache_dir", type=str, help="the directory to store the cache files")
    parser.add_argument("--timeout", type=float, default=0.1, help="how many seconds to wait during execution for each test case")
    parser.add_argument("--test_case_limit", type=int, default=5, help="first n test cases per sample")

    args = parser.parse_args()

    print(args)

    handled_solutions, task_count = PostProcessor.map_task_id_for_solution(args.predict_path_for_solution, args.source_path_for_solution)
    handled_test_cases = PostProcessor.map_task_id_for_test_case(args.predict_path_for_test, args.source_path_for_test)


    ground_truth_exec_result = Tools.load_pickle(os.path.join(args.cache_dir, 'ground_truth_exec_result.pkl'))
    dual_exec_result = Tools.load_pickle(os.path.join(args.cache_dir, 'dual_exec_result.pkl'))

    data_manager = DataManager(dual_exec_result, handled_solutions, handled_test_cases, args.test_case_limit)
    set_consistency = DualAgreement(data_manager)
    ranked_result = set_consistency.get_sorted_solutions_without_iter()
    logger.info('pass rates of ranked solutions')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    ranked_result = set_consistency.get_sorted_solutions_without_iter_remove_sqrt()
    logger.info('pass rates of ranked solutions without sqrt')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    ranked_result = set_consistency.get_sorted_solutions_with_iter_page_rank(T=0, beta=0.85)
    logger.info('pass rates of ranked solutions with iter page rank')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    logger.info('pass rates of random solutions')
    pass_at_K(ground_truth_exec_result)



def param_search_T_beta():
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_path_for_solution", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_solution", type=str, help="model output file in .jsonl format")
    parser.add_argument("--source_path_for_test", type=str, help="model input file in .jsonl format")
    parser.add_argument("--predict_path_for_test", type=str, help="model output file in .jsonl format")
    parser.add_argument("--cache_dir", type=str, help="the directory to store the cache files")
    parser.add_argument("--timeout", type=float, default=0.1, help="how many seconds to wait during execution for each test case")
    parser.add_argument("--test_case_limit", type=int, default=5, help="first n test cases per sample")

    args = parser.parse_args()

    print(args)

    handled_solutions, task_count = PostProcessor.map_task_id_for_solution(args.predict_path_for_solution, args.source_path_for_solution)
    handled_test_cases = PostProcessor.map_task_id_for_test_case(args.predict_path_for_test, args.source_path_for_test)

    ground_truth_exec_result = evaluate_with_test_code(handled_solutions, timeout=args.timeout)
    dual_exec_result = evaluate_with_test_cases(handled_solutions, handled_test_cases, timeout=args.timeout, limit=args.test_case_limit)

    Tools.dump_pickle(os.path.join(args.cache_dir, 'ground_truth_exec_result.pkl'), ground_truth_exec_result)
    Tools.dump_pickle(os.path.join(args.cache_dir, 'dual_exec_result.pkl'), dual_exec_result)
    ground_truth_exec_result = Tools.load_pickle(os.path.join(args.cache_dir, 'ground_truth_exec_result.pkl'))
    dual_exec_result = Tools.load_pickle(os.path.join(args.cache_dir, 'dual_exec_result.pkl'))

    data_manager = DataManager(dual_exec_result, handled_solutions, handled_test_cases, args.test_case_limit)
    set_consistency = DualAgreement(data_manager)
    ranked_result = set_consistency.get_sorted_solutions_without_iter()
    logger.info('pass rates of ranked solutions')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    ranked_result = set_consistency.get_sorted_solutions_without_iter_remove_sqrt()
    logger.info('pass rates of ranked solutions without sqrt')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    ranked_result = set_consistency.get_sorted_solutions_with_iter_page_rank(T=0, beta=0.85)
    logger.info('pass rates of ranked solutions with iter page rank')
    get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

    all_rst = []

    for each_T in range(50):
        for each_beta_n in range(20):
            each_beta = each_beta_n * 0.05
            ranked_result = set_consistency.get_sorted_solutions_with_iter_page_rank(T=each_T, beta=each_beta)
            this_rst = get_result_of_sorted_solutions(ground_truth_exec_result, ranked_result)

            this_rst['T'] = each_T
            this_rst['beta'] = each_beta

            all_rst.append(this_rst)

    Tools.write_jsonl(os.path.join(args.cache_dir, 'page_rank_T_beta_passrate.jsonl'), all_rst)

    logger.info('pass rates of random solutions')
    pass_at_K(ground_truth_exec_result)



if __name__ == '__main__':
    param_search_T_beta()

