import sys
import os
import time
import pathlib
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

sys.path.append('.')
sys.path.append('./leaps')

from llm import LLMProgramGenerator
from prog_policies.karel import KarelDSL
from prog_policies.karel_tasks import get_task_cls
from prog_policies.search_space import get_search_space_cls
from prog_policies.search_methods import get_search_method_cls
from prog_policies.utils.save_file import inside_seed_save_log_file, outside_seed_save_log_file
from prog_policies.utils.evaluate_and_search import *

if __name__ == '__main__':
    
    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
    
    # LatentSpace, ProgrammaticSpace
    parser.add_argument('--search_space', default='ProgrammaticSpace', help='Search space class name')
    # Scheduled_HillClimbing, HillClimbing, HillClimbingLatent, CEM, CEBS
    parser.add_argument('--search_method', default='Scheduled_HillClimbing', help='Search method class name')
    parser.add_argument('--seed', type=int, default=0, help='Random seed for searching')
    parser.add_argument('--num_envs', type=int, default=32, help='Number of environments to search')
    # StairClimberSparse, MazeSparse, FourCorners, TopOff, Harvester, CleanHouse
    # Hard: DoorKey, OneStroke, Seeder, Snake
    # New: PathFollow, WallAvoider
    parser.add_argument('--task', default='DoorKey', help='Task class name')
    parser.add_argument('--sigma', type=float, default=0.1, help='Standard deviation for Gaussian noise in Latent Space')
    parser.add_argument('--k', type=int, default=1024, help='Number of neighbors to consider')
    parser.add_argument('--start_k', type=int, default=32, help='Number of neighbors to consider')
    parser.add_argument('--end_k', type=int, default=2048, help='Number of neighbors to consider')
    parser.add_argument('--e', type=int, default=2, help='Number of elite candidates in CEM-based methods')
    parser.add_argument("--max_program_nums", type=int, default=1000000)
    
    # Prompt
    parser.add_argument('--llm_program_num', type=int, default=32, help='Number of programs generated by LLM')
    # LLM
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    
    # Output
    parser.add_argument("--output_dir", type=str, default="output")
    parser.add_argument("--output_name", type=str, default="0")
    parser.add_argument("--save_step", type=int, default=5000)
    
    args = parser.parse_args()
    
    print(vars(args))
    
    output_dir = os.path.join(args.output_dir, args.task, args.output_name)
    output_dir_seed = os.path.join(output_dir, str(args.seed))
    
    if os.path.isdir(output_dir_seed):
        assert 0, "Duplicated seed."
    
    pathlib.Path(output_dir_seed).mkdir(parents=True, exist_ok=True)
    
    dsl = KarelDSL()
    
    env_args = {
        "env_height": 8,
        "env_width": 8,
        "crashable": False,
        "leaps_behaviour": True,
        "max_calls": 10000
    }
    
    if (
        args.task == "StairClimber"
        or args.task == "StairClimberSparse"
        or args.task == "TopOff"
        or args.task == "FourCorners"
    ):
        env_args["env_height"] = 12
        env_args["env_width"] = 12

    if args.task == "CleanHouse":
        env_args["env_height"] = 14
        env_args["env_width"] = 22
    
    if args.task == "WallAvoider":
        env_args["env_height"] = 8
        env_args["env_width"] = 5
    
    task_cls = get_task_cls(args.task)
    task_envs = [task_cls(env_args, i) for i in range(args.num_envs)]
    
    search_space_cls = get_search_space_cls(args.search_space)
    search_space = search_space_cls(dsl, args.sigma)
    search_space.set_seed(args.seed)
    
    search_method_cls = get_search_method_cls(args.search_method)
    if args.search_method == "Scheduled_HillClimbing":
        search_method = search_method_cls(args.k, args.e, args.start_k, args.end_k, args.max_program_nums)
    else:
        search_method = search_method_cls(args.k, args.e)
    
    best_reward = -float('inf')
    best_prog = None
    
    log = {}
    log['args'] = vars(args)
    log['seed'] = args.seed
    
    llm_program_generator = LLMProgramGenerator(args.seed, args.task, dsl, args.llm_program_num, args.temperature, args.top_p)
    program_list, llm_log = llm_program_generator.get_program_list_python_to_dsl()
    log['llm_log'] = [llm_log]
    
    init_time = time.time()

    program_reward = []
    best_prog, best_reward, _ = record_evaluate_program_list(best_prog, best_reward, program_list, search_method, task_envs, dsl, program_reward, output_dir_seed, log, init_time, output_dir, args.task, args.seed, args.max_program_nums)
        
    program_and_reward = zip(program_list, program_reward)
    program_and_reward = sorted(program_and_reward, key=lambda x: x[1], reverse=True)
    program_list = [p[0] for p in program_and_reward]
    
    previous_save_program_num = task_envs[0].program_num
    
    while task_envs[0].program_num < args.max_program_nums and best_reward < 1:
        for program in program_list:
            previous_save_program_num = check_save_time(best_prog, best_reward, args.save_step, previous_save_program_num, search_method, task_envs, dsl, output_dir_seed, log, init_time, output_dir, args.task, args.seed)
            best_prog, best_reward, _ = record_search(best_prog, best_reward, search_method, search_space, task_envs, dsl, output_dir_seed, log, init_time, output_dir, args.task, args.seed, program=program)

            if best_reward >= 1:
                break
    
    search_method.record[task_envs[0].program_num] = best_reward
    search_method.program_record[task_envs[0].program_num] = dsl.parse_node_to_str(best_prog)

    inside_seed_save_log_file(log, output_dir_seed, task_envs[0].program_num, init_time, dsl.parse_node_to_str(best_prog), best_reward, search_method.record, search_method.program_record)
    outside_seed_save_log_file(output_dir, args.task, args.seed, task_envs[0].program_num, init_time, dsl.parse_node_to_str(best_prog), best_reward, search_method.record, search_method.program_record)