import os
import json
import logging
import pathlib
import argparse
import copy
from typing import List, Dict, Any
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from colorama import Fore

import eval_agent.tasks as tasks
import eval_agent.agents as agents
import eval_agent.envs as envs
from eval_agent.utils.datatypes import State
from eval_agent.utils.retrieve_action import get_next_action
from eval_agent.utils.documentsearch import DocumentSearch

logger = logging.getLogger("agent_frame")


def interactive_loop(
    task: tasks.Task,
    agent_model: agents.LMAgent,
    world_model: agents.LMAgent,
    env_config: Dict[str, Any],
    search_engine: DocumentSearch
) -> State:
    logger.info(f"Loading environment: {env_config['env_class']}")
    env: envs.BaseEnv = getattr(envs, env_config["env_class"])(task, **env_config)
    # reset the environment and set the prompt
    observation, state = env.reset()

    init_msg = observation

    logger.info(f"\n{Fore.YELLOW}{init_msg}{Fore.RESET}")

    cur_step = 1
    while not state.finished:
        logger.info(f"\n{Fore.RED}Step {cur_step}{Fore.RESET}\n")
        cur_step += 1
        # agent act
        try:
            if cur_step != 2:
                llm_wm_output: str = world_model(state.history_wm)

                # if env_config["env_class"] == "SciWorldEnv":
                #     history = ""
                #     for message in state.history_wm:
                #         history += message['content']+'\n'
                #     state.history_wm = [{"role":"user","content":history}]
                # else:
                state.history_wm.append({"role":"assistant","content":llm_wm_output})
                print(llm_wm_output)
                
                Thought  = agent_model.probs_gen(state.history_ag,get_probs=False,gen_continue =False)
                Thought = Thought+"Action: "
                pre_action = state.history_ag[-2]["content"].split('Action: ')[-1]
                print(f"pre_action: {pre_action}")
                state.history_ag.append({"role":"assistant","content":Thought}) 
                _,all_probs = agent_model.probs_gen(state.history_ag,get_probs=True,gen_continue=True)
                action_probs = get_next_action(
                    model_name = agent_model.config['model_name'],
                    gen_probs = all_probs,
                    task = env_config["env_class"],
                    llm_wm_output= llm_wm_output, # Summarize
                    search_by = "vector",
                    k = 3000,
                    gen_weight=1,
                    search_engine=search_engine,
                    pre_action=pre_action
                )
                action  = max(action_probs,key= action_probs.get)
                
                state.history_ag[-1]["content"] += action
                if not action.isdigit():
                    action_obj = agent_model.probs_gen(state.history_ag,get_probs=False,gen_continue=True)
                    if env_config["env_class"] == "AlfWorldEnv" or env_config["env_class"] == "SciWorldEnv":
                        if action_obj == '1': 
                            #wait1
                            state.history_ag[-1]["content"] +=action_obj
                        else:
                            state.history_ag[-1]["content"] += ' '+action_obj
                    elif env_config['env_class'] == "WebShopEnv":
                        state.history_ag[-1]["content"] += action_obj
                
                llm_ag_output = state.history_ag[-1]["content"]
                state.history_ag = state.history_ag[:-1]
                print(llm_ag_output)
                logger.info(
                    f"\n{Fore.GREEN}{llm_ag_output}{Fore.RESET}\n"
                )
            else:
                llm_wm_output: str = world_model(state.history_wm)


                state.history_wm.append({"role":"assistant","content":llm_wm_output})
                print(llm_wm_output)
                #guideline
                
                # state.history_ag[-1]["content"] += '\n'+llm_wm_output
                
                # to generate until Action:
                llm_ag_output: str = agent_model(state.history_ag)
                print(llm_ag_output)


                logger.info(
                    f"\n{Fore.GREEN}{llm_ag_output}{Fore.RESET}\n"
                )
        except Exception as e:
            print(e)
            logger.info(f"Agent failed with error: {e}")
            state.success = False
            state.finished = True
            state.terminate_reason = "exceeding maximum input length"
            break
        # environment step
        observation, state = env.step(llm_ag_output)
        print(observation)
        # print(observation,'\n',llm_wm_output,'\n',llm_ag_output)
        # color the state in blue
        if not state.finished:
            # color the observation in blue
            logger.info(
                f"\n{Fore.BLUE}{observation}{Fore.RESET}\n"
            )

        if state.finished:
            break

    if state.reward is not None:
        logger.info(
            f"Task finished in {state.steps} steps. Success: {state.success}. Reward: {state.reward}"
        )
    else:
        logger.info(
            f"Task finished in {state.steps} steps. Success: {state.success}"
        )

    return state


