import os
import pandas as pd
import argparse
import json
import shutil
from torchvision import datasets
import importlib
import sys
sys.path.append("..")

from MLAgentBench.low_level_actions import execute_script

MATCH_TASK_STRING_DICT = {
    "cifar10": "Train Accuracy: ",
    "dog-breed-identification": "Train Loss: ",
    "feedback": "Validation MCRMSE: ",
    "house-price": "Train MAE: ",
    "detecting-insults-in-social-commentary": "Test Accuracy: ",
    "chaii-hindi-and-tamil-question-answering": "Validation Jaccard Score: ",
    "denoising-dirty-documents": "Validation RMSE",
    "google-quest-challenge": "Validation Mean Spearman Correlation Coefficient:",
    "leaf-classification": "Validation Log Loss: ",
    "learning-agency-lab-automated-essay-scoring-2": "Validation Quadratic Weighted Kappa:",
    "spooky-author-identification": "Validation Log Loss:",
    "statoil-iceberg-classifier-challenge": "Final Validation Log Loss:",
    "us-patent-phrase-to-phrase-matching": "Validation Pearson Correlation:",
    "jigsaw-toxic-comment-classification-challenge": "Mean column-wise ROC AUC:",
    "lmsys-chatbot-arena": "Validation Log Loss:",
    "tabular-playground-series-dec-2021": "Validation Accuracy:",
    "ventilator-pressure-prediction": "Validation MAE:",
    "whale-categorization-playground": "MAP@5:",
    "nomad2018-predict-transparent-conductors": "Average RMSLE:",
    "spaceship-titanic": "Validation Accuracy:",
}

HIGHER_IS_BETTER_TASKS = ["cifar10", "detecting-insults-in-social-commentary", "chaii-hindi-and-tamil-question-answering", "google-quest-challenge", "learning-agency-lab-automated-essay-scoring-2", "us-patent-phrase-to-phrase-matching", "jigsaw-toxic-comment-classification-challenge", "tabular-playground-series-dec-2021", "whale-categorization-playground", "spaceship-titanic"]
LOWER_IS_BETTER_TASKS = ["dog-breed-identification", "feedback", "house-price", "leaf-classification", "denoising-dirty-documents", "spooky-author-identification", "statoil-iceberg-classifier-challenge", "lmsys-chatbot-arena", "ventilator-pressure-prediction", "nomad2018-predict-transparent-conductors"]

def remove_unnecessary_workspaces(data_dir_workspaces):
    data_dir_log = data_dir_workspaces.replace("workspaces/", "logs/")
    # remove subdirs under data_dir_workspaces that are not contained under data_dir_log
    for subdir in os.listdir(data_dir_workspaces):
        if subdir not in os.listdir(data_dir_log):
            shutil.rmtree(os.path.join(data_dir_workspaces, subdir))

def clear_logs_workspaces(data_dir_workspaces):
    # for logs, remove env_log; for workspaces, remove all subdirs
    data_dir_logs = data_dir_workspaces.replace("workspaces/", "logs/")
    for subdir in os.listdir(data_dir_logs):
        env_log = os.path.join(data_dir_logs, subdir, "env_log")
        if os.path.exists(env_log):
            shutil.rmtree(env_log)
    for subdir in os.listdir(data_dir_workspaces):
        shutil.rmtree(os.path.join(data_dir_workspaces, subdir))
        
        # workspace_path = os.path.join(data_dir_workspaces, subdir)
        # # remove all files and folders except .csv and .py files
        # for sub_subdir in os.listdir(workspace_path):
        #     if sub_subdir.endswith(".csv") or sub_subdir.endswith(".py"):
        #         continue
        #     shutil.rmtree(os.path.join(workspace_path, sub_subdir))


def cal_score(submission_dir, task):
    module = importlib.import_module(f'MLAgentBench.benchmarks.{task}.scripts.eval')
    score = module.get_score(submission_dir)

    return score

