import random
import copy

# Based on CARLA Simulator, See topview.png
BUILDINGS = ["orange-apartment", "orange", "purple-high", "pink", "gray high", "green", "yellow", "white-high", "parking", "gate", "black",
             "pink-twin", "pink-window", "purple-small", "gray small", "white-small", "green-parking", "blue", "beige", "black-high", "green-apartment"]

# Based on CARLA Simulator, See topview.png
MAP_SETTINGS = [
    ("orange-apartment", "straight", "white-high"),
    ("orange-apartment", "right", "parking"),
    ("orange-apartment", "left", "purple-high"),
    ("green-apartment", "straight", "orange"),
    ("green-apartment", "left", "pink"),
    ("green-apartment", "right", "white-small"),
    ("orange", "straight", "orange-apartment"),
    ("orange", "right", "white-small"),
    ("orange", "left", "yellow"),
    ("parking", "straight", "orange-apartment"),
    ("parking", "left", "pink-window"),
    ("white-high", "right", "orange-apartment"),
    ("white-high", "left", "gate"),
    ("white-high", "straight", "yellow"),
    ("yellow", "straight", "orange"),
    ("yellow", "right", "green-apartment"),
    ("yellow", "left", "orange-apartment"),
    ("white-small", "straight", "orange"),
    ("white-small", "left", "pink"),
    ("white-small", "right", "green-apartment"),

    ("blue", "left", "green-apartment"),
    ("pink", "left", "parking"),
    ("beige", "left", "green-apartment"),
    ("purple-high", "left", "white-high"),
    ("gate", "left", "yellow"),

    ("gate", "straight", "blue"),
    ("blue", "straight", "black"),
    ("black", "straight", "beige"),
    ("beige", "straight", "pink"),
    ("pink", "straight", "pink-twin"),
    ("pink-twin", "straight", "green"),
    ("green", "straight", "black-high"),
    ("black-high", "straight", "purple-high"),
    ("purple-high", "straight", "green-parking"),
    ("green-parking", "straight", "pink-window"),
    ("pink-window", "straight", "purple-small"),
    ("purple-small", "straight", "gate")
]

BASE_PROMPT = "You are an autonomous driving agent. You can use 6 skills, turn right, turn left, go straight, pack goods, offloading."
BASE_PROMPT += "Pack, navigate, and deliver goods through the environment based on graphs."

SKILL_SET = ["turn right", "turn left", "go straight", "pack goods", "offloading"]

TASKS_SET = ["Pack goods in green building and deliver goods to white-high building",
             "Pack goods in orange building and deliver goods to purple-high building",
             "Pack goods in yellow building and deliver goods to green-apartment building"]
QUERIES_SET = ["green building is called",
                "orange building is called",
                "yellow building is called",]
SUCCESS_CONDITIONS = [("green building", "white-high building"),
                      ("orange building", "purple-high building"),
                      ("yellow building", "green-apartment building")]

NON_STA_SETTINGS = {
    "high": 4,
    "medium": 6,
    "low": 8,
}