def main(args: argparse.Namespace):

    with open(os.path.join(args.exp_path, f"{args.exp_config}.json")) as f:
        exp_config: Dict[str, Any] = json.load(f)
        
        
    with open(os.path.join(args.agent_path, f"{args.agent_config}.json")) as f:
        agent_model_config: Dict[str, Any] = json.load(f)
        world_model_config  = copy.deepcopy(agent_model_config)
    
    ## agent init
    if args.agent_model_name is not None:
        agent_model_config['config']['model_name'] = args.agent_model_name
    if args.world_model_name is not None:
        world_model_config['config']['model_name'] = args.world_model_name
    world_model_config['config']['max_new_tokens'] = 128
    



    output_path = os.path.join("outputs", agent_model_config['config']['model_name'].replace('/', '_'), args.exp_config+args.exp_name)
    pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

    file_handler = logging.FileHandler(os.path.join(output_path, "log.txt"), mode='w')
    logging.basicConfig(
        format="%(message)s",
        handlers=[logging.StreamHandler(), file_handler],
    )

    env_config = exp_config["env_config"]
    
    logger.info(f"Experiment config: \n{json.dumps(exp_config, indent=2)}")



    
    # to be finished    
    if env_config['env_class'] == 'WebShopEnv':
        from webshop.web_agent_site.envs import WebAgentTextEnv
        env_config['env'] = WebAgentTextEnv(observation_mode="text", human_goals=True)
    elif env_config['env_class'] == 'SciWorldEnv':
        from scienceworld import ScienceWorldEnv
        from eval_agent.utils.replace_sciworld_score import sciworld_monkey_patch
        sciworld_monkey_patch()
        env_config['env'] = ScienceWorldEnv("", serverPath=os.path.join(os.getcwd(), env_config['env_jar_path']), envStepLimit=200)



    # initialize the retrieval system
    print("vector search loading....")
    search_engine = DocumentSearch(file_path=env_config["document_path"],env=env_config['env_class'])
    search_engine.load_documents()
    print("vector search loaded.")

    # initialize all the tasks
    #task & path
    task_config: Dict[str, Any] = exp_config["task"]
    task_class: tasks.Task = getattr(tasks, task_config["task_class"])
    
    #load split task
    all_tasks, n_tasks = task_class.load_tasks(args.split, args.part_num, args.part_idx)
    
    # initialize the agent
    agent: agents.LMAgent = getattr(agents, agent_model_config["agent_class"])(
        agent_model_config["config"]
    )
    world_model: agents.LMAgent = getattr(agents, world_model_config["agent_class"])(
        world_model_config["config"]
    )

    state_list = []

    done_task_id = []
    if os.path.exists(output_path) and not args.override:
        for file in os.listdir(output_path):
            if not file.endswith('json'):
                continue
            state = State.load_json(json.load(open(os.path.join(output_path, file))))
            state_list.append(state)
            done_task_id.append(file.split('.')[0])
        logger.info(f"Existing output file found. {len(done_task_id)} tasks done.")


    if len(done_task_id) == n_tasks:
        logger.info("All tasks done. Exiting.")
        return

    # run the loop for all tasks
    logging.info(f"Running interactive loop for {n_tasks} tasks.")
    n_todo_tasks = n_tasks - len(done_task_id)  # only run the remaining tasks

    with logging_redirect_tqdm():
        pbar = tqdm(total=n_todo_tasks)
        all_reward = 0
        all_round = 0
        for i, task in enumerate(all_tasks):
            # Only test 10 tasks in debug mode
            if args.debug and i == 5:
                break

            # skip done tasks
            if task.task_id in done_task_id or str(task.task_id) in done_task_id:
                continue

            state = interactive_loop(
                task, agent, world_model,env_config,search_engine
            )
            if state.reward:
                all_reward += state.reward
            all_round += 1
            print(f"success rate: {all_reward/all_round}")
            state_list.append(state)
            json.dump(state.to_dict(), open(os.path.join(output_path, f"{task.task_id}.json"), 'w'), indent=4)

            pbar.update(1)
        pbar.close()
    
    logger.warning("All tasks done.")
    logger.warning(f"Output saved to {output_path}")

    # calculate metrics
    reward_list = []
    success_list = []
    for state in state_list:
        if state.reward is not None:
            reward_list.append(state.reward)
        success_list.append(state.success)

    if len(reward_list) != 0:
        logger.warning(f"Average reward: {sum(reward_list)/len(success_list):.4f}")
    logger.warning(f"Success rate: {sum(success_list)/len(success_list):.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Run the interactive loop.")
    parser.add_argument(
        "--exp_name",
        type=str,
        default="",
        help="The name of the experiemnt.",
    )
    parser.add_argument(
        "--exp_path",
        type=str,
        default="./eval_agent/configs/task",
        help="Config path of experiment.",
    )
    parser.add_argument(
        "--exp_config",
        type=str,
        default="webshop",
        help="Config of experiment.",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="test",
        help="Evaluation split.",
    )
    parser.add_argument(
        "--part_num",
        type=int,
        default=1,
        help="Evaluation part.",
    )
    parser.add_argument(
        "--part_idx",
        type=int,
        default=-1,
        help="Evaluation part.",
    )
    parser.add_argument(
        "--agent_path",
        type=str,
        default="./eval_agent/configs/model",
        help="Config path of model.",
    )
    parser.add_argument(
        "--agent_config",
        type=str,
        default="fastchat",
        help="Config of model.",
    )
    parser.add_argument(
        "--agent_model_name",
        type=str,
        required=False,
        help="Model name. It will override the 'model_name' in agent_config"
    )
    parser.add_argument(
        "--world_model_name",
        type=str,
        required=False,
        help="Model name. It will override the 'model_name' in agent_config"
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Whether to run in debug mode (10 ex per task).",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Whether to run in debug mode (10 ex per task).",
    )
    parser.add_argument(
        "--override",
        action="store_true",
        help="Whether to ignore done tasks.",
    )
    parser.add_argument(
        "--interactive",
        action="store_true",
        help="Whether to run in interactive mode for demo purpose.",
    )
    args = parser.parse_args()
    
    if args.verbose:
        logger.setLevel(logging.INFO)
    elif args.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.WARNING)

    main(args)

