# Slightly adapted from a version from Dr. Abhishek Das

import math

import torch
import torch.nn.functional as F
from torch import nn

class TarCommNetMLP(nn.Module):
    """
    MLP based CommNet. Uses communication vector to communicate info
    between agents
    """
    def __init__(self, args, num_inputs):
        """Initialization method for this class, setup various internal networks
        and weights

        Arguments:
            MLP {object} -- Self
            args {Namespace} -- Parse args namespace
            num_inputs {number} -- Environment observation dimension for agents
        """

        super(TarCommNetMLP, self).__init__()
        self.args = args
        self.nagents = args.nagents
        self.hid_size = args.hid_size
        self.comm_passes = args.comm_passes
        self.recurrent = args.recurrent
        self.device = "cpu"

        self.continuous = args.continuous
        if self.continuous:
            self.action_mean = nn.Linear(args.hid_size, args.dim_actions)
            self.action_log_std = nn.Parameter(torch.zeros(1, args.dim_actions))
        else:
            self.heads = nn.ModuleList([nn.Linear(args.hid_size, o)
                                        for o in args.naction_heads])
        self.init_std = args.init_std if hasattr(args, 'comm_init_std') else 0.2

        # Mask for communication
        if self.args.comm_mask_zero:
            self.comm_mask = torch.zeros(self.nagents, self.nagents)
        else:
            self.comm_mask = torch.ones(self.nagents, self.nagents) \
                            - torch.eye(self.nagents, self.nagents)


        # Since linear layers in PyTorch now accept * as any number of dimensions
        # between last and first dim, num_agents dimension will be covered.
        # The network below is function r in the paper for encoding
        # initial environment stage
        self.encoder = nn.Linear(num_inputs, args.hid_size)

        # if self.args.env_name == 'starcraft':
        #     self.state_encoder = nn.Linear(num_inputs, num_inputs)
        #     self.encoder = nn.Linear(num_inputs * 2, args.hid_size)
        if args.recurrent:
            self.hidd_encoder = nn.Linear(args.hid_size, args.hid_size)

        if args.recurrent:
            self.init_hidden(args.batch_size)
            self.f_module = nn.LSTMCell(args.hid_size, args.hid_size)

        else:
            if args.share_weights:
                self.f_module = nn.Linear(args.hid_size, args.hid_size)
                self.f_modules = nn.ModuleList([self.f_module
                                                for _ in range(self.comm_passes)])
            else:
                self.f_modules = nn.ModuleList([nn.Linear(args.hid_size, args.hid_size)
                                                for _ in range(self.comm_passes)])
        # else:
            # raise RuntimeError("Unsupported RNN type.")

        # Our main function for converting current hidden state to next state
        # self.f = nn.Linear(args.hid_size, args.hid_size)
        if args.share_weights:
            self.C_module = nn.Linear(args.hid_size, args.hid_size)
            self.C_modules = nn.ModuleList([self.C_module
                                            for _ in range(self.comm_passes)])
        else:
            self.C_modules = nn.ModuleList([nn.Linear(args.hid_size, args.hid_size)
                                            for _ in range(self.comm_passes)])
        # self.C = nn.Linear(args.hid_size, args.hid_size)

        # initialise weights as 0
        if args.comm_init == 'zeros':
            for i in range(self.comm_passes):
                self.C_modules[i].weight.data.zero_()
        self.tanh = nn.Tanh()

        # print(self.C)
        # self.C.weight.data.zero_()
        # Init weights for linear layers
        # self.apply(self.init_weights)

        self.value_head = nn.Linear(self.hid_size, 1)

        ######################################################
        # [TarMAC changeset] Attentional communication modules
        ######################################################

        self.state2query = nn.Linear(args.hid_size, 16)
        self.state2key = nn.Linear(args.hid_size, 16)
        self.state2value = nn.Linear(args.hid_size, args.hid_size)

        # Add fake quantization configuration
        if hasattr(args, 'use_fake_quantization') and args.use_fake_quantization:
            # Configure quantization parameters
            self.quant_bits = args.quant_bits
            self.quant_min = 0
            self.quant_max = 2**args.quant_bits - 1

    
    def apply_fake_quantization(self, tensor, mask):
        """
        Manual implementation of fake quantization with proper gradient flow
        """
        # Create a copy of the input tensor to avoid modifying the original
        original = tensor.clone()
        
        # Expand mask to match tensor dimensions
        if mask.shape != tensor.shape:
            mask_expanded = mask.expand_as(tensor)
        else:
            mask_expanded = mask
            
        binary_mask = mask_expanded.bool()
        
        # Get min/max values safely
        if torch.any(binary_mask):
            # Detach for min/max calculation to avoid affecting gradient computation
            masked_values = tensor[binary_mask].detach()
            min_val = masked_values.min()
            max_val = masked_values.max()
        else:
            min_val = torch.tensor(0.0, device=tensor.device)
            max_val = torch.tensor(1.0, device=tensor.device)
        
        if min_val == max_val:
            min_val = min_val - 0.01
            max_val = max_val + 0.01
        
        # Calculate scale and zero_point from detached values
        scale = (max_val - min_val) / (self.quant_max - self.quant_min)
        zero_point = torch.round(self.quant_min - min_val / scale)
        
        # Apply quantization formula without modifying the original
        scaled = tensor / scale + zero_point
        rounded = torch.round(scaled)
        clamped = torch.clamp(rounded, self.quant_min, self.quant_max)
        shifted = clamped - zero_point
        dequantized = shifted * scale
        
        # Use a non-in-place operation for the final result
        result = torch.where(binary_mask, dequantized, original)
        
        return result

    def compute_quantization_loss(self, tensor, mask):
        """
        Compute loss based on the quantized tensor values
        """
        # Expand mask to match tensor dimensions if needed
        if mask.shape != tensor.shape:
            mask_expanded = mask.expand_as(tensor)
        else:
            mask_expanded = mask
        
        binary_mask = mask_expanded.bool()
        
        # Only consider masked values
        if not torch.any(binary_mask):
            return torch.tensor(0.0, device=tensor.device)
        
        masked_tensor = tensor[binary_mask]
        
        # Compute the loss
        loss = torch.log2(2 * self.quant_max * torch.abs(masked_tensor) + 1).sum()
        
        return loss

    def compute_quantization_bits(self, tensor, mask):
        """
        Compute the actual bits used based on the tensor values
        """
        # Count number of values that would be transmitted
        # num_values = mask.sum()
        # Expand mask to match tensor dimensions if needed
        if mask.shape != tensor.shape:
            mask_expanded = mask.expand_as(tensor)
        else:
            mask_expanded = mask
            
        num_values = mask_expanded.sum()
        
        # Each value uses quant_bits bits
        total_bits = self.quant_bits * num_values
        
        return total_bits
    

    def get_agent_mask(self, batch_size, info):
        n = self.nagents

        if 'alive_mask' in info:
            agent_mask = torch.from_numpy(info['alive_mask'])
            num_agents_alive = agent_mask.sum()
        else:
            agent_mask = torch.ones(n)
            num_agents_alive = n
        agent_mask = agent_mask.view(1, 1, n)
        agent_mask = agent_mask.expand(batch_size, n, n).unsqueeze(-1).clone()

        return num_agents_alive, agent_mask

    def forward_state_encoder(self, x):
        hidden_state, cell_state = None, None

        if self.args.recurrent:
            x, extras = x
            x = self.encoder(x)

            if self.args.rnn_type == 'LSTM':
                hidden_state, cell_state = extras
            else:
                hidden_state = extras
            # hidden_state = self.tanh( self.hidd_encoder(prev_hidden_state) + x)
        else:
            x = self.encoder(x)
            x = self.tanh(x)
            hidden_state = x

        return x, hidden_state, cell_state


    def forward(self, x, info={}):
        # TODO: Update dimensions
        """Forward function for CommNet class, expects state, previous hidden
        and communication tensor.
        B: Batch Size: Normally 1 in case of episode
        N: number of agents

        Arguments:
            x {tensor} -- State of the agents (N x num_inputs)
            prev_hidden_state {tensor} -- Previous hidden state for the networks in
            case of multiple passes (1 x N x hid_size)
            comm_in {tensor} -- Communication tensor for the network. (1 x N x N x hid_size)

        Returns:
            tuple -- Contains
                next_hidden {tensor}: Next hidden state for network
                comm_out {tensor}: Next communication tensor
                action_data: Data needed for taking next action (Discrete values in
                case of discrete, mean and std in case of continuous)
                v: value head
        """

        # if self.args.env_name == 'starcraft':
        #     maxi = x.max(dim=-2)[0]
        #     x = self.state_encoder(x)
        #     x = x.sum(dim=-2)
        #     x = torch.cat([x, maxi], dim=-1)
        #     x = self.tanh(x)

        x, hidden_state, cell_state = self.forward_state_encoder(x)
        
        comm_loss = 0
        comm_bits = 0

        batch_size = x.size()[0]
        n = self.nagents

        num_agents_alive, agent_mask = self.get_agent_mask(batch_size, info)
        agent_mask_alive = agent_mask.clone()

        # Hard Attention - action whether an agent communicates or not
        if self.args.hard_attn:
            comm_action = torch.tensor(info['comm_action'])
            comm_action_mask = comm_action.expand(batch_size, n, n).unsqueeze(-1)
            # action 1 is talk, 0 is silent i.e. act as dead for comm purposes.
            agent_mask *= comm_action_mask.double()

        agent_mask_transpose = agent_mask.transpose(1, 2)

        for i in range(self.comm_passes):
            # Choose current or prev depending on recurrent
            comm = hidden_state.view(batch_size, n, self.hid_size) if self.args.recurrent else hidden_state

            if self.args.comm_mask_zero:
                comm_mask = torch.zeros_like(comm)
                comm = comm * comm_mask
            #########################################################
            # [TarMAC changeset] Don't expand same comm vector to all
            #########################################################
            # Get the next communication vector based on next hidden state
            # comm = comm.unsqueeze(-2).expand(-1, n, n, self.hid_size)
            #########################################################

            ########################################################
            # [TarMAC changeset] Removing self-communication masking
            ########################################################
            # # Create mask for masking self communication
            # mask = self.comm_mask.view(1, n, n)
            # mask = mask.expand(comm.shape[0], n, n)
            # mask = mask.unsqueeze(-1)
            # mask = mask.expand_as(comm)
            # comm = comm * mask
            ########################################################

            ############################################################
            # [TarMAC changeset] Replacing averaging with soft-attention
            ############################################################
            # if hasattr(self.args, 'comm_mode') and self.args.comm_mode == 'avg' \
            #     and num_agents_alive > 1:
            #     comm = comm / (num_agents_alive - 1)
            ############################################################

            # if info['comm_action'].sum() != 0:
            #     import pdb; pdb.set_trace()

            #########################################################
            # [TarMAC changeset] Attentional communication b/w agents
            #########################################################
            # compute q, k, c
            query = self.state2query(comm)
            key = self.state2key(comm)
            value = self.state2value(comm)

            # add noise to keys if comms channel enabled
            mask_active = (agent_mask * agent_mask_transpose)
            if self.args.use_comms_channel:
                key = key.unsqueeze(1).expand(batch_size, n, n, -1)
                noise = self.get_comms_noise(mask_active, key)
                key = key + noise

                comm_loss += self.compute_component_log_loss(key, mask_active)
                comm_bits += self.compute_num_bits_used(key, mask_active)

            elif hasattr(self.args, 'use_fake_quantization') and self.args.use_fake_quantization:
                key = key.unsqueeze(1).expand(batch_size, n, n, -1)
                key = self.apply_fake_quantization(key, mask_active)
                comm_loss += self.compute_quantization_loss(key, mask_active)
                comm_bits += self.compute_quantization_bits(key, mask_active)
                
            else:
                comm_bits += self.compute_num_bits_used(key.unsqueeze(1).expand(batch_size, n, n, -1), mask_active)
                comm_loss += self.compute_num_bits_used(key.unsqueeze(1).expand(batch_size, n, n, -1), mask_active)
            
            # scores
            scores = torch.matmul(query.unsqueeze(-2), key.transpose(
                -2, -1)) / math.sqrt(self.hid_size)
            scores = scores.squeeze(-2)
                
            # scores = scores.masked_fill(comm_action_mask.squeeze(-1) == 0, -1e9)
            # Use agent_mask instead of comm_action_mask to make this work in tj env
            scores = scores.masked_fill(agent_mask.squeeze(-1) == 0, -1e9)

            # softmax + weighted sum
            attn = F.softmax(scores, dim=-1)
            # if the scores are all -1e9 for all agents, the attns should be all 0 (fixed from the original version)
            attn = attn * agent_mask.squeeze(-1) # cannot use inplace operation *=
            comm = torch.matmul(attn, value)
            
            ####################################################
            # [TarMAC changeset] Incorporated this masking above
            ####################################################
            # # Mask comm_in
            # # Mask communcation from dead agents
            # comm = comm * agent_mask
            # # Mask communication to dead agents
            # comm = comm * agent_mask_transpose
            ###########################################################

            ###########################################################
            # [TarMAC changeset] Replaced this averaging with attention
            ###########################################################
            # # Combine all of C_j for an ith agent which essentially are h_j
            # comm_sum = comm.sum(dim=1)
            ###########################################################
            # for tj: dead agents do not receive messages
            # for tj: alive agents with no comm actions can receive messages (align with tarmac+ic3net in pp)
            comm *= agent_mask_alive.squeeze(-1)[:, 0].unsqueeze(-1).expand(batch_size, n, self.hid_size)
            mask_active = agent_mask_alive.squeeze(-1)[:, 0].unsqueeze(-1).expand(batch_size, n, self.hid_size)
            if self.args.use_comms_channel:
                noise = self.get_comms_noise(mask_active, comm)
                comm = comm + noise

                comm_loss += self.compute_component_log_loss(comm, mask_active)
                comm_bits += self.compute_num_bits_used(comm, mask_active)
            # Apply quantization based on settings
            elif hasattr(self.args, 'use_fake_quantization') and self.args.use_fake_quantization:
                comm = self.apply_fake_quantization(comm, mask_active)
                comm_loss += self.compute_quantization_loss(comm, mask_active)
                comm_bits += self.compute_quantization_bits(comm, mask_active)
            else:
                comm_loss += self.compute_num_bits_used(comm, mask_active)
                comm_bits += self.compute_num_bits_used(comm, mask_active)
            
            c = self.C_modules[i](comm)

            if self.args.recurrent:
                # skip connection - combine comm. matrix and encoded input for all agents
                inp = x + c

                inp = inp.view(batch_size * n, self.hid_size)

                output = self.f_module(inp, (hidden_state, cell_state))

                hidden_state = output[0]
                cell_state = output[1]

            else: # MLP|RNN
                # Get next hidden state from f node
                # and Add skip connection from start and sum them
                hidden_state = sum([x, self.f_modules[i](hidden_state), c])
                hidden_state = self.tanh(hidden_state)

        # v = torch.stack([self.value_head(hidden_state[:, i, :]) for i in range(n)])
        # v = v.view(hidden_state.size(0), n, -1)
        value_head = self.value_head(hidden_state)
        h = hidden_state.view(batch_size, n, self.hid_size)

        # if self.continuous:
        #     action_mean = self.action_mean(h)
        #     action_log_std = self.action_log_std.expand_as(action_mean)
        #     action_std = torch.exp(action_log_std)
        #     # will be used later to sample
        #     action = (action_mean, action_log_std, action_std)
        # else:
        #     # discrete actions
        #     action = [F.log_softmax(head(h), dim=-1) for head in self.heads]

        if self.continuous:
            action_mean = self.action_mean(h)
            action_log_std = self.action_log_std.expand_as(action_mean)
            action_std = torch.exp(action_log_std)
            # will be used later to sample
            action = (action_mean, action_log_std, action_std)
        else:
            # discrete actions
            # action = [F.log_softmax(head(h), dim=-1) for head in self.heads]
            raw_action_logits = [head(h) for head in self.heads]

            masked_actions = []

            # Handle the first head (environment action) with masking
            if 'avail_actions' in info:
                # Handle the environment action head (first head)
                env_action_logits = raw_action_logits[0]
                
                avail_mask = info['avail_actions']
                    
                # Convert to tensor if needed
                if not isinstance(avail_mask, torch.Tensor):
                    avail_mask = torch.as_tensor(avail_mask, dtype=torch.bool, device=env_action_logits.device)
                else:
                    avail_mask = avail_mask.bool()

                # Apply the mask before softmax
                mask_value = -1e8
                masked_logits = env_action_logits.clone()
                
                # Expand mask to match batch dimension
                masked_logits = masked_logits.masked_fill(~avail_mask, mask_value)
                
                # Apply softmax after masking
                masked_actions.append(F.log_softmax(masked_logits, dim=-1))
            else:
                # No masks available, just apply softmax
                masked_actions.append(F.log_softmax(raw_action_logits[0], dim=-1))

            # Always add the communication head without masking
            masked_actions.append(F.log_softmax(raw_action_logits[1], dim=-1))

            action = masked_actions

        if self.args.recurrent:
            return action, value_head, (hidden_state.clone(), cell_state.clone()), (comm_loss, comm_bits)
        else:
            return action, value_head, (comm_loss, comm_bits)

    def init_weights(self, m):
        if type(m) == nn.Linear:
            m.weight.data.normal_(0, self.init_std)

    def init_hidden(self, batch_size):
        # dim 0 = num of layers * num of direction
        return tuple(( torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True),
                       torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True)))
    
    def compute_component_log_loss(self, z, mask):
        """
        Computes the communication penalty loss given by log2(2 * |M| * |z| + 1)
        """
        M = self.args.num_messages
        
        loss = torch.log2(2 * M * z.abs() + 1)
        loss = loss * mask
        return torch.sum(loss)
    
    def compute_num_bits_used(self, target, mask):
        bits_used = mask.expand_as(target) * 32
        return torch.sum(bits_used)
    
    def get_comms_noise(self, mask, target):
        # calculate noise as per comms protocol
        noise = (torch.rand_like(target, device=target.device) - 0.5) * 2
        delta = (1 / self.args.num_messages)
        noise = noise * delta * 0.5

        # masking noise wherever not needed
        noise = noise * mask

        return noise
