from human_model import HumanModel
from agent_model_5 import AgentModel

import yaml
import alfworld
import alfworld.agents.environment

import os 
import json 

from alfworld.agents.utils.misc import add_task_to_grammar, Demangler
import textworld

import gym 

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch 

import numpy as np 
import sys 
import copy 

import argparse

os.environ['ALFWORLD_DATA'] = ""


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--split", type=str, default="eval_out_of_distribution")
    return parser.parse_args()


def process_ob(ob):
    if ob.startswith('You arrive at loc '):
        ob = ob[ob.find('. ')+2:]
    elif ob.startswith('-= Welcome'):
        ob = '\n'.join(ob.split('\n\n')[1:])
    return ob


def get_facts_location(facts_):
    facts = ""
    for f in facts_:
        if f[0] == "inreceptacle":
            facts += f[1] + " is in " + f[2] + ". "
            # facts.append(" ".join(fact))
    return facts


def main(args):
    set_seed(args.seed)

    # === initialize the environment ===
    with open('') as reader:
        config = yaml.safe_load(reader)

    env_batch_size = 1

    split = args.split  # "train" / "eval_out_of_distribution" / "eval_in_distribution"
    print("seed:", args.seed)
    print("split:", args.split)
    if split == 'eval_out_of_distribution':
        episode_total = 134
    elif split == 'eval_in_distribution':
        episode_total = 140
    else:
        raise NotImplementedError
    
    alfred_env = getattr(alfworld.agents.environment, config["env"]["type"])(config, train_eval=split)
    env = alfred_env.init_env(batch_size=env_batch_size)

    model_path_human_model = ""
    model_path_agent_model = ""
    tokenizer_path = ""

    peft_path = None

    # # === initialize the human model ===
    human_model = HumanModel(model_path=model_path_human_model,
                             tokenizer_path=tokenizer_path,
                             prompt_path="")

    # # === initialize the agent ===
    agent = AgentModel(prompt_path="",
                       model_path=model_path_agent_model,
                       tokenizer_path=tokenizer_path,
                       peft_path=peft_path)


    reward_cnt_per_type = [0] * 6
    episode_cnt_per_type = [0] * 6
    ask_cnt_per_type = [0] * 6

    prefixes = {
        'pick_and_place': 0,
        'pick_clean_then_place': 1,
        'pick_heat_then_place': 2,
        'pick_cool_then_place': 3,
        'look_at_obj': 4,
        'pick_two_obj': 5
    }

    episode_reward = 0


    # === start evaluation ===
    ob, info = env.reset()
    # ob, info = env.reset()
    ob = process_ob(ob[0]) + " >"
    name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])
    print(name)
    facts = get_facts_location(info['facts'][0])

    human_model.init_prompt(facts)
    agent.init_prompt(name)

    print(ob)

    while np.sum(episode_cnt_per_type) < episode_total:
            
        admissible_actions = copy.deepcopy(info['admissible_commands'])
        for aa in admissible_actions:
            aa.remove('look')
            aa.remove('inventory')
            aa.append('think')
            aa.append('ask')
            # aa.append('query')

        action = agent.act_greedy_search(ob, admissible_actions[0])

        observation, reward, done, info = env.step([action])
        observation, reward, done = process_ob(observation[0]), info['won'][0], (info['won'][0] or done[0])

        # update human model info:
        facts = get_facts_location(info['facts'][0])
        human_model.update_facts(facts)

        if action.startswith('ask'):
            question = action[4:]
            if not done:
                observation = human_model.answer(question.strip())
            
        elif action.startswith('think'):
            if not done:
                observation = 'OK.'
        
        episode_reward += reward

        print(action)
        print(observation)
        sys.stdout.flush()
        
        ob += f' {action} > {observation} >'

        if done:

            ob, info = env.reset()
            ob = process_ob(ob[0]) + " >"

            name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])
            print(name)
            facts = get_facts_location(info['facts'][0])

            for ii, (k, v) in enumerate(prefixes.items()):
                if name.startswith(k):
                    reward_cnt_per_type[ii] += episode_reward
                    episode_cnt_per_type[ii] += 1
                    ask_cnt_per_type[ii] += human_model.ask_cnt
            
            episode_reward = 0
        
            print("perR:", reward_cnt_per_type, "\tperE:", episode_cnt_per_type, "\tperAsk:", ask_cnt_per_type,
                  "\tperAvgR:", np.array(reward_cnt_per_type) / np.array(episode_cnt_per_type),
                  "\tAvgR:", np.sum(reward_cnt_per_type) / np.sum(episode_cnt_per_type), "\tSumE:", np.sum(episode_cnt_per_type))
            # print('------------\n')


            sys.stdout.flush()

            human_model.init_prompt(facts)
            agent.init_prompt(name)



if __name__ == '__main__':
    args = get_args()
    main(args)