import sys 
sys.path.append(".")
from source.language_models.LLM import _LLM
from source.collection.collect_data_minigrid import IDX_TO_COLOR, IDX_TO_OBJECT, IDX_TO_STATE, IDX_TO_DIRECTION, obs_to_state
import gymnasium as gym
from minigrid.wrappers import FullyObsWrapper
import argparse
import numpy as np
from source.language_models.prompt_templates import prompt_sequence_minigrid, prompt_sequence_ICL_minigrid
from tqdm import tqdm 
from collections import defaultdict
import pandas as pd 
import os 
from multiprocessing import Pool
import re
import random
import glob
from collections import defaultdict
import pickle
import dill
import subprocess
from data.generate_alter_grammar import generate_alternative_goal

data_folder = "./data/datasets/"

dict_abstract = {}

def obs_to_prompt(envs, goal, model_name, seed=None, env_id=None):

    prompts = []
    if seed is not None:
        seeds = [(env_id,seed)]
    else:
        seeds = envs.keys()
    for (env_id,seed) in seeds:
        
        if goal == "undefined":
            goal = envs[(env_id,seed)]["goal"]
        
        goal = goal.lower() 
        
        env = envs[(env_id,seed)]
        obs = env["obs"]
        inv = env["env"].unwrapped.carrying
        obs = obs_to_state(obs, inv)
        state = dict_abstract["LLM"](obs, env_id, envs[(env_id,seed)]["initial_goal"])
        width = env["obs"]["image"].shape[0]
        height = env["obs"]["image"].shape[1]
        
        if "ICL" in model_name:
            prompt = prompt_sequence_ICL_minigrid((width, height), state, goal=goal, nb_goal=int(model_name.split("_")[-1]))
        elif "Llama" in model_name:
            prompt = prompt_sequence_minigrid((width, height), state, goal=goal)
            prompt += "Sequence of actions : ["

        else:
            prompt = prompt_sequence_minigrid((width, height), state, goal=goal)
        
        prompts.append(prompt)
    return prompts


def generate_sequence(envs, model, seed=None, env_id=None, goal=None, abstraction="manual"):
    

    prompts = obs_to_prompt(envs, goal,  model.name, seed, env_id)
    if "Llama" in model.name:
        actions = model.get_action(prompts,  temperature = 0.9, top_p=0.95, top_k=10, max_new_tokens=200)
        actions = [[action[0].split("]")[0]] if ("]" in action[0])  else [action[0].replace("]","")] for action in actions]
        if "ICL" in model.name:
            actions = [[action[0].split("[")[1]] if ("[" in action[0])  else [action[0].replace("[","").replace("]","")] for action in actions]
        
        actions = [[action[0].split("\n")[0]] for action in actions]
        actions = [[re.sub("[^0-9]", "",action[0])] for action in actions]
      

    else:
        actions = model.get_action(prompts,  temperature =0.0, top_p=0.1, top_k=1, max_new_tokens=10)
    
    
    actions = [action[0].replace(",","").replace(" ", "").replace("[","").replace("]","").strip() for action in actions]
    
    
    if seed is not None:
        envs[(env_id,seed)]["sequence"] += actions[0]
    else:
        for i,(env_id,seed) in enumerate(envs):
        
            envs[(env_id,seed)]["sequence"] += actions[i]
    

def process_sequence_actions(envs,model, t, goal, abstraction):
    for (env_id,seed) in envs:
        if not envs[(env_id,seed)]["done"]:
            while (t >= len(envs[(env_id,seed)]["sequence"])):
                generate_sequence(envs, model, seed, env_id, goal, abstraction)
            action = str(envs[(env_id,seed)]["sequence"][t])
            if "0" in action:
                action = 0
            elif "1" in action:
                action = 1
            elif "2" in action:
                action = 2
            elif "3" in action:
                action  = 3
            elif "4" in action:
                action = 4
            elif "5" in action:
                action = 5
            elif "6" in action:
                action = 6
            else:
                action = np.random.randint(0,4)
            envs[(env_id,seed)]["action"] = int(action)