def main(data_dir_workspaces, args):
    n_beam, task, agent_max_steps, python = args.n_beam, args.task, args.agent_max_steps, args.python

    data_dir_logs = data_dir_workspaces.replace("workspaces/", "logs/")
    match_string = MATCH_TASK_STRING_DICT[task]
    remove_unnecessary_workspaces(data_dir_workspaces)

    if not os.path.exists(os.path.join(data_dir_logs, "results.json")):
        data = {}
    else:
        with open(os.path.join(data_dir_logs, "results.json"), "r") as f:
            data = json.load(f)

    if not args.report_only:
        subdir_list = sorted(os.listdir(data_dir_workspaces))
        for subdir in subdir_list:
            if subdir in data:
                continue
            print(f">> Processing {subdir}")
            beam = 0
            if n_beam > 1:
                # beam is the id with highest reward
                beam_tree_log = os.path.join(data_dir_logs, subdir, "agent_log/beam_tree_log.jsonl")
                with open(beam_tree_log, "r") as f:
                    beam_tree_log = f.readlines()
                beam_tree_log = [json.loads(line) for line in beam_tree_log]
                verifier_scores = beam_tree_log[-1]["verifier_scores"]
                choices = beam_tree_log[-1]["choices"]
                # find the choice id with highest reward
                beam = max(range(len(choices)), key=lambda i: verifier_scores[choices[i]])

            # get number of steps
            with open(os.path.join(data_dir_logs, subdir, "agent_log/main_log"), "r") as f:
                main_log_content = "\n".join(f.readlines())
            num_steps = int(main_log_content.split("==================== Step ")[-1].split(" ====================")[0]) + 1

            submission_folder = os.path.join(data_dir_workspaces, subdir)
            submission_path = os.path.join(submission_folder, "submission.csv")
            if os.path.exists(submission_path):
                try:
                    final_score = cal_score(submission_folder, task)
                except Exception as e:
                    print(f"\033[31m>> Error in {submission_folder}: {e}\033[0m")
                    continue
                
                test_score_list = []
                for step in range(agent_max_steps):
                    trace_submission_folder = os.path.join(data_dir_logs, subdir, f"env_log/traces/step_{step}_files")
                    if not os.path.exists(trace_submission_folder):
                        break
                    if "submission.csv" not in os.listdir(trace_submission_folder):
                        continue
                    test_score = cal_score(trace_submission_folder, task)
                    test_score_list.append(test_score)

                if len(test_score_list) == 0:
                    best_score = final_score
                elif task in HIGHER_IS_BETTER_TASKS: # higher is better
                    best_score = max(test_score_list)
                    if final_score > best_score:
                        best_score = final_score
                elif task in LOWER_IS_BETTER_TASKS: # lower is better
                    best_score = min(test_score_list)
                    if final_score < best_score:
                        best_score = final_score

            else: # execute the train.py and test submission # TODO: the script is not unique
                print(f">> Executing train.py in {subdir}")
                final_score = 0
                kwargs = {"device": int(args.device), "python": python}
                try:
                    observation = execute_script("train.py", submission_folder, **kwargs)
                except Exception as e:
                    observation = ""
                    pass
                print(f">> {observation}")
                if match_string in observation: # train.py executed successfully
                    final_score = cal_score(submission_folder, task)
                else:
                    print(f">> Execute train.py failed in {subdir}.")
                    continue
                best_score = final_score
            data[subdir] = {"beam": beam, "final_score": final_score, "best_score": best_score, "step": num_steps}
            with open(os.path.join(data_dir_logs, "results_each.jsonl"), "a") as f:
                json.dump({subdir: data[subdir]}, f)
                f.write("\n") 
            print(f"Save!")
            with open(os.path.join(data_dir_logs, "results.json"), "w") as f:
                json.dump(data, f, indent=4)
            print(f"Save to results.json !")


    final_score_set, best_score_set, num_steps_list = [], [], []
    for subdir, value in data.items():
        # print(f">> {subdir}: Final Score={value['final_score']:.4f}, Best Score={value['best_score']:.4f}.")
        final_score_set.append(value["final_score"])
        best_score_set.append(value["best_score"])
        num_steps_list.append(value["step"])

    if not args.report_only:
        # print average accuracy
        avg_final_score = sum(final_score_set)/len(final_score_set)
        avg_best_score = sum(best_score_set)/len(best_score_set)
        if task in HIGHER_IS_BETTER_TASKS:
            best_best_score = max(best_score_set)
        elif task in LOWER_IS_BETTER_TASKS:
            best_best_score = min(best_score_set)
        print(f">> {task} len {len(final_score_set)} | step {sum(num_steps_list)/len(num_steps_list):.2f}: Avg. Final Score={avg_final_score:.4f}, Avg. Best Score={avg_best_score:.4f}, Best Score={best_best_score:.4f}.")
        print(f">> Save to {os.path.join(data_dir_logs, 'results.json')}.")
    else:
        clear_logs_workspaces(data_dir_workspaces) # For debug
        # pass


    if args.report_number > len(final_score_set):
        print(f"Warning: Report number {args.report_number} is larger than the number of submissions {len(final_score_set)}. Set report number to {len(final_score_set)}.")
        args.report_number = len(final_score_set)

    print(f">> Final Report of First \033[33m{args.report_number}\033[0m Submissions on Task \033[33m{task}:\033[0m")
    num_steps_list = num_steps_list[:args.report_number]
    avg_final_score = sum(final_score_set[:args.report_number])/args.report_number
    avg_best_score = sum(best_score_set[:args.report_number])/args.report_number
    if task in HIGHER_IS_BETTER_TASKS:
        best_best_score = max(best_score_set[:args.report_number])
    elif task in LOWER_IS_BETTER_TASKS:
        best_best_score = min(best_score_set[:args.report_number])
    print(f">> step {sum(num_steps_list)/len(num_steps_list):.2f}: \033[33m{avg_final_score:.4f}/{avg_best_score:.4f}/{best_best_score:.4f}.\033[0m")
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, default="~/EvolveResearchAgent_Ray")
    parser.add_argument("--python", type=str, default="~/miniconda3/envs/mlagentbench/bin/python")
    parser.add_argument("--generator_name", type=str, default="gpt-4o-mini")
    parser.add_argument("--report_only", action="store_true", default=False)
    parser.add_argument("--report_number", type=int, default=8)
    parser.add_argument("--device", type=int, default=2)
    parser.add_argument("--task", type=str, default="cifar10")
    parser.add_argument("--n_beam", type=int, default=1)
    parser.add_argument("--agent_max_steps", type=int, default=15)
    args = parser.parse_args()

    data_dir = f"{args.data_root}/workspaces/search_n1_b1_MLAgentBench_{args.task}/step_{args.agent_max_steps}_{args.generator_name}"
    
    assert os.path.exists(data_dir), f"{data_dir} does not exist."
    main(data_dir, args)