import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
USE_CUDA = torch.cuda.is_available()

class Encoder(nn.Module):
	def __init__(self, din, hidden_dim):
		super(Encoder, self).__init__()
		self.fc = nn.Linear(din, hidden_dim)

	def forward(self, x):
		embedding = F.leaky_relu(self.fc(x))
		return embedding

class Q_Net(nn.Module):
	def __init__(self, hidden_dim, dout):
		super(Q_Net, self).__init__()
		self.fc = nn.Linear(hidden_dim, dout)

	def forward(self, x):
		q = self.fc(x)
		return q

class DC2Net(nn.Module):
    def __init__(self,n_agent,num_inputs,hidden_dim,num_actions):
        super(DC2Net, self).__init__()

        self.encoder = Encoder(num_inputs,hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_actions) 

        self.gl_layer1_1 = nn.Linear(hidden_dim, hidden_dim)
        self.gl_layer1_2 = nn.Linear(hidden_dim, num_actions)
        self.gl_layer2_1 = nn.Linear(hidden_dim, hidden_dim)
        self.gl_layer2_2 = nn.Linear(hidden_dim, num_actions)
        self.gl_layer3_1 = nn.Linear(hidden_dim, hidden_dim)
        self.gl_layer3_2 = nn.Linear(hidden_dim, num_actions)
		
        self.q_net = Q_Net(num_actions*2,num_actions)

    def forward(self, x, mask):
        x = self.encoder(x)
        gl1 = F.leaky_relu(self.gl_layer1_1(self.team_pooling(x.detach())))
        gl2 = F.leaky_relu(self.gl_layer2_1(self.team_pooling(x.detach())))
        gl3 = F.leaky_relu(self.gl_layer3_1(self.team_pooling(x.detach())))
        x = F.leaky_relu(self.fc1(x))       
        gl1 = self.gl_layer1_2(self.team_pooling(x.detach()) + gl1) 
        gl2 = self.gl_layer2_2(self.team_pooling(x.detach()) + gl2) 
        gl3 = self.gl_layer3_2(self.team_pooling(x.detach()) + gl3) 
        gl = (gl1+gl2+gl3) / 3
        x = self.fc2(x)
        q = self.q_net(torch.cat([x, gl.repeat(1,x.shape[1],1)],dim=2))
		
        # regularization
        reg = torch.cosine_similarity(gl1,gl2,dim=2) + \
              torch.cosine_similarity(gl2,gl3,dim=2) + \
              torch.cosine_similarity(gl3,gl1,dim=2)
        return q, reg
        
    def team_pooling(self, x): 
        return torch.max(x, dim=1, keepdim=True).values









