# if __name__ == "__main__":
#     from multiagentenv import MultiAgentEnv
# else:
#     from .multiagentenv import MultiAgentEnv
from multiagentenv import MultiAgentEnv
from pbcmaze_belief_model import ReceiverBeliefModel, SenderBeliefModel
from models.r2d2_config import device
# from utils.dict2namedtuple import convert # This was breaking locally, so define manually
from collections import namedtuple
import copy
import numpy as np
import random
import itertools
import copy
import torch

ACTIONS = list(range(7))
LEFT, RIGHT, UP, DOWN, NOOP, HINT_UP, HINT_DOWN = ACTIONS

def convert(dictionary):
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)

class PBCMaze(MultiAgentEnv):
    def __init__(self, **kwargs):

        args = kwargs["env_args"]
        if isinstance(args, dict):
            args = convert(args)

        self.lengths = tuple(int(l) for l in args.lengths)
        self.starts = tuple(int(l) for l in args.starts)
        self.booth_loc = int(args.booth_loc)
        self.receiver_booth_loc = int(args.receiver_booth_loc)
        self.episode_limit = int(args.episode_limit)
        self.right_r = float(args.right_r)
        self.wrong_r = float(args.wrong_r)
        self.num_sender_decoy_booths = int(args.num_sender_decoy_booths)
        if(self.num_sender_decoy_booths > 0):
            # All positions along the length[0] are possible decoy positions except the starting one
            self.possible_decoy_pos = [i for i in range(1, self.lengths[0] + 1) if i != self.starts[0]]
            if(self.num_sender_decoy_booths > len(self.possible_decoy_pos)):
                print("The number of decoy booths cannot be greater than the length of the corridor - 1")
                exit()
            #self._gen_decoy_booths_locs()
            #print(self.decoy_booth_locs)
        # If fixed, decoy booths' locations would not reinitialized if reset
        self.decoy_booths_fixed = int(args.decoy_booths_fixed)
        self.decoy_booth_locs = None
        self.use_intermediate_reward = args.use_intermediate_reward
        self.use_mi_shaping = args.use_mi_shaping
        self.use_mi_loss = args.use_mi_loss
        self.use_oh_token = args.use_oh_token
        # set in reset:
        self.timestep = None
        self.agent0_loc = [None, None]
        self.agent1_loc = [None, None]
        self.sender_in_booth = None
        self.receiver_in_booth = None
        #self.reset()
        #print(self.decoy_booth_locs)

    def reset(self):
        self.timestep = 0
        self.sender_timestep_to_pb = 0
        self.receiver_timestep_to_pb = 0
        self.agent0_loc = [self.starts[0], 0]
        self.agent1_loc = [self.starts[1], 0]
        self.goal = np.random.choice([UP, DOWN])
        # Communication token: {0: when the receiver receives nothing, 1: when the receiver receives HINT_UP, 2: when the receiver receives HINT_DOWN}
        self.comm_token = 0
        self.comm_token_oh = np.zeros(2)
        self.sender_in_booth = False
        self.sender_reached_booth = False
        self.receiver_in_booth = False
        self.receiver_reached_booth = False
        self._last_sender_action_vec = np.zeros(7)
        self._last_receiver_action_vec = np.zeros(5)
        if(self.num_sender_decoy_booths > 0):
            if(self.decoy_booths_fixed != 1 or self.decoy_booth_locs == None):
                self._gen_decoy_booths_locs()
            else:
                self.sender_in_decoy_booth = [False for i in range(len(self.decoy_booth_locs))]
        return [self.get_grid_obs_for_agent(0), self.get_grid_obs_for_agent(1)], self.get_state()

    def save_env_config(self):
        # Function to save snapshot of the environment for fictitious transition of OBL
        # env_config = {
        #     'lengths' : self.lengths,
        #     'starts' : self.starts,
        #     'booth_loc': self.booth_loc,
        #     'receiver_booth_loc' : self.receiver_booth_loc,
        #     'num_sender_decoy_booths' : self.num_sender_decoy_booths,
        #     'decoy_booth_locs' : self.decoy_booth_locs,
        #     "right_r" : self.right_r,
        #     "wrong_r" : self.left_r,
        #     "timestep"
        # }

        return copy.deepcopy(self.__dict__)

    def load_env_config(self, config):
        for key in config.keys():
            setattr(self, key, config[key])

    def load_env_config_obl(self, index, belief, config):
        """
        Loading env config for OBL sampling, additional checks to ensure the flags are consistent

        Case 1: Loading env config for sampling of sender's belief
                Nothing's special other than updating flags
        Case 2: Loading env config for sampling of receiver's belief
                Updating flags + handle comm token if the belief is that the receiver is in a phonebooth
                and sender performed a cheap talk action
                *if receiver has its comm token set, it must know where the sender is?
        """
        self.load_env_config(config)

        if(index == 0):
            self.agent0_loc = list(belief)
        else:
            self.agent1_loc = list(belief)

        a0_x, a0_y = self.agent0_loc
        a1_x, a1_y = self.agent1_loc

        if a0_x == self.booth_loc and a0_y == 0: # booth obs for agent 0
            self.sender_in_booth = True
        else:
            # For resetting after going into the booth
            self.sender_in_booth = False

        if(a1_x == self.receiver_booth_loc and a1_y == 0):
            self.receiver_in_booth = True
        else:
            # For resetting after going into the booth
            self.receiver_in_booth = False

        if(index == 0):
            # Receiver's turn
            # Sampled sender's belief, i.e. receiver's OBL update
            if(self.num_sender_decoy_booths > 0):
                self.sender_in_decoy_booth = [False for flag in self.sender_in_decoy_booth]
            if(self.num_sender_decoy_booths > 0 and (a0_x+1, a0_y) in self.decoy_booth_locs):
                self.sender_in_decoy_booth[self.decoy_booth_locs.index((a0_x+1, a0_y))] = True
            if(self.sender_in_booth and (self._last_sender_action == HINT_UP or self._last_sender_action == HINT_DOWN)):
                if(self._last_sender_action == HINT_UP):
                    self.comm_token = -1
                    self.comm_token_oh[0] = -1
                    self.comm_token_oh[1] = -1
                else:
                    self.comm_token = 1
                    self.comm_token_oh[0] = 1
                    self.comm_token_oh[1] = 1
            else:
                self.comm_token = 0
                self.comm_token_oh = np.zeros(2)

        elif(index == 1):
            # Sampled receiver's belief, i.e. sender's OBL update
            pass
        else:
            print("load_env_config_obl: this should not happen")
            exit()

    def state_transition(self, index, loc, action):
        x, y = loc
        # agent can move up or down with decoy booths
        if(self.num_sender_decoy_booths == 0):
            if y != 0: # moved up or down
                return loc # Stuck!

        length = self.lengths[index]

        # Reset comm token if it was set last time agent 1 took an action
        if(index == 1 and self.comm_token != 0):
            if(self._last_sender_action == HINT_UP and self.comm_token == -1):
                self.comm_token = 0
                self.comm_token_oh = np.zeros(2)
            elif(self._last_sender_action == HINT_DOWN and self.comm_token == 1):
                self.comm_token = 0
                self.comm_token_oh = np.zeros(2)
            else:
                print("Comm token reset: This should not happen")
                exit()

        new_x, new_y = x, y
        if action == LEFT:
            if(y == 0):
                new_x, new_y = (x-1, y)
            # Decoy booths location is actually (x+1, y) in agent's coordinate system that's why
            elif(index == 0 and (x, y) in self.decoy_booth_locs):
                new_x, new_y = (x-1, y)
        elif action == RIGHT:
            if(y == 0):
                new_x, new_y = (x+1, y)
            # Decoy booths location is actually (x+1, y) in agent's coordinate system that's why
            elif(index == 0 and (x+2, y) in self.decoy_booth_locs):
                new_x, new_y = (x+1, y)
        elif action == UP:
            if index == 1 and x == length-1: # can only go up at the end and if agent 1
                new_x, new_y = (x, y+1)
            if(index == 0 and self.num_sender_decoy_booths > 0):
                # Decoy booths location is actually (x+1, y) in agent's coordinate system that's why
                if((x+1, y+1) in self.decoy_booth_locs):
                    # Moving to decoy booths
                    new_x, new_y = (x, y+1)
                elif(y+1 == 0):
                    # Moving away from decoy booths
                    new_x, new_y = (x, y+1)
        elif action == DOWN:
            if index == 1 and x == length-1: # can only go up at the end and if agent 1
                new_x, new_y = (x, y-1)
            if(index == 0 and self.num_sender_decoy_booths > 0):
                # Decoy booths location is actually (x+1, y) in agent's coordinate system that's why
                if(((x+1, y-1) in self.decoy_booth_locs)):
                    new_x, new_y = (x, y-1)
                elif(y-1 == 0):
                    new_x, new_y = (x, y-1)
        elif action == HINT_UP:
            if index == 1:
                print("This agent does not have this action as part of its action space")
                exit()
            else:
                # Only has an effect if both sender and receiver are in the right booth
                if(self.sender_in_booth and self.receiver_in_booth):
                    self.comm_token = -1
                    self.comm_token_oh[0] = -1
                    self.comm_token_oh[1] = -1
        elif action == HINT_DOWN:
            if index == 1:
                print("This agent does not have this action as part of its action space")
                exit()
            else:
                # Only has an effect if both sender and receiver are in the right booth
                if(self.sender_in_booth and self.receiver_in_booth):
                    self.comm_token = 1
                    self.comm_token_oh[0] = 1
                    self.comm_token_oh[1] = 1
        else:
            assert action == NOOP, (action, NOOP)

        new_x = np.clip(new_x, 0, length-1)
        new_y = np.clip(new_y, -2, 2)

        # In booth flags update
        if index == 0 and new_x == self.booth_loc and new_y == 0: # booth obs for agent 0
            self.sender_in_booth = True
        elif index == 1 and self.agent0_loc[0] == self.booth_loc and self.agent0_loc[1] == 0:
            pass
        else:
            # For resetting after going into the booth
            self.sender_in_booth = False
        if index == 1 and new_x == self.receiver_booth_loc and new_y == 0:
            self.receiver_in_booth = True
        elif index == 0 and self.agent1_loc[0] == self.receiver_booth_loc and self.agent1_loc[1] == 0:
            pass
        else:
            # For resetting after going into the booth
            self.receiver_in_booth = False
        if(index == 0 and self.num_sender_decoy_booths > 0 and (new_x+1, new_y) in self.decoy_booth_locs):
            self.sender_in_decoy_booth[self.decoy_booth_locs.index((new_x+1, new_y))] = True
        elif(index == 1 and self.num_sender_decoy_booths > 0 and (self.agent0_loc[0]+1, self.agent0_loc[1]) in self.decoy_booth_locs):
            pass
        else:
            # For resetting after going into the
            if(self.num_sender_decoy_booths > 0):
                self.sender_in_decoy_booth = [False for flag in self.sender_in_decoy_booth]

        # Update last action for comm token reset
        if(index == 0):
            self._last_sender_action = action
            self._last_sender_action_vec = np.zeros(7)
            self._last_sender_action_vec[action] = 1
        else:
            self._last_receiver_action = action
            self._last_receiver_action_vec = np.zeros(5)
            self._last_receiver_action_vec[action] = 1
        return new_x, new_y

    def step(self, idx, action, policy = None, belief_model = None):
        reward = 0
        second_term_masks = None
        if(self.use_mi_shaping and policy is not None and idx == 0):
            #mi_reward = self.calculate_mi_reward(policy, idx, belief_model) * 2.0
            second_term_masks, mi_reward = self.calculate_mi_reward(policy, idx, belief_model)
            reward += mi_reward * 2
            self.last_mi_reward = mi_reward

        if(self.use_mi_loss and second_term_masks is None and idx == 0 and policy is not None):
            # This means use_mi_shaping is False but use_mi_loss is True
            second_term_masks, mi_reward = self.calculate_mi_reward(policy, idx, belief_model)

        # Flag to set intermediate reward
        self.just_reached_booth = False
        if(idx == 0):
            self.agent0_loc = self.state_transition(0, self.agent0_loc, action)

            # Keep track of steps to first reach phone booth
            if(self.sender_reached_booth == False):
                if(self.sender_in_booth):
                    self.sender_reached_booth = True
                    self.just_reached_booth = True
                self.sender_timestep_to_pb += 1

        else:
            self.agent1_loc = self.state_transition(1, self.agent1_loc, action)

            # Keep track of steps to first reach phone booth
            if(self.receiver_reached_booth == False):
                if(self.receiver_in_booth):
                    self.receiver_reached_booth = True
                    self.just_reached_booth = True
                self.receiver_timestep_to_pb += 1

        a1_x, a1_y = self.agent1_loc
        self.timestep += 1
        done = (self.timestep >= self.episode_limit) or a1_y != 0 # Could also remove y!=0 for repeat reward until horizon reached

        # compute reward
        if a1_y == 1:
            reward += self.right_r if self.goal == UP else self.wrong_r
        if a1_y == -1:
            reward += self.right_r if self.goal == DOWN else self.wrong_r
        # Uncomment for one off intermediate reward
        # if(self.use_intermediate_reward and self.just_reached_booth):
        #     reward += 0.5
        if(self.use_intermediate_reward):
            # if((self.sender_in_booth and idx == 0) or (self.receiver_in_booth and idx == 1)):
            #     reward += 0.5
            if(self.sender_in_booth and self.receiver_in_booth):
                reward += 0.5 * 2.0
            # print("reached booth, agent " + str(idx), "timestep to pb: " + (str(self.sender_timestep_to_pb) if idx == 0 else str(self.receiver_timestep_to_pb)))


        # Giving mi reward to agent 1 too
        # if(self.use_mi_shaping and idx == 1):
        #     reward += self.last_mi_reward
            # if(reward > 0):
            #     print(self.agent0_loc)
            #     print(reward)
            #     print("break")

        if(self.num_sender_decoy_booths == 0):
            info = {'hint': 1 if self.goal == UP else 0, "sender_in_booth": 1 if self.sender_in_booth else 0, "sender_time_to_booth": self.sender_timestep_to_pb if self.sender_reached_booth else 20, "receiver_in_booth": 1 if self.receiver_in_booth else 0, "receiver_time_to_booth": self.receiver_timestep_to_pb if self.receiver_reached_booth else 20, "mi_term_masks":second_term_masks}
        else:
            info = {'hint': 1 if self.goal == UP else 0, "sender_in_booth": 1 if self.sender_in_booth else 0, "sender_time_to_booth": self.sender_timestep_to_pb if self.sender_reached_booth else 20, "sender_in_decoy_booth": [1 if flag else 0 for flag in self.sender_in_decoy_booth], "receiver_in_booth": 1 if self.receiver_in_booth else 0, "receiver_time_to_booth": self.receiver_timestep_to_pb if self.receiver_reached_booth else 20, "mi_term_masks": second_term_masks}

        return reward, done, info


    def calculate_mi_reward(self, policy, idx = 0, a1_belief = None):
        if(idx != 0):
            print("calculate_mi_reward: should only work for the sender")
            exit()
        start_env_config = self.save_env_config()
        start_env_config['use_mi_shaping'] = False
        start_env_config['comm_token'] = 0
        start_env_config['comm_token_oh'] = np.zeros(2)
        pseudo_env = copy.deepcopy(self)
        pseudo_env.use_mi_shaping = False
        pseudo_env.comm_token = 0
        pseudo_env.comm_token_oh = np.zeros(2)
        obs2_a1_dict = {}
        obs2_pos2_dict = {}
        obs_comm_token_dict = {}
        first_term = 0.0
        second_term = 0.0
        third_term = 0.0

        # To avoid divide by zero
        if(isinstance(policy, list)):
            policy = [a + 1e-10 for a in policy]
        else:
            policy += 1e-10

        # Note this is to not consider hint-up and hint down as separate actions for ctdu agent when computing MI info
        for a in ACTIONS[:policy.shape[-1]]:
            old_obs_1 = pseudo_env.get_obs(1)
            _, _, _ = pseudo_env.step(idx, a)
            obs_2 = pseudo_env.get_obs(1)
            # print(obs_2)
            obs_2 = [str(int(i)) for i in obs_2]
            obs_2 = "".join(obs_2)
            if(obs_2 in obs2_a1_dict.keys()):
                obs2_a1_dict[obs_2].append(a)
            else:
                obs2_a1_dict[obs_2] = [a]
            obs2_pos2_dict[obs_2] = pseudo_env.agent1_loc
            obs_comm_token_dict[obs_2] = pseudo_env.comm_token
            # print("a: {}, comm: {}".format(a, pseudo_env.comm_token))
            # Reset environment
            pseudo_env.load_env_config(copy.deepcopy(start_env_config))
            assert (old_obs_1 == pseudo_env.get_obs(1)).all()

            # Compute first term
            first_term  += -policy[a] * np.log2(policy[a])

        # # Compute second term
        second_term_masks = []
        for obs_2 in obs2_a1_dict.keys():
            p_o2 = 0.0
            mask = np.zeros(len(ACTIONS[:policy.shape[-1]]))
            for a in ACTIONS[:policy.shape[-1]]:
                mask[a] = 1.0 if a in obs2_a1_dict[obs_2] else 0.0
                p_o2_a1 = 1.0 if a in obs2_a1_dict[obs_2] else 0.0
                p_o2 += p_o2_a1 * policy[a]
            second_term += -p_o2 * np.log2(p_o2)
            second_term_masks.append(mask)

        second_term_masks = np.array(second_term_masks)

        # Compute third term
        for obs_2 in obs2_a1_dict.keys():
            for a in obs2_a1_dict[obs_2]:
                p_o2_a1 = 1.0
                p_o2_a1 *= policy[a]
                third_term += -p_o2_a1 * np.log2(p_o2_a1)


        mi = first_term + second_term - third_term

        return second_term_masks, mi

    def calculate_soc_influence_reward(self, influencee_agent, influencee_agent_hidden, influencee_agent_idx = 1):
        with torch.no_grad():
            pseudo_env = copy.deepcopy(self)
            obs = torch.Tensor(pseudo_env.get_obs(influencee_agent_idx)).to(device)
            policy, _, _, _ = influencee_agent.get_action(obs, 0.0, influencee_agent_hidden)
            cf_pseudo_env = copy.deepcopy(self)
            cf_pseudo_env.use_mi_shaping = False
            cf_pseudo_env.comm_token = 0
            cf_pseudo_env.comm_token_oh = np.zeros(2)
            obs = torch.Tensor(cf_pseudo_env.get_obs(influencee_agent_idx)).to(device)
            cf_policy, _, _, _ = influencee_agent.get_action(obs, 0.0, influencee_agent_hidden)
            # return (policy * (policy.squeeze() / cf_policy.squeeze()).log()).sum()
            return torch.nn.functional.kl_div(policy.log(), cf_policy.log(), log_target = True)

    def turn_off_mi_training(self):
        self.use_mi_shaping = False

    def _goal_encoding(self):
        return [1] if self.goal== UP else [-1]

    def _gen_decoy_booths_locs(self):
        num_up_isle = random.randint(0, self.num_sender_decoy_booths)
        num_down_isle = self.num_sender_decoy_booths - num_up_isle
        db_y_list = [-1] * num_up_isle + [1] * num_down_isle
        random.shuffle(db_y_list)
        db_x_list = random.sample(self.possible_decoy_pos, self.num_sender_decoy_booths)
        self._db_x_list = db_x_list
        self._db_y_list = db_y_list
        self.decoy_booth_locs = [(x, y) for x, y in zip(db_x_list, db_y_list)]
        self.sender_in_decoy_booth = [False for i in range(len(self.decoy_booth_locs))]

    def get_grid_obs_for_agent(self, indx):
        # Obs shape: (num_channel, num_rows, length of corridor)
        # Channels: wall channel, booth channel, agent location channel
        obs = None
        grid_obs = np.zeros((3, 3, self.lengths[indx] + 2))
        # Set walls
        grid_obs[0, 0, :] = 1
        grid_obs[0, 2, :] = 1
        grid_obs[0, :, 0] = 1
        grid_obs[0, :, -1] = 1
        if(indx == 0):
            # Agent 0
            # Unset those for decoy booths for wall channel and set for booth channel
            if(self.num_sender_decoy_booths > 0):
                for loc in self.decoy_booth_locs:
                    # "-1" needed to invert, because np array iterates in the opposite direction
                    grid_obs[0, -1 * loc[1] + 1, loc[0]] = 0
                    # Set booths
                    grid_obs[1, -1 * loc[1] + 1, loc[0]] = 1
            # Set functional booth
            grid_obs[1, 1, self.booth_loc + 1] = 1
            # agent location channel
            grid_obs[2, -1 * self.agent0_loc[1] + 1, self.agent0_loc[0] + 1] = 1
            # obs for agent 0: (grid obs, goal)
            goal_vec = np.zeros(2)
            if(self.goal == 2):
                goal_vec[0] = 1
            elif(self.goal == 3):
                goal_vec[1] = 1
            obs = (grid_obs, self._last_sender_action_vec, goal_vec)
        else:
            # Agent 1
            # Exit column not walls
            grid_obs[0, :, -2] = 0
            # Set receiver booth
            grid_obs[1, 1, self.receiver_booth_loc + 1] = 1
            # agent location channel
            grid_obs[2, -1 * self.agent1_loc[1]+1, self.agent1_loc[0]] = 1
            # obs for agent 1: (grid obs, communication token)
            if(self.use_oh_token):
                obs = (grid_obs, self._last_receiver_action_vec, self.comm_token_oh)
            else:
                obs = (grid_obs, self._last_receiver_action_vec, self.comm_token)
            # obs = (grid_obs, self.comm_token_oh, self._last_receiver_action_vec)
        return obs

    def get_obs_for_agent(self, indx):
        loc = [self.agent0_loc, self.agent1_loc][indx]
        obs = list(loc)
        x,y = loc
        if indx == 0: # booth obs for agent 0
            obs = obs + self._goal_encoding()
            obs = obs + ([1] if self.sender_in_booth else [-1]) # "if x == self.booth_loc" for local obs of booth
            if(self.num_sender_decoy_booths > 0):
                for flag in self.sender_in_decoy_booth:
                    obs = obs + ([1] if flag else [-1])
        else:
            # obs = obs + [self.comm_token]
            obs = obs + [self.comm_token_oh]
            obs = obs + ([1] if self.receiver_in_booth else [-1])
            if(self.num_sender_decoy_booths > 0):
                for flag in self.sender_in_decoy_booth:
                    obs = obs + [0]
        return obs

    def get_obs(self, indx, grid = True, flatten = True):
        if(grid):
            grid_obs, single_feat, prev_act_feat = self.get_grid_obs_for_agent(indx)
            return np.append(np.append(grid_obs.reshape(-1), prev_act_feat), single_feat)
            # grid_obs, single_feat = self.get_grid_obs_for_agent(indx)
            # return np.append(grid_obs.reshape(-1),single_feat)
        else:
            return self.get_obs_for_agent(indx)

    def get_state(self):
        if(self.num_sender_decoy_booths == 0):
            return np.array(list(self.agent0_loc) + list(self.agent1_loc) + self._goal_encoding() + ([1] if self.sender_in_booth else [0]) + [self.comm_token] + ([1] if self.receiver_in_booth else [0]))
        else:
            state = list(self.agent0_loc) + list(self.agent1_loc) + self._goal_encoding() + ([1] if self.sender_in_booth else [0])
            for flag in self.sender_in_decoy_booth:
                state = state + ([1] if flag else [0])
            state = state + [self.comm_token] + ([1] if self.receiver_in_booth else [0])
            return np.array(state)


    def get_obs_size(self, indx):
        # Gives flatten shape
        obs = self.get_obs(indx)
        #return obs[0].reshape(-1).shape[0] + 1
        return obs.shape[0]

    def get_state_size(self):
        return len(self.get_state())

    def get_avail_actions(self, agent_id):
        return np.array([ACTIONS, ACTIONS[:5]])[agent_id]

    def get_total_actions(self):
        return len(ACTIONS)

    def get_obs_agent(self, agent_id):
        return np.array(self.get_obs_for_agent(agent_id))

    def get_avail_agent_actions(self, agent_id):
        if(agent_id == 0):
            return np.ones(len(ACTIONS))
        elif(agent_id == 1):
            avail_action_array = np.zeros(len(ACTIONS))
            for i in range(len(ACTIONS[:5])):
                avail_action_array[i] = 1
            return avail_action_array
        else:
            raise NotImplementedError
        # return np.array([ACTIONS, ACTIONS[:5]])[agent_id]

    def close(self):
        pass

    def get_stats(self):
        pass

    def seed(self):
        raise NotImplementedError

    def render(self):
        """
        a: agent
        b: booth
        o: booth occupied by agent
        """

        # agent 0 (on left:)
        a0line1 = ["." for _ in range(self.lengths[0])] + ["."] + ["."]
        a0line2 = ["."] + [" " for _ in range(self.lengths[0])] + ["."]
        a0line3 = ["." for _ in range(self.lengths[0])] + ["."] + ["."]
        a0line2[self.booth_loc+1] = "b"
        if(self.num_sender_decoy_booths > 0):
            line_dict = {-1: a0line3, 0: a0line2, 1: a0line1}
            for db_i in range(len(self.decoy_booth_locs)):
                line_dict[self.decoy_booth_locs[db_i][1]][self.decoy_booth_locs[db_i][0]] = "b"
            line_dict[self.agent0_loc[1]][self.agent0_loc[0]+1] = "a"
            # "o" if in booth
            if(True in self.sender_in_decoy_booth or self.sender_in_booth):
                line_dict[self.agent0_loc[1]][self.agent0_loc[0]+1] = "o"
        else:
            a0line2[self.agent0_loc[0]+1] = "a"
        # agent 1 (on right:)
        a1line1 = ["." for _ in range(self.lengths[1])] + [" "] + ["."]
        a1line2 = ["."] + [" " for _ in range(self.lengths[1])] + ["."]
        a1line3 = ["." for _ in range(self.lengths[1])] + [" "] + ["."]
        # if self.sender_in_booth:
        #     a1line2[self.hint_locs[1]+1] = "u" if self.hint == UP else "d
        a1line2[self.receiver_booth_loc+1] = "b"
        a1_line = [a1line3, a1line2, a1line1][self.agent1_loc[1]+1]
        if(self.receiver_in_booth):
            a1_line[self.agent1_loc[0]+1] = "o"
        else:
            a1_line[self.agent1_loc[0]+1] = "a"
        # print
        print("".join(a0line1)+"".join(a1line1))
        print("".join(a0line2)+"".join(a1line2))
        print("".join(a0line3)+"".join(a1line3))

