import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np
from utils.transformer import Transformer
from modules.encoders.transition_encoder import TransformerTransitionEncoder

class LocalEncoder(nn.Module):
    def __init__(self, args, is_club=False) -> None:
        super(LocalEncoder, self).__init__()
        
        self.args = args
        self.is_club = is_club

        self.input_dim = args.transition_encoding_dim
        self.hidden_dim = args.encoder_hidden_dim
        self.output_dim = args.encoding_dim
        if self.is_club:
            self.output_dim = 2 * args.encoding_dim
            self.encoding_dim = args.encoding_dim
        
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)
    
    def forward(self, temporal_encodings):
        x = F.relu(self.fc1(temporal_encodings))
        x = F.relu(self.fc2(x))
        local_encodings = self.fc3(x)
        if self.is_club:
            local_mean = local_encodings[...,:self.encoding_dim]
            local_var = F.softplus(local_encodings[...,self.encoding_dim:])
            return local_mean, local_var
        return local_encodings


class LocalRoleEncoder(nn.Module):
    def __init__(self, args, is_club=False) -> None:
        super(LocalRoleEncoder, self).__init__()
        
        self.args = args
        self.is_club = is_club
        self.use_task_encoding = getattr(args, "role_use_task_encoding", True)

        if self.use_task_encoding:
            self.input_dim = args.transition_encoding_dim + args.encoding_dim
        else:
            self.input_dim = args.transition_encoding_dim
        self.hidden_dim = args.encoder_hidden_dim
        self.output_dim = args.encoding_dim
        if self.is_club:
            self.output_dim = 2 * args.encoding_dim
            self.encoding_dim = args.encoding_dim
        
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.fc3 = nn.Linear(self.hidden_dim, self.output_dim)
    
    def forward(self, temporal_encodings, task_encodings):
        if self.use_task_encoding:
            input = th.cat([temporal_encodings, task_encodings], dim=-1)
        else:
            input = temporal_encodings

        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        local_encodings = self.fc3(x)
        if self.is_club:
            local_mean = local_encodings[...,:self.encoding_dim]
            local_var = F.softplus(local_encodings[...,self.encoding_dim:])
            return local_mean, local_var
        return local_encodings