import json
import yaml
import argparse
import os
import datetime


from preference_learning_main import APRICOTPrefLearning
from pref_learing_baselines.non_interactive import NonInteractivePrefLearning
from pref_learing_baselines.interactive import InteractiveBaselinePrefLearning
# from pref_learing_baselines.querying_only_interactive import QueryingOnlyInteractivePrefLearning

from utils import *

from loguru import logger

if "LOGURU_LEVEL" in os.environ and (os.environ["LOGURU_LEVEL"] == "INFO"):
    fmt = "{message}"
    config = {
        "handlers": [
            {"sink": sys.stderr, "format": fmt},
        ],
    }
    logger.configure(**config)

# runtime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# logger.add(f"./logging/{runtime}.log", colorize=False, backtrace=True, diagnose=True)

def evaluate_one_scenario(approach, score_plan_fn, gt_pref, demos, obj_to_put_away, scenario, progress_save_file, logging_file, logger_handler_id=None):
    """
    Parameters:
        approach (object) -
        demos (list of dictionary) -
        gt_pref (str) - 
        obj_to_put_away (list of str)
        scenario (dictionary)
    """
    if logging_file:
        if logger_handler_id is not None:
            # If there is an old handler, remove it
            logger.remove(logger_handler_id)

        logger_handler_id = logger.add(logging_file, colorize=False, backtrace=True, diagnose=True)

    demos_str = format_demos_dict_for_gb(demos)
    scenario_str = format_state_dictionary_to_str(scenario)

    logger.info(f"Saving Progress At... {progress_save_file}")

    # Format the test case properly
    if approach.NAME == APRICOTPrefLearning.NAME:
        idx_to_best_plan, plan_list, best_pref_idx, preference_list, plan_score_given_best_pref, step = approach.main(demos=demos_str,
                initial_state=scenario_str,
                objs_to_put_away=obj_to_put_away,
                gt_pref=gt_pref,
                progress_save_file=progress_save_file)
        

        best_plan = plan_list[idx_to_best_plan]
        best_preference = preference_list[best_pref_idx]
        best_plan_according_to_best_preference = plan_list[best_pref_idx]

        # Found a plan
        gt_score_for_best_plan = score_plan_fn(preference=gt_pref, initial_state=scenario, plan=best_plan)
        gt_score_for_best_plan_according_to_best_preference = score_plan_fn(preference=gt_pref, initial_state=scenario, plan=best_plan_according_to_best_preference)

        if plan_score_given_best_pref is None:
            plan_score_given_best_pref = -1

        # logger.info(f"\nQuery Steps: {step}\nBest Plan (index={idx_to_best_plan}):\n{best_plan}\n\n[Plan Score] Given Best Pref: {plan_score_given_best_pref:.3f}   |    Given Gt Pref: {gt_score_for_best_plan:.3f}\n\nBest Preference (index={best_pref_idx}):\n{best_preference}\n\nAssociated Plan:\n{best_plan_according_to_best_preference}\n\nGround-Truth Preference\n{gt_pref}\n\n[Plan Score] Given Gt Pref: {gt_score_for_best_plan_according_to_best_preference:.3f}")

        logger.info(f"\nQuery Steps: {step}\nBest Plan (index={idx_to_best_plan}):\n{best_plan}\n\nBest Preference given Best Plan (index={idx_to_best_plan}):\n{preference_list[idx_to_best_plan]}\n\n[Plan Score] Given Best Pref: {plan_score_given_best_pref:.3f}   |    Given Gt Pref: {gt_score_for_best_plan:.3f}")

        approach.save_progress("p_theta", approach.p_theta.tolist())
        approach.save_progress("idx_to_best_plan", int(idx_to_best_plan))
        approach.save_progress("best_pref_idx", int(best_pref_idx))
        approach.save_progress("query_steps", int(step))
        approach.save_progress("plan_score_given_best_pref", plan_score_given_best_pref)
        approach.save_progress("gt_score_best_plan", gt_score_for_best_plan)
        approach.save_progress("gt_score_best_preference_plan", gt_score_for_best_plan_according_to_best_preference)
    elif approach.NAME == NonInteractivePrefLearning.NAME or approach.NAME == InteractiveBaselinePrefLearning.NAME:
        best_plan, best_preference, step = approach.main(demos=demos_str,
                initial_state=scenario_str,
                objs_to_put_away=obj_to_put_away,
                gt_pref=gt_pref,
                progress_save_file=progress_save_file)
        
        gt_score = score_plan_fn(preference=gt_pref, initial_state=scenario, plan=best_plan)
        
        logger.info(f"\nBest Plan:\n{best_plan}\n\nBest Preference:\n{best_preference}\n\nGround-Truth Preference:\n{gt_pref}\n\n[Plan Score]    |    Given Gt Pref: {gt_score:.3f}")

        approach.save_progress("query_steps", int(step))
        approach.save_progress("gt_score", gt_score)
    elif approach.NAME == QueryingOnlyInteractivePrefLearning.NAME:
        best_plan, best_pref_idx, best_preference, step = approach.main(demos=demos_str,
                initial_state=scenario_str,
                objs_to_put_away=obj_to_put_away,
                gt_pref=gt_pref,
                progress_save_file=progress_save_file)
        
        gt_score = score_plan_fn(preference=gt_pref, initial_state=scenario, plan=best_plan)
        
        logger.info(f"\nBest Plan:\n{best_plan}\n\nBest Preference (index={best_pref_idx}):\n{best_preference}\n\nGround-Truth Preference:\n{gt_pref}\n\n[Plan Score]    |    Given Gt Pref: {gt_score:.3f}")

        approach.save_progress("best_pref_idx", int(best_pref_idx))
        approach.save_progress("query_steps", int(step))
        approach.save_progress("gt_score", gt_score)
    else:
        assert False, "Not supported yet"

    approach.reset_progress_save_file()


