from dataclasses import dataclass, field
from typing import List, Tuple, Union, Optional, Literal
import numpy as np
from scipy import stats
from scipy.ndimage import gaussian_filter
import os
import urllib.request
import gzip
import pickle
from pathlib import Path
from sklearn.datasets import make_moons, make_circles


@dataclass
class TabularDataSimulator:
    """Class for simulating tabular data with configurable hidden variables."""
    n_samples: int = 1000
    n_hidden_variables: int = 5
    hidden_dist_type: Literal["gaussian", "poisson", "binomial", "exponential", "circles", "halfmoons"] = "gaussian"
    data_dim: int = 50
    nonlinearity_level: int = 1  # Number of transformations
    nonlinearity_type: Literal["exponential", "polynomial", "sigmoid", "trigonometric", "relu"] = "polynomial"
    polynomial_degree: int = 2  # For polynomial nonlinearity
    hidden_connectivity: float = 0.7  # Density of connections (0-1)
    data_sparsity: float = 0.1  # Dropout rate
    noise_variance: float = 0.1
    noise_mean: float = 0.0
    n_noise_components: int = 1,
    noise_dimension: int = 1,  # Dimension for multivariate Gaussian noise
    random_seed: Optional[int] = None
    
    def __post_init__(self):
        self.noise_dimension = self.noise_dimension[0] if isinstance(self.noise_dimension, (list, tuple)) else self.noise_dimension
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
        
        # Initialize weight matrices with specified connectivity
        self.hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_hidden_variables, self.data_dim, self.hidden_connectivity)
        
        # Initialize noise component parameters
        if self.n_noise_components > 1:
            self.noise_component_variances = np.random.uniform(
                0.1 * self.noise_variance, 
                2.0 * self.noise_variance, 
                size=self.n_noise_components
            )
            # Add slight variations to the mean for different components
            self.noise_component_means = np.linspace(
                -1.0 * abs(self.noise_mean) - 0.1, 
                2.0 * abs(self.noise_mean) + 0.1, 
                num=self.n_noise_components
            )
            self.noise_component_weights = np.ones(self.n_noise_components) / self.n_noise_components
            
            # For multivariate noise, initialize covariance matrices for each component
            if self.noise_dimension > 1:
                self.noise_component_covariances = []
                for i in range(self.n_noise_components):
                    # Create a random correlation matrix
                    corr_matrix = np.random.uniform(-0.5, 0.5, size=(self.noise_dimension, self.noise_dimension))
                    corr_matrix = (corr_matrix + corr_matrix.T) / 2  # Make it symmetric
                    np.fill_diagonal(corr_matrix, 1.0)  # Set diagonal to 1
                    
                    # Ensure it's positive definite by adding a small value to diagonal if needed
                    min_eig = np.min(np.linalg.eigvals(corr_matrix))
                    if min_eig < 0:
                        corr_matrix += (-min_eig + 0.01) * np.eye(self.noise_dimension)
                    
                    # Scale correlation matrix by variance to get covariance matrix
                    variance_vector = np.random.uniform(
                        0.5 * self.noise_component_variances[i],
                        1.5 * self.noise_component_variances[i],
                        size=self.noise_dimension
                    )
                    std_dev = np.sqrt(variance_vector)
                    cov_matrix = np.outer(std_dev, std_dev) * corr_matrix
                    self.noise_component_covariances.append(cov_matrix)
        else:
            # For single component multivariate noise
            if self.noise_dimension > 1:
                # Add slight variations to the mean for different components
                self.noise_mean_vector = np.linspace(
                    -1.0 * abs(self.noise_mean) - 0.1, 
                    2.0 * abs(self.noise_mean) + 0.1, 
                    num=self.noise_dimension
                )
                corr_matrix = np.random.uniform(self.noise_variance*0.9, self.noise_variance*1.1, size=(self.noise_dimension, self.noise_dimension))
                corr_matrix = (corr_matrix + corr_matrix.T) / 2  # Make it symmetric
                np.fill_diagonal(corr_matrix, 1.0)  # Set diagonal to 1
                
                # Ensure it's positive definite
                min_eig = np.min(np.linalg.eigvals(corr_matrix))
                if min_eig < 0:
                    corr_matrix += (-min_eig + 0.01) * np.eye(self.noise_dimension)
                
                # Scale correlation matrix by variance to get covariance matrix
                variance_vector = np.ones(self.noise_dimension) * self.noise_variance
                std_dev = np.sqrt(variance_vector)
                self.noise_covariance = np.outer(std_dev, std_dev) * corr_matrix
    
    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        """Create a sparse matrix (positive entries only btw 0 and 1) with given density of non-zero elements."""
        matrix = np.random.beta(1, 1, size=(rows, cols))
        mask = np.random.rand(rows, cols) > density
        matrix[mask] = 0
        return matrix
    
    def _generate_hidden_variables(self) -> np.ndarray:
        """Generate hidden variables according to the specified distribution."""
        if self.hidden_dist_type == "gaussian":
            # randomly sample means for each hidden variable
            means = np.random.uniform(0.0, 1.0, size=self.n_hidden_variables)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Gaussian with mean and stddev 1.0
                out[:, i] = np.random.normal(loc=means[i], scale=1.0, size=self.n_samples) + 1.0
            #return np.random.randn(self.n_samples, self.n_hidden_variables) + 1.0
            return out
        elif self.hidden_dist_type == "poisson":
            lambdas = np.random.uniform(0.5, 2.0, size=self.n_hidden_variables)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Poisson with lambda
                out[:, i] = np.random.poisson(lam=lambdas[i], size=self.n_samples)
            #return np.random.poisson(lam=1.0, size=(self.n_samples, self.n_hidden_variables))
            return out
        elif self.hidden_dist_type == "binomial":
            ns = np.random.randint(1, 10, size=self.n_hidden_variables)
            ps = np.random.uniform(0.1, 0.9, size=self.n_hidden_variables)
            #return np.random.binomial(n=1, p=0.5, size=(self.n_samples, self.n_hidden_variables))
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Binomial with n and p
                out[:, i] = np.random.binomial(n=ns[i], p=ps[i], size=self.n_samples)
            return out
        elif self.hidden_dist_type == "exponential":
            #return np.random.exponential(scale=1.0, size=(self.n_samples, self.n_hidden_variables))
            scales = np.random.uniform(0.5, 2.0, size=self.n_hidden_variables)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Exponential with scale
                out[:, i] = np.random.exponential(scale=scales[i], size=self.n_samples)
            return out
        elif self.hidden_dist_type == 'beta':
            #return np.random.beta(3, 2, size=(n_samples, n_hidden))
            alphas = np.random.uniform(0.1, 5.0, size=self.n_hidden_variables)
            #betas = np.random.uniform(0.1, 5.0, size=n_hidden)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Beta with alpha and beta
                #out[:, i] = np.random.beta(a=alphas[i], b=betas[i], size=self.n_samples) * 10
                out[:, i] = np.random.beta(a=alphas[i], b=0.5, size=self.n_samples)
            return out
        elif self.hidden_dist_type == 'uniform':
            #return np.random.uniform(0, 1, size=(n_samples, n_hidden))
            return np.random.uniform(0, 1, size=(self.n_samples, self.n_hidden_variables))
        elif self.hidden_dist_type == 'gumbel':
            #return np.random.gumbel(0, 1, size=(n_samples, n_hidden))
            locs = np.random.uniform(0.0, 10.0, size=self.n_hidden_variables)
            scales = np.random.uniform(0.5, 1.0, size=self.n_hidden_variables)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Gumbel with loc and scale
                out[:, i] = np.random.gumbel(loc=locs[i], scale=scales[i], size=self.n_samples) #* 0.1 #* 5
            return out
        elif self.hidden_dist_type == 'weibull':
            #return np.random.weibull(1.5, size=(n_samples, n_hidden)) * 0.3
            scales = np.random.uniform(0.5, 2.0, size=self.n_hidden_variables)
            out = np.zeros((self.n_samples, self.n_hidden_variables))
            for i in range(self.n_hidden_variables):
                # sample from Weibull with scale
                out[:, i] = np.random.weibull(a=scales[i], size=self.n_samples) * 0.3
            return out
        elif self.hidden_dist_type == "circles":
            if self.n_hidden_variables != 2:
                raise ValueError("For 'circles', n_hidden_variables must be 2.")
            x, y = make_circles(n_samples=self.n_samples, noise=0.1, factor=0.5, random_state=self.random_seed)
            return x
        elif self.hidden_dist_type == "halfmoons":
            if self.n_hidden_variables != 2:
                raise ValueError("For 'halfmoons', n_hidden_variables must be 2.")
            x, y = make_moons(n_samples=self.n_samples, noise=0.1, random_state=self.random_seed)
            return x
        else:
            raise ValueError(f"Unknown distribution type: {self.hidden_dist_type}")
    
    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        """Apply specified nonlinearity to the data."""
        if self.nonlinearity_type == "exponential":
            return np.exp(data)
        elif self.nonlinearity_type == "polynomial":
            return np.power(data, self.polynomial_degree)
        elif self.nonlinearity_type == "sigmoid":
            return 1 / (1 + np.exp(-data))
        elif self.nonlinearity_type == "tanh":
            return np.tanh(data)
        elif self.nonlinearity_type == "trigonometric":
            return np.sin(data)
        elif self.nonlinearity_type == "relu":
            return np.maximum(0, data)
        elif self.nonlinearity_type == "interaction":
            # sample the binary interaction matrix
            interaction_matrix = np.random.binomial(1, self.hidden_connectivity, size=(self.n_hidden_variables, self.n_hidden_variables))
            return data @ interaction_matrix
        else:
            raise ValueError(f"Unknown nonlinearity type: {self.nonlinearity_type}")
    
    def _add_noise(self, data: np.ndarray) -> np.ndarray:
        """Add Gaussian noise to the data."""
        # Skip adding noise if variance is zero
        if self.noise_variance == 0:
            return data
        
        # For multivariate Gaussian noise
        if self.noise_dimension > 1:
            # Reshape data to handle multivariate noise dimensions
            orig_shape = data.shape
            n_features = orig_shape[1]
            
            if self.n_noise_components > 1:
                #raise NotImplementedError("Mixture of noise components and dimensions is not implemented.")
                noise = np.zeros_like(data)
                component_ids = np.random.choice(
                    self.n_noise_components, 
                    size=data.shape[0], 
                    p=self.noise_component_weights
                )
                for sample_idx in range(orig_shape[0]):
                    # Choose component
                    component_idx = np.random.choice(
                        self.n_noise_components, p=self.noise_component_weights
                    )
                    # Get mean vector for this component
                    # Sample from multivariate Gaussian with this component's parameters
                    for noise_dim_idx in range(self.noise_dimension):
                        noise[sample_idx,:] += np.random.normal(
                            self.noise_component_means[component_idx], self.noise_component_covariances[component_idx][noise_dim_idx, noise_dim_idx],
                            size=n_features
                        )
            else:
                #noise = np.random.multivariate_normal(self.noise_mean_vector, self.noise_covariance, size=orig_shape[0])
                noise = np.zeros_like(data)
                for noise_dim_idx in range(self.noise_dimension):
                    noise += np.random.normal(
                        self.noise_mean_vector[noise_dim_idx],
                        self.noise_covariance[noise_dim_idx, noise_dim_idx],
                        size=(orig_shape[0], n_features))
        else:
            # Original univariate noise implementation
            if self.n_noise_components == 1:
                noise = np.random.normal(self.noise_mean, self.noise_variance, size=data.shape)
            else:
                # True mixture of Gaussians: sample component ID for each data point
                noise = np.zeros_like(data)
                
                # Get the shape of the data for reshaping operations
                original_shape = data.shape
                flattened_size = np.prod(original_shape)
                
                # Sample component IDs for each element
                component_ids = np.random.choice(
                    self.n_noise_components, 
                    size=flattened_size, 
                    p=self.noise_component_weights
                )
                
                # Create the noise by sampling from the corresponding components
                flat_noise = np.zeros(flattened_size)
                for comp_id in range(self.n_noise_components):
                    # Get indices where this component was selected
                    idx = (component_ids == comp_id)
                    # Sample noise from this component for those indices
                    flat_noise[idx] = np.random.normal(
                        self.noise_component_means[comp_id], 
                        self.noise_component_variances[comp_id],
                        size=np.sum(idx)
                    )
                
                # Reshape back to original dimensions
                noise = flat_noise.reshape(original_shape)
        
        return data + noise
    
    def _apply_sparsity(self, data: np.ndarray) -> np.ndarray:
        """Apply dropout to simulate data sparsity."""
        mask = np.random.rand(*data.shape) > self.data_sparsity
        return data * mask
    
    def generate_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate tabular data based on hidden variables.
        
        Returns:
            Tuple containing (observed_data, hidden_variables)
        """
        # Generate hidden variables
        hidden_vars = self._generate_hidden_variables()
        #print(f"Hidden variables: {hidden_vars.mean(axis=0)} ± {hidden_vars.std(axis=0)}")
        hidden_vars_original = hidden_vars.copy()

        # Apply nonlinearities multiple times if specified
        for _ in range(self.nonlinearity_level):
            hidden_vars = self._apply_nonlinearity(hidden_vars)
        #print(f"Transformed hidden variables: {hidden_vars.mean(axis=0)} ± {hidden_vars.std(axis=0)}")
        
        # Transform hidden variables to observed data
        data = hidden_vars @ self.hidden_to_data_matrix
        #print(f"Data: {data.mean(axis=0)} ± {data.std(axis=0)}")


        # Add noise
        data = self._add_noise(data)
        #print(f"Data with noise: {data.mean(axis=0)} ± {data.std(axis=0)}")
        
        # Apply sparsity
        data = self._apply_sparsity(data)
        #print(f"Data after sparsity: {data.mean(axis=0)} ± {data.std(axis=0)}")
        
        return data, hidden_vars, hidden_vars_original


@dataclass
class ImageDataSimulator:
    """Class for simulating image data with configurable hidden variables."""
    n_samples: int = 500
    n_hidden_variables: int = 10
    hidden_dist_type: Literal["gaussian", "uniform"] = "gaussian"
    image_height: int = 32
    image_width: int = 32
    n_channels: int = 1  # 1 for grayscale, 3 for RGB
    nonlinearity_level: int = 1
    nonlinearity_type: Literal["sigmoid", "tanh", "relu"] = "sigmoid"
    hidden_connectivity: float = 0.7
    data_sparsity: float = 0.1
    noise_variance: float = 0.05
    noise_mean: float = 0.0
    n_noise_components: int = 1
    # Image-specific parameters
    pattern_type: Literal["shapes", "textures"] = "shapes"  # Removed "frequency"
    pattern_complexity: int = 3  # Number of shapes/patterns to include
    pattern_size_variation: float = 0.3  # How much pattern sizes can vary
    pattern_rotation: bool = True  # Whether patterns can be rotated
    # Textures parameters
    geometric_function: Literal["waves", "spirals", "voronoi"] = "waves"
    texture_scale: float = 1.0  # Scale factor for texture patterns
    # Smoothing parameters
    smoothing_strength: float = 0.5  # 0 = no smoothing, 1 = maximum smoothing
    kernel_size: int = 3  # Size of the smoothing kernel (3x3, 5x5, etc.)
    # dsprites parameters
    dsprites_path: str = None  # Path to dsprites dataset
    use_color: bool = False  # Whether to colorize dsprites (original is b&w)
    random_seed: Optional[int] = None
    
    def __post_init__(self):
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
        
        self.data_dim = self.image_height * self.image_width * self.n_channels
        
        # Initialize latent-to-image mapping
        flattened_dim = self.image_height * self.image_width * self.n_channels
        self.hidden_to_data_matrix = self._create_sparse_matrix(
            self.n_hidden_variables, flattened_dim, self.hidden_connectivity)
        
        # Initialize noise component variances and means
        if self.n_noise_components > 1:
            self.noise_component_variances = np.random.uniform(
                0.1 * self.noise_variance, 
                self.noise_variance, 
                size=self.n_noise_components
            )
            # Add slight variations to the mean for different components
            self.noise_component_means = self.noise_mean + np.random.uniform(
                -0.1 * abs(self.noise_mean) - 0.01, 
                0.1 * abs(self.noise_mean) + 0.01, 
                size=self.n_noise_components
            )
            self.noise_component_weights = np.ones(self.n_noise_components) / self.n_noise_components
        
        # Load dsprites dataset if pattern_type is "shapes"
        if self.pattern_type == "shapes":
            self.dsprites_data, self.dsprites_latents = self._load_dsprites_dataset()
    
    def _load_dsprites_dataset(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load the dsprites dataset for shapes generation."""
        # Default path for dsprites dataset
        if self.dsprites_path is None:
            self.dsprites_path = os.path.join(os.path.expanduser("~"), "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz")
        
        # Download the dataset if it doesn't exist
        if not os.path.exists(self.dsprites_path):
            print("Downloading dsprites dataset...")
            url = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"
            urllib.request.urlretrieve(url, self.dsprites_path)
        
        # Load the dataset
        dataset = np.load(self.dsprites_path, allow_pickle=True)
        data = dataset['imgs']
        latents_values = dataset['latents_values']
        
        # Resize images if needed
        if self.image_height != 64 or self.image_width != 64:
            from skimage.transform import resize
            resized_data = np.zeros((data.shape[0], self.image_height, self.image_width))
            for i in range(data.shape[0]):
                resized_data[i] = resize(data[i], (self.image_height, self.image_width), 
                                         anti_aliasing=True, preserve_range=True)
            data = resized_data
        
        return data, latents_values
    
    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        """Create a sparse matrix with given density of non-zero elements."""
        #matrix = np.random.beta(1, 1, size=(rows, cols))
        #mask = np.random.rand(rows, cols) > density
        matrix = np.ones((rows, cols))
        #matrix[mask] = 0
        return matrix
    
    def _generate_hidden_variables(self) -> np.ndarray:
        """Generate hidden variables according to the specified distribution."""
        if self.hidden_dist_type == "gaussian":
            return np.random.randn(self.n_samples, self.n_hidden_variables)
        elif self.hidden_dist_type == "uniform":
            return np.random.uniform(-1, 1, size=(self.n_samples, self.n_hidden_variables))
        else:
            raise ValueError(f"Unknown distribution type: {self.hidden_dist_type}")
    
    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        """Apply specified nonlinearity to the data."""
        if self.nonlinearity_type == "sigmoid":
            return 1 / (1 + np.exp(-data))
        elif self.nonlinearity_type == "tanh":
            return np.tanh(data)
        elif self.nonlinearity_type == "relu":
            return np.maximum(0, data)
        else:
            raise ValueError(f"Unknown nonlinearity type: {self.nonlinearity_type}")
    
    def _add_pattern_components(self, images: np.ndarray) -> np.ndarray:
        """Add specific patterns to images based on pattern_type."""
        # Reshape to proper image dimensions
        images_reshaped = images.reshape(self.n_samples, self.image_height, self.image_width, self.n_channels)
        
        if self.pattern_type == "shapes":
            # Use dsprites dataset
            for i in range(self.n_samples):
                # Map hidden variables to dsprites latent space
                # Use first 5 hidden variables to select dsprites parameters
                shape_idx = int(np.clip((self.n_hidden_variables > 0 and images[i, 0] + 1) * 1.5, 0, 2.99))  # 0=square, 1=ellipse, 2=heart
                scale_idx = int(np.clip((self.n_hidden_variables > 1 and images[i, 1] + 1) * 3, 0, 5.99))  # 6 scales
                orient_idx = int(np.clip((self.n_hidden_variables > 2 and images[i, 2] + 1) * 20, 0, 39.99))  # 40 orientations
                pos_x_idx = int(np.clip((self.n_hidden_variables > 3 and images[i, 3] + 1) * 16, 0, 31.99))  # 32 x positions
                pos_y_idx = int(np.clip((self.n_hidden_variables > 4 and images[i, 4] + 1) * 16, 0, 31.99))  # 32 y positions
                
                # Find the matching latent combination
                latent_idx = np.argmin(np.sum(np.abs(self.dsprites_latents - 
                                                    [1.0, shape_idx/2, scale_idx/5, orient_idx/39, pos_x_idx/31, pos_y_idx/31]), axis=1))
                
                # Get the corresponding image
                sprite = self.dsprites_data[latent_idx]
                
                # Add to the image
                for c in range(self.n_channels):
                    if self.use_color and self.n_channels > 1:
                        # Add some color variations if multi-channel
                        color_factor = np.random.uniform(0.5, 1.0, size=self.n_channels)
                        images_reshaped[i, :, :, c] = sprite * color_factor[c]
                    else:
                        images_reshaped[i, :, :, c] = sprite
        
        elif self.pattern_type == "textures":
            # Generate more elaborate textures based on geometric functions
            for i in range(self.n_samples):
                # Extract parameters from hidden variables (if available)
                scale = self.texture_scale
                if self.n_hidden_variables > 0:
                    scale *= (1.0 + 0.5 * images[i, 0])
                
                phase = 0.0
                if self.n_hidden_variables > 1:
                    phase = np.pi * images[i, 1]
                
                complexity = self.pattern_complexity
                if self.n_hidden_variables > 2:
                    complexity = max(1, int(self.pattern_complexity * (1.0 + 0.5 * images[i, 2])))
                
                # Create coordinate grids
                x = np.linspace(-1, 1, self.image_width)
                y = np.linspace(-1, 1, self.image_height)
                xx, yy = np.meshgrid(x, y)
                
                # Generate texture based on geometric function
                if self.geometric_function == "waves":
                    # Superposition of sine waves with different frequencies and orientations
                    pattern = np.zeros((self.image_height, self.image_width))
                    for j in range(complexity):
                        # Random direction
                        theta = np.random.uniform(0, 2 * np.pi)
                        # Random frequency
                        freq = np.random.uniform(5, 10) * scale
                        # clamp to avoid too small values
                        freq = max(freq, 5)
                        # Create wave along direction
                        wave = np.sin(freq * (xx * np.cos(theta) + yy * np.sin(theta)) + phase)
                        pattern += wave / complexity
                    
                    # Normalize to [0, 1]
                    pattern = (pattern + 1) / 2
                
                elif self.geometric_function == "spirals":
                    # Archimedean spiral patterns
                    r = np.sqrt(xx**2 + yy**2)
                    phi = np.arctan2(yy, xx)
                    
                    pattern = np.zeros((self.image_height, self.image_width))
                    for j in range(complexity):
                        spiral_freq = np.random.uniform(1, 5) * scale
                        spiral = np.sin(spiral_freq * (r * 5 + phi) + phase)
                        pattern += spiral / complexity
                    
                    # Normalize to [0, 1]
                    pattern = (pattern + 1) / 2
                
                elif self.geometric_function == "voronoi":
                    # Voronoi patterns
                    pattern = np.ones((self.image_height, self.image_width))
                    # Generate random points
                    n_points = 5 * complexity
                    points = np.random.uniform(-1, 1, (n_points, 2))
                    
                    # For each pixel, find closest point
                    for y_idx in range(self.image_height):
                        for x_idx in range(self.image_width):
                            point = np.array([xx[y_idx, x_idx], yy[y_idx, x_idx]])
                            distances = np.sqrt(np.sum((points - point)**2, axis=1))
                            min_dist = np.min(distances)
                            second_min = np.partition(distances, 1)[1]
                            # Edge between cells
                            if abs(min_dist - second_min) < 0.05 / scale:
                                pattern[y_idx, x_idx] = 0
                
                # Apply to all channels
                for c in range(self.n_channels):
                    if self.n_channels > 1:
                        # Add some color variations if multi-channel
                        color_factor = np.random.uniform(0.5, 1.0)
                        images_reshaped[i, :, :, c] = pattern * color_factor
                    else:
                        images_reshaped[i, :, :, c] = pattern
        
        # Normalize to [0, 1] range
        images_reshaped = np.clip(images_reshaped, 0, 1)
        
        # Reshape back to original format
        return images_reshaped.reshape(self.n_samples, -1)
    
    def _apply_smoothing(self, data: np.ndarray) -> np.ndarray:
        """Apply smoothing to reduce pixelation in images."""
        # Reshape to image dimensions for smoothing
        reshaped_data = data.reshape(self.n_samples, self.image_height, self.image_width, self.n_channels)
        smoothed_data = np.zeros_like(reshaped_data)
        
        # Calculate sigma based on smoothing strength and kernel size
        # Higher smoothing_strength = more blur
        sigma = self.smoothing_strength * (self.kernel_size / 6.0)
        
        # Apply Gaussian filter to each image and channel
        for i in range(self.n_samples):
            for c in range(self.n_channels):
                smoothed_data[i, :, :, c] = gaussian_filter(
                    reshaped_data[i, :, :, c], 
                    sigma=sigma, 
                    truncate=(self.kernel_size-1)/2
                )
        
        # Reshape back to original format
        return smoothed_data.reshape(self.n_samples, -1)
    
    def _add_noise(self, data: np.ndarray) -> np.ndarray:
        """Add Gaussian noise to the image data."""
        # Skip adding noise if variance is zero
        if self.noise_variance == 0:
            return data
            
        if self.n_noise_components == 1:
            noise = np.random.normal(self.noise_mean, self.noise_variance, size=data.shape)
        else:
            # True mixture of Gaussians: sample component ID for each data point
            noise = np.zeros_like(data)
            
            # Get the shape of the data for reshaping operations
            original_shape = data.shape
            flattened_size = np.prod(original_shape)
            
            # Sample component IDs for each element
            component_ids = np.random.choice(
                self.n_noise_components, 
                size=flattened_size, 
                p=self.noise_component_weights
            )
            
            # Create the noise by sampling from the corresponding components
            flat_noise = np.zeros(flattened_size)
            for comp_id in range(self.n_noise_components):
                # Get indices where this component was selected
                idx = (component_ids == comp_id)
                # Sample noise from this component for those indices
                flat_noise[idx] = np.random.normal(
                    self.noise_component_means[comp_id], 
                    self.noise_component_variances[comp_id],
                    size=np.sum(idx)
                )
            
            # Reshape back to original dimensions
            noise = flat_noise.reshape(original_shape)
        
        return np.clip(data + noise, 0, 1)  # Keep in valid image range
    
    def _apply_sparsity(self, data: np.ndarray) -> np.ndarray:
        """Apply dropout to simulate data sparsity (like missing pixels)."""
        mask = np.random.rand(*data.shape) > self.data_sparsity
        return data * mask
    
    def generate_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate image data based on hidden variables.
        
        Returns:
            Tuple containing (image_data, hidden_variables)
        """
        if self.pattern_type == "shapes":
            data_idx = np.random.randint(0, self.dsprites_data.shape[0], size=self.n_samples)
            data = self.dsprites_data.copy()[data_idx,:]
            hidden_vars = self.dsprites_latents[data_idx,:]
        else:
            # Generate hidden variables
            hidden_vars = self._generate_hidden_variables()
            
            # Transform hidden variables to image data - skip for dsprites "shapes"
            if self.pattern_type == "shapes":
                # For dsprites, we'll use the hidden variables directly in _add_pattern_components
                data = np.zeros((self.n_samples, self.image_height * self.image_width * self.n_channels))
            else:
                data = hidden_vars @ self.hidden_to_data_matrix
                
                # Apply nonlinearities
                for _ in range(self.nonlinearity_level):
                    data = self._apply_nonlinearity(data)
            
            # Add pattern components specific to images
            data = self._add_pattern_components(data)
        
        # Apply smoothing to reduce pixelation
        if self.smoothing_strength > 0:
            data = self._apply_smoothing(data)
        
        # Apply sparsity
        data = self._apply_sparsity(data)
        
        # Add noise
        data = self._add_noise(data)
        
        # Return both the image data and the hidden variables
        return data, hidden_vars


@dataclass
class TimeSeriesDataSimulator:
    """Class for simulating time series data with configurable hidden variables."""
    n_samples: int = 100
    n_hidden_variables: int = 5
    hidden_dist_type: Literal["gaussian", "uniform"] = "gaussian"
    sequence_length: int = 200
    data_dim: int = 10  # Number of channels/variables in the time series
    nonlinearity_level: int = 1
    nonlinearity_type: Literal["sigmoid", "tanh", "exponential"] = "tanh"
    hidden_connectivity: float = 0.7
    data_sparsity: float = 0.1
    noise_variance: float = 0.1
    noise_mean: float = 0.0
    n_noise_components: int = 1
    # Time series specific parameters
    seasonality_strength: float = 0.5
    trend_strength: float = 0.3
    n_seasonal_patterns: int = 2
    seasonal_periods: List[int] = field(default_factory=lambda: [20, 50])
    autocorrelation_lag: int = 5
    autocorrelation_strength: float = 0.7
    changepoint_probability: float = 0.05
    random_seed: Optional[int] = None
    
    def __post_init__(self):
        if self.random_seed is not None:
            np.random.seed(self.random_seed)
        
        if len(self.seasonal_periods) != self.n_seasonal_patterns:
            self.seasonal_periods = [
                np.random.randint(10, self.sequence_length // 4) 
                for _ in range(self.n_seasonal_patterns)
            ]
        
        # Initialize mapping from hidden variables to time series features
        self.hidden_to_feature_matrix = self._create_sparse_matrix(
            self.n_hidden_variables, self.data_dim, self.hidden_connectivity)
        
        # Initialize noise component variances and means
        if self.n_noise_components > 1:
            self.noise_component_variances = np.random.uniform(
                0.1 * self.noise_variance, 
                self.noise_variance, 
                size=self.n_noise_components
            )
            # Add slight variations to the mean for different components
            self.noise_component_means = self.noise_mean + np.random.uniform(
                -0.1 * abs(self.noise_mean) - 0.01, 
                0.1 * abs(self.noise_mean) + 0.01, 
                size=self.n_noise_components
            )
            self.noise_component_weights = np.ones(self.n_noise_components) / self.n_noise_components
    
    def _create_sparse_matrix(self, rows: int, cols: int, density: float) -> np.ndarray:
        """Create a sparse matrix with given density of non-zero elements."""
        matrix = np.random.randn(rows, cols)
        mask = np.random.rand(rows, cols) > density
        matrix[mask] = 0
        return matrix
    
    def _generate_hidden_variables(self) -> np.ndarray:
        """Generate hidden variables according to the specified distribution."""
        if self.hidden_dist_type == "gaussian":
            return np.random.randn(self.n_samples, self.n_hidden_variables)
        elif self.hidden_dist_type == "uniform":
            return np.random.uniform(-1, 1, size=(self.n_samples, self.n_hidden_variables))
        else:
            raise ValueError(f"Unknown distribution type: {self.hidden_dist_type}")
    
    def _apply_nonlinearity(self, data: np.ndarray) -> np.ndarray:
        """Apply specified nonlinearity to the data."""
        if self.nonlinearity_type == "sigmoid":
            return 1 / (1 + np.exp(-data))
        elif self.nonlinearity_type == "tanh":
            return np.tanh(data)
        elif self.nonlinearity_type == "exponential":
            # Clip for stability
            data = np.clip(data, -10, 10)
            return np.exp(data)
        else:
            raise ValueError(f"Unknown nonlinearity type: {self.nonlinearity_type}")
    
    def _add_time_series_components(self, base_features: np.ndarray) -> np.ndarray:
        """Add time series specific components like trends, seasonality, and autocorrelation."""
        # Initialize the result array (samples, sequence_length, data_dim)
        result = np.zeros((self.n_samples, self.sequence_length, self.data_dim))
        
        # For each sample
        for i in range(self.n_samples):
            # Initialize with the base features
            features = base_features[i]
            
            # For each time series channel
            for j in range(self.data_dim):
                # Start with the base feature value
                time_series = np.zeros(self.sequence_length)
                
                # Add trend component
                if self.trend_strength > 0:
                    trend_coef = np.random.uniform(-self.trend_strength, self.trend_strength)
                    trend = np.linspace(0, 1, self.sequence_length) * trend_coef * features[j]
                    time_series += trend
                
                # Add seasonal components
                if self.seasonality_strength > 0:
                    for period in self.seasonal_periods:
                        amplitude = np.random.uniform(0, self.seasonality_strength * abs(features[j]))
                        phase = np.random.uniform(0, 2 * np.pi)
                        seasonal = amplitude * np.sin(2 * np.pi * np.arange(self.sequence_length) / period + phase)
                        time_series += seasonal
                
                # Add changepoints
                if self.changepoint_probability > 0:
                    # Generate random changepoints
                    changepoints = np.random.rand(self.sequence_length) < self.changepoint_probability
                    
                    # Apply level shifts at changepoints
                    level = 0
                    for t in range(self.sequence_length):
                        if changepoints[t]:
                            level += np.random.uniform(-1, 1) * features[j]
                        time_series[t] += level
                
                # Add autocorrelation
                if self.autocorrelation_strength > 0:
                    ar_series = np.zeros(self.sequence_length)
                    for t in range(self.autocorrelation_lag, self.sequence_length):
                        for lag in range(1, self.autocorrelation_lag + 1):
                            lag_weight = self.autocorrelation_strength ** lag
                            ar_series[t] += lag_weight * ar_series[t - lag]
                        # Add innovation term
                        ar_series[t] += np.random.normal(0, 0.1)
                    
                    time_series += ar_series * abs(features[j])
                
                # Store the resulting time series
                result[i, :, j] = time_series
        
        return result
    
    def _add_noise(self, data: np.ndarray) -> np.ndarray:
        """Add Gaussian noise to the time series data."""
        # Skip adding noise if variance is zero
        if self.noise_variance == 0:
            return data
            
        if self.n_noise_components == 1:
            noise = np.random.normal(self.noise_mean, self.noise_variance, size=data.shape)
        else:
            # True mixture of Gaussians: sample component ID for each data point
            noise = np.zeros_like(data)
            
            # Get the shape of the data for reshaping operations
            original_shape = data.shape
            flattened_size = np.prod(original_shape)
            
            # Sample component IDs for each element
            component_ids = np.random.choice(
                self.n_noise_components, 
                size=flattened_size, 
                p=self.noise_component_weights
            )
            
            # Create the noise by sampling from the corresponding components
            flat_noise = np.zeros(flattened_size)
            for comp_id in range(self.n_noise_components):
                # Get indices where this component was selected
                idx = (component_ids == comp_id)
                # Sample noise from this component for those indices
                flat_noise[idx] = np.random.normal(
                    self.noise_component_means[comp_id], 
                    self.noise_component_variances[comp_id],
                    size=np.sum(idx)
                )
            
            # Reshape back to original dimensions
            noise = flat_noise.reshape(original_shape)
        
        return data + noise
    
    def _apply_sparsity(self, data: np.ndarray) -> np.ndarray:
        """Apply dropout to simulate missing values in time series."""
        mask = np.random.rand(*data.shape) > self.data_sparsity
        return data * mask
    
    def generate_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate time series data based on hidden variables.
        
        Returns:
            Tuple containing (time_series_data, hidden_variables)
        """
        # Generate hidden variables
        hidden_vars = self._generate_hidden_variables()
        
        # Transform hidden variables to base features
        base_features = hidden_vars @ self.hidden_to_feature_matrix
        
        # Apply nonlinearities
        for _ in range(self.nonlinearity_level):
            base_features = self._apply_nonlinearity(base_features)
        
        # Generate time series from base features
        time_series = self._add_time_series_components(base_features)
        
        # Apply sparsity (missing values)
        time_series = self._apply_sparsity(time_series)
        
        # Add noise
        time_series = self._add_noise(time_series)
        
        # Return the time series data and hidden variables
        return time_series, hidden_vars
