
from functools import reduce
import gymnasium as gym
import numpy as np
from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX, DIR_TO_VEC
import os 
import gymnasium as gym
import pytest
from minigrid.wrappers import SymbolicObsWrapper, FullyObsWrapper
import argparse
from minigrid.utils.baby_ai_bot import BabyAIBot
from collections import defaultdict
import pandas as pd 
from tqdm import tqdm
from minigrid.envs.babyai.core.verifier import GoToInstr, ObjDesc
from multiprocessing import Pool
import sys
sys.path.append(".")
from source.reward.building_goal_detection_dataset import remove_abstract



IDX_TO_COLOR = {v: k for k, v in COLOR_TO_IDX.items()}
IDX_TO_OBJECT = {v: k for k, v in OBJECT_TO_IDX.items()}
IDX_TO_STATE = {v: k for k, v in STATE_TO_IDX.items()}

IDX_TO_DIRECTION = {0: "right", 1: "down", 2: "left", 3: "up"}


DIR_TO_IDX = {v: k for k, v in IDX_TO_DIRECTION.items()}


modified_COLOR_TO_IDX = COLOR_TO_IDX.copy()

modified_COLOR_TO_IDX["undefined"] = 6
    
modified_OBJECT_TO_IDX = OBJECT_TO_IDX.copy()


