import torch
import torch.nn as nn
from torch.nn import functional as F
import math

class PositionalEncoding(nn.Module):
    '''
        Positional encoding for the input sequence
    '''
    def __init__(self, d_model, max_len=500, dropout=0.1):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe.unsqueeze_(0)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
            x: [batch_size, seq_len, d_model]
        '''
        # cap the x size 1 to 500
        # TODO: check if this makes sense
        #x = x[:, :500]
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class PropertyPredictor(nn.Module):
    '''
        Predicts the property of the input sequence
    '''
    def __init__(self, config, alphabet_size, output_size=None):
        '''
            config: config object
            alphabet_size: size of the alphabet
        '''
        super().__init__()
        self.alphabet_size = alphabet_size # for the concatenation of the input and the condition
        self.config = config
        self.embedder = nn.Embedding(self.alphabet_size, config.classifier_guidance.model.hidden_dim)
        self.pos_encoder = PositionalEncoding(config.classifier_guidance.model.hidden_dim)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=config.classifier_guidance.model.hidden_dim,
                                                                            nhead=config.classifier_guidance.model.num_heads,
                                                                            dim_feedforward=config.classifier_guidance.model.hidden_dim,
                                                                            dropout=config.classifier_guidance.model.dropout,
                                                                            batch_first=True),
                                                                            num_layers=config.classifier_guidance.model.num_layers,
                                                                            norm=nn.LayerNorm(config.classifier_guidance.model.hidden_dim))
        if output_size is None:
            output_size = 1 if config.classifier_guidance.as_regression else config.classifier_guidance.model.num_classes
        else:
            output_size = output_size
        self.predictor_head = nn.Sequential(nn.Linear(config.classifier_guidance.model.hidden_dim, config.classifier_guidance.model.hidden_dim),
                                            nn.ReLU(),
                                            nn.Linear(config.classifier_guidance.model.hidden_dim, output_size))
        
        if config.classifier_guidance.train.model_log_var:
            self.log_var_head = nn.Sequential(
                nn.Linear(config.classifier_guidance.model.hidden_dim, config.classifier_guidance.model.hidden_dim),
                nn.ReLU(),
                nn.Linear(config.classifier_guidance.model.hidden_dim, 1)
            )
        else:
            self.log_var_head = None
        
    def forward(self, seq):
        '''
            seq: [batch_size, seq_len]
        '''
        src_key_padding_mask = (seq == 1)
        feat = self.embedder(seq)
        feat = self.pos_encoder(feat)
        feat = self.transformer(feat, src_key_padding_mask=src_key_padding_mask)
        feat = feat.mean(dim=1)
        mean = self.predictor_head(feat)
        if self.log_var_head is not None:
            log_var = self.log_var_head(feat)
            var = torch.exp(log_var)
            return mean, var
        else:
            return mean
