import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

from transformers import AutoTokenizer, GemmaForCausalLM, AutoModelForCausalLM
import itertools
import time
import numpy as np
import random
import torch
import pickle as pkl

from ExRAP.src.model import LLMAgent
from ExRAP.src.knowledge_graph import KG
from ExRAP.src.retriever import EmbeddingFnsClass, InContextExpert
from ExRAP.src.value_fn import find_distance, information_function, action_information_gain, query_execution
from simulation.environment.unity_environment import UnityEnvironment
from simulation.environment.continual_instruction import TASKS_SET, QUERIES_SET, SUCCESS_CONDITIONS, OBJECTS, ROOMS, BASE_PROMPT, NON_STA_ACTION, NON_STA_SETTINGS

EXPLORATION_VALUE = True
TEMPORAL_CONSISTENCY = True

def icl_prompt_template(knowledge_graph, task, prev_action, action):
    prompts = f"Knowledge graph: {knowledge_graph}\n"
    prompts += f"Required executing callback: {task}\nPrevious action: {prev_action}\nNext action: {action}"
    return prompts

def prompt_template(knowledge_graph, history, task_list, prev_action, queries):
    prompts = BASE_PROMPT
    prompts += "\nSet of queries: {}\n".format(", ".join(queries))

    prompts += "\n".join(history)+"\n"
    prompts += f"Knowledge graph: {knowledge_graph}\n"
    prompts += f"Required executing callback: {task_list}\nPrevious action: {prev_action}\nNext action:"
    return prompts

