import torch
import torch.nn.functional as F
from torch import nn

from models import MLP
from action_utils import select_action, translate_action

import numpy as np
import math

class CommNetMLP(nn.Module):
    """
    MLP based CommNet. Uses communication vector to communicate info
    between agents
    """
    def __init__(self, args, num_inputs):
        """Initialization method for this class, setup various internal networks
        and weights

        Arguments:
            MLP {object} -- Self
            args {Namespace} -- Parse args namespace
            num_inputs {number} -- Environment observation dimension for agents
        """

        super(CommNetMLP, self).__init__()
        self.args = args
        self.nagents = args.nagents
        self.hid_size = args.hid_size
        self.noise_size = args.noise_size
        self.comm_passes = args.comm_passes
        self.recurrent = args.recurrent
        self.hard_consensus = args.hard_consensus
        self.last_consensus = args.last_consensus
        self.ir_amplifier = args.ir_amplifier
        self.consensus_prior = args.consensus_prior
        self.noise_mode = args.noise_mode

        self.continuous = args.continuous
        if self.continuous:
            self.action_mean = nn.Linear(args.hid_size, args.dim_actions)
            self.action_log_std = nn.Parameter(torch.zeros(1, args.dim_actions))
        else:
            self.heads = nn.ModuleList([nn.Linear(args.hid_size, o)
                                        for o in args.naction_heads])
        self.init_std = args.init_std if hasattr(args, 'comm_init_std') else 0.2

        # Mask for communication
        if self.args.comm_mask_zero:
            self.comm_mask = torch.zeros(self.nagents, self.nagents)
        else:
            self.comm_mask = torch.ones(self.nagents, self.nagents) \
                            - torch.eye(self.nagents, self.nagents)
            
            
        # For mulitple communities (e.g., Byzantine mode)
        if args.byzantine:
            self.ncommunities = 2
            self.nfault = args.nfault
            self.communities = [self.nagents - args.nfault, args.nfault]
        else:
            self.ncommunities = 1
            self.communities = [self.nagents]


        # Since linear layers in PyTorch now accept * as any number of dimensions
        # between last and first dim, num_agents dimension will be covered.
        # The network below is function r in the paper for encoding
        # initial environment stage
        
        if self.ncommunities == 1:
            self.encoder = nn.Linear(num_inputs, args.hid_size)
        else:
            self.encoders = nn.ModuleList([nn.Linear(num_inputs, args.hid_size)
                                           for _ in range(self.ncommunities)])

        # if self.args.env_name == 'starcraft':
        #     self.state_encoder = nn.Linear(num_inputs, num_inputs)
        #     self.encoder = nn.Linear(num_inputs * 2, args.hid_size)
        if args.recurrent:
            if self.ncommunities == 1:
                self.hidd_encoder = nn.Linear(args.hid_size, args.hid_size)
            else:
                self.hidd_encoders = nn.ModuleList([nn.Linear(args.hid_size, args.hid_size)
                                                    for _ in range(self.ncommunities)])

        if args.recurrent:
            self.init_hidden(args.batch_size)
            
            if self.ncommunities == 1:
                self.f_module = nn.LSTMCell(args.hid_size, args.hid_size)
                self.g_module = nn.Linear(args.hid_size + args.noise_size, args.hid_size)
            else:
                self.f_modules = nn.ModuleList([nn.LSTMCell(args.hid_size, args.hid_size)
                                                for _ in range(self.ncommunities)])
                self.g_modules = nn.ModuleList([nn.Linear(args.hid_size + args.noise_size, args.hid_size)
                                                for _ in range(self.ncommunities)])
        else:
            if args.share_weights:
                if self.ncommunities == 1:
                    self.f_module = nn.Linear(args.hid_size, args.hid_size)
                    self.f_modules = nn.ModuleList([self.f_module
                                                    for _ in range(self.comm_passes)])
                    raise(Exception("Not Implemented G module"))
                else:
                    self.f_modules_tmp = [nn.Linear(args.hid_size, args.hid_size) for _ in range(self.ncommunities)]
                    self.f_modules = nn.ModuleList([self.f_modules_tmp[i]
                                                    for i in range(self.ncommunities) for j in range(self.comm_passes)])
                    raise(Exception("Not Implemented G module"))
                    
            else:
                self.f_modules = nn.ModuleList([nn.Linear(args.hid_size, args.hid_size)
                                                    for _ in range(self.comm_passes*self.ncommunities)])
                raise(Exception("Not Implemented G module"))
        
        # else:
            # raise RuntimeError("Unsupported RNN type.")

            
        # Our main function for converting current hidden state to next state
        # self.f = nn.Linear(args.hid_size, args.hid_size)
        if args.share_weights:
            if self.ncommunities == 1:
                self.C_module = nn.Linear(args.hid_size, args.hid_size)
                self.C_modules = nn.ModuleList([self.C_module
                                                for _ in range(self.comm_passes)])
            else:
                self.C_module_tmp = [nn.Linear(args.hid_size, args.hid_size) for _ in range(self.ncommunities)] 
                self.C_modules = nn.ModuleList([self.C_module_tmp[i]
                                                for i in range(self.ncommunities) for _ in range(self.comm_passes)])
                
        else:
            self.C_modules = nn.ModuleList([nn.Linear(args.hid_size, args.hid_size)
                                            for _ in range(self.comm_passes*self.ncommunities)])
        # self.C = nn.Linear(args.hid_size, args.hid_size)

        # initialise weights as 0
        if args.comm_init == 'zeros':
            for i in range(len(self.C_modules)):
                self.C_modules[i].weight.data.zero_()
        self.tanh = nn.Tanh()
        
        self.intrinsic_reward = None

        # print(self.C)
        # self.C.weight.data.zero_()
        # Init weights for linear layers
        # self.apply(self.init_weights)

        if self.ncommunities == 1:
            self.value_head = nn.Linear(self.hid_size, 1)
        else:
            self.value_heads = nn.ModuleList([nn.Linear(self.hid_size, 1)
                                             for _ in range(self.ncommunities)])


    def get_agent_mask(self, batch_size, info):
        n = self.nagents
