import torch
import torch.nn as nn


class Embedding(nn.Module):

    def __init__(self, 
                 input_dim, 
                 hidden_dim, 
                 output_dim, 
                 num_hidden_layers, 
                 with_skip,
                 skip_at = 3,
                 use_BE = False,
                 device = 'cpu'):
        
        super().__init__()
        self.skip = with_skip
        self.skip_at = skip_at
        self.dims = [input_dim] + [hidden_dim] * num_hidden_layers + [output_dim]
        self.layers = nn.ModuleList()
        self.use_BE = use_BE

        for i in range(len(self.dims) - 1):

            if i == len(self.dims) - self.skip_at:
                if self.skip:
                    self.layers.append(
                        nn.utils.weight_norm(nn.Linear(self.dims[i], self.dims[i+1] - input_dim)))
                else:
                    self.layers.append(
                        nn.utils.weight_norm(nn.Linear(self.dims[i], self.dims[i+1])))
            else:
                self.layers.append(
                    nn.utils.weight_norm(nn.Linear(self.dims[i], self.dims[i+1])))
                
        for layer in self.layers:
            layer.to(device)
        
        self.initialize_weights()

    def initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    
    def forward(self, x): # infer energy #

        input = x
        #print(f"Input x device: {x.device}")

    

        for i, layer in enumerate(self.layers):
            #print(f"Layer {i} weight device: {layer.weight.device}")
            #print(f"x device before layer {i}: {x.device}")

            if i != (len(self.dims) - 2):  # --- if not the last layer --- #
                x = layer(x)
                x = nn.functional.relu(x)


            
                if i == (len(self.dims) - self.skip_at): # --- if at the residual layler --- #
                  if self.skip:
                      x = torch.cat((input, x), dim = -1)
                      #print(f"x device after residual add: {x.device}")
            else:
    
                
                x = layer(x)

        # --- Next: associate energy with probability --- #
        if self.use_BE:
            x = boltzmann_distribution(x, temperature = 1)
            x = torch.sqrt(x)
        
        return x
    



def logtrick(x):
    """Computes log-sum-exp in a numerically stable way."""
    max_x = torch.max(x)
    return max_x + torch.log(torch.sum(torch.exp(x - max_x)))

def boltzmann_distribution(energies, temperature, k_B=None, normalize=False):
    if k_B:
        beta = 1 / (k_B * temperature)
    else:
        beta = 1 / temperature

    if beta == 0 or beta is None:
        raise ValueError('Temperature must be a nonzero number.')

    # Compute log-partition function using logtrick for numerical stability
    log_partition_function = logtrick(-beta * energies)

    # Compute probabilities using log-space exponentiation
    probabilities = torch.exp(-beta * energies - log_partition_function)

    if torch.isnan(probabilities).any():
        # print energies and log_partition_function
        print(energies)
        print(log_partition_function)
        
        raise ValueError('Probabilities before normalization contain NaN.') #<----
    if torch.isinf(probabilities).any():
        raise ValueError('Probabilities before normalization contain infinity.')
    
    probabilities = probabilities / (torch.sum(probabilities)+1e-9)

    if torch.isnan(probabilities).any():
        raise ValueError('Probabilities after normalization contain NaN.')
    if torch.isinf(probabilities).any():
        raise ValueError('Probabilities after normalization contain infinity.')
    
    
    return probabilities