def envs_step(envs, step_id):
    for (env_id,seed) in envs:
        env = envs[(env_id,seed)]
        if not env["done"]:
            obs, reward, done, _, _ = envs[(env_id,seed)]["env"].step(env["action"])
            if np.array_equal(obs["image"],envs[(env_id,seed)]["obs"]["image"]) and (obs["direction"] == envs[(env_id,seed)]["obs"]["direction"]):
                envs[(env_id,seed)]["out of distrib"] += 1
                
            envs[(env_id,seed)]["cum_reward"] += reward
            if done:
                envs[(env_id,seed)]["done"] = True
                envs[(env_id,seed)]["length"] = step_id
            envs[(env_id,seed)]["obs"] = obs

def all_envs_done(envs):
    for (env_id,seed) in envs:
        if not envs[(env_id,seed)]["done"]:
                return False
    return True


        
            

def online_evaluation_minigrid(envs, model, goal, num_steps=240, n_seed=1, model_name="Llama8", alternative_goal=False, env_id=0, abstraction="manual"):
    """
    Collect data from the minigrid environment
    :param env: The name of the dataset
    :param model: The name of the model
    :param num_steps: The number of steps to take
    :param n_seed: The number of seed to use
    :param output_file: The name of the output file
    """
    
    for (env_id,seed) in envs:
        envs[(env_id,seed)]["cum_reward"] = 0
        envs[(env_id,seed)]["broken"] = False
        envs[(env_id,seed)]["done"] = False
        envs[(env_id,seed)]["out of distrib"] = 0
        envs[(env_id,seed)]["action"] = None      
        envs[(env_id,seed)]["sequence"] = []

        if "_" in env_id:
            env_id_int = int(env_id.split("_")[1])
        else:
            env_id_int = int(env_id)
        
        obs, _ = envs[(env_id,seed)]["env"].reset(seed=env_id_int)
        width = obs["image"].shape[0]
        height = obs["image"].shape[1]
 
        pos = random.randint(1,width-1), random.randint(1,height-1)
        while envs[(env_id,seed)]["env"].grid.get(*pos) is not None:
            pos = random.randint(1,width-1), random.randint(1,height-1)
            
        envs[(env_id,seed)]["env"].place_agent_noRandom(pos[0],pos[1])
        envs[(env_id,seed)]["initial_pos"] = pos
        envs[(env_id,seed)]["goal"] = goal
        if goal == "undefined":
            for i in range(seed):
                try:
                    envs[(env_id,seed)]["env"].reroll_mission()
                except:
                    envs[(env_id,seed)]["env"].place_agent()

            envs[(env_id,seed)]["goal"] = envs[(env_id,seed)]["env"].mission.replace("the","a")
            if "Qbot" in model_name or "BCbot" in model_name:
                model.add_goal(envs[(env_id,seed)]["goal"])
                
        envs[(env_id,seed)]["initial_pos"] = envs[(env_id,seed)]["env"].agent_pos


        envs[(env_id,seed)]["initial_goal"] = envs[(env_id,seed)]["goal"]

        if alternative_goal:
            envs[(env_id,seed)]["goal"] = np.random.choice(generate_alternative_goal(envs[(env_id,seed)]["goal"]))
        
      

        envs[(env_id,seed)]["obs"] = envs[(env_id,seed)]["env"].observation(envs[(env_id,seed)]["env"].gen_obs())   
       

    dict_env_id = defaultdict(lambda : {"obs": None, "goal": []})
    for (env_id,seed) in envs:
        obs = envs[(env_id,seed)]["obs"]
        inv = envs[(env_id,seed)]["env"].unwrapped.carrying
        obs = obs_to_state(obs, inv)
        dict_env_id[env_id]["obs"] = obs
        dict_env_id[env_id]["goal"].append(envs[(env_id,seed)]["initial_goal"])
    
    os.makedirs("./cache", exist_ok=True)
    
    with open("./cache/dict_env_id.pkl", "wb") as f:
        pickle.dump(dict(dict_env_id), f)
    
    subprocess.run(["python","source/abstraction/abstraction_process.py"])
    
    with open('./cache/dict_abstract.pkl', 'rb') as handle:
        abstraction_builder = dill.load(handle)
        
    dict_abstract["LLM"] = lambda obs,env_id, goal : abstraction_builder.abstract(obs, env_id, goal)
        
    
    if "Llama8" in args.model_name:
        model = _LLM(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
    elif "Llama70" in args.model_name:
        model = _LLM(model_name="meta-llama/Meta-Llama-3-70B-Instruct")
    else:
        model = _LLM(model_name="FT_model",weight_path="./results/models/" +model_name, GPUs=[0,1])
        
    model.name = args.model_name
    
    generate_sequence(envs, model,  goal=goal, abstraction=abstraction)

    for i in tqdm(range(num_steps)):
        process_sequence_actions(envs,model, i, goal, abstraction=abstraction)             
        envs_step(envs, i)
        if all_envs_done(envs):
            break
        
    for (env_id,seed) in envs:
        envs[(env_id,seed)]["env"].close()
        
    results = defaultdict(dict)
    for (env_id,seed) in envs:
        if not envs[(env_id,seed)]["broken"]:
            results[(env_id,seed)]["reward"] = envs[(env_id,seed)]["cum_reward"]
            if envs[(env_id,seed)]["done"]:
                results[(env_id,seed)]["length_episode"] = envs[(env_id,seed)]["length"]
            else:
                results[(env_id,seed)]["length_episode"] = num_steps
            results[(env_id,seed)]["goal"] = envs[(env_id,seed)]["goal"]
            results[(env_id,seed)]["out of distrib"] = envs[(env_id,seed)]["out of distrib"]  

    return results

def eval_per_goal(args, model, goal):
        
        if args.env == "":
            envs ={(env_id,seed):{"env":FullyObsWrapper(gym.make(env_id.split("_")[0], render_mode="rgb_array", max_steps=args.num_steps))} for env_id in args.env_id for seed in range(args.n_seed)}
        else:
            envs ={(env_id,seed):{"env":FullyObsWrapper(gym.make(args.env, render_mode="rgb_array", max_steps=args.num_steps))} for env_id in args.env_id for seed in range(args.n_seed)}
        
        results = online_evaluation_minigrid(envs, model, goal, args.num_steps, args.n_seed, model_name=args.model_name, alternative_goal=args.alternative_goal, abstraction="LLM")
        
        df = pd.DataFrame()
        
        os.makedirs("results/online_evaluation", exist_ok=True)
        
        for ref_env_id in args.env_id:
            
            if goal == "undefined":
                env_seeds = [(env_id,seed) for (env_id,seed) in results if env_id==ref_env_id]
                for (env_id,seed) in env_seeds:
                    reward = results[(env_id,seed)]["reward"]
                    length_episode = results[(env_id,seed)]["length_episode"]
                    out_of_distrib = results[(env_id,seed)]["out of distrib"]
                    initial_pos = envs[(env_id,seed)]["initial_pos"]
                    initial_goal = envs[(env_id,seed)]["initial_goal"]
                    out_of_distrib_ratio = out_of_distrib/max(1,length_episode)
                    goal_result = results[(env_id,seed)]["goal"]
                    print("----------------------------------------------------")
                    print(goal)
                    print(ref_env_id)
                    print(f"The cumulative reward is {reward}")
                    print(f"The length of the episode is {length_episode}")
                    print(f"The success rate is {1 if reward > 0 else 0}")
                    print(f"The out of distribution ratio is {out_of_distrib_ratio}")
                    df = pd.DataFrame({ "env_id": ref_env_id ,"goal" : goal_result, "cumulative reward" :reward,"episode_length" : length_episode,"success rate":1 if reward > 0 else 0, "ood ratio": out_of_distrib_ratio, "size": 1, "initial_pos": f"{str(initial_pos[0])}_{str(initial_pos[1])}", "initial_goal":initial_goal}, index=[0])
                    path_result = f"results/online_evaluation/{args.output_file}_{args.model_name}_{args.num_steps}_{args.n_seed}.csv"
                    if not os.path.exists(path_result):
                        df.to_csv(path_result, mode="a+")
                    else:
                        df.to_csv(path_result, mode="a+", header=False)
            else:
            
                list_cumulative_reward = [results[(env_id,seed)]["reward"] for (env_id,seed) in results if env_id==ref_env_id]
                list_length_episode = [results[(env_id,seed)]["length_episode"] for (env_id,seed) in results if env_id==ref_env_id]
                out_of_distrib = [results[(env_id,seed)]["out of distrib"] for (env_id,seed) in results if env_id==ref_env_id]
                out_of_distrib_ratio = [out_of_distrib[i]/max(1,list_length_episode[i]) for i in range(len(out_of_distrib))]
                goal_result = goal
                
                
                print(goal)
                print(ref_env_id)
                print(f"The cumulative reward is {np.mean(list_cumulative_reward)}")
                print(f"The length of the episode is {np.mean(list_length_episode)}")
                print(f"The success rate is {np.sum([1 for x in list_cumulative_reward if x > 0]) / len(list_cumulative_reward) * 100}")
                print(f"The out of distribution ratio is {np.mean(out_of_distrib_ratio)}")
                df = pd.DataFrame({ "env_id": ref_env_id ,"goal" : goal_result, "cumulative reward" :np.mean(list_cumulative_reward),"episode_length" : np.mean(list_length_episode),"success rate":np.sum([1 for x in list_cumulative_reward if x > 0]) / max(1,len(list_cumulative_reward)) * 100, "ood ratio": np.mean(out_of_distrib_ratio), "cumulative reward std" :np.std(list_cumulative_reward), "episode_length std" : np.std(list_length_episode), "ood ratio std": np.std(out_of_distrib_ratio), "size": len(list_length_episode)}, index=[0])
                
                print("path result : ", f"results/online_evaluation/{args.output_file}_{args.model_name}_{args.num_steps}_{args.n_seed}.csv")
                path_result = f"results/online_evaluation/{args.output_file}_{args.model_name}_{args.num_steps}_{args.n_seed}.csv"
                if not os.path.exists(path_result):
                    df.to_csv(path_result, mode="a+")
                else:
                    df.to_csv(path_result, mode="a+", header=False)


if __name__ == "__main__":
    
    os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

    random.seed(0)
    np.random.seed(0)
    parser = argparse.ArgumentParser(description="Collect data from the minigrid environment")
    parser.add_argument("--dataset", type=str, default="BabyAI", help="The name of the dataset")
    parser.add_argument("--env", type=str, default="", help="The name of the dataset")
    parser.add_argument("--num_steps", type=int, default=500,help="The number of steps to take")
    parser.add_argument("--n_seed", type=int, default=1, help="The number of seed to use")
    parser.add_argument("--output_file", type=str, default="", help="The name of the output file")
    parser.add_argument("--model_name", type=str, default="Llama8", help="The name of the model")
    parser.add_argument("--obs", type=bool, default=False, help="")
    parser.add_argument("--goals", nargs="+", type=str, default=["undefined"])
    parser.add_argument("--env_id", type=str, nargs="+", default=[])
    parser.add_argument("--alternative_goal", type=bool, default=False)
    args = parser.parse_args()
    
    path_folder = f"{data_folder}/"

    if "Llama8" in args.model_name:
        model = None
        parallele = False
    elif "Llama70" in args.model_name:
        model = None
        parallele = False
    else:
        model = None
        parallele = False
        
    parallele = False
        

    if args.goals == []:
        goal_per_env = True
        list_goals_per_env = defaultdict(list)
        for env_id in args.env_id:
            goals = [ x.split("/")[-2] for x in glob.glob(f"{path_folder}{args.name.replace('_abstract','')}/env/{env_id}/*/")]
    
            for goal in goals:
                list_goals_per_env[goal].append(env_id)
        args.goals = list(list_goals_per_env.keys())
    else:
        goal_per_env = False


    if parallele:
        n_job = 40
        def run_process(goal):
            if goal_per_env:
                args.env_id = list_goals_per_env[goal]
            eval_per_goal(args, model, goal)
        
        pool = Pool(n_job)
        pool.map(run_process, args.goals)
        
    else:
        for goal in tqdm(args.goals, leave=False, desc="goals"):
            if goal_per_env:
                args.env_id = list_goals_per_env[goal]
            eval_per_goal(args,model,goal)