import numpy as np
import json
from scipy.special import expit
from sklearn.preprocessing import StandardScaler
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch.nn as nn
import torch.optim as optim

import torch
import time

class SimplexLogits(nn.Module):
    """
        A learnable parameterization of a probability vector on the simplex
        using logits.
    """
    def __init__(self, U_init: np.ndarray):
        super().__init__()
        U = np.maximum(np.asarray(U_init, np.float64), 1e-12)
        self.logits = nn.Parameter(torch.tensor(np.log(U), dtype=torch.float64))
    def forward(self):
        return torch.softmax(self.logits, dim=0)