import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer
from modules.encoders.transition_encoder import TransformerTransitionEncoder

class GlobalEncoder(nn.Module):
    def __init__(self, args, is_club=False) -> None:
        super(GlobalEncoder, self).__init__()
        
        self.args = args
        self.is_club = is_club

        self.encoding_dim = args.transition_encoding_dim
        self.output_dim = args.encoding_dim
        if self.is_club:
            self.output_dim = 2 * args.encoding_dim
            self.final_encoding_dim = args.encoding_dim


        self.transformer = Transformer(self.encoding_dim, args.head, args.depth, self.encoding_dim)

        self.out = nn.Linear(self.encoding_dim, self.output_dim)
    
    def forward(self, temporal_encodings):
        # [bs, n_agents, z_dim]
        output = self.transformer(temporal_encodings, None)
        global_encoding = output[:, 0, :]
        global_encoding = self.out(global_encoding)
        if self.is_club:
            global_mean = global_encoding[...,:self.final_encoding_dim]
            global_var = F.softplus(global_encoding[...,self.final_encoding_dim:])
            return global_mean, global_var
        # [bs, z_dim]
        return global_encoding

