import torch
from torch import nn
from torch.nn import functional as F
import time

class Encoder(torch.nn.Module):
    def __init__(self, n_nodes, n_games, hidden_dim):
        super(Encoder, self).__init__()
        self.n_nodes = n_nodes
        self.n_games = n_games
        self.hidden_dim = hidden_dim

    def forward(self, X):
        return NotImplemented

class MLPOnNodesEncoder(Encoder):
    def __init__(self, n_nodes, n_games, hidden_dim, dropout=0.):
        super(MLPOnNodesEncoder, self).__init__(n_nodes, n_games, hidden_dim)
        self.lin1 = torch.nn.Linear(n_games, hidden_dim)
        self.lin2 = torch.nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, X):
        Z = F.relu(self.lin1(X))
        Z = self.lin2(Z)

        return Z

class MLPEncoder(Encoder):
    def __init__(self, n_nodes, n_games, hidden_dim, dropout=0.):
        super(MLPEncoder, self).__init__(n_nodes, n_games, hidden_dim)
        self.lin1 = torch.nn.Linear(n_nodes * n_games, n_nodes * hidden_dim)
        self.lin2 = torch.nn.Linear(n_nodes * hidden_dim, n_nodes * hidden_dim)
    
    def forward(self, X):
        inp = X.reshape(-1, self.n_nodes * self.n_games)
        Z = F.relu(self.lin1(inp))
        Z = self.lin2(Z).reshape(-1, self.n_nodes, self.hidden_dim)

        return Z


class TransformerEncoder(Encoder):
    def __init__(self, n_nodes, n_games, hidden_dim, dropout=0., num_layers=1, transformer_feedforward_dim=100):
        super(TransformerEncoder, self).__init__(n_nodes, n_games, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=n_games, nhead=1, dropout=dropout, dim_feedforward=transformer_feedforward_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)
        self.lin = nn.Linear(n_games, hidden_dim)
    
    def forward(self, X):
        # inp: [n_nodes, n_graphs, n_games * 2]
        inp = torch.transpose(X, 0, 1)
        
        # z: [n_graphs, n_nodes, n_games * 2]
        Z = torch.transpose(self.encoder(inp), 0, 1)
        Z = self.lin(Z) 
        
        return Z


class PerGameTransformerEncoder(Encoder):
    def __init__(self, n_nodes, n_games, hidden_dim, dropout=0., num_layers=1, use_B=True, max_norm=100):
        super(PerGameTransformerEncoder, self).__init__(n_nodes, n_games, hidden_dim)

        hidden_dim = 10  # 10
        n_heads = 10

        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        self.W_Q = nn.Linear(hidden_dim, hidden_dim*n_heads, bias=False)
        self.W_K = nn.Linear(hidden_dim, hidden_dim*n_heads, bias=False)

        """self.mlp = nn.Sequential(
                                torch.nn.Linear(hidden_dim*n_heads + hidden_dim, 100),
                                nn.ReLU(inplace=True),
                                torch.nn.Linear(100, hidden_dim)
                            )"""
        self.mlp = nn.Sequential(
            torch.nn.Linear(hidden_dim * n_heads + hidden_dim, hidden_dim * n_heads),
            nn.ReLU(inplace=True),
            torch.nn.Linear(hidden_dim * n_heads, hidden_dim * n_heads)
        )

        self.lin = nn.Linear(1, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 1)
        self.linF = nn.Linear(hidden_dim, hidden_dim*n_heads)
        self.use_B = use_B
        self.hidden_dim = hidden_dim
        self.max_norm = max_norm

    def forward(self, X):
        # X: [n_graphs, n_nodes, n_games]
        # z: [n_graphs, n_nodes, n_games * 2]
        n_graphs, n_nodes, n_games = X.shape[0], X.shape[1], X.shape[2]

        X_ = X.unsqueeze(3)
        X_ = F.relu(self.lin(X_))  # n_graphs, n_nodes, n_games, hidden_dim

        # MAIN IDEA: compute attention considering the behavior of the agents in all games but refine the value of the action of an agent wrt the behavior of the other nodes in the same game

        # tic = time.time()
        m1 = self.W_Q(X_).reshape(n_graphs, n_nodes, n_games, self.n_heads, self.hidden_dim).permute(0, 3, 1, 2, 4)  # n_graphs, n_heads, n_nodes, n_games, hidden_dim
        m1 = m1.reshape(n_graphs*self.n_heads, n_nodes, -1)  # n_graphs*n_heads, n_nodes, n_games*hidden_dim
        m2 = self.W_K(X_).reshape(n_graphs, n_nodes, n_games, self.n_heads, self.hidden_dim).permute(0, 3, 2, 4, 1)   # n_graphs, n_heads, n_games, hidden_dim, n_nodes
        m2 = m2.reshape(n_graphs*self.n_heads, -1, n_nodes)   # n_graphs*n_heads, n_games*hidden_dim, n_nodes

        a_tilde = torch.bmm(m1, m2).reshape(n_graphs, self.n_heads, n_nodes, n_nodes)  # n_graphs, n_heads, n_nodes, n_nodes
        a = torch.nn.functional.softmax(a_tilde, dim=3)  # n_graphs, n_heads, n_nodes, n_nodes
        a = a.reshape(n_graphs, self.n_heads*n_nodes, n_nodes)

        X__ = X_.view(n_graphs, n_nodes, -1)  # n_graphs, n_nodes, n_games*hidden_dim
        s = torch.bmm(a, X__).reshape(n_graphs, self.n_heads, n_nodes, n_games, -1)  # n_graphs, n_nodes, n_games*hidden_dim
        s = s.permute(0, 2, 3, 1, 4).reshape(n_graphs, n_nodes, n_games, -1)  # n_graphs, n_nodes, n_games, n_heads*hidden_dim

        Z = self.linF(X_) + self.mlp(torch.cat([X_, s], dim=3))  # n_graphs, n_nodes, n_games, hidden_dim

        #Z = X_ + self.mlp(torch.cat([X_, s], dim=3))   # n_graphs, n_nodes, n_games, hidden_dim

        # Z = Z.reshape(n_graphs, n_nodes, -1)
        
        return Z
