from dataclasses import dataclass
from typing import List, Tuple, Union, Optional, Literal
import numpy as np
import random

@dataclass
class TabularMMDataSimulator:
    """Class for simulating tabular data with configurable hidden variables."""
    n_samples: int = 1000
    n_shared_variables: int = 2
    shared_hidden_dist_type: Literal["gaussian", "poisson", "binomial", "exponential", "beta", "uniform", "gumbel", "weibull"] = "binomial"
    n_hidden_variables: tuple = (5, 5)
    hidden_dist_types: tuple = ("poisson", "uniform") # [shared, mod1, mod2]
    data_dims: tuple = (50, 100)
    nonlinearity_level: int = 1  # Number of transformations
    nonlinearity_type: Literal["exponential", "polynomial", "sigmoid", "trigonometric", "relu"] = "polynomial"
    polynomial_degree: int = 2  # For polynomial nonlinearity
    hidden_connectivities: tuple = (0.6, 0.3, 0.8)  # Density of connections (0-1)
    data_sparsity: float = 0.1  # Dropout rate
    noise_variance: float = 0.1
    noise_mean: float = 0.0
    random_seed: Optional[int] = None

    def __post_init__(self):
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
            random.seed(self.random_seed)

        # Initialize weight matrices with specified connectivity
        #self.shared_to_data_matrix = self._create_sparse_matrix(
        #    self.n_shared_variables, self.data_dims[0], self.hidden_connectivities[0])
        self.mod1_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[0], self.data_dims[0], self.hidden_connectivities[1])
        self.mod2_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[1], self.data_dims[1], self.hidden_connectivities[2])
    
    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        """Create a sparse matrix (positive entries only btw 0 and 1) with given density of non-zero elements."""
        matrix = np.random.beta(1, 1, size=(rows, cols))
        mask = np.random.rand(rows, cols) > density
        matrix[mask] = 0
        return matrix
    
    def _generate_hidden_variables(self, hidden_dist_type, n_hidden) -> np.ndarray:
        """Generate hidden variables according to the specified distribution."""
        if n_hidden == 0:
            self.classes = np.zeros((self.n_samples,))
            return np.zeros((self.n_samples, 0))
        if hidden_dist_type == "gaussian":
            # randomly sample means for each hidden variable
            means = np.random.uniform(0.0, 1.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Gaussian with mean and stddev 1.0
                out[:, i] = np.random.normal(loc=means[i], scale=1.0, size=self.n_samples) + 1.0
            #return np.random.randn(self.n_samples, self.n_hidden_variables) + 1.0
            return out
        elif hidden_dist_type == "poisson":
            lambdas = np.random.uniform(0.5, 2.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Poisson with lambda
                out[:, i] = np.random.poisson(lam=lambdas[i], size=self.n_samples) * 0.5
            #return np.random.poisson(lam=1.0, size=(self.n_samples, self.n_hidden_variables))
            return out
        elif hidden_dist_type == "binomial":
            #ns = np.random.randint(1, 10, size=n_hidden)
            #ps = np.random.uniform(0.1, 0.9, size=n_hidden)
            return np.random.binomial(n=1, p=0.5, size=(self.n_samples, n_hidden)) * self.n_shared_variables
            #out = np.zeros((self.n_samples, n_hidden))
            #for i in range(n_hidden):
            #    # sample from Binomial with n and p
            #    out[:, i] = np.random.binomial(n=ns[i], p=ps[i], size=self.n_samples)
            #return out
        elif hidden_dist_type == "5-class-gaussian":
            # sample 5 means
            means = np.random.uniform(1.0, 10.0, size=(5,))
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 5, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :] = np.random.normal(loc=means[classes[i]], scale=1.0, size=n_hidden) + 1.0
            return out
        elif hidden_dist_type == "3-class-gaussian":
            # sample 3 means
            #means = np.random.uniform(1.0, 20.0, size=(3, n_hidden))
            means = []
            means.append(np.random.uniform(1.0, 3.0, size=n_hidden))
            means.append(np.random.uniform(1.0, 5.0, size=n_hidden))
            means.append(np.random.uniform(5.0, 10.0, size=n_hidden))
            means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 3, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :] = np.random.normal(loc=means[classes[i]], scale=0.5) + 1.0
            return out
        elif hidden_dist_type == "20-class-gaussian":
            means = []
            for i in range(10):
                means.append(np.random.uniform(2*i, 2*(i+3), size=10))
            means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 10, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :10] = np.random.normal(loc=means[classes[i]], scale=0.1) / 1 + 1.0
            # add other variables
            # beta
            #alphas = np.random.uniform(0.5, 2.0, size=10)
            #betas = np.random.uniform(0.5, 2.0, size= 10)
            #for i in range(10):
            #    # sample from Beta with alpha and beta
            #    out[:, 10+i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples) * 0.5
            scales = np.random.uniform(0.5, 2.0, size=10)
            for i in range(10):
                # sample from Exponential with scale
                out[:, i+10] = np.random.exponential(scale=scales[i], size=self.n_samples)
            if ('poisson' in self.hidden_dist_types) or ('weibull' in self.hidden_dist_types):
                return out * 0.1
            return out * 0.1
        elif hidden_dist_type == "50-class-gaussian":
            # sample 3 means
            '''
            #means = np.random.uniform(1.0, 20.0, size=(3, n_hidden))
            means = []
            for i in range(50):
                means.append(np.random.uniform(i, i+3, size=n_hidden))
            means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 50, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :] = np.random.normal(loc=means[classes[i]], scale=0.1) / 2 + 1.0
            return out
            '''
            '''
            means = []
            for i in range(50):
                means.append(np.random.uniform(i, i+3, size=max(1,int(n_hidden/10))))
            means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 50, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :max(1,int(n_hidden/10))] = np.random.normal(loc=means[classes[i]], scale=0.1) / 1 + 1.0
            # for the rest add other variables
            # beta
            alphas = np.random.uniform(0.5, 2.0, size=max(5,int(n_hidden/2))-max(1,int(n_hidden/10)))
            betas = np.random.uniform(0.5, 2.0, size= max(5,int(n_hidden/2))-max(1,int(n_hidden/10)))
            for i in range(max(5,int(n_hidden/2))-max(1,int(n_hidden/10))):
                # sample from Beta with alpha and beta
                out[:, max(1,int(n_hidden/10))+i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples)
            # weibull 
            scales = np.random.uniform(0.5, 2.0, size=n_hidden-max(5,int(n_hidden/2)))
            for i in range(n_hidden-max(5,int(n_hidden/2))):
                # sample from Weibull with scale
                out[:, max(5,int(n_hidden/2))+i] = np.random.weibull(a=scales[i], size=self.n_samples) * 0.3
            return out
            '''
            means = []
            for i in range(30):
                means.append(np.random.uniform(2*i, 2*(i+3), size=10))
            means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 30, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(self.n_samples):
                # sample from Gaussian with mean and stddev 1.0
                out[i, :10] = np.random.normal(loc=means[classes[i]], scale=0.1) / 1 + 1.0
            # add other variables
            # beta
            alphas = np.random.uniform(0.5, 2.0, size=10)
            betas = np.random.uniform(0.5, 2.0, size= 10)
            for i in range(10):
                # sample from Beta with alpha and beta
                out[:, 10+i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples) * 0.5
            # weibull 
            scales = np.random.uniform(0.5, 2.0, size=10)
            for i in range(10):
                # sample from Weibull with scale
                out[:, 20+i] = np.random.weibull(a=scales[i], size=self.n_samples) * 0.3
            # now some combinations
            out[:, 30:40] = out[:, 0:10] * out[:, 10:20]
            out[:, 40:50] = out[:, 10:20] + out[:, 20:30]
            return out * 0.1
        elif hidden_dist_type == "10-class":
            #means = []
            #for i in range(6):
            #    means.append(np.random.uniform(2*i, 2*(i+3), size=2))
            #means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, 5, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            # add a categorical based on the classes
            out[:, 0] = classes
            #for i in range(self.n_samples):
            #    # sample from Gaussian with mean and stddev 1.0
            #    out[i, :2] = np.random.normal(loc=means[classes[i]], scale=0.1) / 1 + 1.0
            out[:, 1:int(n_hidden/2)] = np.random.binomial(n=1, p=0.5, size=(self.n_samples, int(n_hidden/2)-1))
            # add other variables
            # beta
            alphas = np.random.uniform(0.5, 2.0, size=int(n_hidden/2))
            betas = np.random.uniform(0.5, 2.0, size= int(n_hidden/2))
            for i in range(int(n_hidden/2)):
                # sample from Beta with alpha and beta
                out[:, int(n_hidden/2)+i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples) #* 0.5
            #for i in range(self.n_samples):
            #    out[i, 5:] += self.classes[i]
            return out #* 0.5 + 1
        elif hidden_dist_type == "20-class":
            n_classes = 5
            #means = []
            #for i in range(6):
            #    means.append(np.random.uniform(2*i, 2*(i+3), size=2))
            #means = np.array(means)
            # randomly sample which sample belongs to which class
            classes = np.random.randint(0, n_classes, size=self.n_samples)
            self.classes = classes
            out = np.zeros((self.n_samples, n_hidden))
            # add a categorical based on the classes
            out[:, 0] = classes
            # beta
            alphas = np.random.uniform(0.5, 2.0, size=(n_classes,n_hidden-1))
            betas = np.random.uniform(0.5, 2.0, size=(n_classes,n_hidden-1))
            for i in range(self.n_samples):
                out[i, 1:] = np.random.beta(a=alphas[classes[i]], b=betas[classes[i]])
            return out #* 0.5 + 1
        elif hidden_dist_type == "exponential":
            #return np.random.exponential(scale=1.0, size=(self.n_samples, self.n_hidden_variables))
            scales = np.random.uniform(0.5, 2.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Exponential with scale
                out[:, i] = np.random.exponential(scale=scales[i], size=self.n_samples) * 0.1
            return out
        elif hidden_dist_type == 'beta':
            #return np.random.beta(3, 2, size=(n_samples, n_hidden))
            alphas = np.random.uniform(0.1, 5.0, size=n_hidden)
            #betas = np.random.uniform(0.1, 5.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Beta with alpha and beta
                #out[:, i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples) * 10
                out[:, i] = np.random.beta(a=alphas[i], b=0.5, size=self.n_samples)
            return out
        elif hidden_dist_type == 'uniform':
            #return np.random.uniform(0, 1, size=(n_samples, n_hidden))
            return np.random.uniform(0, 1, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == 'gumbel':
            #return np.random.gumbel(0, 1, size=(n_samples, n_hidden))
            locs = np.random.uniform(0.0, 10.0, size=n_hidden)
            scales = np.random.uniform(0.5, 1.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Gumbel with loc and scale
                out[:, i] = np.random.gumbel(loc=locs[i], scale=scales[i], size=self.n_samples) #* 0.1 #* 5
            return out
        elif hidden_dist_type == 'weibull':
            #return np.random.weibull(1.5, size=(n_samples, n_hidden)) * 0.3
            scales = np.random.uniform(0.5, 2.0, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                # sample from Weibull with scale
                out[:, i] = np.random.weibull(a=scales[i], size=self.n_samples) * 0.3
            return out
        elif hidden_dist_type == 'geometric':
            # Sample from Geometric distribution
            p = np.random.uniform(0.1, 0.9, size=n_hidden)
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(n_hidden):
                out[:, i] = np.random.geometric(p[i], size=self.n_samples)
            return out
        else:
            raise ValueError(f"Unknown distribution type: {hidden_dist_type}")
        
    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        """Apply specified nonlinearity to the data."""
        if self.nonlinearity_type == "exponential":
            return np.exp(data)
        elif self.nonlinearity_type == "polynomial":
            return np.power(data, self.polynomial_degree)
        elif self.nonlinearity_type == "sigmoid":
            return 1 / (1 + np.exp(-data))
        elif self.nonlinearity_type == "tanh":
            return np.tanh(data)
        elif self.nonlinearity_type == "trigonometric":
            return np.sin(data)
        elif self.nonlinearity_type == "relu":
            return np.maximum(0, data)
        #elif self.nonlinearity_type == "interaction":
        #    # sample the binary interaction matrix
        #    interaction_matrix = np.random.binomial(1, self.hidden_connectivity, size=(self.n_hidden_variables, self.n_hidden_variables))
        #    return data @ interaction_matrix
        else:
            raise ValueError(f"Unknown nonlinearity type: {self.nonlinearity_type}")
    
    def _add_noise(self, data: np.ndarray, method: str = "signal-to-noise") -> np.ndarray:
        """Add noise to the data."""
        if method == 'absolute':
            if self.noise_variance > 0:
                noise = np.random.normal(
                    loc=self.noise_mean, scale=self.noise_variance, 
                    size=(self.n_samples, data.shape[1])
                )
                return data + noise
        elif method == 'signal-to-noise':
            if self.noise_variance > 0:
                data_std = np.std(data)
                norm_std = data_std * self.noise_variance
                noise = np.random.normal(
                    loc=self.noise_mean, scale=norm_std,
                    size=(self.n_samples, data.shape[1])
                )
                return data + noise
        else:
            raise ValueError(f"Unknown noise addition method: {method}")
        return data
    
    def _apply_sparsity(self, data: np.ndarray) -> np.ndarray:
        """Apply dropout to simulate data sparsity."""
        mask = np.random.rand(*data.shape) > self.data_sparsity
        return data * mask
    
    def generate_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate tabular data based on hidden variables.
        
        Returns:
            Tuple containing (observed_data, hidden_variables)
        """

        # Generate hidden variables
        hidden_vars_shared = self._generate_hidden_variables(
            self.shared_hidden_dist_type, self.n_shared_variables)
        hidden_vars_mod1 = self._generate_hidden_variables(
            self.hidden_dist_types[0], self.n_hidden_variables[0])
        hidden_vars_mod2 = self._generate_hidden_variables(
            self.hidden_dist_types[1], self.n_hidden_variables[1])
        
        # apply nonlinearity to hidden variables
        for _ in range(self.nonlinearity_level):
            hidden_vars_shared = self._apply_nonlinearity(hidden_vars_shared)
            hidden_vars_mod1 = self._apply_nonlinearity(hidden_vars_mod1)
            hidden_vars_mod2 = self._apply_nonlinearity(hidden_vars_mod2)
        
        # create labels
        labels = np.zeros((self.n_samples, 3))
        
        # Get class labels from shared hidden variables (based on unique combinations)
        if ("-class" in self.shared_hidden_dist_type) or hasattr(self, 'classes'):
            labels[:, 0] = self.classes
        else:
            shared_unique = np.unique(hidden_vars_shared, axis=0)
            for i, sh in enumerate(hidden_vars_shared):
                labels[i, 0] = np.where((shared_unique == sh).all(axis=1))[0][0]
        
        # Continuous labels from modality-specific variables
        labels[:, 1] = hidden_vars_mod1.sum(axis=1)
        labels[:, 2] = hidden_vars_mod2.sum(axis=1)

        # concatenate hidden variables
        X_1 = np.concatenate([hidden_vars_shared, hidden_vars_mod1], axis=1)
        X_2 = np.concatenate([hidden_vars_shared, hidden_vars_mod2], axis=1)

        # generate the data
        data_1 = X_1 @ self.mod1_hidden_to_data_matrix
        data_2 = X_2 @ self.mod2_hidden_to_data_matrix

        data_1 = self._add_noise(data_1)
        data_2 = self._add_noise(data_2)

        data_1 = self._apply_sparsity(data_1)
        data_2 = self._apply_sparsity(data_2)

        return data_1, data_2, hidden_vars_shared, hidden_vars_mod1, hidden_vars_mod2, labels

@dataclass
class HierarchicalTabularMMDataSimulator:
    """
    Class for simulating hierarchical multi-modal data where shared variables
    influence the generation of modality-specific variables.
    """
    n_samples: int = 1000
    n_shared_variables: int = 2
    shared_hidden_dist_type: Literal["gaussian", "poisson", "binomial", "exponential", "beta", "uniform", "gumbel", "weibull"] = "binomial"
    n_hidden_variables: tuple = (5, 5)
    hidden_dist_types: tuple = ("gaussian", "poisson")
    data_dims: tuple = (50, 100)
    nonlinearity_level: int = 1
    nonlinearity_type: Literal["exponential", "polynomial", "sigmoid", "trigonometric", "relu"] = "polynomial"
    polynomial_degree: int = 2
    hidden_connectivities: tuple = (0.6, 0.3, 0.8)
    data_sparsity: float = 0.1
    noise_variance: float = 0.1
    noise_mean: float = 0.0
    random_seed: Optional[int] = None

    def __post_init__(self):
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
            random.seed(self.random_seed)

        # Weight matrices mapping concatenated latent space to data space
        self.mod1_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[0], self.data_dims[0], self.hidden_connectivities[1])
        self.mod2_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[1], self.data_dims[1], self.hidden_connectivities[2])
            
        # NEW: Weight matrices to map shared variables to parameters of specific variables
        self.shared_to_mod1_params_matrix = np.random.randn(self.n_shared_variables, self.n_hidden_variables[0])
        self.shared_to_mod2_params_matrix = np.random.randn(self.n_shared_variables, self.n_hidden_variables[1])

    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        matrix = np.random.beta(1, 1, size=(rows, cols))
        mask = np.random.rand(rows, cols) < (1 - density)
        matrix[mask] = 0
        return matrix

    def _generate_independent_hidden_variables(self, hidden_dist_type: str, n_hidden: int) -> np.ndarray:
        """Generates variables that do not depend on others (e.g., shared variables)."""
        if hidden_dist_type == "binomial":
            # sample different parameters for each hidden dim
            ns = np.random.randint(1, 10, size=n_hidden)
            ps = np.random.uniform(0, 1, size=n_hidden)
            out = np.random.binomial(n=ns, p=ps, size=(self.n_samples, n_hidden))
            # figure out the classes based on the binomial clustering
            self.classes = np.argmax(out, axis=1)
            return out
        elif hidden_dist_type == "gaussian":
            means = []
            self.classes = np.random.randint(0, 5, size=self.n_samples)  # For 5-class Gaussian
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(5):
                means.append(np.random.uniform(-2.0, 2.0, size=n_hidden))
            for i in range(self.n_samples):
                out[i] = np.random.normal(loc=means[self.classes[i]], scale=1.0, size=n_hidden) * i
            return out
        elif hidden_dist_type == "exponential":
            # sample different parameters for each hidden dim
            lambdas = []
            self.classes = np.random.randint(0, 5, size=self.n_samples)  # For 5-class Exponential
            out = np.zeros((self.n_samples, n_hidden))
            for i in range(5):
                lambdas.append(np.random.uniform(0, 1, size=n_hidden))
            for i in range(self.n_samples):
                out[i] = np.random.exponential(scale=1/lambdas[self.classes[i]], size=n_hidden) * (i+1)
            return out
        elif hidden_dist_type == "uniform":
            lows = np.random.uniform(-2.0, 2.0, size=n_hidden)
            highs = lows + np.random.uniform(0, 2.0, size=n_hidden)
            return np.random.uniform(low=lows, high=highs, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "beta":
            #alphas = []
            self.classes = np.random.randint(0, 5, size=self.n_samples)  # For 5-class Beta
            out = np.zeros((self.n_samples, n_hidden))
            #for i in range(5):
            #    alphas.append(np.random.uniform(0, 10, size=n_hidden))
            for i in range(self.n_samples):
                out[i] = np.random.beta(a=self.classes[i]+1, b=1, size=n_hidden) + self.classes[i]
            return out
        else:
            # Fallback for simplicity
            return np.random.randn(self.n_samples, n_hidden)

    def _generate_dependent_hidden_variables(self, parent_vars: np.ndarray, hidden_dist_type: str, n_hidden: int, param_matrix: np.ndarray) -> np.ndarray:
        """
        Generates modality-specific variables whose parameters depend on parent_vars.
        """
        # Linearly transform parent variables to get the key parameter for each distribution
        # We add 1 to avoid parameters like lambda=0 for Poisson
        params = parent_vars @ param_matrix + 1
        
        if hidden_dist_type == "gaussian":
            # Parent vars control the mean
            # Ensure scale is positive
            scales = np.abs(np.random.standard_normal(size=(1, n_hidden))) 
            return np.random.normal(loc=params, scale=scales)
        elif hidden_dist_type == "poisson":
            # Parent vars control the rate (lambda)
            # Ensure lambda is positive
            lam = np.abs(params)
            return np.random.poisson(lam=lam)
        elif hidden_dist_type == "weibull":
            # Parent vars control the scale (lambda)
            # Ensure lambda is positive
            scales = np.abs(params)
            return np.random.weibull(a=1.5, size=(self.n_samples, n_hidden)) * scales
        elif hidden_dist_type == "beta":
            # Parent vars control the alpha and beta parameters
            alpha = np.abs(params)
            beta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.beta(a=alpha, b=beta, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "gamma":
            # Parent vars control the shape (k) and scale (theta) parameters
            k = np.abs(params)
            theta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.gamma(shape=k, scale=theta, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "uniform":
            # Parent vars control the low and high parameters
            low = np.abs(params)
            high = low + 1.0
            return np.random.uniform(low=low, high=high, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "exponential":
            # Parent vars control the rate (lambda)
            lam = np.abs(params)
            return np.random.exponential(scale=1/lam, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "gumbel":
            # Parent vars control the location (mu) and scale (beta) parameters
            mu = np.abs(params)
            beta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.gumbel(loc=mu, scale=beta, size=(self.n_samples, n_hidden))
        else:
             # Fallback for simplicity
            return np.random.normal(loc=params, scale=1.0)


    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        if self.nonlinearity_type == "polynomial":
            return np.power(data, self.polynomial_degree)
        # Add other nonlinearities from your original code as needed...
        else:
            return np.power(data, 2)

    def _add_noise(self, data: np.ndarray) -> np.ndarray:
        if self.noise_variance > 0:
            noise = np.random.normal(loc=self.noise_mean, scale=self.noise_variance, size=data.shape)
            return data + noise
        return data

    def _apply_sparsity(self, data: np.ndarray) -> np.ndarray:
        mask = np.random.rand(*data.shape) > self.data_sparsity
        return data * mask

    def generate_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Generates hierarchical multi-modal data.
        """
        # 1. Generate the shared (parent) variables independently
        hidden_vars_shared = self._generate_independent_hidden_variables(
            self.shared_hidden_dist_type, self.n_shared_variables)

        # 2. Generate modality-specific variables dependent on the shared variables
        hidden_vars_mod1 = self._generate_dependent_hidden_variables(
            hidden_vars_shared, self.hidden_dist_types[0], self.n_hidden_variables[0], self.shared_to_mod1_params_matrix)
        
        hidden_vars_mod2 = self._generate_dependent_hidden_variables(
            hidden_vars_shared, self.hidden_dist_types[1], self.n_hidden_variables[1], self.shared_to_mod2_params_matrix)
        
        # Apply nonlinearity to all hidden variables
        for _ in range(self.nonlinearity_level):
            hidden_vars_shared = self._apply_nonlinearity(hidden_vars_shared)
            hidden_vars_mod1 = self._apply_nonlinearity(hidden_vars_mod1)
            hidden_vars_mod2 = self._apply_nonlinearity(hidden_vars_mod2)
            
        # Create labels (same logic as before)
        labels = np.zeros((self.n_samples, 3))
        if not hasattr(self, 'classes'):
            shared_unique, counts = np.unique(hidden_vars_shared, axis=0, return_inverse=True)
            print("Shared variable classes:", shared_unique)
            print("Counts of shared variable classes:", counts)
            labels[:, 0] = counts
        else:
            labels[:, 0] = self.classes
            print("Shared variable classes from self.classes:", np.unique(self.classes))
        labels[:, 1] = hidden_vars_mod1.sum(axis=1)
        labels[:, 2] = hidden_vars_mod2.sum(axis=1)

        # Concatenate and project to data space
        X_1 = np.concatenate([hidden_vars_shared, hidden_vars_mod1], axis=1)
        X_2 = np.concatenate([hidden_vars_shared, hidden_vars_mod2], axis=1)

        data_1 = X_1 @ self.mod1_hidden_to_data_matrix
        data_2 = X_2 @ self.mod2_hidden_to_data_matrix

        # Add noise and sparsity
        data_1 = self._add_noise(data_1)
        data_2 = self._add_noise(data_2)

        data_1 = self._apply_sparsity(data_1)
        data_2 = self._apply_sparsity(data_2)

        return data_1, data_2, hidden_vars_shared, hidden_vars_mod1, hidden_vars_mod2, labels

@dataclass
class GeometricTabularMMDataSimulator:
    """
    Generates multi-modal data where the shared latent space is sampled
    from distinct geometric manifolds (hypersphere and hypercube).
    """
    n_samples: int = 2000
    n_shared_variables: int = 2  # The dimension of the geometric space
    n_hidden_variables: tuple = (5, 5)
    hidden_dist_types: tuple = ("poisson", "beta")
    data_dims: tuple = (50, 100)
    nonlinearity_level: int = 1
    nonlinearity_type: Literal["polynomial", "relu"] = "polynomial"
    polynomial_degree: int = 3
    hidden_connectivities: tuple = (0.6, 0.3, 0.8)
    data_sparsity: float = 0.1
    noise_variance: float = 0.1
    random_seed: Optional[int] = None

    def __post_init__(self):
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
            random.seed(self.random_seed)

        # Mapping concatenated latent -> data
        self.mod1_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[0], self.data_dims[0], self.hidden_connectivities[1])
        self.mod2_hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_shared_variables + self.n_hidden_variables[1], self.data_dims[1], self.hidden_connectivities[2])
            
        # Mapping shared -> parameters of specific
        self.shared_to_mod1_params_matrix = np.random.randn(self.n_shared_variables, self.n_hidden_variables[0])
        self.shared_to_mod2_params_matrix = np.random.randn(self.n_shared_variables, self.n_hidden_variables[1])

    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        matrix = np.random.randn(rows, cols)
        mask = np.random.rand(rows, cols) < (1 - density)
        matrix[mask] = 0
        return matrix

    def _sample_from_hypersphere(self, n_points: int, n_dim: int) -> np.ndarray:
        """Samples points uniformly from the surface of an N-sphere."""
        vec = np.random.randn(n_points, n_dim)
        norm = np.linalg.norm(vec, axis=1, keepdims=True)
        return vec / norm

    def _sample_from_hypercube(self, n_points: int, n_dim: int) -> np.ndarray:
        """Samples points uniformly from the surface of an N-cube."""
        points = np.random.uniform(-1, 1, size=(n_points, n_dim))
        
        # For each point, choose one dimension to fix to the edge (-1 or 1)
        fixed_dim = np.random.randint(0, n_dim, size=n_points)
        fixed_val = np.random.choice([-1, 1], size=n_points)
        
        points[np.arange(n_points), fixed_dim] = fixed_val
        return points

    def _generate_geometric_shared_variables(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generates the shared variables by sampling from the two shapes."""
        # Split samples between the two classes
        n_class_0 = self.n_samples // 2
        n_class_1 = self.n_samples - n_class_0

        # Generate points for each class
        points_sphere = self._sample_from_hypersphere(n_class_0, self.n_shared_variables)
        points_cube = self._sample_from_hypercube(n_class_1, self.n_shared_variables)
        
        # Create labels and concatenate
        labels = np.array([0] * n_class_0 + [1] * n_class_1)
        hidden_vars = np.vstack([points_sphere, points_cube])
        
        # Shuffle the data and labels together
        p = np.random.permutation(self.n_samples)
        return hidden_vars[p], labels[p]

    def _generate_dependent_hidden_variables(self, parent_vars: np.ndarray, hidden_dist_type: str, n_hidden: int, param_matrix: np.ndarray) -> np.ndarray:
        """Generates modality-specific variables hierarchically."""
        # We can add a non-linear transform here to make the dependency more complex
        params = np.tanh(parent_vars @ param_matrix) * 2 + 3 # Center around 3
        
        if hidden_dist_type == "gaussian":
            # Parent vars control the mean
            # Ensure scale is positive
            scales = np.abs(np.random.standard_normal(size=(1, n_hidden))) 
            return np.random.normal(loc=params, scale=scales)
        elif hidden_dist_type == "poisson":
            # Parent vars control the rate (lambda)
            # Ensure lambda is positive
            lam = np.abs(params)
            return np.random.poisson(lam=lam)
        elif hidden_dist_type == "weibull":
            # Parent vars control the scale (lambda)
            # Ensure lambda is positive
            scales = np.abs(params)
            return np.random.weibull(a=1.5, size=(self.n_samples, n_hidden)) * scales
        elif hidden_dist_type == "beta":
            # Parent vars control the alpha and beta parameters
            alpha = np.abs(params)
            beta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.beta(a=alpha, b=beta, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "gamma":
            # Parent vars control the shape (k) and scale (theta) parameters
            k = np.abs(params)
            theta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.gamma(shape=k, scale=theta, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "uniform":
            # Parent vars control the low and high parameters
            low = np.abs(params)
            high = low + 1.0
            return np.random.uniform(low=low, high=high, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "exponential":
            # Parent vars control the rate (lambda)
            lam = np.abs(params)
            return np.random.exponential(scale=1/lam, size=(self.n_samples, n_hidden))
        elif hidden_dist_type == "gumbel":
            # Parent vars control the location (mu) and scale (beta) parameters
            mu = np.abs(params)
            beta = np.abs(np.random.standard_normal(size=(1, n_hidden)))
            return np.random.gumbel(loc=mu, scale=beta, size=(self.n_samples, n_hidden))
        else:
             # Fallback for simplicity
            return np.random.normal(loc=params, scale=1.0)
    
    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        # A simple non-linearity is sufficient given the complex latent space
        return data

    def generate_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        # 1. Generate the geometrically structured shared variables and their class labels
        hidden_vars_shared, class_labels = self._generate_geometric_shared_variables()

        # 2. Generate modality-specific variables dependent on the shared variables
        hidden_vars_mod1 = self._generate_dependent_hidden_variables(
            hidden_vars_shared, self.hidden_dist_types[0], self.n_hidden_variables[0], self.shared_to_mod1_params_matrix)
        hidden_vars_mod2 = self._generate_dependent_hidden_variables(
            hidden_vars_shared, self.hidden_dist_types[1], self.n_hidden_variables[1], self.shared_to_mod2_params_matrix)

        # Create labels array
        labels = np.zeros((self.n_samples, 3))
        labels[:, 0] = class_labels
        labels[:, 1] = hidden_vars_mod1.sum(axis=1)
        labels[:, 2] = hidden_vars_mod2.sum(axis=1)

        # Concatenate and project to data space
        X_1 = np.concatenate([hidden_vars_shared, hidden_vars_mod1], axis=1)
        X_2 = np.concatenate([hidden_vars_shared, hidden_vars_mod2], axis=1)

        data_1 = X_1 @ self.mod1_hidden_to_data_matrix
        data_2 = X_2 @ self.mod2_hidden_to_data_matrix
        
        # Noise and sparsity can be added here if needed

        return data_1, data_2, hidden_vars_shared, hidden_vars_mod1, hidden_vars_mod2, labels