#         print("info", info)

        if 'alive_mask' in info:
            agent_mask = torch.from_numpy(info['alive_mask'])
            num_agents_alive = agent_mask.sum()
        else:
            agent_mask = torch.ones(n)
            num_agents_alive = n

        agent_mask = agent_mask.view(1, 1, n)
        agent_mask = agent_mask.expand(batch_size, n, n).unsqueeze(-1)

        return num_agents_alive, agent_mask

    def forward_state_encoder(self, x):
        hidden_state, cell_state = None, None

        if self.args.recurrent:
            x, extras = x
            if self.ncommunities == 1:
                x = self.encoder(x)
            elif self.ncommunities == 2:
                x = torch.cat( 
                        ( self.encoders[0](x[:, :self.communities[0]]), 
                          self.encoders[1](x[:, self.communities[0]:])   ), dim=1)
            else:
                raise(Exception("Not Implemented"))

            if self.args.rnn_type == 'LSTM':
                hidden_state, cell_state = extras
            else:
                hidden_state = extras
            # hidden_state = self.tanh( self.hidd_encoder(prev_hidden_state) + x)
        else:
            if self.ncommunities == 1:
                x = self.encoder(x)
                x = self.tanh(x)
            else:
                raise(Exception("Not Implemented"))
            hidden_state = x

        return x, hidden_state, cell_state


    def forward(self, x, info={}):
        """Forward function for CommNet class, expects state, previous hidden
        and communication tensor.
        B: Batch Size: Normally 1 in case of episode
        N: number of agents

        Arguments:
            x {tensor} -- State of the agents (N x num_inputs)
            prev_hidden_state {tensor} -- Previous hidden state for the networks in
            case of multiple passes (1 x N x hid_size)
            comm_in {tensor} -- Communication tensor for the network. (1 x N x N x hid_size)

        Returns:
            tuple -- Contains
                next_hidden {tensor}: Next hidden state for network
                comm_out {tensor}: Next communication tensor
                action_data: Data needed for taking next action (Discrete values in
                case of discrete, mean and std in case of continuous)
                v: value head
        """

        noise = [torch.randn_like(x[0]), (torch.randn_like(x[1][0]), torch.randn_like(x[1][1]))]
        x, hidden_state, cell_state = self.forward_state_encoder(x)
        adversarial_state = self.forward_state_encoder(noise)[1]
        
        batch_size = x.size()[0]
        n = self.nagents

        num_agents_alive, agent_mask = self.get_agent_mask(batch_size, info)
        
        comm_action = None
        m = 1

        # Hard Attention - action whether an agent communicates or not
        if self.args.hard_attn:
            comm_action = torch.tensor(info['comm_action']) 
            comm_action_mask = comm_action.expand(batch_size, n, n).unsqueeze(-1)
            # action 1 is talk, 0 is silent i.e. act as dead for comm purposes.
            agent_mask *= comm_action_mask.double()
            m = comm_action.double().view(batch_size, n, 1).expand(batch_size, n, self.hid_size)
        #m = agent_mask[:,:,0].expand(batch_size, n, self.hid_size)
        agent_mask_transpose = agent_mask.transpose(1, 2)

        # Adversarial Attack
        if self.noise_mode == 1:
            noise = torch.randint(2, (batch_size * n, self.noise_size) )
            if self.ncommunities == 1:
                reported_state = self.tanh( self.g_module( torch.cat((hidden_state, noise), dim=1) ) )
            elif self.ncommunities == 2:
                sep = self.communities[0]
                h0, h1 = hidden_state[:sep], hidden_state[sep:]
                rep0 = self.g_modules[0]( torch.cat((h0, noise[:sep]), dim=1) )
                rep1 = self.g_modules[1]( torch.cat((h1, noise[sep:]), dim=1) )
                reported_state = self.tanh( torch.cat((rep0, rep1), dim=0) )
            else:
                raise(Exception("Not Implemented"))
        
        reward = None
        for i in range(self.comm_passes):
            # Choose current or prev depending on recurrent
            if self.noise_mode in [0, 2]:
                comm = hidden_state.view(batch_size, n, self.hid_size) if self.args.recurrent else hidden_state
            else:
                comm = reported_state.view(batch_size, n, self.hid_size) if self.args.recurrent else reported_state
            
            if self.noise_mode == 2:
                adv_comm = adversarial_state.view(batch_size, n, self.hid_size) if self.args.recurrent else adversarial_state
            
            
            if self.noise_mode == 2:
                masked_comm = comm * m + adv_comm * (1-m)
                peer_prediction = (masked_comm + 1.0)/2.0 
            else:
                peer_prediction = (comm*m + 1.0)/2.0 
            consensus = peer_prediction.mean(dim=1).expand(batch_size, n, self.hid_size)
        
            if self.hard_consensus:
                consensus = ( torch.sign(consensus - 0.5) + 1.0) / 2.0
            else:
                c_amplifier = self.consensus_prior
                consensus = torch.clamp((consensus - 0.5)*c_amplifier, -1, 1)/2.0 + 0.5 #=> [0, 1]
            

            # Get the next communication vector based on next hidden state
            comm = comm.unsqueeze(-2).expand(-1, n, n, self.hid_size)

            # Create mask for masking self communication
            mask = self.comm_mask.view(1, n, n)
            mask = mask.expand(comm.shape[0], n, n)
            mask = mask.unsqueeze(-1)

            mask = mask.expand_as(comm)
            comm = comm * mask

            # Mask comm_in
            # Mask communcation from dead agents
            comm = comm * agent_mask
            # Mask communication to dead agents
            comm = comm * agent_mask_transpose
            
            
            if self.noise_mode == 2:
                comm = comm * m + adv_comm * (1-m)
                comm = comm * mask # selfcommunication
                if hasattr(self.args, 'comm_mode') and self.args.comm_mode == 'avg':
                    comm /= self.nagents
            else:
                if hasattr(self.args, 'comm_mode') and self.args.comm_mode == 'avg' \
                    and num_agents_alive > 1:
                    comm = comm / (num_agents_alive - 1)

            
            # Combine all of C_j for an ith agent which essentially are h_j
            comm_sum = comm.sum(dim=1) #=> [1, nagents, n_hid]
            
            entropy = - peer_prediction * torch.log( peer_prediction ) \
                      - (1-peer_prediction) * torch.log( 1-peer_prediction )
            cross_entropy = - consensus * torch.log( peer_prediction  ) \
                    - ( 1- consensus ) * torch.log(1-peer_prediction)
            score = - cross_entropy + np.log(2)
            score /= self.hid_size
            

            if self.last_consensus or reward is None:
                reward = score.mean(dim=2).squeeze()
            else:
                reward += score.mean(dim=2).squeeze()

            if comm_action is not None:
                reward *= comm_action.double()
                
            
            
            if self.ncommunities == 1:
                c = self.C_modules[i](comm_sum)
            elif self.ncommunities == 2:
                c = torch.cat( 
                        ( self.C_modules[i                 ](comm_sum[:, :self.communities[0]]), 
                          self.C_modules[i+self.comm_passes](comm_sum[:, self.communities[0]:]) ), dim=1)
            else:
                raise(Exception("Not Implemented"))
            
            if self.args.recurrent:
                # skip connection - combine comm. matrix and encoded input for all agents
                inp = x + c

                inp = inp.view(batch_size * n, self.hid_size)
                noise = torch.randint(2, (batch_size * n, self.noise_size) )

                if self.ncommunities == 1:
                    if self.noise_mode == 1:
                        reported_state = self.tanh( self.g_module( torch.cat((hidden_state, noise), dim=1) ) )
                    elif self.noise_mode == 2:
                        inp_noise = torch.randn_like(x) + c
                        inp_noise = inp_noise.view(batch_size * n, self.hid_size)
                        hidden_noise = torch.randn_like(hidden_state)
                        cell_noise   = torch.randn_like(cell_state)
                        adversarial_state = self.f_module(inp_noise, (hidden_noise, cell_noise))[0]
                    
                    output = self.f_module(inp, (hidden_state, cell_state))
                    hidden_state = output[0]
                    cell_state = output[1]
                    
                    
                elif self.ncommunities == 2:
                    sep = self.communities[0]
                    
                    h0, c0 = self.f_modules[0](inp[:sep], (hidden_state[:sep], cell_state[:sep]))
                    h1, c1 = self.f_modules[1](inp[sep:], (hidden_state[sep:], cell_state[sep:]))
                    
                    
                    hidden_state   = torch.cat((h0, h1), dim=0)
                    cell_state     = torch.cat((c0, c1), dim=0)
                    
                    if self.noise_mode == 1:
                        rep0 = self.g_modules[0]( torch.cat((h0, noise[:sep]), dim=1) )
                        rep1 = self.g_modules[1]( torch.cat((h1, noise[sep:]), dim=1) )
                        reported_state = self.tanh(torch.cat((rep0, rep1), dim=0))
                    elif self.noise_mode == 2:
                        raise(Exception("Not Implemented"))
                else:
                    raise(Exception("Not Implemented"))


            else: # MLP|RNN
                # Get next hidden state from f node
                # and Add skip connection from start and sum them
                if self.ncommunities == 1:
                    hidden_state = sum([x, self.f_modules[i](hidden_state), c])
                    hidden_state = self.tanh(hidden_state)
                else:
                    raise(Exception("Not Implemented"))

        if self.ncommunities == 1:
            value_head = self.value_head(hidden_state)
        elif self.ncommunities == 2:
            sep = self.communities[0]
            value_head = torch.cat((
                 self.value_heads[0](hidden_state[:sep]),
                 self.value_heads[1](hidden_state[sep:])), dim=0)
        else:
            raise(Exception("Not Implemented"))
        h = hidden_state.view(batch_size, n, self.hid_size)

        if self.continuous:
            action_mean = self.action_mean(h)
            action_log_std = self.action_log_std.expand_as(action_mean)
            action_std = torch.exp(action_log_std)
            # will be used later to sample
            action = (action_mean, action_log_std, action_std)
        else:
            # discrete actions
            action = [F.log_softmax(head(h), dim=-1) for head in self.heads]
        
        ir_amplifier = self.ir_amplifier #beta
    
        self.intrinsic_reward = reward / self.comm_passes * ir_amplifier
        if self.args.recurrent:
            return action, value_head, (hidden_state.clone(), cell_state.clone())
        else:
            return action, value_head

    def init_weights(self, m):
        if type(m) == nn.Linear:
            m.weight.data.normal_(0, self.init_std)

    def init_hidden(self, batch_size):
        # dim 0 = num of layers * num of direction
        return tuple(( torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True),
                       torch.zeros(batch_size * self.nagents, self.hid_size, requires_grad=True)))

