import argparse
import sys 
sys.path.append(".")
from source.language_models.LLM import _LLM
from source.language_models.prompt_templates import prompt_goal_detection_minigrid, remove_abstract
import pandas as pd
import glob
import itertools
import random
from tqdm import tqdm
import os 
import glob 
from multiprocessing import Pool
from itertools import groupby


data_folder = "./data/datasets/"


def filter_state(state, goal):    
    keep_feature = []
    
    features = state.split("\n")
    for f in features:
        if "goal" in f or "agent is currently at" in f or "Inventory" in f or ("open" in goal and "closed" in f) :
            keep_feature.append(f)
    return "\n".join(keep_feature)




def generate_dataset_per_goal(args, data_folder, goal, env_id):
     
    path_folder = f"{data_folder}{args.name}/env/{env_id}/{goal}"
   
    obs_df = pd.read_csv(f"{path_folder}/collected_data_abstract.csv")
    
    
    obs_df["bug abstract"] = obs_df["obs"].str.contains("A open")
    obs_df = obs_df[obs_df["bug abstract"] == False].drop(columns=["bug abstract"])
    

    obs = list(set(obs_df["next_obs"].drop_duplicates().to_list()) | set(obs_df["obs"].drop_duplicates().to_list()))
   
    filtered_obs = [(filter_state(x, goal),i) for i,x in enumerate(obs)]
    
    filtered_obs.sort(key=lambda x:x[0])
    
    filtered_obs = [next(g) for _, g in groupby(filtered_obs, key=lambda x:x[0])]
    
    random.shuffle(filtered_obs)
    
    filtered_obs = filtered_obs[:args.size] 
    
    obs = [obs[x[1]] for x in filtered_obs]
    filtered_obs = [x[0] for x in filtered_obs]
    filtered_obs = [remove_abstract(x,goal) for x in filtered_obs]
    
    list_prompt = [prompt_goal_detection_minigrid(args.grid_size,x,goal) for x in filtered_obs] 
                        
    list_answer = llm.get_action(list_prompt, use_tqdm=True, temperature=0.0, top_p=0.95,top_k=1,max_new_tokens=800)  
    
    dataset = []

    for i,answer in enumerate(list_answer):
 
        try:
            answer = answer[0].lower().split('goal achieved":')[1].split(")")[0].strip()
        except Exception as e:
            answer = "None"

        if "true" in answer:
            dataset.append({"state" :filtered_obs[i],"goal achieved":1})
            
        elif "false" in answer:
            dataset.append({"state" :filtered_obs[i],"goal achieved":0})       
            
        else:
            pass
        
    pd.DataFrame(dataset).to_csv(f"{path_folder}/goal_detection_dataset.csv")
   

def run_process(p):
    (args, data_folder, goal, env_id) = p
    generate_dataset_per_goal(args, data_folder, goal, env_id)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the experiment")
    parser.add_argument("--name", type=str, default="minigrid_go_to_few_env", help="The name of the dataset")
    parser.add_argument("--size", type=int, default=1e15, help="The name of the dataset")
    parser.add_argument("--seed", type=int, default=0, help="The name of the dataset")
    parser.add_argument("--shuffle", type=bool, default=True, help="The name of the dataset")
    parser.add_argument("--env_id", nargs = "+", type=str, default=[], help="The name of the dataset")
    parser.add_argument("--grid_size", type=int, nargs="+", default=[22,22])
    
    args = parser.parse_args()
    random.seed(args.seed)
    args.size = int(args.size)
    
    
    llm = _LLM(model_name="meta-llama/Meta-Llama-3-70B-Instruct")
    
    if args.env_id == []:
        list_env_id = [ x.split("/")[-2] for x in glob.glob(f"{data_folder}{args.name}/env/*/")]
    else:
        list_env_id = args.env_id
    
    
    
    
    for env_id in tqdm(list_env_id):    
        goals = [x.split("/")[-2] for x in glob.glob(f"{data_folder}{args.name}/env/{env_id}/*/") if "csv" not in x and "goal_detection_dataset.csv" not in os.listdir(x) ] 
    
        zip_goals = list(zip(itertools.repeat(args),itertools.repeat(data_folder),goals,itertools.repeat(env_id)))

        for (arg, data_f, goal, env) in tqdm(zip_goals, total=len(goals)):
            generate_dataset_per_goal(arg, data_f, goal, env)