from setenc_modules import *

'''
This code is based on the implementation of SetTransformer (https://github.com/juho-lee/set_transformer)
'''

class DeepSet(nn.Module):
  def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128):
    super(DeepSet, self).__init__()
    self.num_outputs = num_outputs
    self.dim_output = dim_output
    self.enc = nn.Sequential(
      nn.Linear(dim_input, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_hidden))
    self.dec = nn.Sequential(
      nn.Linear(dim_hidden, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, dim_hidden),
      nn.ReLU(),
      nn.Linear(dim_hidden, num_outputs * dim_output))
  
  def forward(self, X):
    X = self.enc(X).mean(-2)
    X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output)
    return X


class SetPool(nn.Module):
  def __init__(self, dim_input, num_outputs, dim_output,
               num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
    super(SetPool, self).__init__()
    if 'sab' in mode: # [32, 400, 128]
      self.enc = nn.Sequential(
        SAB(dim_input, dim_hidden, num_heads, ln=ln),  # SAB?
        SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
    else: # [32, 400, 128]
      self.enc = nn.Sequential(
        ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),  # SAB?
        ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
    if 'PF' in mode: #[32, 1, 501]
      self.dec = nn.Sequential(
        PMA(dim_hidden, num_heads, num_outputs, ln=ln),
        nn.Linear(dim_hidden, dim_output))
    elif 'P' in mode:
      self.dec = nn.Sequential(
        PMA(dim_hidden, num_heads, num_outputs, ln=ln))
    else: #torch.Size([32, 1, 501])
      self.dec = nn.Sequential(
        PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
        SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
        SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
        nn.Linear(dim_hidden, dim_output))
  # "", sm, sab, sabsm
  def forward(self, X):
    x1 = self.enc(X)
    #import pdb;pdb.set_trace()
    x2 = self.dec(x1)
    #import pdb;pdb.set_trace()
    return x2
