import abc
import yaml
from typing import Any
import torch
import numpy as np
from torch.utils.data import Dataset
from functools import partial 
import matplotlib.pyplot as plt
from utils.misc import dotdict

class DiscreteToyDataset(Dataset, abc.ABC):
    def __iter__(self):
        return self
    
    def __next__(self):
        pass
        
    def __len__(self):
        pass

    def generate_cond(self, n_samples):
        pass
    
    def sample_from_condition(self, cond):
        pass

    def plot_samples(self, samples, cond, path):
        pass
    

class TensorMixture(DiscreteToyDataset):
    def __init__(self, batch_size, tensors, weights=None):
        """
        Initialize a dataset that samples from a mixture of probability tensors.
        
        Args:
            batch_size: Number of samples per batch
            tensors: A list of tensors of any shape, each representing a probability distribution
            weights: Optional list of weights for each tensor. If None, equal weights are used.
        """
        self.batch_size = batch_size
        self.tensors = tensors
        self.n_classes = len(tensors)
        
        # Set up weights for the mixture
        if weights is None:
            self.weights = torch.ones(self.n_classes).float() / self.n_classes
        else:
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        # Ensure all tensors have the same shape
        shapes = [t.shape for t in tensors]
        if len(set(shapes)) > 1:
            raise ValueError("All tensors must have the same shape")
            
        self.tensor_shape = shapes[0]
        self.context_len = len(self.tensor_shape)
        self.vocab_size = max(self.tensor_shape) if self.tensor_shape else 1
        
        # Normalize each tensor to create proper probability distributions
        self.tensor_probs = []
        self.flat_probs = []
        self.full_tensor = torch.zeros_like(tensors[0]).float()
        
        for w, tensor in zip(self.weights, tensors):
            probs = tensor / tensor.sum()
            self.tensor_probs.append(probs)
            self.flat_probs.append(probs.flatten())
            self.full_tensor += w * probs
        
        self.cond_dim = len(self.tensor_probs)
        
        # Create indices for all possible coordinates
        if len(self.tensor_shape) > 0:
            # Create meshgrid for n-dimensional indexing
            ranges = [torch.arange(dim) for dim in self.tensor_shape]
            if len(ranges) == 1:
                self.indices = ranges[0].unsqueeze(1)
            else:
                meshgrids = torch.meshgrid(*ranges, indexing='ij')
                self.indices = torch.stack([mg.flatten() for mg in meshgrids], dim=1)
        else:
            # Handle scalar case
            self.indices = torch.tensor([[0]])

    def __len__(self):
        return 1000

    def generate_cond(self, n_samples):
        """Generate component indices as conditions."""
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)

    def sample_from_condition(self, cond):
        """Sample from the specified mixture component."""
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, self.context_len), dtype=torch.int64)
        
        # Sample for each class separately
        for class_idx in range(self.n_classes):
            # Find samples with this class
            mask = (cond.flatten() == class_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Sample indices according to the probability distribution for this class
                flat_indices = torch.multinomial(
                    self.flat_probs[class_idx], 
                    count, 
                    replacement=True
                )
                # Convert to n-dimensional coordinates
                if self.context_len == 1:
                    class_samples = flat_indices.unsqueeze(1)
                else:
                    class_samples = self.indices[flat_indices]
                # Assign to the correct positions in the output tensor
                samples[mask] = class_samples
                
        return samples

    def sample(self, n_samples=None):
        """Sample from the mixture."""
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self

    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def get_guided_distribution(self, class_idx, w):
        """
        Apply classifier guidance to get a guided distribution between 
        a specific component and the full mixture.
        
        Args:
            class_idx: Index of the component to guide towards
            w: Guidance weight (0 = full mixture, 1 = only the component)
        
        Returns:
            A tensor representing the guided distribution
        """
        p_i = self.tensor_probs[class_idx]
        p = self.full_tensor
        
        # Apply classifier guidance formula in log space for numerical stability
        log_p_i = torch.log(torch.clamp(p_i, min=1e-10))
        log_p = torch.log(torch.clamp(p, min=1e-10))
        
        log_tempered = w * log_p_i + (1-w) * log_p
        tempered = torch.exp(log_tempered)
        
        # Normalize to get a valid probability distribution
        if tempered.sum() > 0:
            return tempered / tempered.sum()
        return tempered

    def plot_samples(self, samples, path=None, plot_tensors=True, fig=None, ax=None):
        """
        Plot samples. The visualization method depends on tensor dimensionality.
        """
        samples = samples.detach().cpu().numpy()
        
        if self.context_len == 1:
            # 1D case - use bar plots
            self._plot_1d_samples(samples, path, plot_tensors, fig, ax)
        elif self.context_len == 2:
            # 2D case - use heatmaps
            self._plot_2d_samples(samples, path, plot_tensors, fig, ax)
        else:
            # Higher dimensions - use pairwise projections
            self._plot_nd_samples(samples, path)
    
    def _plot_1d_samples(self, samples, path, plot_tensors, fig, ax):
        """Plot 1D tensor samples."""
        samples = samples.flatten()
        
        if plot_tensors:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
            
            # Create histogram of samples
            hist, _ = np.histogram(
                samples,
                bins=self.tensor_shape[0],
                range=(0, self.tensor_shape[0]),
                density=True
            )
            
            to_plot = [*self.tensors, self.full_tensor, hist]
            names = [*(f'Class {i}' for i in range(len(self.tensors))), 'Full Prob', 'Empirical']
            
            for i, (tensor, name) in enumerate(zip(to_plot, names)):
                tensor_np = tensor.cpu().numpy() if torch.is_tensor(tensor) else tensor
                axes[i].bar(np.arange(len(tensor_np)), tensor_np)
                axes[i].set_title(name)
                axes[i].set_xlabel('Position')
                axes[i].set_ylabel('Probability')
        else:
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 4))
            
            hist, _ = np.histogram(
                samples,
                bins=self.tensor_shape[0],
                range=(0, self.tensor_shape[0]),
                density=True
            )
            
            ax.bar(np.arange(len(hist)), hist)
            ax.set_title('Samples Distribution')
            ax.set_xlabel('Position')
            ax.set_ylabel('Frequency')
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()
    
    def _plot_2d_samples(self, samples, path, plot_tensors, fig, ax):
        """Plot 2D tensor samples."""
        if plot_tensors:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
            
            hist, _, _ = np.histogram2d(
                samples[:, 1], samples[:, 0],
                bins=self.tensor_shape,
                range=[[0, self.tensor_shape[1]], [0, self.tensor_shape[0]]],
                density=True
            )

            to_plot = [*self.tensors, self.full_tensor, hist.T]
            names = [*(f'Class {i}' for i in range(len(self.tensors))), 'Full Prob', 'Empirical']

            for i, (tensor, name) in enumerate(zip(to_plot, names)):
                tensor_np = tensor.cpu().numpy() if torch.is_tensor(tensor) else tensor
                im = axes[i].imshow(tensor_np, cmap='viridis', origin='lower')
                axes[i].set_title(name)
                axes[i].set_xlabel('Dim 1')
                axes[i].set_ylabel('Dim 0')
        else:
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 8))
            
            hist, _, _ = np.histogram2d(
                samples[:, 1], samples[:, 0],
                bins=self.tensor_shape,
                range=[[0, self.tensor_shape[1]], [0, self.tensor_shape[0]]],
                density=True
            )
            
            im = ax.imshow(hist.T, cmap='viridis', origin='lower')
            ax.set_title('Samples Distribution')
            ax.set_xlabel('Dim 1')
            ax.set_ylabel('Dim 0')
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()
    
    def _plot_nd_samples(self, samples, path):
        """Plot n-dimensional tensor samples using pairwise projections."""
        n_dims = self.context_len
        n_pairs = min(10, n_dims * (n_dims - 1) // 2)  # Limit to 10 plots
        
        # Calculate grid size
        cols = min(5, n_pairs)
        rows = (n_pairs + cols - 1) // cols
        
        # Create figure with extra space for colorbar
        fig = plt.figure(figsize=(3 * cols + 1, 3 * rows))  # Extra width for colorbar
        
        # Create a grid layout: main plots + colorbar space
        gs = fig.add_gridspec(rows, cols + 1, width_ratios=[1] * cols + [0.05])
        
        # Generate dimension pairs
        dim_pairs = []
        for i in range(n_dims):
            for j in range(i + 1, n_dims):
                dim_pairs.append((i, j))
                if len(dim_pairs) >= n_pairs:
                    break
            if len(dim_pairs) >= n_pairs:
                break
        
        # Keep track of images for colorbar
        images = []
        
        for idx, (dim1, dim2) in enumerate(dim_pairs[:n_pairs]):
            row = idx // cols
            col = idx % cols
            ax = fig.add_subplot(gs[row, col])
            
            # Create 2D histogram for this pair of dimensions
            hist, _, _ = np.histogram2d(
                samples[:, dim1], samples[:, dim2],
                bins=[self.tensor_shape[dim1], self.tensor_shape[dim2]],
                range=[[0, self.tensor_shape[dim1]], [0, self.tensor_shape[dim2]]],
                density=True
            )
            
            im = ax.imshow(hist.T, cmap='viridis', origin='lower')
            images.append(im)
            ax.set_title(f'Dims {dim1} vs {dim2}')
            ax.set_xlabel(f'Dimension {dim1}')
            ax.set_ylabel(f'Dimension {dim2}')
        
        # Add colorbar in the dedicated space
        if images:
            # Create colorbar axis spanning the full height
            cbar_ax = fig.add_subplot(gs[:, -1])
            fig.colorbar(images[-1], cax=cbar_ax, label='Density')
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()


class MatrixInputMixture(DiscreteToyDataset):
    def __init__(self, batch_size, matrices, weights=None):
        """
        Initialize a dataset that samples from a mixture of probability matrices.
        
        Args:
            batch_size: Number of samples per batch
            matrices: A list of 2D tensors of probabilities, each with shape (height, width)
            weights: Optional list of weights for each matrix. If None, equal weights are used.
        """
        self.batch_size = batch_size
        self.context_len = 2
        self.matrices = matrices
        self.n_classes = len(matrices)
        
        # Set up weights for the mixture
        if weights is None:
            # Equal weights if none provided
            self.weights = torch.ones(self.n_classes) / self.n_classes
        else:
            # Normalize provided weights
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        # Ensure all matrices have the same shape
        shapes = [m.shape for m in matrices]
        if len(set(shapes)) > 1:
            raise ValueError("All matrices must have the same shape")
            
        self.height, self.width = shapes[0]
        self.vocab_size = max(self.height, self.width)
        
        # Normalize each matrix to create proper probability distributions
        self.m_probs = []
        self.flat_probs = []
        self.full_matrix = 0
        for w, matrix in zip(self.weights, matrices):
            probs = matrix / matrix.sum()
            self.m_probs.append(probs)
            self.flat_probs.append(probs.flatten())
            self.full_matrix += w * probs
        
        self.cond_dim = len(self.m_probs)
        
        # Create indices for all possible (i,j) coordinates
        i_indices, j_indices = torch.meshgrid(
            torch.arange(self.height), 
            torch.arange(self.width),
            indexing='ij'
        )
        self.indices = torch.stack([i_indices.flatten(), j_indices.flatten()], dim=1)

    def __len__(self):
        return 1000

    def generate_cond(self, n_samples):
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)

    def sample_from_condition(self, cond):
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, 2), dtype=torch.int64)
        
        # Sample for each class separately
        for class_idx in range(self.n_classes):
            # Find samples with this class
            mask = (cond.flatten() == class_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Sample indices according to the probability distribution for this class
                flat_indices = torch.multinomial(
                    self.flat_probs[class_idx], 
                    count, 
                    replacement=True
                )
                # Convert to (i,j) coordinates
                class_samples = self.indices[flat_indices]
                # Assign to the correct positions in the output tensor
                samples[mask] = class_samples
                
        return samples

    def sample(self, n_samples=None):
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self

    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def safe_exp(self, p, w):
        return torch.where(p > 0., p**w, 0.)

    def get_guided_distribution(self, class_idx, w):
        p_i = self.m_probs[class_idx]
        p = self.full_matrix
        m_iw = p_i**w
        m_w = self.safe_exp(p, 1-w)
        tempered = m_iw * m_w
        tempered = tempered/tempered.sum()

        dim_1 = tempered.sum(-1)
        dim_0 = tempered.sum(-2)
        ddd1 = self.safe_exp(p.sum(-1), (1-w)) * p_i.sum(-1)**w
        ddd0 = self.safe_exp(p.sum(-2), (1-w)) * p_i.sum(-2)**w

        inv_ci = torch.where(dim_0 != 0, ddd0 / dim_0, 1.)
        inv_cN = (ddd1).sum() / tempered.sum()
        inv_dj = torch.where(dim_1 != 0, ddd1 / dim_1, 1.)
        inv_dN = (ddd0).sum() / tempered.sum()

        den = inv_cN + inv_dN

        coeff_M = inv_ci.unsqueeze(0) + inv_dj.unsqueeze(1)
        coeff_M /= den
        
        return tempered * coeff_M

    def plot_matrix_with_annotations(self, ax, matrix, title, xlabel='j', ylabel='i'):
        """
        Helper method to plot a probability matrix with text annotations.
        
        Args:
            ax: The matplotlib axis to plot on
            matrix: The matrix to plot
            title: Title for the plot
            xlabel: Label for x-axis
            ylabel: Label for y-axis
        """
        # Normalize to get true probabilities if not already normalized
        matrix_prob = matrix / matrix.sum() if matrix.sum() > 0 else matrix
        
        im = ax.imshow(matrix_prob, cmap='viridis', origin='lower')
        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        
        # Only add text annotations if the matrix is small enough
        # (typically matrices larger than 20x20 become too cluttered)
        max_size_for_annotations = 20
        if matrix.shape[0] <= max_size_for_annotations and matrix.shape[1] <= max_size_for_annotations:
            # Add text annotations with probability values
            for y in range(matrix_prob.shape[0]):
                for x in range(matrix_prob.shape[1]):
                    value = matrix_prob[y, x]
                    # Only show text for cells with non-zero probability
                    if value > 0:
                        # Format the value with appropriate precision
                        if value >= 0.01:
                            text = f'{value:.2f}'
                        else:
                            # For very small values, show truncated format
                            text = f'{value:.4f}'
                        # Choose text color based on background darkness
                        text_color = 'black'
                        # Use cell centers for text placement
                        ax.text(x, y, text, ha='center', va='center', 
                                color=text_color, fontsize=8)
        
        return im

    def plot_samples(self, samples, path=None, plot_matrices=True, fig=None, ax=None):
        # Convert samples to numpy for plotting
        samples = samples.detach().cpu().numpy()
        
        if plot_matrices:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 5))
            
            hist, _, _ = np.histogram2d(
                samples[:, 1],
                samples[:, 0],
                bins=[self.width, self.height],
                range=[[0, self.width], [0, self.height]],
                density=True
            )

            to_plot = [*self.matrices, self.full_matrix, hist.T ]
            names = [*(f'Class {i}' for i in range(len(self.matrices))), 'Full Prob', 'Empirical']

            for i, (mat, name) in enumerate(zip(to_plot, names)):
                self.plot_matrix_with_annotations(
                    axes[i], 
                    mat.cpu().numpy() if torch.is_tensor(mat) else mat, 
                    name
                )
        else:
            # Only plot the combined samples histogram
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 8))
            
            # Create a histogram for all samples
            hist, _, _ = np.histogram2d(
                samples[:, 1],  # j coordinate (x-axis)
                samples[:, 0],  # i coordinate (y-axis)
                bins=[self.width, self.height],
                range=[[0, self.width], [0, self.height]],
                density=True
            )
            
            # Plot the histogram as a matrix
            self.plot_matrix_with_annotations(
                ax, 
                hist.T,  # Transpose to match matrix orientation
                'Samples Distribution'
            )
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()

