import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
import functools
import operator
import time

import sys

from models.r2d2_config import gamma, device, batch_size, sequence_length, burn_in_length, temperature
from models.dru import DRU
from models.gb_softmax import Gumbel_Softmax
from models.rao_gb_softmax import Rao_Gumbel_Softmax
from utils.pbmaze_config import env_config, multi_env_config
from phone_booth_colab_maze_final import PBCMaze
from multi_phone_booth_collab_maze import PBCMaze as MultiPBCMaze

ACTIONS = list(range(6))
CTDU_ACTIONS = list(range(7))
CTDU_LEFT, CTDU_RIGHT, CTDU_UP, CTDU_DOWN, CTDU_NOOP, CTDU_SEND = ACTIONS
LEFT, RIGHT, UP, DOWN, NOOP, HINT_UP, HINT_DOWN = CTDU_ACTIONS

def convert_msg_to_actions(comm_bit, msg):
    if(comm_bit == 1):
        return HINT_UP if msg.item() == -1 else HINT_DOWN
    elif(comm_bit == 2):
        return HINT_UP if torch.prod(msg.squeeze()).item() == -1 else HINT_DOWN
    else:
        raise NotImplementedError

class R2D2(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_comm_bits):
        super(R2D2, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        self.lstm = nn.LSTM(input_size=num_inputs, hidden_size=16, batch_first=True)
        self.fc = nn.Linear(16, 128)
        self.fc_adv = nn.Linear(128, num_outputs)
        self.fc_val = nn.Linear(128, 1)
        self.comm_fc = nn.Linear(16, 128)
        self.comm_message = nn.Linear(128, num_comm_bits)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x, hidden=None):
        # x [batch_size, sequence_length, num_inputs]
        batch_size = x.size()[0]
        sequence_length = x.size()[1]
        out, hidden = self.lstm(x, hidden)

        env_out = F.relu(self.fc(out))
        adv = self.fc_adv(env_out)
        adv = adv.view(batch_size, sequence_length, self.num_outputs)
        val = self.fc_val(env_out)
        val = val.view(batch_size, sequence_length, 1)

        m_out = F.relu(self.comm_fc(out))
        message = self.comm_message(m_out)

        qvalue = val + (adv - adv.mean(dim=2, keepdim=True))
        return qvalue, hidden, message

    def get_td_error(cls, online_net, target_net, batch, lengths):
        def slice_burn_in(item):
            return item[:, burn_in_length:, :]
        batch_size = torch.stack(batch.state).size()[0]
        states = torch.stack(batch.state).view(batch_size, sequence_length, online_net.num_inputs)
        next_states = torch.stack(batch.next_state).view(batch_size, sequence_length, online_net.num_inputs)
        actions = torch.stack(batch.action).view(batch_size, sequence_length, -1).long().to(device)
        rewards = torch.stack(batch.reward).view(batch_size, sequence_length, -1).to(device)
        masks = torch.stack(batch.mask).view(batch_size, sequence_length, -1).to(device)
        steps = torch.stack(batch.step).view(batch_size, sequence_length, -1).to(device)
        rnn_state = torch.stack(batch.rnn_state).view(batch_size, sequence_length, 2, -1)
        [h0, c0] = rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        h0 = h0.unsqueeze(0).detach()
        c0 = c0.unsqueeze(0).detach()
        [h1, c1] = rnn_state[:, 1, :, :].transpose(0, 1).contiguous()
        h1 = h1.unsqueeze(0).detach()
        c1 = c1.unsqueeze(0).detach()

        pred, _, _ = online_net(states, (h0, c0))
        next_pred, _, _  = target_net(next_states, (h1, c1))

        next_pred_online, _, _  = online_net(next_states, (h1, c1))

        pred = slice_burn_in(pred)
        next_pred = slice_burn_in(next_pred)
        actions = slice_burn_in(actions)
        rewards = slice_burn_in(rewards)
        masks = slice_burn_in(masks)
        steps = slice_burn_in(steps)
        next_pred_online = slice_burn_in(next_pred_online)

        pred = pred.gather(2, actions)

        _, next_pred_online_action = next_pred_online.max(2)

        target = rewards + masks * pow(gamma, steps) * next_pred.gather(2, next_pred_online_action.unsqueeze(2))

        td_error = pred - target.detach()

        for idx, length in enumerate(lengths):
            td_error[idx][length-burn_in_length:][:] = 0

        return td_error

    def get_mi_td_error(cls, online_net, target_net, batch, lengths):
        def slice_burn_in(item):
            return item[:, burn_in_length:, :]
        batch_size = torch.stack(batch.state).size()[0]
        states = torch.stack(batch.state).view(batch_size, sequence_length, online_net.num_inputs)
        next_states = torch.stack(batch.next_state).view(batch_size, sequence_length, online_net.num_inputs)
        actions = torch.stack(batch.action).view(batch_size, sequence_length, -1).long().to(device)
        rewards = torch.stack(batch.reward).view(batch_size, sequence_length, -1).to(device)
        masks = torch.stack(batch.mask).view(batch_size, sequence_length, -1).to(device)
        rnn_state = torch.stack(batch.rnn_state).view(batch_size, sequence_length, 2, -1)
        [h0, c0] = rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        h0 = h0.unsqueeze(0).detach()
        c0 = c0.unsqueeze(0).detach()
        [h1, c1] = rnn_state[:, 1, :, :].transpose(0, 1).contiguous()
        h1 = h1.unsqueeze(0).detach()
        c1 = c1.unsqueeze(0).detach()

        pred, _, _ = online_net(states, (h0, c0))
        next_pred, _, _ = target_net(next_states, (h1, c1))

        next_pred_online, _, _ = online_net(next_states, (h1, c1))

        pred = slice_burn_in(pred)
        next_pred = slice_burn_in(next_pred)
        actions = slice_burn_in(actions)
        rewards = slice_burn_in(rewards)
        masks = slice_burn_in(masks)
        next_pred_online = slice_burn_in(next_pred_online)

        mi_term = 0.0
        mi_term_masks = batch.mi_term_mask
        # next_states = torch.stack(batch.next_state).view(batch_size, sequence_length, online_net.num_inputs)
        # next_pred, _, _ = online_net(next_states, (h1, c1))
        # next_pred = slice_burn_in(next_pred)
        # next_pred = masks * next_pred

        # softmax_next_pred = torch.softmax(next_pred / temperature, dim = -1)
        softmax_pred = torch.softmax(pred / temperature, dim = -1)

        # Process mi term mask
        flattned_mi_term_masks = functools.reduce(operator.iconcat, mi_term_masks, [])
        padded_flattened_mi_term_masks = torch.nn.utils.rnn.pad_sequence(flattned_mi_term_masks, batch_first=True)

        # flattened_softmax_next_pred = softmax_next_pred.view(-1, 6).unsqueeze(1).repeat(1, padded_flattened_mi_term_masks.size(1), 1)
        # dp = torch.sum(flattened_softmax_next_pred * padded_flattened_mi_term_masks, dim = -1) + 1e-10
        flattened_softmax_pred = softmax_pred.view(-1, 6).unsqueeze(1).repeat(1, padded_flattened_mi_term_masks.size(1), 1)
        dp = torch.sum(flattened_softmax_pred * padded_flattened_mi_term_masks, dim = -1) + 1e-10

        mi_term = torch.sum((- dp * torch.log2(dp)).view(pred.size(0), sequence_length, -1), dim = (1, 2))
        mi_term = torch.mean(mi_term)

        pred = pred.gather(2, actions)

        _, next_pred_online_action = next_pred_online.max(2)

        target = rewards + masks * gamma * next_pred.gather(2, next_pred_online_action.unsqueeze(2))

        td_error = pred - target.detach()

        for idx, length in enumerate(lengths):
            td_error[idx][length-burn_in_length:][:] = 0

        return td_error, mi_term

    def get_obl_td_error(cls, online_net, batch, lengths, use_mi_loss = False):
        def slice_burn_in(item):
            return item[:, burn_in_length:, :]
        batch_size = torch.stack(batch.state).size()[0]
        states = torch.stack(batch.state).view(batch_size, sequence_length, online_net.num_inputs)
        targets = torch.stack(batch.target).view(batch_size, sequence_length, -1).to(device)
        actions = torch.stack(batch.action).view(batch_size, sequence_length, -1).long().to(device)
        masks = torch.stack(batch.mask).view(batch_size, sequence_length, -1).to(device)
        rnn_state = torch.stack(batch.rnn_state).view(batch_size, sequence_length, 2, -1)
        [h0, c0] = rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        h0 = h0.unsqueeze(0).detach()
        c0 = c0.unsqueeze(0).detach()
        [h1, c1] = rnn_state[:, 1, :, :].transpose(0, 1).contiguous()
        h1 = h1.unsqueeze(0).detach()
        c1 = c1.unsqueeze(0).detach()

        pred, _, _  = online_net(states, (h0, c0))

        pred = slice_burn_in(pred)
        actions = slice_burn_in(actions)
        targets = slice_burn_in(targets)
        masks = slice_burn_in(masks)

        mi_term = 0.0
        if(use_mi_loss):
            mi_term_masks = batch.mi_term_mask
            softmax_pred = torch.softmax(pred / temperature, dim = -1)

            # Process mi term mask
            flattned_mi_term_masks = functools.reduce(operator.iconcat, mi_term_masks, [])
            padded_flattened_mi_term_masks = torch.nn.utils.rnn.pad_sequence(flattned_mi_term_masks, batch_first=True)


            flattened_softmax_pred = softmax_pred.view(-1, softmax_pred.size(-1)).unsqueeze(1).repeat(1, padded_flattened_mi_term_masks.size(1), 1)
            dp = torch.sum(flattened_softmax_pred * padded_flattened_mi_term_masks, dim = -1) + 1e-10
            mi_term = torch.sum((- dp * torch.log2(dp)).view(32, sequence_length, -1), dim = (1, 2))
            mi_term = torch.mean(mi_term)
            # mi_term = torch.sum(- dp * torch.log2(dp))
            # mi_term = torch.mean(- dp * torch.log2(dp))

            # Slow approach
            # for b_idx in range(softmax_pred.size(0)):
            #     for seq_idx in range(softmax_pred.size(1)):
            #         for mask in mi_term_masks[b_idx][seq_idx]:
            #             term = torch.dot(softmax_pred[b_idx][seq_idx], mask.squeeze()) + 1e-10
            #             mi_term += - term * torch.log2(term)

        pred = pred.gather(2, actions)
        td_error = pred - targets.detach()

        for idx, length in enumerate(lengths):
            td_error[idx][length-burn_in_length:][:] = 0

        return td_error, mi_term


    def train_model(cls, online_net, target_net, optimizer, batch, lengths, obl = False, use_mi_loss = False, use_only_mi_loss = False):
        if(obl):
            td_error, mi_term = cls.get_obl_td_error(online_net, batch, lengths, use_mi_loss)
            # minus MI to maximize
            if(use_only_mi_loss):
                loss =  -1.0 * mi_term
            else:
                loss = pow(td_error, 2).mean() - 1.0 * mi_term
        else:
           if(use_mi_loss):
                td_error, mi_term = cls.get_mi_td_error(online_net, target_net, batch, lengths)
                if(use_only_mi_loss):
                    loss =  -1.0 * mi_term
                else:
                    loss = pow(td_error, 2).mean() - 1.0 * mi_term
           else:
               td_error = cls.get_td_error(online_net, target_net, batch, lengths)
               loss = pow(td_error, 2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss, td_error

    def get_action(self, state, hidden, return_softmax_q_values = True, temp = 1.0):
        state = state.unsqueeze(0).unsqueeze(0)

        qvalue, hidden, message = self.forward(state, hidden)
        # exp_q_value = torch.exp(qvalue / temperature)
        # softmax_q_value = exp_q_value / (torch.sum(exp_q_value + 1e-10))
        softmax_q_value = torch.softmax(qvalue / temp, dim = -1)
        _, action = torch.max(qvalue, 2)
        if(return_softmax_q_values):
            return softmax_q_value.cpu(), action.cpu(), hidden, message
        else:
            return qvalue.cpu(), action.cpu(), hidden, message

    def get_stochastic_action(self, state, hidden, return_softmax_q_values = True, temp = 1.0):
        state = state.unsqueeze(0).unsqueeze(0)

        qvalue, hidden, message = self.forward(state, hidden)

        # exp_q_value = torch.exp(qvalue / temperature)
        # softmax_q_value = exp_q_value / (torch.sum(exp_q_value + 1e-10))
        softmax_q_value = torch.softmax(qvalue / temp, dim = -1)

        action = torch.distributions.Categorical(softmax_q_value.squeeze()).sample()
        if(return_softmax_q_values):
            return softmax_q_value.cpu(), action.cpu(), hidden, message
        else:
            return qvalue.cpu(), action.cpu(), hidden, message


class R2D2Agent():
    def __init__(self, num_inputs, num_actions, num_comm_bits, memory, local_buffer, lr, batch_size, device, ct_util_dict):
        self.num_inputs = num_inputs
        self.num_actions = num_actions
        self.num_comm_bits = num_comm_bits
        self.online_net = R2D2(num_inputs, num_actions, num_comm_bits)
        self.target_net = R2D2(num_inputs, num_actions, num_comm_bits)
        self.update_target_model()
        self.online_net.to(device)
        self.target_net.to(device)
        self.online_net.train()
        self.target_net.train()
        if(ct_util_dict['ct_util_type'] == 'dru'):
            self.ct_utilizer = DRU(ct_util_dict['sigma'], comm_narrow = ct_util_dict['comm_narrow'], hard = ct_util_dict['hard'])
        elif(ct_util_dict['ct_util_type'] == 'gb_softmax'):
            self.ct_utilizer = Gumbel_Softmax(ct_util_dict['tau'], ct_util_dict['hard'])
        elif(ct_util_dict['ct_util_type'] == 'rao_gb_softmax'):
            self.ct_utilizer = Rao_Gumbel_Softmax(ct_util_dict['tau'], ct_util_dict['k'])
        self.lr = lr
        self.batch_size = batch_size
        self.optimizer = optim.Adam(self.online_net.parameters(), lr = lr)

        # Replay objects
        self.memory = memory
        self.local_buffer = local_buffer

    def update_target_model(self):
        # Target <- Net
        self.target_net.load_state_dict(self.online_net.state_dict())

    def train_model(self, obl = False, use_mi_loss = False, use_only_mi_loss = False):
        batch, indexes, lengths = self.memory.sample(self.batch_size)
        loss, td_error = self.online_net.train_model(self.online_net, self.target_net, self.optimizer, batch, lengths, obl, use_mi_loss, use_only_mi_loss)
        self.memory.update_prior(indexes, td_error.cpu(), lengths)
        return loss, td_error

    def get_action(self, obs, epsilon, hidden, discretize_message = True):
        policy, action, hidden, message = self.online_net.get_action(obs, hidden)
        message = self.ct_utilizer.forward(message, not discretize_message)

        if np.random.rand() <= epsilon:
            return policy, np.random.randint(0, self.num_actions), hidden, message
        else:
            return policy, action, hidden, message

    def push_to_memory(self, use_mi_loss = False):
        batch, lengths = self.local_buffer.sample()
        if(use_mi_loss):
            td_error, _ = R2D2.get_mi_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        else:
            td_error = R2D2.get_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        self.memory.push(td_error.cpu(), batch, lengths)

    def save_model(self, filename_header):
        torch.save(self.online_net.state_dict(), filename_header + "_online_net.pt")
        torch.save(self.target_net.state_dict(), filename_header + "_target_net.pt")

class OBLR2D2Agent(R2D2Agent):
    def __init__(self, num_inputs, num_actions, num_comm_bits, iql_memory, iql_buffer, memory, local_buffer, lr, batch_size, device, agent_idx, belief_model, ct_util_dict, use_mi_loss = False, multi_pb = False):
        super(OBLR2D2Agent, self).__init__(num_inputs, num_actions, num_comm_bits, memory, local_buffer, lr, batch_size, device, ct_util_dict)

        self.iql_memory = iql_memory
        self.iql_buffer = iql_buffer

        self.agent_idx = agent_idx
        self.other_agent_idx = 1 if self.agent_idx == 0 else 0
        # Belief model of another agent
        self.belief_model = belief_model

        # Pseudo env for OBL sampling
        if(multi_pb == False):
            self.pseudo_env = PBCMaze(env_args = env_config)
        else:
            self.pseudo_env = MultiPBCMaze(env_args = multi_env_config)

        # Whether to add MI term in the loss
        self.use_mi_loss = use_mi_loss

    def get_action(self, obs, hidden, argmax = False, return_softmax_q_values = True, discretize_message = True, temp = 1.0):
        if(argmax):
            policy, action, hidden, message = self.online_net.get_action(obs, hidden, return_softmax_q_values, temp)
        else:
            policy, action, hidden, message = self.online_net.get_stochastic_action(obs, hidden, return_softmax_q_values, temp)

        message = self.ct_utilizer.forward(message, not discretize_message)
        return policy, action, hidden, message

    def get_iql_action(self, obs, epsilon, hidden, discretize_message = True):
        policy, action, hidden, message = self.online_net.get_action(obs, hidden)
        message = self.ct_utilizer.forward(message, not discretize_message)
        explored = False
        if np.random.rand() <= epsilon:
            explored = True
            return policy, np.random.randint(0, self.num_actions), hidden, message, explored
        else:
            return policy, action, hidden, message, explored

    def obl_sampling(self, first_hidden, first_next_hidden, first_policy, action, curr_env_config, other_agent, other_agent_hidden, message = None):
        """
        We need the other agent's policy to sample action
        Do we need to multiply the target with policy probability and belief probability?
        """
        prob, belief = self.belief_model.sample_belief()
        # Load in the belief for fictitious transition
        self.pseudo_env.load_env_config_obl(self.other_agent_idx, belief, curr_env_config)
        # This agent moves in the pseudo-environment
        first_obs = torch.Tensor(self.pseudo_env.get_obs(self.agent_idx)).to(device)
        if(self.agent_idx == 0):
            if action == CTDU_SEND:
                first_reward, done, first_info = self.pseudo_env.step(self.agent_idx, convert_msg_to_actions(message.size(-1), message), policy = first_policy)
            else:
                first_reward, done, first_info = self.pseudo_env.step(self.agent_idx, action, policy = first_policy)
        else:
            first_reward, done, first_info = self.pseudo_env.step(self.agent_idx, action, policy = first_policy)
        third_obs = first_obs
        if(not done):
            # The other agent moves in the pseudo-environment
            second_obs = torch.Tensor(self.pseudo_env.get_obs(self.other_agent_idx)).to(device)
            second_policy, second_action, second_next_hidden, second_message = other_agent.online_net.get_stochastic_action(second_obs, other_agent_hidden)
            if(self.other_agent_idx == 0):
                if(second_action == CTDU_SEND):
                    second_reward, done, second_info = self.pseudo_env.step(self.other_agent_idx, convert_msg_to_actions(second_message.size(-1), second_message), policy = second_policy.squeeze().detach().numpy())
                else:
                    second_reward, done, second_info = self.pseudo_env.step(self.other_agent_idx, second_action, policy = second_policy.squeeze().detach().numpy())
            else:
                second_reward, done, second_info = self.pseudo_env.step(self.other_agent_idx, second_action, policy = second_policy.squeeze().detach().numpy())
            if(not done):
                # Compute target
                third_obs = torch.Tensor(self.pseudo_env.get_obs(self.agent_idx)).to(device)
                qvalue, _, _= self.target_net.forward(third_obs.unsqueeze(0).unsqueeze(0), first_next_hidden)
                # UNSURE: should not discount second reward because it's the step by the second agent?
                # target = first_reward + gamma * second_reward + (gamma ** 2) * torch.max(qvalue.squeeze()).detach()
                target = first_reward + second_reward + (gamma) * torch.max(qvalue.squeeze()).detach()
                # target = first_reward + (gamma) * torch.max(qvalue.squeeze()).detach()
            else:
                target = first_reward + second_reward
                # target = first_reward
        else:
            target = first_reward
        mask = 0 if done else 1
        if(self.use_mi_loss):
            self.local_buffer.push(first_obs, third_obs, target, action, mask, torch.Tensor(first_info['mi_term_masks']).to(device), first_hidden)
        else:
            self.local_buffer.push(first_obs, target, action, mask, first_hidden)

    def obl_sampling_flat(self, first_hidden, first_next_hidden, first_policy, action, curr_env_config, other_agent, other_agent_hidden):
        """
        We need the other agent's policy to sample action
        Do we need to multiply the target with policy probability and belief probability?
        """
        prob, belief = self.belief_model.sample_belief()
        # Load in the belief for fictitious transition
        self.pseudo_env.load_env_config_obl(self.other_agent_idx, belief, curr_env_config)
        # This agent moves in the pseudo-environment
        first_obs = torch.Tensor(self.pseudo_env.get_obs(self.agent_idx)).to(device)
        first_reward, done, first_info = self.pseudo_env.step(self.agent_idx, action, policy = first_policy)
        third_obs = first_obs
        if(not done):
            # The other agent moves in the pseudo-environment
            second_obs = torch.Tensor(self.pseudo_env.get_obs(self.other_agent_idx)).to(device)
            second_policy, second_action, second_next_hidden, second_message = other_agent.online_net.get_stochastic_action(second_obs, other_agent_hidden)
            second_reward, done, second_info = self.pseudo_env.step(self.other_agent_idx, second_action, policy = second_policy.squeeze().detach().numpy())
            if(not done):
                # Compute target
                third_obs = torch.Tensor(self.pseudo_env.get_obs(self.agent_idx)).to(device)
                qvalue, _, _= self.target_net.forward(third_obs.unsqueeze(0).unsqueeze(0), first_next_hidden)
                # UNSURE: should not discount second reward because it's the step by the second agent?
                # target = first_reward + gamma * second_reward + (gamma ** 2) * torch.max(qvalue.squeeze()).detach()
                target = first_reward + second_reward + (gamma) * torch.max(qvalue.squeeze()).detach()
                # target = first_reward + (gamma) * torch.max(qvalue.squeeze()).detach()
            else:
                target = first_reward + second_reward
                # target = first_reward
        else:
            target = first_reward
        mask = 0 if done else 1

        if(self.use_mi_loss):
            self.local_buffer.push(first_obs, third_obs, target, action, mask, torch.Tensor(first_info['mi_term_masks']).to(device), first_hidden)
        else:
            self.local_buffer.push(first_obs, target, action, mask, first_hidden)


    def train_iql_model(self, obl = False, use_mi_loss = False):
        batch, indexes, lengths = self.iql_memory.sample(self.batch_size)
        loss, td_error = self.online_net.train_model(self.online_net, self.target_net, self.optimizer, batch, lengths, obl, use_mi_loss)
        self.iql_memory.update_prior(indexes, td_error.cpu(), lengths)
        return loss, td_error

    def push_to_memory(self):
        batch, lengths = self.local_buffer.sample()
        td_error, _ = R2D2.get_obl_td_error(R2D2, self.online_net, batch, lengths)
        self.memory.push(td_error.cpu(), batch, lengths)

    def push_to_iql_memory(self, use_mi_loss = False):
        batch, lengths = self.iql_buffer.sample()
        if(use_mi_loss):
            td_error, _ = R2D2.get_mi_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        else:
            td_error = R2D2.get_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        self.iql_memory.push(td_error.cpu(), batch, lengths)


class DIAL_OBLR2D2Agent(OBLR2D2Agent):
    def __init__(self, num_inputs, num_actions, num_comm_bits, iql_memory, iql_buffer, memory, local_buffer, dial_memory, dial_buffer, lr, batch_size, dial_batch_size, device, agent_idx, belief_model, ct_util_dict, use_mi_loss = False, multi_pb = False):
        super(DIAL_OBLR2D2Agent, self).__init__(num_inputs, num_actions, num_comm_bits, iql_memory, iql_buffer, memory, local_buffer, lr, batch_size, device, agent_idx, belief_model, ct_util_dict, use_mi_loss, multi_pb)
        self.dial_memory = dial_memory
        self.dial_buffer = dial_buffer
        self.dial_batch_size = dial_batch_size
        # self.dial_optimizer = optim.Adam(self.online_net.parameters(), lr = dial_lr)

    # Old, suspect to have stale target problem
    def dial_precompute_target(self, dial_transitions, other_agent):
        # Compute target for dial and add to local buffer
        """
        dial_transitions: [a0_t, a1_t+1, both_in_booth_flag, a0_t+2, a1_t+3]
        We are trying to compute for a1 at t+1 so it's equal to reward of a1_t+1 + reward of a0_t+2 + q_max for a1 and t+3
        """
        if(len(dial_transitions) != 6):
            print(dial_transitions)
            print("something's wrong")
            exit()
        with torch.no_grad():
            # This is not using the last transition's reward
            qvalues, _, _, _  = other_agent.get_action(dial_transitions[-1][0], dial_transitions[-1][1], return_softmax_q_values = False)
            # agent 1's reward + agent 0's reward + agent 1 's max q value'
            # target = dial_transitions[1][5] + dial_transitions[3][5] + torch.max(qvalues)
            target = dial_transitions[1][5] + torch.max(qvalues)
        self.dial_buffer.push(dial_transitions[0][0], dial_transitions[0][3], dial_transitions[0][2], dial_transitions[0][5], dial_transitions[0][6],
        dial_transitions[0][4], dial_transitions[1][0], dial_transitions[1][1], dial_transitions[1][2], target, dial_transitions[1][6],
        dial_transitions[2]
        )

        # Clear memory if done
        if(dial_transitions[-1][6] == 0 or dial_transitions[-2][6] == 0):
            self.dial_buffer.clear_local()

    def dial_push_to_local_buffer(self, dial_transitions, other_agent):
        # Compute target for dial and add to local buffer
        """
        dial_transitions: [a0_t, a1_t+1, both_in_booth_flag, a0_t+2, a1_t+3]
        We are trying to compute for a1 at t+1 so it's equal to reward of a1_t+1 + reward of a0_t+2 + q_max for a1 and t+3
        """
        if(len(dial_transitions) != 6):
            print(dial_transitions)
            print("something's wrong")
            exit()

        # target_reward = dial_transitions[1][5] + dial_transitions[3][5] + gamma * dial_transitions[4][5]
        target_reward = dial_transitions[1][5] + gamma * dial_transitions[4][5]
        self.dial_buffer.push(dial_transitions[0][0], dial_transitions[0][3], dial_transitions[0][2], dial_transitions[0][5], dial_transitions[0][6], dial_transitions[0][1],
        dial_transitions[1][0], dial_transitions[1][1], dial_transitions[1][2], dial_transitions[1][6],
        target_reward, dial_transitions[4][3], dial_transitions[4][1], dial_transitions[4][6],
        dial_transitions[2]
        )

        if(dial_transitions[3][6] == 0 or dial_transitions[4][6] == 0):
            # Add terminal transition
            terminal_target_reward = dial_transitions[4][5]
            self.dial_buffer.push(dial_transitions[3][0], dial_transitions[3][3], dial_transitions[3][2], dial_transitions[3][5], dial_transitions[3][6], dial_transitions[3][1],
            dial_transitions[4][0], dial_transitions[4][1], dial_transitions[4][2], dial_transitions[4][6],
            terminal_target_reward, dial_transitions[4][3], dial_transitions[4][1], dial_transitions[4][6],
            dial_transitions[5]
            )
            # Clear memory if done
            self.dial_buffer.clear_local()

    def dial_compute_error(self, batch, lengths, other_agent, use_precompute_target = False):
        # use_precompute_target is no longer used
        batch_size = torch.stack(batch.state).size()[0]
        states = torch.stack(batch.state).view(batch_size, sequence_length, self.online_net.num_inputs)
        other_states = torch.stack(batch.other_state).view(batch_size, sequence_length, other_agent.online_net.num_inputs)
        actions = torch.stack(batch.action).view(batch_size, sequence_length, -1).long().to(device)
        other_actions = torch.stack(batch.other_action).view(batch_size, sequence_length, -1).long().to(device)
        rnn_state = torch.stack(batch.rnn_state).view(batch_size, sequence_length, 2, -1)
        communicated = torch.stack(batch.communicated).view(batch_size, sequence_length, -1).to(device)
        communicated = communicated > 0
        target_states = torch.stack(batch.target_state).view(batch_size, sequence_length, other_agent.online_net.num_inputs)
        target_masks = torch.stack(batch.target_mask).view(batch_size, sequence_length, -1).to(device)
        target_rewards = torch.stack(batch.target_reward).view(batch_size, sequence_length, -1).to(device)

        [h0, c0] = rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        h0 = h0.unsqueeze(0).detach()
        c0 = c0.unsqueeze(0).detach()
        [h1, c1] = rnn_state[:, 1, :, :].transpose(0, 1).contiguous()
        h1 = h1.unsqueeze(0).detach()
        c1 = c1.unsqueeze(0).detach()
        other_rnn_state = torch.stack(batch.other_hidden).view(batch_size, sequence_length, 2, -1)
        [other_h0, other_c0] = other_rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        other_h0 = other_h0.unsqueeze(0).detach()
        other_c0 = other_c0.unsqueeze(0).detach()
        target_rnn_state = torch.stack(batch.target_hidden).view(batch_size, sequence_length, 2, -1)
        [target_h0, target_c0] = other_rnn_state[:, 0, :, :].transpose(0, 1).contiguous()
        target_h0 = target_h0.unsqueeze(0).detach()
        target_c0 = target_c0.unsqueeze(0).detach()

        a0_pred, _, a0_message= self.online_net(states, (h0, c0))

        # To use the whole episode for DIAL :
        # 1. Flatten the predictions
        # 2. Flatten the other agent's observations
        # 3. Assign 2 to 1 and reshape it back to do back propagation
        # 4. Compute target and filter by length
        flattned_a0_message = a0_message.view(-1, self.num_comm_bits)
        flattened_communicated = communicated.view(-1, 1)

        communicated_a0_message = flattned_a0_message[flattened_communicated.squeeze()]

        m = self.ct_utilizer.forward(communicated_a0_message, True)

        # Remove the message from the environment in agent 1's state and concatenate m
        flattened_other_states = other_states.view(-1, other_states.size(2))
        flattened_other_states[flattened_communicated.squeeze()] = torch.cat((flattened_other_states[flattened_communicated.squeeze()][:, : -self.num_comm_bits], m), 1)

        # flattened_other_states[flattened_communicated.squeeze()][:, -2 :] = m
        dial_other_states = flattened_other_states.view(other_states.size())

        a1_pred, _, _ = other_agent.online_net(dial_other_states, (other_h0, other_c0))
        a1_pred = a1_pred.gather(2, other_actions)

        a1_targets = target_rewards + gamma * target_masks * torch.max(other_agent.target_net(target_states, (target_h0, target_c0))[0], dim = 2)[0].unsqueeze(2)

        td_error = a1_pred - a1_targets.detach()

        # Zero beyond length error here
        for idx, length in enumerate(lengths):
            td_error[idx][length-burn_in_length:][:] = 0

        return td_error

    def train_model_dial(self, other_agent):
        batch, indexes, lengths = self.dial_memory.sample(self.dial_batch_size)
        # Perform DIAL
        td_error = self.dial_compute_error(batch, lengths, other_agent)
        loss = pow(td_error, 2).mean()
        #self.dial_optimizer.zero_grad()
        self.optimizer.zero_grad()
        other_agent.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(parameters=list(self.online_net.parameters()), max_norm=10)
        clip_grad_norm_(parameters=list(other_agent.online_net.parameters()), max_norm=10)
        self.optimizer.step()
        other_agent.optimizer.step()
        self.dial_memory.update_prior(indexes, td_error.cpu(), lengths)
        # print("dial loss: {}".format(loss))
        return loss.detach()


    def push_to_dial_memory(self, other_agent):
        batch, lengths = self.dial_buffer.sample()
        with torch.no_grad():
            td_error = self.dial_compute_error(batch, lengths, other_agent)
        self.dial_memory.push(td_error.cpu(), batch, lengths)

    def push_to_iql_memory(self, use_mi_loss = False):
        batch, lengths = self.iql_buffer.sample()
        if(use_mi_loss):
            td_error, _ = R2D2.get_mi_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        else:
            td_error = R2D2.get_td_error(R2D2, self.online_net, self.target_net, batch, lengths)
        self.iql_memory.push(td_error.cpu(), batch, lengths)