import torch
if __name__ == "__main__":
    d =  {
        "lengths":(8,4),
        "starts": (4,2),
        "receiver_booth_loc":0,
        "booth_loc":7,
        "episode_limit": 16,
        "right_r": 1.0,
        "wrong_r":-0.5,
        "num_sender_decoy_booths": 2,
        "decoy_booths_fixed": 1,
        "use_intermediate_reward": False,
        "use_mi_shaping": True,
        "use_mi_loss": False,
        "use_oh_token" : True
    }
    # d = {
    #     "lengths":(1,5),
    #     "starts": (0,2),
    #     "receiver_booth_loc":0,
    #     "booth_loc":0,
    #     "episode_limit": 40,
    #     "right_r": 1.0,
    #     "wrong_r":-0.5,
    #     "num_sender_decoy_booths": 0,
    #     "decoy_booths_fixed": 1,
    #     "use_intermediate_reward": False,
    #     "use_mi_shaping": True,
    #     "use_mi_loss": True
    # }

    np.random.seed(1)
    random.seed(1)
    env = PBCMaze(env_args=d)
    obs, state = env.reset()
    # policy = np.zeros(7)
    # policy[4] = 1.0
    # policy += 1e-10
    # print(env.calculate_mi_reward(policy))
    # exit()

    num_ep = 10
    for i in range(num_ep):
        print("ep:", i+1)
        obs, state = env.reset()
        receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
        sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
        rb_model = ReceiverBeliefModel(receiver_pi_0, env)
        sb_model = SenderBeliefModel(sender_pi_0, env)
        print(env.decoy_booth_locs)
        print("\nobs0:", obs[0])
        print("\nobs1:", obs[1])
        print("state:", state)
        print("goal:", env.goal)
        env.render()
        # env_config = env.save_env_config()
        # print(env_config)
        done = False
        while not done:
            action_map = {"a": LEFT,
                          "d": RIGHT,
                          "w": UP,
                          "s": DOWN,
                          "n": NOOP,
                          "q": HINT_UP,
                          "e": HINT_DOWN}

            # Agent 0 move
            action0 = None
            while action0 not in action_map:
                print("Action0?")
                print("actions allowed: w,a,s,d,n, q, e")
                action0 = input()
            sender_pi_0 = torch.tensor([1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7])
            # sender_pi_0 = [3/35, 3/35, 3/35, 3/35, 3/35, 2/7, 2/7]
            reward, done, info = env.step(0, action_map[action0], sender_pi_0, rb_model)
            obs, state = env.get_obs(0), env.get_state()
            sb_model.update_belief(comm_token = env.comm_token)
            obs1 = env.get_obs(1)
            env.render()
            print("Agent 0 moved")
            print("obs 0:", obs)
            # print("obs1: ", obs1)
            #print("state:", state)
            print("done:", done)
            print("info:", info)
            print("goal:", env.goal)
            print("reward:", reward)
            print("comm token: " + str(env.comm_token))
            print("comm token one hot: " + str(env.comm_token_oh))
            print()

            # Agent 1 move
            print("agent 1 obs agent 0 moved: " + str(env.get_obs(1)))
            action1 = None
            while action1 not in action_map:
                print("Action1?")
                print("actions allowed: w,a,s,d,n")
                action1 = input()
            reward1, done, info1 = env.step(1, action_map[action1])
            obs1, state1 = env.get_obs(1), env.get_state()
            rb_model.update_belief()
            env.render()
            print("Agent 1 moved")
            print("obs1:", obs1)
            #print("state:", state1)
            print("done:", done)
            print("info:", info1)
            print("reward:", reward1)
            print("comm token: " + str(env.comm_token))
            # print(env_config)
            # print("break")
            # print(env.__dict__)
            # print("Reset env")
            # env.load_env_config(env_config)
            # env.render()
            # print("break")

        print("\nEND\n")
        sb_model.reset_belief()
        rb_model.reset_belief()