class VectorInputMixture(DiscreteToyDataset):
    def __init__(self, batch_size, vectors, weights=None):
        self.batch_size = batch_size
        self.context_len = 1
        self.vectors = vectors
        self.n_classes = len(vectors)
        
        if weights is None:
            self.weights = torch.ones(self.n_classes).float() / self.n_classes
        else:
            weights_tensor = torch.tensor(weights, dtype=torch.float)
            self.weights = weights_tensor / weights_tensor.sum()
        
        lengths = [v.shape[0] for v in vectors]
        if len(set(lengths)) > 1:
            raise ValueError("All vectors must have the same length")
            
        self.length = lengths[0]
        self.vocab_size = self.length
        
        # Normalize each vector to create proper probability distributions
        self.v_probs = []
        self.full_vector = torch.zeros_like(vectors[0]).float()  # Initialize with zeros of correct shape
        for w, vector in zip(self.weights, vectors):
            probs = vector / vector.sum()
            self.v_probs.append(probs)
            self.full_vector += w * probs
        
        self.cond_dim = len(self.v_probs)
        
        # Create indices for all possible positions
        self.indices = torch.arange(self.length).unsqueeze(1)

    def __len__(self):
        return 1000

    def generate_cond(self, n_samples):
        return torch.multinomial(self.weights, n_samples, replacement=True).reshape(-1, 1)

    def sample_from_condition(self, cond):
        n_samples = cond.shape[0]
        samples = torch.zeros((n_samples, 1), dtype=torch.int64)
        
        # Sample for each class separately
        for class_idx in range(self.n_classes):
            # Find samples with this class
            mask = (cond.flatten() == class_idx)
            count = mask.sum().item()
            
            if count > 0:
                # Sample indices according to the probability distribution for this class
                indices = torch.multinomial(
                    self.v_probs[class_idx], 
                    count, 
                    replacement=True
                )
                # Assign to the correct positions in the output tensor
                samples[mask] = indices.unsqueeze(1)
                
        return samples

    def sample(self, n_samples=None):
        n = self.batch_size if n_samples is None else n_samples
        cond = self.generate_cond(n)
        samples = self.sample_from_condition(cond)
        return samples, cond
    
    def __iter__(self):
        return self

    def __next__(self):
        samples, cond = self.sample()
        return samples, cond
    
    def safe_exp(self, p, w):
        return torch.where(p > 0., p**w, 0.)

    def get_guided_distribution(self, class_idx, w):
        p_i = self.v_probs[class_idx]
        p = self.full_vector
        
        # Apply classifier guidance formula in log space for numerical stability
        # log(p_i^w * p^(1-w)) = w*log(p_i) + (1-w)*log(p)
        log_p_i = torch.log(torch.clamp(p_i, min=1e-10))
        log_p = torch.log(torch.clamp(p, min=1e-10))
        
        log_tempered = w * log_p_i + (1-w) * log_p
        tempered = torch.exp(log_tempered)
        
        # Normalize to get a valid probability distribution
        if tempered.sum() > 0:
            return tempered / tempered.sum()
        return tempered

    def plot_samples(self, samples, path=None, plot_vectors=True, fig=None, ax=None):
        # Convert samples to numpy for plotting
        samples = samples.detach().cpu().numpy().flatten()
        
        if plot_vectors:
            n_cols = self.n_classes + 2
            fig, axes = plt.subplots(1, n_cols, figsize=(5 * n_cols, 4))
            
            # Create histogram of samples
            hist, _ = np.histogram(
                samples,
                bins=self.length,
                range=(0, self.length),
                density=True
            )
            
            to_plot = [*self.vectors, self.full_vector, hist]
            names = [*(f'Class {i}' for i in range(len(self.vectors))), 'Full Prob', 'Empirical']
            
            for i, (vec, name) in enumerate(zip(to_plot, names)):
                vec_np = vec.cpu().numpy() if torch.is_tensor(vec) else vec
                axes[i].bar(np.arange(self.length), vec_np)
                axes[i].set_title(name)
                axes[i].set_xlabel('Position')
                axes[i].set_ylabel('Probability')
                axes[i].set_xlim(-0.5, self.length-0.5)
        else:
            # Only plot the combined samples histogram
            if fig is None or ax is None:
                fig, ax = plt.subplots(figsize=(8, 4))
            
            # Create a histogram for all samples
            hist, bins = np.histogram(
                samples,
                bins=self.length,
                range=(0, self.length),
                density=True
            )
            
            # Plot the histogram
            ax.bar(np.arange(self.length), hist)
            ax.set_title('Samples Distribution')
            ax.set_xlabel('Position')
            ax.set_ylabel('Frequency')
            ax.set_xlim(-0.5, self.length-0.5)
        
        plt.tight_layout()
        if path is not None:
            fig.savefig(path)
            plt.close(fig)
        else:
            plt.show()

