import pandas as pd 
import argparse
from tqdm import tqdm
import glob
import os 
from collections import defaultdict
import sys
sys.path.append('.')
from source.language_models.LLM import _LLM
from source.language_models.prompt_templates import prompt_abstraction_feature
from source.abstraction.generate_goal import generate_goals
import random

random.seed(0)

output_folder = "./data/datasets/"

def parse_feature_selection(answer, obs):
    features_dict_tmp = defaultdict(list)
    
    answer_feature = answer.split("{")[1].split("}")[0]
    
    if "\n" in answer:
        answer_feature = answer_feature.split("\n")
    else:
        answer_feature = answer_feature.split(', "')
        answer_feature[1] = '"' + answer_feature[1]
    
    goal_location = []
    keys = ['"goal object','"goal location']
    for f in answer_feature:
        for k in keys:
            if k in f and any(char.isdigit() for char in f) and "(" in f:
                f = f.replace(" or","")
                if ")," in f:
                    f= (f.split("),")[0] + ")")
                coords = f.split(":")[1].strip().split("#")[0].strip().replace("[","").replace("]","").replace("),",")").split(" ")
                coords = [x + ")" if ")" not in x else x for x in coords]
                coords = ["(" + x if "(" not in x else x for x in coords]
                features_dict_tmp[k] = coords
                goal_location_str = f.split(': ')[1].strip().replace("[","").replace("]","").split(" ")
                goal_location_str = [(x.replace("(","").replace(")","").split(",")[0].strip(), x.replace("(","").replace(")","").split(",")[1].strip()) for x in goal_location_str if "," in x ]
                goal_location += [(int(x[0]),int(x[1])) for x in goal_location_str if (x[0].isdigit() and x[1].isdigit())]
                break
    
    features_dict = {}
    for feature in obs.split("\n"):
        if ("box" in feature) or ("chest" in feature) or ("key" in feature) or ("ball" in feature) or ("door" in feature):
            object = feature.split("is on tile ")[0].split("A")[1].replace("open","").replace("closed","").strip()
            coord = feature.split("is on tile ")[1]
            for k in keys:
                if coord in features_dict_tmp[k]:
                    if object not in features_dict:
                        features_dict[object] = k
                    break
    print("feature dict temps : ", features_dict_tmp)
    return features_dict,  goal_location
            