if __name__ == "__main__":
    action_history = ["First step"]
    model_name = "meta-llama/Meta-Llama-3-8B" #"google/gemma-7b"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)

    # Using cache for task and executions
    TASK_IDX = [0, 1, 2, 3, 4]
    queries = [QUERIES_SET[i] for i in TASK_IDX]
    executions = [TASKS_SET[i] for i in TASK_IDX]
    success_condition = [SUCCESS_CONDITIONS[i] for i in TASK_IDX]
    non_sta_action = [NON_STA_ACTION[i] for i in TASK_IDX]
    non_sta_setting = NON_STA_SETTINGS["medium"]

    # Environment start
    env = UnityEnvironment(url='127.0.0.1', seed=int(time.time()))
    env.set_task({"required_condition":  success_condition, "prohibited_condition": []})

    obs = env.reset(environment_id=0)
    print(env.task_relevant_orginal_pos)
    for idx in range(len(queries)):
        for key in env.task_relevant_orginal_pos:
            if key in queries[idx]:
                queries[idx] = queries[idx] + env.task_relevant_orginal_pos[key]
    embedding_fns = EmbeddingFnsClass("facebook/dpr-ctx_encoder-single-nq-base")

    agent = LLMAgent(tokenizer, model, KG(env.get_position_graph()), embedding_fns=embedding_fns)
    agent.kg_add(obs['visible_graph'], 0, use_refinement=True)
    agent.kg_add(obs['agent_graph'], 0, use_refinement=True)
    query_temp_eval, inferred_query_execs, selected_executions = agent.evaluate_query(queries, executions, temporal_consistency=TEMPORAL_CONSISTENCY)
    temp_execution = selected_executions[:1]
    failure_action, reward_mean, aoe, earn_reward, valid_timesteps = [], [], [], 0, 0
    accurate, total_query = 0, 0

    # For submit, using small set of demonstration
    ICL_Learner = InContextExpert("demonstrations.pkl")
    for timesteps in range(200):
        start_time = time.time()
        temp_kg, icl_prompt = ICL_Learner.make_context(agent, temp_execution, icl_prompt_template, queries)
        prompts = prompt_template(", ".join(temp_kg), icl_prompt, ", ".join(temp_execution), action_history[-1], queries)
        make_prompt_time = time.time()

        action_list = []
        action_list_comb = itertools.product(["walk"], OBJECTS + ROOMS)
        action_list.extend([' '.join(comb) for comb in action_list_comb])
        action_list_comb = itertools.product(["grab", "switch", "open"], OBJECTS)
        action_list.extend([' '.join(comb) for comb in action_list_comb])
        action_list_comb = itertools.product(["put", "putin"], OBJECTS)
        action_list.extend([' '.join(comb) for comb in action_list_comb])

        if failure_action:
            for failure_act in failure_action:
                action_list.remove(failure_act)

        action, action_value = agent.predict(prompts, action_list, return_list=True)
        prediction_time = time.time()

        query_temp_eval = np.array(query_temp_eval)
        binary_query_temp_eval = query_temp_eval[:, :2] + np.stack([query_temp_eval[:, 2], query_temp_eval[:, 2]], axis=-1)/2
        top_20_actions = [action_value[i][0] for i in range(40)]
        temp_info, act_info_list, temp_info_timesteps = action_information_gain(agent.knowledge_graph, top_20_actions, queries_list=queries, temp_query_eval=binary_query_temp_eval, timestep=timesteps+1, embedding_fns=agent.embedding_fns)
        final_actions = []

        for i in range(40):
            if EXPLORATION_VALUE:
                if temp_execution[0] == "Explore the home" or temp_execution[0].split()[0] == "Find":
                    final_actions.append((action_value[i][0], action_value[i][1] + (act_info_list[i] - temp_info) * 1e+6))
                else:
                    final_actions.append((action_value[i][0], action_value[i][1] + (act_info_list[i] - temp_info) * 10 * (len(executions) - len(selected_executions))))
            else:
                final_actions.append((action_value[i][0], action_value[i][1]))
            final_actions = sorted(final_actions, key=lambda x: x[1], reverse=True)

        information_cal_time = time.time()

        obs, reward, done, info = env.step(final_actions[0][0])
        if not info['success']:
            failure_action.append(final_actions[0][0])
        else:
            failure_action = []
            action_history.append(final_actions[0][0])
        agent.kg_add(obs['visible_graph'], timesteps+1, use_refinement=True)
        agent.kg_add(obs['agent_graph'], timesteps+1, use_refinement=True)
        step_and_update_time = time.time()

        query_temp_eval, inferred_query_execs, selected_executions = agent.evaluate_query(queries, executions, query_temp_eval, temporal_consistency=TEMPORAL_CONSISTENCY)
        if temp_execution[0] not in selected_executions:
            temp_execution = selected_executions[:1]

        query_execution_time = time.time()
        reward_mean.append(reward)

        if timesteps % non_sta_setting == 0:
            non_sta_check = False
            while not non_sta_check:
                next_change = random.choice(non_sta_action)
                non_sta_check = env.reset_object_pos(next_change[0], next_change[1])

        if reward != len(executions):
            valid_timesteps += 1

        action_history_full.append(final_actions[0][0])
        earn_reward += info['earn_reward']
        aoe.append(np.mean((timesteps + 1) - np.clip(np.array(temp_info_timesteps), a_min=-1, a_max=200)))
        print("========================================================")
        print("Temp. action: ", final_actions[0][0])
        print("Timestep: {} | Goal Success: {} | Path length: {}".format(timesteps + 1, earn_reward, valid_timesteps / max(earn_reward, 1)))
        print("Avg. age of information = {}".format(np.mean(aoe)))
        print("Temp reward: {} | Mean reward: {}".format(reward, np.mean(reward_mean)))
        print("Make prompts time:%.5f"%(make_prompt_time - start_time))
        print("Prediction time:%.5f"%(prediction_time - make_prompt_time))
        print("Calculate information gain time:%.5f"%(information_cal_time - prediction_time))
        print("Env step and update time:%.5f"%(step_and_update_time - information_cal_time))
        print("Query execution time time:%.5f"%(query_execution_time - step_and_update_time))
        print("========================================================")
        print("")