def init_approach(approach_name, configs, debug):
    """
    Return
        approach
        scoring function
    """
    gb = APRICOTPrefLearning(
                    seed=configs["seed"],
                    prompt_config_path=configs["prompt_config_path"],
                    learn_pref_config = configs["learn_pref_config"],
                    planner_config = configs["planner_config"],
                    sim_config = configs["sim_config"],
                    debug=debug
    )

    scoring_fn = gb.score_plan

    if approach_name == "APRICOT":
        return gb, scoring_fn
    elif approach_name == "NonInteractive":
        non_interactive_baseline = NonInteractivePrefLearning(
                    seed=configs["seed"],
                    prompt_config_path=configs["prompt_config_path"],
                    learn_pref_config = configs["learn_pref_config"],
                    planner_config = configs["planner_config"],
                    sim_config = configs["sim_config"],
                    debug=debug
        )

        return non_interactive_baseline, scoring_fn
    elif approach_name == "Interactive":
        interactive_baseline = InteractiveBaselinePrefLearning(
                    seed=configs["seed"],
                    prompt_config_path=configs["prompt_config_path"],
                    learn_pref_config = configs["learn_pref_config"],
                    planner_config = configs["planner_config"],
                    sim_config = configs["sim_config"],
                    debug=debug
        )
        
        return interactive_baseline, scoring_fn
    elif approach_name == "QueryingOnlyInteractive":
        querying_only_interactive = QueryingOnlyInteractivePrefLearning(
                    seed=configs["seed"],
                    prompt_config_path=configs["prompt_config_path"],
                    learn_pref_config=configs["learn_pref_config"],
                    planner_config=configs["planner_config"],
                    sim_config=configs["sim_config"],
                    debug=debug,
        )

        return querying_only_interactive, scoring_fn
    else:
        assert False, f"{approach_name=} is not supported"
    

def init_progress_save_file(approach_name, test_case_type, gt_id, scenario_id, not_agreeable_plan_threshold=None):
    progress_save_file = os.path.join(f"saved_progress/{approach_name}/{test_case_type}", f"{test_case_type}_gt={gt_id}_sc={scenario_id}_saved_progress.json")

    run_test = False
    if not os.path.exists(progress_save_file):
        if approach_name == "NonInteractive" or approach_name == "APRICOT" or approach_name == "QueryingOnlyInteractive":
            if scenario_id != 0:
                scenario_0_progress_save_file = progress_save_file.replace(f"sc={scenario_id}", f"sc=0")

                with open(scenario_0_progress_save_file, "r") as fin:
                    scenario_0_progress_save_dict = json.load(fin)

                curr_scenario_progress_save_dict = {"preference_writing_reasoning": scenario_0_progress_save_dict["preference_writing_reasoning"]}

                if approach_name == "APRICOT":
                    # Load generated preferences
                    curr_scenario_progress_save_dict["reasoning_worksheet"] = scenario_0_progress_save_dict["reasoning_worksheet"]
                    curr_scenario_progress_save_dict["preference_list"] = scenario_0_progress_save_dict["preference_list"]
                elif approach_name == "QueryingOnlyInteractive":
                    # Load generated preferences
                    curr_scenario_progress_save_dict["reasoning_worksheet"] = scenario_0_progress_save_dict["reasoning_worksheet"]
                    curr_scenario_progress_save_dict["preference_list"] = scenario_0_progress_save_dict["preference_list"]
                elif approach_name == "NonInteractive":
                    curr_scenario_progress_save_dict["preference"] = scenario_0_progress_save_dict["preference"]

                with open(progress_save_file, "w") as fout:
                    json.dump(curr_scenario_progress_save_dict, fout, indent=4)
        
        run_test = True
    else:
        if approach_name == "APRICOT":
            assert not_agreeable_plan_threshold is not None, "not_agreeable_plan_threshold shouldn't be none for APRICOT"
            keyword = str(not_agreeable_plan_threshold)
        else:
            keyword = "gt_score"
            
        with open(progress_save_file, "r") as fin:
            progress_save_dict = json.load(fin)

        run_test = not (keyword in progress_save_dict)

    return run_test, progress_save_file