def make_abstraction_function(answer, grid):
    features_dict, goal_location  = parse_feature_selection(answer[0], grid)
    features_dict = { k:v.replace('"',"") for k,v in features_dict.items()}
    
    def abstraction_function(obs):

        agent_pos = [int(x.strip()) for x in obs.split("The agent is currently at the following tile: (")[1].split(")")[0].split(",")]
        state = []
        obstacles = []
        
        updated_features = obs.split("\n")
        
        for feature in updated_features:
            if "wall" in feature:
                state.append(feature.split(". ")[1])
            elif "Inventory" in feature:
                object = feature.split("[")[1].split("]")[0].replace("a ","").replace("the ","").replace("open","").replace("close","").strip()
                if object == "":
                    abstract_object = ""
                else:
                    if (object not in features_dict) or ("goal" not in features_dict[object]) :
                        if "key" in object:
                            color = object.split(" ")[0]
                            if f"locked {color} door" in obs:
                                abstract_object = object
                            else:
                                abstract_object = "obstacle"
                        else:
                            abstract_object = "obstacle"
                    else:
                        abstract_object = features_dict[object]
                    
                state.append(f"Inventory: [{abstract_object}]")
            elif ("box" in feature) or ("chest" in feature) or ("ball" in feature):
                coord = feature.split("is on tile ")[1]
                object = feature.split("is on tile ")[0].split("A")[1].strip()
                
                if (object.replace("open","").replace("close","").strip() not in features_dict) or ("goal" not in features_dict[object.replace("open","").replace("close","").strip()]):
                    abstract_object = "obstacle"
                else:
                    abstract_object = features_dict[object.replace("open","").replace("close","").strip()]
                if "obstacle" in abstract_object:
                    obstacles.append(coord)
                else:
                    state.append(feature.replace(object, abstract_object).split(". ")[1])
                    
            elif  ("key" in feature):
                coord = feature.split("is on tile ")[1]
                object = feature.split("is on tile ")[0].split("A")[1].strip()
                
                
                if (object not in features_dict) or ("goal" not in features_dict[object]) :
                    color = object.split(" ")[0]
                    if f"locked {color} door" in obs:
                        abstract_object = object
                    else:
                        abstract_object = "obstacle"
                else:
                    abstract_object = features_dict[object]
                if "obstacle" in abstract_object:
                    obstacles.append(coord)
                else:
                    state.append(feature.replace(object, abstract_object).split(". ")[1])
                
            elif ("door" in feature):
                object = feature.split("is on tile ")[0].split("A")[1].strip()
                abstract_object = []
                
                
                if "close" in object:
                    abstract_object.append("closed door")
                elif "lock" in object:
                    abstract_object.append(feature.split("is on tile ")[0].split("A")[1].strip())
                
                
                if (object.replace("closed","").replace("open","").strip() in features_dict) and ("goal" in features_dict[object.replace("closed","").replace("open","").strip()]) :
                    abstract_object.append(features_dict[object.replace("closed","").replace("open","").strip()])                          
                    
                for a_o in abstract_object:
                    state.append(feature.replace(object, a_o).split(". ")[1])
                    
                    
                    
                    
            elif ("agent" in feature):
                state.append(feature.split(". ")[1])
        
        state.insert(1, f"""The following tiles are obstacle: {" ".join(obstacles)}""")
        state.sort()
        state = "\n".join(state)
        

        def local_function(x,y):
            
            room1 = (1,1,7,7)
            room2 = (1,7,7,14)
            room3 = (1,14,7,20)
            room4 = (7,1,14,7)
            room5 = (7,7,14,14)
            room6 = (7,14,14,20)
            room7 = (14,1,20,7)
            room8 = (14,7,20,14)
            room9 = (14,14,20,20)
            
            rooms = [room1, room2, room3, room4, room5, room6, room7, room8, room9]

            goal_location_rooms = []
            for location in goal_location:
                for i,room in enumerate(rooms):
                    if (room[0] <= location[0] <= room[2]) and (room[1] <= location[1] <= room[3]):

                        goal_location_rooms.append(i)
                        break
                    
            if len(set(goal_location_rooms)) == 1:
                if (rooms[goal_location_rooms[0]][0] <= x <= rooms[goal_location_rooms[0]][2]) and (rooms[goal_location_rooms[0]][1] <= y <= rooms[goal_location_rooms[0]][3]):
                    return True
            return False
            
        use_local = local_function(*agent_pos) 
        if use_local:
            filtered_state = []
            for feature in state.split("\n"):
                if "Inventory" in feature:
                    filtered_state.append(feature)
                elif "obstacle" in feature:
                    filtered_coords = []
                    list_coords = feature.split(": ")[1].split(" ")
                    list_coords = [(int(x.strip().replace("(","").replace(")","").split(",")[0]),int(x.strip().replace("(","").replace(")","").split(",")[1])) for x in list_coords if x != ""]
                    for coord in list_coords:

                        if local_function(*coord):
                            filtered_coords.append(coord)

                    if filtered_coords != []:
                        filtered_coords = " ".join([f"({x},{y})" for (x,y) in filtered_coords])
                        filterd_obstacle = f"The following tiles are obstacle: {filtered_coords}"
                        filtered_state.append(filterd_obstacle)
                elif "goal" in feature:
                    filtered_state.append(feature)
                elif ("box" in feature) or ("chest" in feature) or ("key" in feature) or ("ball" in feature):
                    coord = feature.split("is on tile ")[1]
                    coord = (int(coord.strip().replace("(","").replace(")","").split(",")[0]),int(coord.strip().replace("(","").replace(")","").split(",")[1]))
                    if local_function(*coord):
                        filtered_state.append(feature)
                elif "agent" in feature:
                    filtered_state.append(feature)
            filtered_state.sort()
            filtered_state = "\n".join(filtered_state)
        else:
            filtered_state = state        
        
        return filtered_state

    return abstraction_function