def get_dataset(name, batch_size=6):
    if name == 'matrix-disjoint':
        import torch
        
        # Matrix size
        height, width = 30, 30
        
        matrix1 = torch.zeros((height, width))
        cluster = [
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.1, 0.2, 0.5, 0.8, 1.0, 0.8, 0.5, 0.2, 0.1],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
        ]
        matrix1[7:16, 7:16] = torch.tensor(cluster)
        matrix1[1:10, 1:10] = torch.maximum(matrix1[1:10, 1:10], torch.tensor(cluster))
        
        matrix2 = torch.zeros((height, width))
        matrix2[13:22, 13:22] = torch.tensor(cluster)
        matrix2[19:28, 19:28] = torch.maximum(matrix2[19:28, 19:28], torch.tensor(cluster))
        weights = [0.5, 0.5]
        return MatrixInputMixture(batch_size=batch_size, matrices=[matrix1, matrix2], weights=weights)
    elif name == 'matrix-intersection':
        import torch
        
        # Matrix size
        height, width = 30, 30
        
        matrix1 = torch.zeros((height, width))
        cluster = [
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.1, 0.2, 0.5, 0.8, 1.0, 0.8, 0.5, 0.2, 0.1],
            [0.0, 0.1, 0.3, 0.6, 0.8, 0.6, 0.3, 0.1, 0.0],
            [0.0, 0.0, 0.1, 0.3, 0.5, 0.3, 0.1, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.1, 0.2, 0.1, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0],
        ]
        matrix1[1:10, 1:10] = torch.tensor(cluster)
        matrix1[9:18, 9:18] = torch.tensor(cluster)

        matrix2 = torch.zeros((height, width))
        matrix2[11:20, 11:20] = torch.tensor(cluster)
        matrix2[19:28, 19:28] = torch.tensor(cluster)
        

        weights = [0.5, 0.5]
        return MatrixInputMixture(batch_size=batch_size, matrices=[matrix1, matrix2], weights=weights)
    elif name == 'cubes':
        import torch
        def create_d_cube_tensor(shape, cube_bounds):
            tensor = torch.zeros(shape)
            D = len(shape)
            
            # Build slice objects for each dimension
            slices = tuple(slice(cube_bounds[i][0], cube_bounds[i][1] + 1) for i in range(D))
            tensor[slices] = 1.0
            
            return tensor

        D = 5
        grid_size = 5
        tensor_shape = tuple([grid_size] * D)
        
        # Cube 1: corner cube starting at origin (3^D cube)
        cube1_bounds = tuple((0, 2) for _ in range(D))
        tensor1 = create_d_cube_tensor(tensor_shape, cube1_bounds)

        # Cube 2: intersecting cube starting at corner (3^D cube) - intersects at corner point
        cube2_bounds = tuple((2, 4) for _ in range(D))
        tensor2 = create_d_cube_tensor(tensor_shape, cube2_bounds)

        return TensorMixture(batch_size=32, tensors=[tensor1, tensor2], weights=[0.5, 0.5])
    else:
        print('Dataset is not implemented')
