import torch
import torch.nn as nn
import torch.nn.functional as F

# model framework
class BIG_Q_ATTN(nn.Module):
    def __init__(self, state_dim, num_actions):
        super(BIG_Q_ATTN, self).__init__()
        
        self.LayerNorm = nn.LayerNorm(1024, eps = 1e-8)
        
        self.att = nn.MultiheadAttention(embed_dim= 1024, num_heads = 16, batch_first=True)
        
        self.q1 = nn.Linear(state_dim, 1024 ) # ffwd
        
        self.q2 = nn.Linear(1024, 1024)
        
        self.q3 = nn.Linear(1024, num_actions)
        
        
        
        
        
    def forward(self, state, mask, action = None):
        
        #embd
        embed = F.relu(self.q1(state))
        
        # attn, add, norm
        self.trans_sub = self.LayerNorm(
            self.att(embed, embed, embed, key_padding_mask = (1-mask).bool())[0] + embed)
        
        # ffwd, add, norm
        self.trans_out = self.LayerNorm(
            F.relu(self.q2(self.trans_sub)) + self.trans_sub)
        
        
        self.output = self.q3(self.trans_out)
        
        # batch BY seq BY num_actions

        if action is None:
            # for next state, return heads
            return self.output # 32 (batch_size) X (bed number) X 2; Q(a=1) > Q(a=0) 
        else:
            # for current state, return Q 
            return (torch.stack((torch.where(mask == 0, 0, (1 - action).long())
                                 , action), dim = 2) * self.output).sum(axis = -1).sum(axis = -1).reshape(-1,1) #\