class GraphCarla(object):
    def __init__(self, max_timesteps=200):
        self.nodes = []
        self.edges = []
        self.name_to_id = {}
        self.id_to_name = {}
        self.character_location = None
        self.task_building = None
        self.non_stationarity = None
        self.timesteps = 0
        self.max_timesteps = max_timesteps

    def construct_map(self):
        self.nodes = []
        self.edges = []
        id_idx = 0
        self.nodes.append({"id": 0, "class_name": "character", "state": ["unpacked"]})
        for building in BUILDINGS:
            self.nodes.append({
                "id": id_idx,
                "class_name": building + " building",
                "state": ["quite"],
            })
            self.name_to_id[building + " building"] = id_idx
            self.id_to_name[id_idx] = building + " building"
            id_idx += 1
        for setting in MAP_SETTINGS:
            print(setting)
            from_id = self.name_to_id[setting[0] + " building"]
            to_id = self.name_to_id[setting[2] + " building"]
            self.edges.append({
                "from_id": from_id,
                "relation_type": setting[1],
                "to_id": to_id
            })

    def reset(self, task_building, non_stationarity):
        self.construct_map()
        self.character_location = random.choice(BUILDINGS)+" building"
        self.task_building = task_building
        self.non_stationarity = non_stationarity
        self.stationary_setting()
        self.timesteps = 0
        obs = {}
        obs['visible_graph'] = self.get_visible_observations()
        obs['agent_graph'] = self.get_agent_graph()

        return obs

    def reward(self):
        reward = 0
        for node in self.nodes:
            for target in self.task_building:
                if node['class_name'] == target[1] and 'delivered' in node['state']:
                    reward += 1
        return reward


    def stationary_setting(self):
        for target_building_tuple in self.task_building:
            for node in self.nodes:
                if node['class_name'] == target_building_tuple[0] and "quite" in node['state']:
                    node['state'].remove("quite")
                    node['state'].append("called")
                if node['class_name'] == target_building_tuple[1] and "delivered" in node['state']:
                    node['state'].remove("delivered")

    def non_stationary_dynamics(self):
        if self.timesteps % self.non_stationarity == 0:
            target_building_tuple = random.choice(self.task_building)
            for node in self.nodes:
                if node['class_name'] == target_building_tuple[0] and "quite" in node['state']:
                    node['state'].remove("quite")
                    node['state'].append("called")
                if node['class_name'] == target_building_tuple[1] and "delivered" in node['state']:
                    node['state'].remove("delivered")


    def step(self, action):
        info = {}
        if action == "go straight" or action == "turn right" or action == "turn left":
            location_id = self.name_to_id[self.character_location]
            target_location = None
            for e in self.edges:
                if e['from_id'] == location_id and e['relation_type'] == action.split()[1]:
                    target_location = e['to_id']
            if target_location is None:
                print("no actions")
            else:
                self.character_location = self.id_to_name[target_location]
                print("Temp location", self.character_location)

        elif action == "pack goods":
            checked = False
            location_id = self.name_to_id[self.character_location]
            for node in self.nodes:
                if node['id'] == location_id and "called" in node['state']:
                    for target_node in self.nodes:
                        if target_node["id"] == 0:
                            target_node["state"].append("packed")
                            target_node["state"].remove("unpacked")
                    node["state"].remove("called")
                    node["state"].append("waiting")
                    checked = True
            if not checked:
                print("no actions")

        elif action == "offloading":
            checked = False
            location_id = self.name_to_id[self.character_location]
            for node in self.nodes:
                if node['id'] == 0 and "packed" in node['state']:
                    for target_node in self.nodes:
                        if target_node["id"] == location_id:
                            target_node["state"].append("delivered")
                            node["state"].remove("packed")
                            node["state"].append("unpacked")
                            for source_node in self.nodes:
                                if "waiting" in source_node['state']:
                                    source_node["state"].remove("waiting")
                                    source_node["state"].append("quite")
                    checked = True
            if not checked:
                print("no actions")

        obs = {}
        obs['visible_graph'] = self.get_visible_observations()
        obs['agent_graph'] = self.get_agent_graph()
        reward = self.reward()
        done = False
        if self.timesteps == self.max_timesteps:
            done = True
        info = {}
        self.timesteps += 1
        return obs, reward, done, info


    def get_visible_observations(self):
        edges = []
        nodes = []

        for node in self.nodes:
            if node['class_name'] == self.character_location:
                nodes.append(copy.deepcopy(node))
                if node['state']:
                    for s in node['state']:
                        edges.append({"from_id": node['id'], "relation_type": "IS", "to_id": s})

        return {"nodes": nodes, "edges": edges}


    def get_agent_graph(self):
        edges = [{"from_id": 0, "relation_type": "CLOSE", "to_id": self.name_to_id[self.character_location]}]
        nodes = []

        for node in self.nodes:
            if node['class_name'] == self.character_location or node['class_name'] == 'character':
                nodes.append(copy.deepcopy(node))
                if node['class_name'] == 'character' and node['state']:
                    for s in node['state']:
                        edges.append({"from_id": node['id'], "relation_type": "IS", "to_id": s})

        return {"nodes": nodes, "edges": edges}

    def get_position_graph(self):
        edges = copy.deepcopy(self.edges)
        nodes = copy.deepcopy(self.nodes)

        for node in nodes:
            if node['state']:
                for s in node['state']:
                    edges.append({"from_id": node['id'], "relation_type": "IS", "to_id": s})

        return {"nodes": nodes, "edges": edges}


if __name__ == "__main__":
    env = GraphCarla()
    env.reset()
    while True:
        action = input("ACTION: ")
        env.step(action)