
"""
An instace of integrating Flatland into MARLLib
"""

import os, sys
import numpy as np
import time
from collections import defaultdict
import yaml
import copy

from ray.rllib.env.multi_agent_env import MultiAgentEnv
import gym
from gym.spaces import Dict as GymDict, Box

from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.persistence import RailEnvPersister 
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.rail_generators import SparseRailGen
from flatland.envs.malfunction_generators import (
    MalfunctionParameters,
    ParamMalfunctionGen,
)

from flatland_cutils import TreeObsForRailEnv as TreeCutils
from flatland.envs.observations import TreeObsForRailEnv as SlowTreeCutils

from marllib import marl
from marllib.envs.base_env import ENV_REGISTRY
from flatland.envs.rail_env import TrainState

# render
cur_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(cur_path)
from render_utils import debug_show  


agent_prefix = 'agent_' # consistent with rllib
max_steps = 1000 # only influence model, not environment (update every 10*100 steps)

policy_mapping_dict = {
    "all_scenario": {
        "description": "train controls",
        "team_prefix": (agent_prefix,),
        "all_agents_one_policy": True,
        "one_agent_one_policy": True,
    },
}

class FlatlandEnv(MultiAgentEnv):

    def __init__(self, env_config):

        self.shared_reward = env_config.get('shared_reward', 'global')
        self.obs_builder = env_config.get('obs_builder', 'fast_tree')
        map_name = env_config.get('map_name', 'test_00')
        self.map_name = map_name.split('-')[0]

        with open(os.path.join(cur_path, f'maps/{self.map_name}.yaml'), 'r') as rf:
            map_cfg = yaml.safe_load(rf)
        print(f'\n=== Loaded MAP {self.map_name} with SETUP ===\n{map_cfg}\n')

        if self.obs_builder == 'fast_tree': 
            obs_builder_object = TreeCutils(31, 500) # https://github.com/RoboEden/flatland-marl/blob/main/flatland_cutils/setup.py

        # TODO: support other obs_builder, also need to change observation_space, _process_obs
        elif self.obs_builder == 'slow_tree':  
            obs_builder_object = SlowTreeCutils(max_depth=2) # https://gitlab.aicrowd.com/flatland/flatland/-/blob/master/flatland/envs/observations.py
            raise NotImplementedError

        self.env = RailEnv(
            number_of_agents=map_cfg['number_of_agents'],
            width=map_cfg['width'],
            height=map_cfg['height'],
            rail_generator=SparseRailGen(
                max_num_cities=map_cfg['max_num_cities'],
                grid_mode=map_cfg['grid_mode'],
                max_rails_between_cities=map_cfg['max_rails_between_cities'],
                max_rail_pairs_in_city=map_cfg['max_rail_pairs_in_city'],
            ),
            line_generator=SparseLineGen(speed_ratio_map=map_cfg['speed_ratio_map']),
            malfunction_generator=ParamMalfunctionGen(
                MalfunctionParameters(
                    malfunction_rate=map_cfg['malfunction_rate'], min_duration=map_cfg['malfunction_min_duration'], max_duration=map_cfg['malfunction_max_duration']
                )
            ),
            obs_builder_object=obs_builder_object,
            random_seed=env_config.get('seed', 2023)
        )

        # TODO: environment wrappers
        # https://gitlab.aicrowd.com/flatland/neurips2020-flatland-baselines/-/tree/flatland3/envs/flatland/utils



        n_actions = 5
        n_obses = 606
        self.num_agents = self.env.number_of_agents
        self.agents = [agent_prefix+str(i) for i in list(range(self.num_agents))]
        self.action_space = gym.spaces.Discrete(n_actions)
        self.observation_space = GymDict({
            "obs": Box(low=-float("inf"), high=float("inf"), shape=(n_obses,), dtype=np.dtype("float32")),
            "action_mask": Box(-1.0, 1.0, shape=(n_actions,)),
            # "other_obs": Box(-100.0, 100.0, shape=(self.max_neighbors_num, n_obses)),
            # "other_mask": Box(-1.0, 1.0, shape=(self.max_neighbors_num, )),
            "agent_mask": Box(-1.0, 1.0, shape=(self.num_agents, ), dtype=np.dtype("float32")),
            "last_agent_mask": Box(-1.0, 1.0, shape=(self.num_agents,), dtype=np.dtype("float32")),
            })
        self.env_config = env_config


        self.last_agent_mask  = {}

        # assume all agent same action/obs space
        self.max_neighbors_node = 4 # 10  4=first depth
        self.max_neighbors_num = self.num_agents-1 # 3
        self.neighbors = {}


    def final_metric(self,):
       assert self.env.dones["__all__"]
       env = self.env
    
       n_arrival, n_no_departure = 0, 0
       for a in env.agents:
           if a.position is None and a.state != TrainState.READY_TO_DEPART:
               n_arrival += 1
           elif a.position is None and a.state == TrainState.READY_TO_DEPART:
                n_no_departure += 1

       arrival_ratio = n_arrival / env.get_num_agents()
       departure_ratio = 1 - n_no_departure / env.get_num_agents()
       total_reward = sum(list(env.rewards_dict.values()))
       norm_reward = 1 + total_reward / env._max_episode_steps / env.get_num_agents()

       deadlock_ratio = np.mean(list(env.deadlocks_dict.values()))

       print(f'\n=== Episode Ends! ===\n# Steps:{env._elapsed_steps}\n# Agents:{env.get_num_agents()}\nArrival Ratio:{arrival_ratio:.3f}\nDeparture Ratio:{departure_ratio:.3f}\nDeadlock Ratio: {deadlock_ratio:.3f}\nTotal Reward:{total_reward:.3f}\nNorm Reward:{norm_reward:.3f}\n')
       return arrival_ratio, departure_ratio, deadlock_ratio, total_reward, norm_reward


    def _action_required(self):
        return {
            i: self.env.action_required(agent)
            for i, agent in enumerate(self.env.agents)
        }


    def _update_env_properties(self):
        self.obs_properties = {}
        properties = self.env.obs_builder.get_properties()
        env_config, agents_properties, valid_actions = properties
        self.obs_properties.update(env_config)
        self.obs_properties.update(agents_properties)
        self.obs_properties["valid_actions"] = np.array(valid_actions)


    def _extract_agents(self, numbers):
        n_agents = self.num_agents

        new_numbers, new_agents = [], []
        for n in numbers: # node 30
            if n > 0:
                n = int(n * n_agents)
                agents = [i for i, x in enumerate(list(bin(n)[::-1])) if x =='1']
                new_agents.append(agents)
                new_numbers.append( 1.0*len(agents)/n_agents )
            else:
                new_numbers.append(n) # 0 or -1
                new_agents.append([])
        return np.array(new_numbers), new_agents


    def _process_obs(self, feature):

        feature_list = {}
        feature_list["agent_attr"] = np.array(feature[0])
        feature_list["forest"] = np.array(feature[1][0])
        feature_list["forest"][feature_list["forest"] == np.inf] = -1
        feature_list["adjacency"] = np.array(feature[1][1])
        feature_list["node_order"] = np.array(feature[1][2])
        feature_list["edge_order"] = np.array(feature[1][3])
        valid_actions = self.obs_properties["valid_actions"]


        obs = {}
        self.neighbors = {}
        for i, name in enumerate(self.agents):
            
            # --- extract neighbors ---
            # same direction -5
            # opposite direction -4
            same_cnt, same_agents = self._extract_agents(feature_list["forest"][i][:,-5])
            opposite_cnt, opposite_agents = self._extract_agents(feature_list["forest"][i][:,-4])
            opposite_agents = list(set([a for agents in opposite_agents[:self.max_neighbors_node] for a in agents]))
            same_agents = list(set([a for agents in same_agents[:self.max_neighbors_node] for a in agents]))
            self.neighbors[name] = list(set(opposite_agents+same_agents))[:self.max_neighbors_num]
            feature_list["forest"][i][:,-5] = same_cnt
            feature_list["forest"][i][:,-4] = opposite_cnt
            # --------------------------

            # TODO: train_id = 0 to create homogeneous agents
            feature_list['agent_attr'][i][-13] = 0.0

            obs[name] = {"obs": np.concatenate([feature_list['agent_attr'][i], 
                                        feature_list["forest"][i].flatten(),
                                        feature_list["adjacency"][i].flatten(),
                                        feature_list["node_order"][i].flatten(),
                                        feature_list["edge_order"][i].flatten()], axis=-1).astype(np.float32),
                         "action_mask": valid_actions[i].flatten().astype(np.float32)
                        }

        # in two directions...
        self.new_neighbors = copy.deepcopy(self.neighbors)
        for i, name in enumerate(self.agents):
            neighbors = self.neighbors[name]
            for n in neighbors:
                self.new_neighbors[agent_prefix+str(n)].append(i)
        self.neighbors = self.new_neighbors


        all_obs = np.stack([v["obs"] for k, v in obs.items()], axis=0)
        for i, name in enumerate(self.agents):
            neighbors = self.neighbors[name]
            agents_pad = np.array(neighbors + [i] * max(0, self.max_neighbors_num+1 - len(neighbors)) ).astype(int)
            multihot = np.eye(len(self.agents))[agents_pad] # [n_neighbors, n_all]
            # other_mask = np.array([0.0] * len(neighbors) + [1.0] * max(0, self.max_neighbors_num - len(neighbors)))
            # obs[name]["other_mask"] = other_mask # [n,]
            # obs[name]["other_obs"] = np.matmul(multihot, all_obs) # [n_neighbors, obs_dim]
            obs[name]["agent_mask"] = 1 - np.sum(multihot, axis=0).astype(bool).astype(np.float32)
            obs[name]["last_agent_mask"] = self.last_agent_mask[name].astype(np.float32)
            self.last_agent_mask[name] = obs[name]["agent_mask"]


        return obs


    def _process_action(self, action_dict):
        action_required = self._action_required()
        parsed_action = dict()
        for idx, act in action_dict.items():
            simple_idx = int(idx.split('_')[-1])
            if action_required[simple_idx]:
                parsed_action[simple_idx] = RailEnvActions(act)
        return parsed_action


    def reset(self):
        obs, _ = self.env.reset()
        self._update_env_properties()
        for i, name in enumerate(self.agents):
            self.last_agent_mask[name] = 1 - np.eye(len(self.agents))[i]
        return self._process_obs(obs)


    def step(self, action_dict):

        # print(f'\n--- Step: {self.env._elapsed_steps} ---')
        # process actions
        action = self._process_action(action_dict)

        # env.step
        obs, reward, done, info = self.env.step(action)

        # process obs/reward/done
        self._update_env_properties()
        new_obs = self._process_obs(obs)
        new_reward, new_done, new_info = {}, {}, defaultdict(dict)
        # calculate global reward
        total_reward = sum(list(reward.values())) / self.env.get_num_agents()
        for k, v in reward.items():
            if self.shared_reward == 'global' :
                new_reward[agent_prefix+str(k)] = total_reward
            elif self.shared_reward == 'local' :
                new_reward[agent_prefix+str(k)] = reward[k]
                for n in self.neighbors[agent_prefix+str(k)]:
                    new_reward[agent_prefix+str(k)] += reward[n] 
                new_reward[agent_prefix+str(k)] /= 1 + len(self.neighbors[agent_prefix+str(k)])
            else:
                new_reward[agent_prefix+str(k)] = reward[k]
            # M: rllib requires all done the same
            # new_done[agent_prefix+str(k)] = done[k]
            new_done[agent_prefix+str(k)] = self.env.dones['__all__']
            for ki, vi in info.items():
                new_info[agent_prefix+str(k)][ki] = info[ki][k]
        new_done['__all__'] = self.env.dones['__all__']

        if new_done['__all__']:
            stat = self.final_metric()
        return new_obs, new_reward, new_done, new_info


    def close(self):
        self.env.close()


    def render(self, mode=None):
        return debug_show(self.env, mode)


    def get_env_info(self):
        env_info = {
            "space_obs": self.observation_space,
            "space_act": self.action_space,
            "num_agents": self.num_agents,
            "episode_limit": max_steps, 
            "policy_mapping_info": policy_mapping_dict 
        }
        return env_info



if __name__ == '__main__':
    # register new env
    ENV_REGISTRY["flatland"] = FlatlandEnv 
    # initialize env
    env = marl.make_env(environment_name="flatland", map_name='test_00')
    # obs = env[0].reset()
    # action = {f'agent_{i}':0 for i in range(50)}
    # o, r, d, i = env[0].step(action)
    # print(o, r, d)

    # pick mappo algorithms
    mappo = marl.algos.mappo(hyperparam_source="test")
    # customize model
    model = marl.build_model(env, mappo, {"core_arch": "tree"})
    # start learning
    mappo.fit(env, model, stop={'episode_reward_mean': 1000, 'timesteps_total': int(1e6)}, local_mode=True, num_gpus=1, num_workers=2, share_policy='all', checkpoint_freq=int(1e4))

