import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.mar_utils_mol import convert_val_to_onehot

class MLPRes(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, res=True):
        super(MLPRes, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for i in range(num_layers-1)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.LeakyReLU()
        self.res = res

    def forward(self, x):
        x = self.activation(self.input_layer(x))
        for layer in self.hidden_layers:
            out = layer(x)
            out = self.activation(out)
            if self.res:
                out = out + x  # Residual connection
            x = out
        out = self.output_layer(x)
        return out
    
class MLPResDual(nn.Module):
    def __init__(self, hidden_dim, K, L, num_layers, res=True):
        super(MLPResDual, self).__init__()
        input_dim = (K + 1) * L
        num_logits = K * L
        self.mlp_ar = MLPRes(input_dim, hidden_dim, num_logits, num_layers, res)
        self.mlp_marg = MLPRes(input_dim, hidden_dim, 1, num_layers, res)
        self.K = K
        self.L = L

    def forward(self, x):
        x = convert_val_to_onehot(x, self.K)
        x = x.flatten(-2,-1)
        logits_ar = self.mlp_ar(x)
        logp = self.mlp_marg(x)
        return logp.squeeze(-1), logits_ar.reshape(x.shape[0], self.L, self.K) # logp, logits
    
class MLPResSingle(nn.Module):
    def __init__(self, hidden_dim, K, L, num_layers, res=True):
        super(MLPResSingle, self).__init__()
        input_dim = (K + 1) * L
        num_logits = K * L
        self.mlp_ar = MLPRes(input_dim, hidden_dim, num_logits, num_layers, res)
        self.K = K
        self.L = L

    def forward(self, x):
        x = convert_val_to_onehot(x, self.K)
        x = x.flatten(-2,-1)
        logits_ar = self.mlp_ar(x)
        return logits_ar.reshape(x.shape[0], self.L, self.K) # logp, logits