import torch
from torch.distributions import Categorical
from typing import Optional

def simulate_CMLR_data(
    w_true: torch.Tensor,
    batch_size: Optional[int] = None
) -> torch.Tensor:
    """
    Memory-efficient sampling of feature vectors X such that each orientation
    (row of w_true) is observed at least once.

    Sampling is done in batches to avoid storing a large candidate pool:
      1) Draw a batch of random feature vectors Xbatch.
      2) Compute logits = Xbatch @ w_true.T and sample labels via Categorical.
      3) For each sampled label not yet “seen,” record its feature vector.
      4) Repeat until all orientations have been sampled.

    Args:
        w_true (Tensor[L, D]):
            Ground-truth weight matrix; each of the L rows is a “class” or
            orientation, and D is the feature dimension.
        batch_size (int, optional):
            Number of candidate feature vectors to draw per batch. If None,
            defaults to 1000.

    Returns:
        xsamp (Tensor[L, D]):
            One feature vector per orientation, such that orientation i was
            the argmax of w_true[i]·x for that x.
    """
    # Number of orientations (classes) and feature dimension
    L, D = w_true.shape

    # Decide batch size: default to 1000 if not provided
    N = 1000 if batch_size is None else batch_size

    # Prepare output tensor and tracking of which labels are covered
    xsamp = torch.zeros(L, D, dtype=w_true.dtype, device=w_true.device)
    seen = torch.zeros(L, dtype=torch.bool, device=w_true.device)
    remaining = L  # how many orientations still need a sample

    # Continue until every orientation has at least one sample
    while remaining > 0:
        # 1) Draw N candidate feature vectors from standard normal
        Xbatch = torch.randn(N, D, device=w_true.device)

        # 2) Compute logits for each candidate-label pair:
        #    logits[n, l] = w_true[l] · Xbatch[n]
        logits = Xbatch @ w_true.T  # shape [N, L]

        # 3) Sample a label for each candidate
        labels = Categorical(logits=logits).sample()  # shape [N]

        # 4) For each candidate, if its label hasn't been seen, record it
        for i in range(N):
            lab = labels[i].item()
            if not seen[lab]:
                xsamp[lab] = Xbatch[i]
                seen[lab] = True
                remaining -= 1
                if remaining == 0:
                    break  # Exit early once all labels are covered

    return xsamp


