
import argparse
from json import load
import logging
import os

from src.postprocess import PostProcessor
from src.execution import evaluate_with_test_code, evaluate_with_test_cases
from src.io_utils import Tools
from src.agreement import DataManager, DualAgreement
from src.evaluation import pass_at_K, get_result_of_sorted_solutions

from src.evaluation import get_SolScorePassed_pairs

import pandas as pd


import tempfile
import os
from cruise.utilities.hdfs_io import hcopy
def load_hdfs_path(ckpt_path):
    if ckpt_path.startswith("hdfs"):
        tmp_dir = os.path.join(
            tempfile.gettempdir(), os.path.basename(ckpt_path)
        )
        local_dir = tmp_dir
        hcopy(ckpt_path, local_dir)
    else:
        local_dir = ckpt_path
    return local_dir


logging.basicConfig(
    format="SystemLog: [%(asctime)s][%(name)s][%(levelname)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_parquet_dir", type=str, help="model input file in .jsonl format")
    parser.add_argument("--input_hdfs_list_txt", type=str, help="txt file, each line is a hdfs parquet path")
    parser.add_argument("--output_parquet_dir", type=str, help="model output file in .jsonl format")
    parser.add_argument("--cache_dir", type=str, help="the directory to store the cache files")
    parser.add_argument("--timeout", type=float, default=0.1, help="how many seconds to wait during execution for each test case")
    parser.add_argument("--test_case_limit", type=int, default=5, help="first n test cases per sample")

    args = parser.parse_args()

    print(args)

    def save_parquet(input_parquet_path, save_parquet_path, ranked_result, ranked_test, correct_test):
        import json
        import pyarrow as pa
        import pyarrow.parquet as pq
        import pandas as pd


        prompt = pd.read_parquet(input_parquet_path)
        prompt['ranked_code'] = prompt['task_id'].apply(lambda x: [e[0][0] for e in ranked_result[x]])
        prompt['ranked_code_score'] = prompt['task_id'].apply(lambda x: [e[1] for e in ranked_result[x]])
        prompt['ranked_code_score_is_same'] = prompt['task_id'].apply(lambda x: ranked_result[x][0][1] == ranked_result[x][-1][1] if len(ranked_result[x]) else False)


        prompt['ranked_test'] = prompt['task_id'].apply(lambda x: [e[0][0] for e in ranked_test[x]])
        prompt['ranked_test_score'] = prompt['task_id'].apply(lambda x: [e[1] for e in ranked_test[x]])
        prompt['ranked_test_score_is_same'] = prompt['task_id'].apply(lambda x: ranked_test[x][0][1] == ranked_test[x][-1][1] if len(ranked_test[x]) else False)

        prompt['correct_test'] = prompt['task_id'].apply(lambda x: [e[0][0] for e in correct_test[x]])
        prompt['correct_test_score'] = prompt['task_id'].apply(lambda x: [e[1] for e in correct_test[x]])
        prompt['correct_test_score_is_same'] = prompt['task_id'].apply(lambda x: correct_test[x][0][1] == correct_test[x][-1][1] if len(correct_test[x]) else False)


        print("[SAVE] Save to " + str(save_parquet_path))

        prompt.to_parquet(save_parquet_path, engine='pyarrow')

    def process_each_file(each_input_parquet_path, base_path_for_save):
        save_path = os.path.join(args.output_parquet_dir, base_path_for_save+ ".save_ranked.parquet")
        if os.path.exists(save_path):
            print("[SKIP] Skip to " + str(save_path))
            return

        handled_solutions, task_count, handled_test_cases = PostProcessor.parquet_map_task_id_for_find_correct_test(each_input_parquet_path)

        dual_exec_result = evaluate_with_test_cases(handled_solutions, handled_test_cases, timeout=args.timeout, limit=args.test_case_limit)

        Tools.dump_pickle(os.path.join(args.cache_dir, base_path_for_save + 'dual_exec_result.pkl'), dual_exec_result)


        data_manager = DataManager(dual_exec_result, handled_solutions, handled_test_cases, args.test_case_limit)
        set_consistency = DualAgreement(data_manager)

        ranked_result, ranked_test, correct_test = set_consistency.get_correct_test_with_highest_code()
        logger.info('pass rates of ranked solutions with iter page rank')

        Tools.dump_pickle(os.path.join(args.cache_dir, base_path_for_save+ "page_rank_scores.pkl"), ranked_result)
        Tools.dump_pickle(os.path.join(args.cache_dir, base_path_for_save+ "page_rank_test_scores.pkl"), ranked_test)

        save_parquet(each_input_parquet_path, save_path,ranked_result, ranked_test, correct_test)

    if args.input_parquet_dir.strip():
        for each_parquet_path in os.listdir(args.input_parquet_dir):
            if not each_parquet_path.endswith("parquet"):
                print("[Skip] Skip non-parquet file: " + str(each_parquet_path))
                continue
            base_path_for_save = each_parquet_path.split("/")[-1].split('.')
            base_path_for_save = ".".join(base_path_for_save[:-1])
            each_input_parquet_path = os.path.join(args.input_parquet_dir, each_parquet_path)
            process_each_file(each_input_parquet_path, base_path_for_save)
    elif args.input_hdfs_list_txt.strip():
        with open(args.input_hdfs_list_txt, 'r') as f:
            all_hdfs_paths = f.readlines()
        all_hdfs_paths = [e.strip() for e in all_hdfs_paths]
        for each_parquet_path in all_hdfs_paths:
            base_path_for_save = each_parquet_path.split("/")[-1].split('.')
            base_path_for_save = ".".join(base_path_for_save[:-1])
            each_input_parquet_path = load_hdfs_path(each_parquet_path)
            process_each_file(each_input_parquet_path, base_path_for_save)



if __name__ == '__main__':
    main()