def build_abstraction_function(llm, grid, list_goal):
    
    dict_abstraction_function = {}
    
    list_prompt_feature = [prompt_abstraction_feature(grid, goal) for goal in list_goal]
    answers = llm.get_action(list_prompt_feature, temperature=0.0, top_p=0.95, top_k=1, max_new_tokens=8000)
     
    for i,goal in enumerate(list_goal):
        
        dict_abstraction_function[goal] = make_abstraction_function(answers[i], grid)

    return dict_abstraction_function



class LLM_abstraction():
    def __init__(self, model_name="meta-llama/Meta-Llama-3-70B-Instruct"):
        self.llm = _LLM(model_name=model_name)
        self.abstraction_function_dict = defaultdict(dict)
    
    def generate_new_abstraction_function(self, grid, env_id, list_goal):
        self.abstraction_function_dict[env_id] = build_abstraction_function(self.llm, grid, list_goal)
    
    def abstract(self, obs, env_id, goal):      

        state = self.abstraction_function_dict[env_id][goal](obs)
        return state



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Filter the observation")
    parser.add_argument("--name", type=str, default="BabyAI", help="The name of the dataset")
    parser.add_argument("--goals", type=str, nargs="+" ,default=[])
    parser.add_argument("--fraction", type=int, default=100, help="The name of the dataset")
    parser.add_argument("--output_name", type=str, default="", help="The name of the dataset")
    args = parser.parse_args()
            
    goal_df = defaultdict(pd.DataFrame)
    
    path_df= glob.glob(f"{output_folder}{args.name}/goal/*/collected_data.csv")
    
    print(f"{output_folder}{args.name}/goal/*/collected_data.csv")
    env_list = []
    
    if args.output_name == "":
        args.output_name = args.name
    
    
    for path in tqdm(path_df, desc="Loading data"):
        df = pd.read_csv(path)
        env_list += df["env_id"].unique().tolist()

    del df

    env_list = list(set(env_list))
    
    llm = _LLM(model_name="meta-llama/Meta-Llama-3-70B-Instruct")

    
    for i,env in tqdm(enumerate(env_list), desc="Building abstraction", total=len(env_list)):

        df = pd.DataFrame()
        
        for path in tqdm(path_df, desc="Loading data"):
            new_df = pd.read_csv(path)
            df = pd.concat([df, new_df[new_df["env_id"] == env]])   
        
        df = df.drop(df[df.obs == "obs"].index)
        df = df[:int(len(df)*args.fraction/100)]
        
        if len(df) > 0:
            list_goals = generate_goals(df["obs"].to_list()[0])
            
            if args.goals != []:
                list_goals = [goal for goal in list_goals if goal in args.goals]
            
                
            if list_goals != []:
            
                abstraction_function_dict = build_abstraction_function(llm, df["obs"].to_list()[0], list_goals)
             
            
            for goal in tqdm(list_goals, desc="Applying abstraction"):
                df_copy = df.copy()
                abstraction_function = abstraction_function_dict[goal]
                
                memory = {}
                
                new_obs_list = []
                for obs in df_copy["obs"].to_list():
                    if obs not in memory:
                        new_obs = abstraction_function(obs)
                        memory[obs] = new_obs
                    else:
                        new_obs = memory[obs]
                    new_obs_list.append(new_obs)

                        
                df_copy["obs"] = new_obs_list
                
                new_obs_list = []
                for obs in df_copy["next_obs"].to_list():
                    if obs not in memory:
                        new_obs = abstraction_function(obs)
                        memory[obs] = new_obs
                    else:
                        new_obs = memory[obs]
                    new_obs_list.append(new_obs)
                    
                df_copy["next_obs"] = new_obs_list
                
                
                goal_df[goal] = pd.concat([goal_df[goal], df_copy])
            

    for goal in tqdm(goal_df, desc="Saving data"):
        path = output_folder+f"{args.output_name}/env/0/{goal}"
        os.makedirs(path, exist_ok=True)
        if os.path.exists(path + "/collected_data_abstract.csv"):
            goal_df[goal].to_csv(path + "/collected_data_abstract.csv", mode='a', header=False, index=False)
        else:
            print(path + "/collected_data_abstract.csv")

            goal_df[goal].to_csv(path + "/collected_data_abstract.csv",index=False)
          
   