def init_logging_file(approach_name, test_case_type, gt_id, scenario_id):
    runtime = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    logging_file = os.path.join(f"logging/{approach_name}/{test_case_type}", f"{runtime}_{approach_name}_{test_case_type}_gt={gt_id}_sc={scenario_id}.log")

    return logging_file

    
def evaluate_test_cases(approach_name, test_case_type, test_config_filepath, debug=False):
    # Load testing config
    with open(test_config_filepath, "r") as fin:
        configs = yaml.safe_load(fin)

    # Load approach to test on, and scoring function that will score how good the plan is
    approach, scoring_fn = init_approach(approach_name, configs, debug)

    # Load the test cases json
    with open(f"pref_learning_benchmark/{test_case_type}.json", "r") as fin:
        test_cases = json.load(fin)

    logger_handler_id = None

    for i in test_cases.keys():
        gt_pref = test_cases[str(i)]["gt_preference"]
        demos = test_cases[str(i)]["demonstrations"]
        scenarios = test_cases[str(i)]["scenarios"]

        # if int(i) < 2:
        for j in range(len(scenarios)):
            run_test, progress_save_file = init_progress_save_file(approach.NAME, test_case_type, gt_id=i, scenario_id=j, not_agreeable_plan_threshold=configs["learn_pref_config"]["not_agreeable_threshold"])

            if run_test:
                logging_file = init_logging_file(approach.NAME, test_case_type, gt_id=i, scenario_id=j)

                evaluate_one_scenario(approach, scoring_fn, 
                                    gt_pref,
                                    demos,
                                    obj_to_put_away=scenarios[j]["objects_to_put_away"],
                                    scenario=scenarios[j]["initial_state"],
                                    progress_save_file=progress_save_file,
                                    logging_file=logging_file,
                                    logger_handler_id=logger_handler_id)
                
    os.system(f'say "Hey, the experiments for the approach, {approach_name}, on {test_case_type} test cases are done!"')
                    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="APRICOT Preference Learning main function")

    parser.add_argument("--debug", action="store_true", default=False, help="Whether to allow debug")
    parser.add_argument("--approach", choices=["APRICOT", "NonInteractive", "Interactive"], required=True)
    parser.add_argument("--type", choices=["specific_location", "general_location", "relative_position", "subcategory_exception", "conditional", "real_robot"], default=None)
    parser.add_argument("--gt_id", type=int, default=None)
    parser.add_argument("--scenario_id", type=int, default=None)
    parser.add_argument("--test_config_path", type=str, default="testing_configs.yaml")
    
    args = parser.parse_args()
    
    if args.gt_id is None and args.scenario_id is None and args.type is not None:
        evaluate_test_cases(args.approach, args.type, args.test_config_path, args.debug)
    elif args.type is None:
        input("Are you sure you want to continue? You are about to run all the test cases")
        for test_case_type in ["specific_location", "general_location", "relative_position", "subcategory_exception", "conditional"]:
            evaluate_test_cases(args.approach, test_case_type, args.test_config_path, args.debug)
    else:
        with open(args.test_config_path, "r") as fin:
            configs = yaml.safe_load(fin)

        approach, scoring_fn = init_approach(args.approach, configs, args.debug)

        with open(f"pref_learning_benchmark/{args.type}.json", "r") as fin:
            test_cases = json.load(fin)

        test_case = test_cases[str(args.gt_id)]

        _, progress_save_file = init_progress_save_file(approach.NAME, args.type, args.gt_id, args.scenario_id, not_agreeable_plan_threshold=configs["learn_pref_config"]["not_agreeable_threshold"])

        logging_file = init_logging_file(approach.NAME, args.type, args.gt_id, args.scenario_id)

        evaluate_one_scenario(approach, 
                            scoring_fn,
                            test_case["gt_preference"], 
                            test_case["demonstrations"], 
                            test_case["scenarios"][args.scenario_id]["objects_to_put_away"], 
                            test_case["scenarios"][args.scenario_id]["initial_state"],
                            progress_save_file,
                            logging_file)