def state_to_obs(state, grid_size):
    
    
    if "goal" in state:
        print(state)
        raise ValueError("Goal in state")
    
    
    state = state.split("\n")
    grid = np.zeros((grid_size[0], grid_size[1], 3))
    grid[:,0,0] = OBJECT_TO_IDX["wall"]
    grid[:,0,1] = COLOR_TO_IDX["grey"]
    grid[:,0,2] = STATE_TO_IDX["open"]
    grid[:,-1,0] = OBJECT_TO_IDX["wall"]
    grid[:,-1,1] = COLOR_TO_IDX["grey"]
    grid[:,-1,2] = STATE_TO_IDX["open"]
    grid[0,:,0] = OBJECT_TO_IDX["wall"]
    grid[0,:,1] = COLOR_TO_IDX["grey"]
    grid[0,:,2] = STATE_TO_IDX["open"]
    grid[-1,:,0] = OBJECT_TO_IDX["wall"]
    grid[-1,:,1] = COLOR_TO_IDX["grey"]
    grid[-1,:,2] = STATE_TO_IDX["open"]
        
    for feature in state:
        if "The follinwg tiles are wall" in feature:
            for tile in feature.split(": ")[1].split(" "):
                x,y = tile[1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX["wall"]
                grid[x,y,1] = COLOR_TO_IDX["grey"]
                grid[x,y,2] = STATE_TO_IDX["open"]
        elif "The following tiles are lava" in feature:
            for tile in feature.split(": ")[1].split(" "):
                x,y = tile[1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX["lava"]
                grid[x,y,1] = COLOR_TO_IDX["red"]
                grid[x,y,2] = STATE_TO_IDX["open"]
        elif "The following tiles are obstacle" in feature:
            for tile in feature.split(": ")[1].strip().split(" "):
                if tile != "":
                    try:
                        x,y = tile[1:-1].split(",")
                    except Exception as e:
                        print("tile : ", tile)
                        print(len(tile))
                        raise e
                    x,y = int(x),int(y)
                    grid[x,y,0] = OBJECT_TO_IDX["wall"]
                    grid[x,y,1] = COLOR_TO_IDX["grey"]
                    grid[x,y,2] = STATE_TO_IDX["open"]
        elif "is on tile" in feature:
            object = feature.split(" is on tile ")[0].split(" ")[-1]
            if  object in ["chest","box","door"]:
                color = feature.split(" is on tile ")[0].split(" ")[-2]
                state = feature.split(" is on tile ")[0].split(" ")[-3]
                x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX[object]
                grid[x,y,1] = COLOR_TO_IDX[color]
                grid[x,y,2] = STATE_TO_IDX[state]
            elif object in ["key","ball"]:
                color = feature.split(" is on tile ")[0].split(" ")[-2]
                x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX[object]
                grid[x,y,1] = COLOR_TO_IDX[color]
                grid[x,y,2] = STATE_TO_IDX["open"]
        elif "Inventory" in feature:
            if "[]" not in feature:
                object = feature.split("Inventory : [")[1].split(" ")[2].replace("]","")
                color = feature.split("Inventory : [")[1].split(" ")[1]
                inventory = [OBJECT_TO_IDX[object], COLOR_TO_IDX[color],STATE_TO_IDX["open"]]
            else:
                inventory = [0,0,0]
        elif "The agent is currently at the following tile" in feature:
            x,y = feature.split("The agent is currently at the following tile: (")[1][:-1].split(",")
            x,y = int(x),int(y)
            grid[x,y,0] = OBJECT_TO_IDX["agent"]
            grid[x,y,1] = COLOR_TO_IDX["red"]
            agent = (x,y)
        elif "The agent is facing" in feature:
            direction = feature.split("The agent is facing ")[1]
            try:
                grid[agent[0],agent[1],2] = DIR_TO_IDX[direction]
            except Exception as e :
                print(state)
                raise e
    return grid, inventory



def state_to_obs_spe(state, grid_size):
    # print(state)
    def characterise_object(desc):
        color = "undefined"
        state = "open"
        object = None
        for c in modified_COLOR_TO_IDX.keys():
            if c in desc:
                color = c
        for s in STATE_TO_IDX.keys():
            if s in desc:
                state = s 
        for o in modified_OBJECT_TO_IDX.keys():
            if o in desc:
                object = o
        if object is None:
            raise ValueError("Object not found")
        
        return object, color, state        
        
    
    if "goal" in state:
        raise ValueError("Goal in state")
    
    
    inventory = [0,0,0]
    state = state.split("\n")
    grid = np.zeros((grid_size[0], grid_size[1], 3))
    grid[:,0,0] = OBJECT_TO_IDX["wall"]
    grid[:,0,1] = COLOR_TO_IDX["grey"]
    grid[:,0,2] = STATE_TO_IDX["open"]
    grid[:,-1,0] = OBJECT_TO_IDX["wall"]
    grid[:,-1,1] = COLOR_TO_IDX["grey"]
    grid[:,-1,2] = STATE_TO_IDX["open"]
    grid[0,:,0] = OBJECT_TO_IDX["wall"]
    grid[0,:,1] = COLOR_TO_IDX["grey"]
    grid[0,:,2] = STATE_TO_IDX["open"]
    grid[-1,:,0] = OBJECT_TO_IDX["wall"]
    grid[-1,:,1] = COLOR_TO_IDX["grey"]
    grid[-1,:,2] = STATE_TO_IDX["open"]
        
    for feature in state:
        if "The follinwg tiles are wall" in feature:
            for tile in feature.split(": ")[1].split(" "):
                x,y = tile[1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX["wall"]
                grid[x,y,1] = COLOR_TO_IDX["grey"]
                grid[x,y,2] = STATE_TO_IDX["open"]
        elif "The following tiles are obstacle" in feature:
            for tile in feature.split(": ")[1].strip().split(" "):
                if tile != "":
                    try:
                        x,y = tile[1:-1].split(",")
                    except Exception as e:
                        print("tile : ", tile)
                        print(len(tile))
                        raise e
                    x,y = int(x),int(y)
                    grid[x,y,0] = OBJECT_TO_IDX["wall"]
                    grid[x,y,1] = COLOR_TO_IDX["grey"]
                    grid[x,y,2] = STATE_TO_IDX["open"]
        elif "is on tile" in feature:
            object = feature.split(" is on tile ")[0].split(" ")[-1]
            if  object in ["chest","box","door"]:
                
                _, color, state = characterise_object(feature.split(" is on tile ")[0])
                
                x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX[object]
                grid[x,y,1] = modified_COLOR_TO_IDX[color]
                grid[x,y,2] = STATE_TO_IDX[state]
            elif object in ["key","ball"]:
                
                _, color, state = characterise_object(feature.split(" is on tile ")[0])
                
                x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = OBJECT_TO_IDX[object]
                grid[x,y,1] = modified_COLOR_TO_IDX[color]
                grid[x,y,2] = STATE_TO_IDX["open"]
        elif "Inventory" in feature:
            if "[]" not in feature:
                if "obstacle" in feature:
                    inventory = [modified_OBJECT_TO_IDX["wall"], modified_COLOR_TO_IDX["grey"],STATE_TO_IDX["open"]]
                else:
                    object, color, state = characterise_object(feature.replace(" :",":").split("Inventory: [")[1].replace("]",""))
                    inventory = [OBJECT_TO_IDX[object], modified_COLOR_TO_IDX[color],STATE_TO_IDX["open"]]
            else:
                inventory = [0,0,0]
        elif "The agent is currently at the following tile" in feature:
            x,y = feature.split("The agent is currently at the following tile: (")[1][:-1].split(",")
            x,y = int(x),int(y)
            grid[x,y,0] = OBJECT_TO_IDX["agent"]
            grid[x,y,1] = COLOR_TO_IDX["red"]
            agent = (x,y)
        elif "The agent is facing" in feature:
            direction = feature.split("The agent is facing ")[1]
            try:
                grid[agent[0],agent[1],2] = DIR_TO_IDX[direction]
            except Exception as e :
                print(state)
                raise e
    return grid, inventory

def abs_state_to_obs(state, grid_size, goal):
    
    state = remove_abstract(state, goal)
    
    
    state = state.split("\n")
    grid = np.zeros((grid_size[0], grid_size[1], 3))
    grid[:,0,0] = modified_OBJECT_TO_IDX["wall"]
    grid[:,0,1] = modified_COLOR_TO_IDX["grey"]
    grid[:,0,2] = STATE_TO_IDX["open"]
    grid[:,-1,0] = modified_OBJECT_TO_IDX["wall"]
    grid[:,-1,1] = modified_COLOR_TO_IDX["grey"]
    grid[:,-1,2] = STATE_TO_IDX["open"]
    grid[0,:,0] = modified_OBJECT_TO_IDX["wall"]
    grid[0,:,1] = modified_COLOR_TO_IDX["grey"]
    grid[0,:,2] = STATE_TO_IDX["open"]
    grid[-1,:,0] = modified_OBJECT_TO_IDX["wall"]
    grid[-1,:,1] = modified_COLOR_TO_IDX["grey"]
    grid[-1,:,2] = STATE_TO_IDX["open"]
    
    try:
        inventory = [0,0,0]
        direction = None
        for feature in state:
            
            if "The following tiles are wall" in feature:
                for tile in feature.split(": ")[1].split(" "):
                    x,y = tile[1:-1].split(",")
                    x,y = int(x),int(y)
                    grid[x,y,0] = modified_OBJECT_TO_IDX["wall"]
                    grid[x,y,1] = modified_COLOR_TO_IDX["grey"]
                    grid[x,y,2] = STATE_TO_IDX["open"]
            elif "The following tiles are lava" in feature:
                for tile in feature.split(": ")[1].split(" "):
                    x,y = tile[1:-1].split(",")
                    x,y = int(x),int(y)
                    grid[x,y,0] = modified_OBJECT_TO_IDX["lava"]
                    grid[x,y,1] = modified_COLOR_TO_IDX["red"]
                    grid[x,y,2] = STATE_TO_IDX["open"]
            elif "The following tiles are obstacle" in feature:
                for tile in feature.split(": ")[1].strip().split(" "):
                    if tile != "":
                        try:
                            x,y = tile[1:-1].split(",")
                        except Exception as e:
                            print("tile : ", tile)
                            print(len(tile))
                            raise e
                        x,y = int(x),int(y)
                        grid[x,y,0] = modified_OBJECT_TO_IDX["wall"]
                        grid[x,y,1] = modified_COLOR_TO_IDX["grey"]
                        grid[x,y,2] = STATE_TO_IDX["open"]
            elif "is on tile" in feature:
                object = feature.split(" is on tile ")[0].split(" ")[-1]
                if "closed door" in feature:
                    state = "closed"
                    color = "undefined"
                    x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                    x,y = int(x),int(y)
                    grid[x,y,0] = modified_OBJECT_TO_IDX["door"]
                    grid[x,y,1] = modified_COLOR_TO_IDX[color]
                    grid[x,y,2] = STATE_TO_IDX[state]
                
                elif  object in ["chest","box","door"]:
                    f = feature.split(" is on tile ")[0].split(" ")
                    color = feature.split(" is on tile ")[0].split(" ")[-2]

                    if len(f) == 2:
                        state = "open"
                    else:
                        state = feature.split(" is on tile ")[0].split(" ")[-3]
                        
                    x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                    x,y = int(x),int(y)
                    if color in ["a","the"]:
                        color = "undefined"
                    if state in ["a","the"]:
                        state = "open"
                    grid[x,y,0] = modified_OBJECT_TO_IDX[object]
                    grid[x,y,1] = modified_COLOR_TO_IDX[color]
                    grid[x,y,2] = STATE_TO_IDX[state]
                    
                elif object in ["key","ball"]:
                    color = feature.split(" is on tile ")[0].split(" ")[-2]
                    x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                    x,y = int(x),int(y)
                    if color in ["a","the"]:
                        color = "undefined"
                    if state in ["a","the"]:
                        state = "open"
                    grid[x,y,0] = modified_OBJECT_TO_IDX[object]
                    grid[x,y,1] = modified_COLOR_TO_IDX[color]
                    grid[x,y,2] = STATE_TO_IDX["open"]
                
                elif object in ["goal"]:
                    type = feature.split(" is on tile ")[0].split(" ")[-1]
                    x,y = feature.split(" is on tile ")[1][1:-1].split(",")
                    x,y = int(x),int(y)
                    grid[x,y,0] = modified_OBJECT_TO_IDX[object]
                    grid[x,y,1] = modified_COLOR_TO_IDX[type]
                    grid[x,y,2] = STATE_TO_IDX["open"]
                    
            elif "Inventory" in feature:
                feature = feature.replace(" :",":")
                
                if "[]" not in feature:
                    if "goal" in feature:
                        object = feature.split("Inventory: [")[1].split(" ")[0]
                        type = feature.split("Inventory: [")[1].split(" ")[1].replace("]","")
                        inventory = [modified_OBJECT_TO_IDX[object], modified_COLOR_TO_IDX[type],STATE_TO_IDX["open"]]
                    elif "obstacle" in feature:
                        inventory = [modified_OBJECT_TO_IDX["wall"], modified_COLOR_TO_IDX["grey"],STATE_TO_IDX["open"]]

                    else:
                        feature = feature.replace("]","").strip().split("Inventory: [")[1].split(" ")
                        if len(feature) == 2:
                            object = feature[1]
                            color = "undefined"
                        elif len(feature) == 3:
                            object = feature[2]
                            color = feature[1]
                        else:
                            print(feature)
                            raise ValueError("Inventory not recognized")
                        inventory = [modified_OBJECT_TO_IDX[object], modified_COLOR_TO_IDX[color],STATE_TO_IDX["open"]]
                else:
                    inventory = [0,0,0]
                    
            elif "The agent is currently at the following tile" in feature:
                x,y = feature.split("The agent is currently at the following tile: (")[1][:-1].split(",")
                x,y = int(x),int(y)
                grid[x,y,0] = modified_OBJECT_TO_IDX["agent"]
                grid[x,y,1] = modified_COLOR_TO_IDX["red"]
                agent = (x,y)
            elif "The agent is facing" in feature:
                direction = feature.split("The agent is facing ")[1]
                try:
                    grid[agent[0],agent[1],2] = DIR_TO_IDX[direction]
                except Exception as e :
                    print(state)
                    raise e
    except Exception as e:
        print("goal : ", goal)
        print("feature : ", feature)
        raise e
    try:
        
        if direction is None:
            return grid, inventory, direction 
        else:
            return grid, inventory, DIR_TO_IDX[direction]
    except Exception as e:
        print(state)
        raise e





def obs_to_state(obs, inv=None):
    interesting_objects = {"chest":[],"door":[],"key":[],"ball":[],"box":[],"lava":[],"wall":[],"agent":[],"obstacle":[]}
    for x in range(1,obs["image"].shape[0]-1):
        for y in range(1,obs["image"].shape[1]-1):
            object=IDX_TO_OBJECT[obs["image"][x][y][0]]
            color=IDX_TO_COLOR[obs["image"][x][y][1]]
            if object != "agent":
                state=IDX_TO_STATE[obs["image"][x][y][2]]
            else:
                state = ""
            if object in interesting_objects:
                interesting_objects[object].append((x,y,color,state))
    state = ""
    counter = 0
    for object in ["wall","lava","obstacle"]:
        if interesting_objects[object] != []:
            state += f"""{counter}. The following tiles are {object}: {" ".join([f'({x},{y})' for x,y,_,_ in interesting_objects[object]])}\n"""
            counter +=1
    for object in ["chest","box","door"]:
            if interesting_objects[object] != []:
                for x,y,color,state_obj in interesting_objects[object]:
                    state += f"{counter}. A {state_obj} {color} {object} is on tile ({x},{y})\n"
                    counter += 1
                    
    for object in ["key","ball"]:                
        if interesting_objects[object] != []:
            for x,y,color,state_obj in interesting_objects[object]:
                state += f"{counter}. A {color} {object} is on tile ({x},{y})\n"
                counter += 1
    
    
    if inv is not None:
        state += f'''{counter}. Inventory : [a {inv.color} {inv.type}] \n'''
    else:
        state += f'''{counter}. Inventory : [] \n'''
    counter += 1
    
    
    
                
    if interesting_objects["agent"] != []:
        try:
            state += f'''{counter}. The agent is currently at the following tile: ({interesting_objects["agent"][0][0]},{interesting_objects["agent"][0][1]})\n'''
        except Exception as e:
            print(interesting_objects["agent"])
            raise e
        counter += 1
    
    state += f'''{counter}. The agent is facing {IDX_TO_DIRECTION[obs["direction"]]}'''
    counter +=1
    return state


class Random_policy():
    def __init__(self, env):
        self.env = env
        self.history = defaultdict(list)
    def replan(self,action):
        
        state = obs_to_state(self.env.gen_obs(),self.env.carrying )
        
        history_state = self.history[state]
        
            
            
        if self.env.carrying is not None:
            action = np.random.choice([0,1,2,3,4,5])
            if len(history_state) == 6:        
                history_state = []
        else:
            action =  np.random.choice([0,1,2,3,5])
            if len(history_state) == 5:        
                history_state = []

        while action in history_state:
            if self.env.carrying is not None:
                action = np.random.choice([0,1,2,3,4,5])
            else:
                action =  np.random.choice([0,1,2,3,5])
        self.history[state] = history_state + [action]

        return action


def collect_traj(args, seed):
    env_id = args.env
    num_steps = args.max_number_of_steps
    nb_traj = args.traj_per_env
    name = args.env
    random = args.random
    
   
    # Use the parameter env_id to make the environment
    env = gym.make(env_id, render_mode="rgb_array")
    env = FullyObsWrapper(env)
    
   
    # reset env
    curr_seed = seed
    results = defaultdict(pd.DataFrame)
    
    for i in tqdm(range(nb_traj), leave=True, desc=f"nb_ traj {seed}", disable=False):
        if i % 1000 == 0:
            print(f"seed : {seed} - traj : {i}")
        list_action = []
        list_obs = []
        list_terminated = []
        list_next_obs = []
        grid_size = []
        terminated = False


        obs,_ = env.reset(seed=curr_seed)
        
        for _ in range(i):
            try:
                env.reroll_mission()
            except:
                env.place_agent()

      

        obs = env.gen_obs()
        goal = obs["mission"]
      
        if random:
            expert = Random_policy(env)
        else:
            expert = BabyAIBot(env)

        last_action = None
        action = expert.replan(last_action)
        obs, reward, terminated, truncated, info = env.step(action)
        last_action = action


        for _step in tqdm(range(num_steps), leave=False, desc=f"nb_ step {seed}", disable=False):
            list_obs.append(obs_to_state(obs, env.carrying))
            try:
                action = expert.replan(last_action)
            except Exception as e:
                print(e)
                break
            
            
            if action is None:
                break
            obs, reward, terminated, truncated, info = env.step(action)

            list_action.append(action)
            list_terminated.append(terminated)
            list_next_obs.append(obs_to_state(obs,env.carrying))
            last_action = action
        
            if terminated:
                break
        
        if random:
            grid_size = [obs["image"].shape]*len(list_obs)
            env_id_list = [name + "_" + str(seed)]*len(list_obs)
            results["random"] = pd.concat([results["random"],pd.DataFrame({"obs":list_obs,"action":list_action, "next_obs":list_next_obs, "terminated": list_terminated, "grid_size":grid_size, "env_id":env_id_list})])

        if (not random) and terminated:
            grid_size = [obs["image"].shape]*len(list_obs)
            env_id_list = [name + "_" + str(seed)]*len(list_obs)
            results[goal] = pd.concat([results[goal],pd.DataFrame({"obs":list_obs,"action":list_action, "next_obs":list_next_obs, "terminated": list_terminated, "grid_size":grid_size, "env_id":env_id_list})])

  
    env.close()
    return results


def collect_per_env(args, seed):
    results = collect_traj(args, seed)
    for goal in results:
        os.makedirs(f"{args.data_folder}/{args.name}/goal/{goal}", exist_ok=True)
        results[goal].to_csv(f"{args.data_folder}/{args.name}/goal/{goal}/collected_data.csv", mode="a+")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the experiment")
    parser.add_argument("--name", type=str, default="BabyAI", help="The name of the dataset")
    parser.add_argument("--env", type=str, default="BabyAI-BossCustomLevel-v0 ", help="The name of the dataset")
    parser.add_argument("--number_env", type=int, default=4, help="The name of the dataset")
    parser.add_argument("--max_number_of_steps", type=int, default=500, help="The name of the dataset")
    parser.add_argument("--data_folder", type=str, default="./data/datasets", help="The name of the dataset")
    parser.add_argument("--traj_per_env", type=int, default=2, help="The name of the dataset")
    parser.add_argument("--random", type=int, default=0, help="The name of the dataset")
    args = parser.parse_args()

    os.makedirs(f"{args.data_folder}/{args.name}", exist_ok=True)

    results = defaultdict(pd.DataFrame)

    n_seed = args.number_env
    os.makedirs(f"{args.data_folder}/{args.name}/goal", exist_ok=True)

    
    list_seed = [x for x in range(0,n_seed)]
    
    n_job = max(100,args.number_env)
    def run_process(seed):
        collect_per_env(args, seed)
    
    pool = Pool(n_job)
    pool.map(run_process, list_seed)
    
    

        

