import math

import torch
import torch.distributions as td
from scipy.spatial.transform import Rotation

from group_discovery.geometry_2d import (
    angle_to_matrix,
    angle_to_vector,
    matrix_to_angle,
    vector_to_angle,
    wrap_angle,
)
from group_discovery.geometry_3d import ZYZ_angles_to_matrix, quaternion_to_matrix
from group_discovery.utils import blogm

expm = torch.linalg.matrix_exp


def Cn_elements_on_R2(group_order, representation="angle"):
    """
    Returns the elements of the cyclic group C_n in the given representation.

    Args:
        group_order (int): The order n of the cyclic group C_n.
        representation (str): One of "angle", "vector", or "matrix".

    Returns:
        torch.Tensor:
            - shape [n, 1] for "angle"
            - shape [n, 2] for "vector"
            - shape [n, 2, 2] for "matrix"
    """
    if group_order < 1:
        raise ValueError("group_order must be >= 1")

    # Generate angles in [0, 2π), then wrap to [-π, π)
    thetas = torch.arange(group_order) * (2 * torch.pi / group_order)  # [n]
    thetas = wrap_angle(thetas)  # Wrap angles to [-π, π)
    thetas = thetas.unsqueeze(-1)  # [n, 1]

    if representation == "angle":
        return thetas  # [n, 1]

    elif representation == "vector":
        return angle_to_vector(thetas)  # [n, 2]

    elif representation == "matrix":
        return angle_to_matrix(thetas)  # [n, 2, 2]

    else:
        raise ValueError(
            f"Invalid representation: {representation}. Must be one of 'angle',"
            " 'vector', or 'matrix'."
        )


def Dn_elements_on_R2(group_order):
    """
    Returns the elements of the dihedral group D_n.

    Args:
        group_order (int): The order n of the dihedral group D_n.

    Returns:
        torch.Tensor: [n, 2, 2]
    """
    rotations = Cn_elements_on_R2(group_order, representation="matrix")  # [n, 2, 2]

    reflection_x = torch.tensor(
        [[1, 0], [0, -1]], dtype=rotations.dtype
    )  # Reflection matrix
    reflections = []
    for i in range(group_order):
        rot = rotations[i] @ reflection_x
        reflections.append(rot)
    reflections = torch.stack(reflections)  # [2*n, 2, 2]

    # Combine rotations and reflections
    dn_elements = torch.cat([rotations, reflections], dim=0)  # [2n, 2, 2]

    return dn_elements


def finite_group_elements_on_R3(group):
    elements = Rotation.create_group(group).as_matrix()

    return torch.tensor(elements, dtype=torch.float32)


# Ref: https://docs.pytorch.org/rl/main/reference/generated/torchrl.modules.Delta.html
class DiscreteDeltaMixture(td.Distribution):
    """
    Mixture of delta functions at discrete locations.
    Each sample returns one of the predefined locations with given probabilities.
    """

    arg_constraints = {}
    support = torch.distributions.constraints.real
    has_rsample = False

    def __init__(
        self,
        group: str,
        locs: torch.Tensor,
        weights: torch.Tensor = None,
        validate_args=None,
    ):
        """
        Args:
            locs: [K, *event_shape] tensor of possible outcomes
            weights: [K] tensor of mixture weights, should sum to 1. If None, uniform.
        """
        self.group = group

        self.locs = locs  # [K, *event_shape]
        self.K = locs.shape[0]

        if weights is None:
            weights = torch.ones(self.K) / self.K
        if weights.sum() != 1:
            weights = weights / weights.sum()  # normalize to sum to 1
        self.weights = weights
        self.categorical = td.Categorical(probs=self.weights)

        event_shape = locs.shape[1:]
        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def sample(self, sample_shape=torch.Size()):
        indices = self.categorical.sample(sample_shape)  # [sample_shape]
        samples = self.locs[indices]  # [*sample_shape, *event_shape]
        return samples

    def log_prob(self, value):
        """
        value: [..., *event_shape]
        Returns: [...], log probability of each value
        """
        # Reshape for broadcasting: [1, K, *event_shape] vs [..., 1, *event_shape]
        value_exp = value.unsqueeze(
            -self.locs.dim() - 1 + value.dim()
        )  # [..., 1, *event_shape]
        locs_exp = self.locs.unsqueeze(0)  # [1, K, *event_shape]
        matches = (value_exp == locs_exp).all(
            dim=(
                -1
                if self.event_shape == torch.Size()
                else tuple(range(-len(self.event_shape), 0))
            )
        )  # [..., K]

        # Weighted log prob
        log_weights = torch.log(self.weights)  # [K]
        log_probs = torch.where(
            matches, log_weights, torch.tensor(-float("inf"), device=log_weights.device)
        )  # [..., K]
        return torch.logsumexp(log_probs, dim=-1)  # sum over K


class RandomDiscreteDeltaMixture(DiscreteDeltaMixture):
    def __init__(
        self,
        num_modes: int,
        locs: torch.Tensor = None,
        validate_args=None,
    ):

        if locs is None:
            # Sample random locations on SO(3)
            locs = torch.randn(num_modes, 4)
            locs = locs / locs.norm(
                dim=-1, keepdim=True
            )  # Normalize to unit quaternions
            locs = quaternion_to_matrix(locs)  # Convert to rotation matrices [K, 3, 3]

        super().__init__(
            group=f"{num_modes} modes on SO(3)",
            locs=locs,
            weights=None,  # Uniform weights by default
            validate_args=validate_args,
        )


class GaussianMixture(td.Distribution):
    """
    Gaussian mixture of at discrete locations.
    Each sample returns one of the predefined locations with given probabilities.
    """

    arg_constraints = {}
    support = torch.distributions.constraints.real
    has_rsample = False

    def __init__(
        self,
        group: str,
        locs: torch.Tensor,
        scale: float = 0.1,
        weights: torch.Tensor = None,
        validate_args=None,
    ):
        """
        Args:
            locs: [K, *event_shape] tensor of possible outcomes
            scale: [K, *event_shape] tensor of standard deviations for each location
            weights: [K] tensor of mixture weights, should sum to 1. If None, uniform.
        """
        self.group = group

        self.locs = locs  # [K, *event_shape]
        self.scale = scale * torch.ones_like(locs)  # [K, *event_shape]
        self.K = locs.shape[0]

        if weights is None:
            weights = torch.ones(self.K) / self.K
        if weights.sum() != 1:
            weights = weights / weights.sum()  # normalize to sum to 1
        self.weights = weights

        event_shape = locs.shape[-2:]

        mix = td.Categorical(probs=self.weights)
        comp = td.Independent(td.Normal(self.locs, self.scale), len(event_shape))
        self.dist = td.MixtureSameFamily(mix, comp)

        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

    def sample(self, sample_shape=torch.Size()):
        return self.dist.sample(sample_shape)

    def log_prob(self, value):
        return self.dist.log_prob(value)


class RandomGaussianMixture(GaussianMixture):

    def __init__(
        self,
        num_modes: int,
        locs: torch.Tensor = None,
        validate_args=None,
    ):

        if locs is None:
            # Sample random locations on SO(3)
            locs = torch.randn(num_modes, 4)
            locs = locs / locs.norm(
                dim=-1, keepdim=True
            )  # Normalize to unit quaternions
            locs = quaternion_to_matrix(locs)  # Convert to rotation matrices [K, 3, 3]

        super().__init__(
            group=f"{num_modes} gaussian mixture on SO(3)",
            locs=locs,
            scale=0.05,
            weights=None,  # Uniform weights by default
            validate_args=validate_args,
        )


class SO2PushforwardDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real  # Output is real matrices
    has_rsample = False

    def __init__(
        self,
        representation="matrix",
        coeff_dist=None,
        validate_args=False,
    ):
        """
        Pushforward distribution on SO(2) represented as 2x2 matrices via exponential map.

        Args:
            coeff_dist: Distribution over scalar coefficients of so(2) generator. Default: Uniform(-pi, pi)
        """
        self.group = "SO(2)"
        self.representation = representation
        if representation == "matrix":
            event_shape = (2, 2)
        elif representation == "vector":
            event_shape = (2,)
        elif representation == "angle":
            event_shape = (1,)
        else:
            raise NotImplementedError()

        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Define Lie algebra basis of gl(2)
        self.basis = torch.tensor(
            [[0, -1], [1, 0]],
            dtype=torch.float32,
        )  # shape [2, 2]

        if coeff_dist is None:
            low = torch.tensor([-torch.pi])
            high = torch.tensor([torch.pi])
            self.coeff_dist = td.Independent(td.Uniform(low, high), 1)
        else:
            self.coeff_dist = coeff_dist

    def to(self, device):
        self.basis = self.basis.to(device)
        self.coeff_dist.base_dist.low = self.coeff_dist.base_dist.low.to(device)
        self.coeff_dist.base_dist.high = self.coeff_dist.base_dist.high.to(device)

    def sample(self, sample_shape=torch.Size()):
        theta = self.coeff_dist.sample(sample_shape)
        if self.representation == "angle":
            return theta
        elif self.representation == "vector":
            return angle_to_vector(theta)
        elif self.representation == "matrix":
            return angle_to_matrix(theta)
        else:
            raise NotImplementedError()

    def log_prob(self, x):
        """
        Evaluate log-probability under pushforward from angle base distribution.
        """
        if self.representation == "angle":
            theta = x
        elif self.representation == "vector":
            theta = vector_to_angle(x)
        elif self.representation == "matrix":
            theta = matrix_to_angle(x)
        else:
            raise NotImplementedError()

        return self.coeff_dist.log_prob(theta)


class GL2PlusPushforwardDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real
    has_rsample = False

    def __init__(
        self,
        coeff_dist=None,
        det_range=None,  # None means no restriction on determinants
        validate_args=False,
        estimate_acceptance_rate=True,
        n_samples_for_estimation=10_000,
    ):
        """
        Pushforward distribution on GL⁺(2) via exp from a uniform on Lie algebra coefficients.

        Args:
            coeff_dist: Distribution over coefficients of the Lie algebra gl(2). Default: Uniform([-1,-1,-1,-pi], [1,1,1,pi]).
        """
        self.group = "GL(2)"
        self.det_range = det_range

        event_shape = (2, 2)
        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Standard basis for gl(2, R)
        self.basis = torch.tensor(
            [
                [[1, 0], [0, 0]],  # E_11
                [[0, 1], [0, 0]],  # E_12
                [[0, 0], [1, 0]],  # E_21
                [[0, 0], [0, 1]],  # E_22
            ],
            dtype=torch.float32,
        )  # shape [4, 2, 2]

        if coeff_dist is None:
            low = torch.tensor(
                [-torch.pi / 2, -torch.pi / 2, -torch.pi / 2, -torch.pi / 2],
            )
            high = torch.tensor(
                [torch.pi / 2, torch.pi / 2, torch.pi / 2, torch.pi / 2],
            )
            self.coeff_dist = td.Independent(td.Uniform(low, high), 1)  # [4]
        else:
            self.coeff_dist = coeff_dist

        # Estimate acceptance rate if requested
        if self.det_range is not None and estimate_acceptance_rate:
            self._estimate_acceptance_rate(n_samples_for_estimation)
        else:
            self.log_acceptance_rate = torch.tensor(0.0)

    def to(self, device):
        self.basis = self.basis.to(device)
        self.coeff_dist.base_dist.low = self.coeff_dist.base_dist.low.to(device)
        self.coeff_dist.base_dist.high = self.coeff_dist.base_dist.high.to(device)
        self.log_acceptance_rate = self.log_acceptance_rate.to(device)

    def _estimate_acceptance_rate(self, n_samples):
        """Estimate the acceptance rate for rejection sampling."""
        with torch.no_grad():
            # Sample many matrices from the unrestricted distribution
            coeffs = self.coeff_dist.sample((n_samples,))
            A = self._coeffs_to_matrix(coeffs)
            X = expm(A)

            # Check which ones satisfy the determinant constraint
            accepted = self._is_valid_det(X)

            # Estimate acceptance rate
            acceptance_rate = accepted.float().mean()

            # Add small epsilon to avoid log(0)
            eps = 1e-8
            self.log_acceptance_rate = torch.log(acceptance_rate + eps)

            print(f"Estimated acceptance rate: {acceptance_rate.item():.4f}")

    def _coeffs_to_matrix(self, coeffs):
        # coeffs: [..., 4], returns [..., 2, 2]
        return torch.einsum("...i,ijk->...jk", coeffs, self.basis)

    def _is_valid_det(self, X):
        """Check if matrices have determinant in the allowed range."""
        if self.det_range is None:
            return torch.linalg.det(X) > 0

        det = torch.linalg.det(X)
        return (det >= self.det_range[0]) & (det <= self.det_range[1])

    def sample(self, sample_shape=torch.Size(), max_attempts=1000):
        """Sample from GL^+(2)"""
        if self.det_range is None:
            # No restriction on determinants, just sample from the Lie algebra
            coeffs = self.coeff_dist.sample(sample_shape)  # [*sample_shape, 4]
            A = self._coeffs_to_matrix(coeffs)  # [*sample_shape, 2, 2]
            return expm(A)  # [*sample_shape, 2, 2]

        # Use rejection sampling to fit determinant range.
        total_shape = sample_shape + self.event_shape
        batch_size = torch.Size(sample_shape).numel()

        # Initialize output
        device = self.basis.device
        dtype = self.basis.dtype
        result = torch.empty(total_shape, device=device, dtype=dtype)

        # Track which samples still need to be filled
        remaining_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        flat_result = result.view(batch_size, *self.event_shape)

        attempts = 0
        while remaining_mask.any() and attempts < max_attempts:
            # Sample candidates for remaining positions
            n_remaining = remaining_mask.sum().item()
            coeffs = self.coeff_dist.sample((n_remaining,))
            A = self._coeffs_to_matrix(coeffs)
            X_candidates = expm(A)

            # Check which candidates are valid
            valid = self._is_valid_det(X_candidates)

            if valid.any():
                # Fill in valid samples
                remaining_indices = torch.where(remaining_mask)[0]
                valid_indices = torch.where(valid)[0]

                # Take as many valid samples as we can use
                n_to_fill = min(len(valid_indices), len(remaining_indices))
                flat_result[remaining_indices[:n_to_fill]] = X_candidates[
                    valid_indices[:n_to_fill]
                ]

                # Update remaining mask
                remaining_mask[remaining_indices[:n_to_fill]] = False

            attempts += 1

        if remaining_mask.any():
            print(
                "Warning: Could not generate all samples after"
                f" {max_attempts} attempts. {remaining_mask.sum().item()} samples"
                " remaining."
            )
            # Fill remaining with identity matrices (valid since det(I) = 1 ∈ [0.5, 2])
            I = torch.eye(2, device=device, dtype=dtype)
            flat_result[remaining_mask] = I

        return result

    def project_to_manifold(self, X: torch.Tensor, eps=1e-6) -> torch.Tensor:
        """Project matrices to GL^+(2) by ensuring positive determinant."""

        det = torch.linalg.det(X)

        # For negative determinants, flip sign of one column
        negative_mask = det < 0
        if negative_mask.any():
            X = X.clone()
            X[negative_mask, :, 0] *= -1

        # For near-zero determinants, add small identity
        small_det_mask = torch.abs(torch.linalg.det(X)) < eps
        if small_det_mask.any():
            X = X.clone()
            I = torch.eye(2, device=X.device, dtype=X.dtype)
            X[small_det_mask] += eps * I

        return X

    def log_prob(self, X):
        """
        Exact log-probability for X under the pushforward from Lie algebra,
        including the Jacobian determinant of exp, restricted to GL⁺(2) (det > 0).

        Args:
            X: Tensor of shape [..., 2, 2]

        Returns:
            Tensor of shape [...] with log-probabilities, or -inf where det(X) <= 0.
        """

        # Batch shape
        batch_shape = X.shape[:-2]
        device = X.device
        dtype = X.dtype

        # Check positive determinant (GL⁺(2) support)
        det = torch.linalg.det(X)
        if self.det_range is None:
            valid_mask = det > 0
        else:
            valid_mask = (det >= self.det_range[0]) & (det <= self.det_range[1])

        # Initialize result with -inf
        result = torch.full(batch_shape, float("-inf"), device=device, dtype=dtype)
        if not valid_mask.any():
            return result  # all invalid

        # Extract valid X entries
        X_valid = X[valid_mask]  # [N_valid, 2, 2]

        # Compute matrix logarithm for valid entries
        A = blogm(X_valid)  # [N_valid, 2, 2]

        # Project onto basis to get coefficients
        # Normalize by squared Frobenius norms of the basis
        coeffs = torch.einsum("...ij,kij->...k", A, self.basis)  # [N_valid, 4]

        # Clamp the first three coefficients to [-1, 1] for stability
        low = self.coeff_dist.base_dist.low[:3].to(device, dtype)
        high = self.coeff_dist.base_dist.high[:3].to(device, dtype)
        coeffs[:, :3] = coeffs[:, :3].clamp(low, high)
        # Wrap the rotation coefficient to [-pi, pi]
        coeffs[:, 3] = wrap_angle(coeffs[:, 3])  # [N_valid, 4]
        log_p_base = self.coeff_dist.log_prob(coeffs)  # [N_valid]

        # Prepare for Jacobian: compute flattened A from coeffs
        # flattened A = sum_i coeffs[i] * basis[i].flatten
        A_flat = coeffs @ self.basis.view(4, -1)  # [N_valid, 4]
        A_flat = A_flat.detach().requires_grad_(True)

        # Define function from flattened coeffs to flattened exp(A)
        def exp_from_flat(coeffs_flat):
            # coeffs_flat: [N_valid, 4]
            # Reconstruct A flattened then matrix form
            A_batch = coeffs_flat @ self.basis.view(4, -1)  # [N_valid, 4]
            A_batch = A_batch.view(-1, 2, 2)  # [N_valid, 2, 2]
            expA = expm(A_batch)  # [N_valid, 2, 2]
            return expA.reshape(-1, 4)  # [N_valid, 4]

        # Compute Jacobian: shape [N_valid, 4, N_valid, 4]
        # Use vectorize=True for efficiency if available
        jac = torch.autograd.functional.jacobian(
            exp_from_flat, A_flat, create_graph=False, vectorize=True
        )  # [N_valid, 4, N_valid, 4]

        # Extract per-sample Jacobians: [N_valid, 4, 4]
        # Permute to [N_valid, N_valid, 4, 4], then take diagonal along samples
        J = (
            jac.permute(2, 0, 1, 3).diagonal(dim1=0, dim2=1).permute(2, 0, 1)
        )  # [N_valid, 4, 4]

        # Compute log|det(J)| robustly
        eps = 1e-12
        eye = torch.eye(4, device=device, dtype=dtype)
        # Use abs to handle sign, though for exp Jacobian det > 0 often
        _, log_abs_det = torch.linalg.slogdet(J.abs() + eps * eye)  # [N_valid]

        # Assign valid entries
        if self.det_range is None:
            result_flat = log_p_base - log_abs_det  # [N_valid]
        else:
            # For restricted range, subtract log acceptance rate
            result_flat = (
                log_p_base - log_abs_det - self.log_acceptance_rate
            )  # [N_valid]

        result[valid_mask] = result_flat

        return result


class GL2ComplexPushforwardDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real
    has_rsample = False

    def __init__(
        self,
        coeff_dist=None,
        det_range=None,  # None means no restriction on determinants
        dtype=torch.complex64,
        validate_args=False,
        estimate_acceptance_rate=True,
        n_samples_for_estimation=10_000,
    ):
        """
        Pushforward distribution on GL(2,C) via exp from a uniform on Lie algebra coefficients.

        Args:
            coeff_dist: Distribution over coefficients of the Lie algebra gl(2).
        """
        self.group = "GL(2,C)"
        self.det_range = det_range

        event_shape = (2, 2)
        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Standard basis for gl(2,C) - 4 complex matrices
        self.basis = torch.tensor(
            [
                [[1, 0], [0, 0]],  # E_11
                [[0, 1], [0, 0]],  # E_12
                [[0, 0], [1, 0]],  # E_21
                [[0, 0], [0, 1]],  # E_22
            ],
            dtype=dtype,
        )

        if coeff_dist is None:
            # 4 complex coefficients = 8 real parameters
            low = torch.tensor(
                [
                    [-torch.pi / 2, -torch.pi / 2],
                    [-torch.pi / 2, -torch.pi / 2],
                    [-torch.pi / 2, -torch.pi / 2],
                    [-torch.pi / 2, -torch.pi / 2],
                ],
                dtype=torch.float32,
            )
            high = torch.tensor(
                [
                    [torch.pi / 2, torch.pi / 2],
                    [torch.pi / 2, torch.pi / 2],
                    [torch.pi / 2, torch.pi / 2],
                    [torch.pi / 2, torch.pi / 2],
                ],
                dtype=torch.float32,
            )
            self.coeff_dist = td.Independent(td.Uniform(low, high), 2)  # [4, 2]
        else:
            self.coeff_dist = coeff_dist

        # Estimate acceptance rate if requested
        if self.det_range is not None and estimate_acceptance_rate:
            self._estimate_acceptance_rate(n_samples_for_estimation)
        else:
            self.log_acceptance_rate = torch.tensor(0.0, dtype=torch.float32)

    def to(self, device):
        self.basis = self.basis.to(device)
        self.coeff_dist.base_dist.low = self.coeff_dist.base_dist.low.to(device)
        self.coeff_dist.base_dist.high = self.coeff_dist.base_dist.high.to(device)
        self.log_acceptance_rate = self.log_acceptance_rate.to(device)

    def _estimate_acceptance_rate(self, n_samples):
        """Estimate the acceptance rate for rejection sampling."""
        with torch.no_grad():
            # Sample many matrices from the unrestricted distribution
            coeffs = self.coeff_dist.sample((n_samples,))
            A = self._coeffs_to_matrix(coeffs)
            X = expm(A)

            # Check which ones satisfy the determinant constraint
            accepted = self._is_valid_det(X)

            # Estimate acceptance rate
            acceptance_rate = accepted.float().mean()

            # Add small epsilon to avoid log(0)
            eps = 1e-8
            self.log_acceptance_rate = torch.log(acceptance_rate + eps)

            print(f"Estimated acceptance rate: {acceptance_rate.item():.4f}")

    def _coeffs_to_matrix(self, coeffs):
        # coeffs: [..., 4, 2], returns [..., 2, 2] complex
        complex_coeffs = torch.complex(coeffs[..., 0], coeffs[..., 1])  # [..., 4]
        return torch.einsum("...i,ijk->...jk", complex_coeffs, self.basis)

    def _is_valid_det(self, X):
        """Check if matrices have determinant in the allowed range."""
        if self.det_range is None:
            return torch.abs(torch.linalg.det(X)) > 0

        det = torch.abs(torch.linalg.det(X))
        return (det >= self.det_range[0]) & (det <= self.det_range[1])

    def sample(self, sample_shape=torch.Size(), max_attempts=1000):
        """Sample from GL(2, C)"""
        if self.det_range is None:
            # No restriction on determinants, just sample from the Lie algebra
            coeffs = self.coeff_dist.sample(sample_shape)  # [*sample_shape, 4, 2]
            A = self._coeffs_to_matrix(coeffs)  # [*sample_shape, 2, 2]
            return expm(A)  # [*sample_shape, 2, 2]

        # Use rejection sampling to fit determinant range.
        total_shape = sample_shape + self.event_shape
        batch_size = torch.Size(sample_shape).numel()

        # Initialize output
        device = self.basis.device
        dtype = self.basis.dtype
        result = torch.empty(total_shape, device=device, dtype=dtype)

        # Track which samples still need to be filled
        remaining_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        flat_result = result.view(batch_size, *self.event_shape)

        attempts = 0
        while remaining_mask.any() and attempts < max_attempts:
            # Sample candidates for remaining positions
            n_remaining = remaining_mask.sum().item()
            coeffs = self.coeff_dist.sample((n_remaining,))
            A = self._coeffs_to_matrix(coeffs)
            X_candidates = expm(A)

            # Check which candidates are valid
            valid = self._is_valid_det(X_candidates)

            if valid.any():
                # Fill in valid samples
                remaining_indices = torch.where(remaining_mask)[0]
                valid_indices = torch.where(valid)[0]

                # Take as many valid samples as we can use
                n_to_fill = min(len(valid_indices), len(remaining_indices))
                flat_result[remaining_indices[:n_to_fill]] = X_candidates[
                    valid_indices[:n_to_fill]
                ]

                # Update remaining mask
                remaining_mask[remaining_indices[:n_to_fill]] = False

            attempts += 1

        if remaining_mask.any():
            print(
                "Warning: Could not generate all samples after"
                f" {max_attempts} attempts. {remaining_mask.sum().item()} samples"
                " remaining."
            )
            # Fill remaining with identity matrices
            I = torch.eye(2, device=device, dtype=dtype)
            flat_result[remaining_mask] = I

        return result


class SO3UniformDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real  # Output is real matrices
    has_rsample = False

    def __init__(
        self,
        validate_args=False,
    ):
        """
        Use Gaussian normalization over quaternions to create uniform SO(3) distribution
        """

        self.group = "SO(3)"
        event_shape = (3, 3)

        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Standard normal in R^4 for quaternions before normalization
        self.coeff_dist = td.Independent(td.Normal(torch.zeros(4), torch.ones(4)), 1)

    def to(self, device):
        self.coeff_dist = td.Independent(
            td.Normal(torch.zeros(4, device=device), torch.ones(4, device=device)), 1
        )

    def sample(self, sample_shape=torch.Size()):
        shape = sample_shape + self.batch_shape
        z = self.coeff_dist.sample(shape)  # Unnormalized quaternions [..., 4]
        q = z / z.norm(dim=-1, keepdim=True)  # Normalize to unit quaternions [..., 4]
        return quaternion_to_matrix(q)

    def rsample(self, sample_shape=torch.Size()):
        shape = sample_shape + self.batch_shape
        z = self.coeff_dist.rsample(shape)  # Unnormalized quaternions [..., 4]
        q = z / z.norm(dim=-1, keepdim=True)  # Normalize to unit quaternions [..., 4]
        return quaternion_to_matrix(q)

    def log_prob(self, value):
        """
        Compute log_prob on SO(3).

        Checks if `value` is a valid rotation matrix:
        - R^T R == I (within tol)
        - det(R) == 1 (within tol)
        If not valid, returns -inf.

        Otherwise, returns constant log probability for uniform distribution on SO(3).
        """

        # Check shape
        if value.shape[-2:] != (3, 3):
            raise ValueError(f"Expected (..., 3, 3) shape, got {value.shape}")

        # Orthogonality check: R^T R ≈ I
        RtR = value.transpose(-2, -1) @ value
        I = torch.eye(3, device=value.device, dtype=value.dtype).expand_as(RtR)
        orthogonality_error = (
            (RtR - I).abs().max(dim=-1)[0].max(dim=-1)[0]
        )  # max abs deviation

        # Determinant check
        det = torch.linalg.det(value)
        det_error = (det - 1).abs()

        tol = 1e-5  # tolerance

        valid_mask = (orthogonality_error < tol) & (det_error < tol)

        # Broadcast mask to batch dims
        valid_mask = valid_mask.reshape(value.shape[:-2])

        log_uniform_density = -math.log(8 * math.pi**2)  # ≈ -6.847

        # Fill -inf where invalid
        out = torch.full(
            value.shape[:-2], float("-inf"), device=value.device, dtype=value.dtype
        )
        out[valid_mask] = log_uniform_density

        return out


class SO3ZYZDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real  # Output is real matrices
    has_rsample = False

    def __init__(
        self,
        alpha_range=(0, 2 * torch.pi),
        beta_range=(0, torch.pi),
        gamma_range=(0, 2 * torch.pi),
        validate_args=False,
    ):
        """
        SO(3) Distribution with constraints on ZYZ Euler angles
        """

        self.group = "SO(3)"
        event_shape = (3, 3)

        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Uniform distributions for each Euler angle
        self.alpha_dist = td.Uniform(*alpha_range)
        self.beta_dist = td.Uniform(*beta_range)
        self.gamma_dist = td.Uniform(*gamma_range)

    def to(self, device):
        self.alpha_dist.low = self.alpha_dist.low.to(device)
        self.alpha_dist.high = self.alpha_dist.high.to(device)
        self.beta_dist.low = self.beta_dist.low.to(device)
        self.beta_dist.high = self.beta_dist.high.to(device)
        self.gamma_dist.low = self.gamma_dist.low.to(device)
        self.gamma_dist.high = self.gamma_dist.high.to(device)

    def _sample_angles(self, sample_shape=torch.Size()):
        shape = sample_shape
        alpha = self.alpha_dist.sample(shape)
        beta = self.beta_dist.sample(shape)
        gamma = self.gamma_dist.sample(shape)
        return alpha, beta, gamma

    def sample(self, sample_shape=torch.Size()):
        alpha, beta, gamma = self._sample_angles(sample_shape)
        return ZYZ_angles_to_matrix(alpha, beta, gamma)

    def log_prob(self, value):
        """
        Returns constant log-density over the Euler angle box,
        or -inf for invalid rotation matrices.
        """
        if value.shape[-2:] != (3, 3):
            raise ValueError(f"Expected (..., 3, 3) shape, got {value.shape}")

        RtR = value.transpose(-2, -1) @ value
        I = torch.eye(3, device=value.device, dtype=value.dtype).expand_as(RtR)
        orthogonality_error = (RtR - I).abs().max(dim=-1)[0].max(dim=-1)[0]
        det = torch.linalg.det(value)
        det_error = (det - 1).abs()

        tol = 1e-5
        valid_mask = (orthogonality_error < tol) & (det_error < tol)

        volume = (
            (self.alpha_range[1] - self.alpha_range[0])
            * (self.beta_range[1] - self.beta_range[0])
            * (self.gamma_range[1] - self.gamma_range[0])
        )
        log_density = -torch.log(
            torch.tensor(volume, device=value.device, dtype=value.dtype)
        )

        out = torch.full(
            value.shape[:-2], float("-inf"), device=value.device, dtype=value.dtype
        )
        out[valid_mask] = log_density
        return out


class GL3ComplexPushforwardDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real
    has_rsample = False

    def __init__(
        self,
        coeff_dist=None,
        det_range=None,  # None means no restriction on determinants
        dtype=torch.complex64,
        validate_args=False,
        estimate_acceptance_rate=True,
        n_samples_for_estimation=10_000,
    ):
        """
        Pushforward distribution on GL(3,C) via exp from a uniform on Lie algebra coefficients.

        Args:
            coeff_dist: Distribution over coefficients of the Lie algebra gl(3).
        """
        self.group = "GL(3,C)"
        self.det_range = det_range

        event_shape = (3, 3)
        super().__init__(
            batch_shape=torch.Size(),
            event_shape=event_shape,
            validate_args=validate_args,
        )

        # Standard basis for gl(2,C) - 4 complex matrices
        self.basis = torch.tensor(
            [
                [[1, 0, 0], [0, 0, 0], [0, 0, 0]],  # E_11
                [[0, 1, 0], [0, 0, 0], [0, 0, 0]],  # E_12
                [[0, 0, 1], [0, 0, 0], [0, 0, 0]],  # E_13
                [[0, 0, 0], [1, 0, 0], [0, 0, 0]],  # E_21
                [[0, 0, 0], [0, 1, 0], [0, 0, 0]],  # E_22
                [[0, 0, 0], [0, 0, 1], [0, 0, 0]],  # E_23
                [[0, 0, 0], [0, 0, 0], [1, 0, 0]],  # E_31
                [[0, 0, 0], [0, 0, 0], [0, 1, 0]],  # E_32
                [[0, 0, 0], [0, 0, 0], [0, 0, 1]],  # E_33
            ],
            dtype=dtype,
        )

        if coeff_dist is None:
            # 9 complex coefficients = 18 real parameters
            low = torch.tensor(
                [[-torch.pi / 2, -torch.pi / 2] * 9],
                dtype=torch.float32,
            )
            high = torch.tensor(
                [[torch.pi / 2, torch.pi / 2] * 9],
                dtype=torch.float32,
            )
            self.coeff_dist = td.Independent(td.Uniform(low, high), 2)  # [9, 2]
        else:
            self.coeff_dist = coeff_dist

        # Estimate acceptance rate if requested
        if self.det_range is not None and estimate_acceptance_rate:
            self._estimate_acceptance_rate(n_samples_for_estimation)
        else:
            self.log_acceptance_rate = torch.tensor(0.0, dtype=torch.float32)

    def to(self, device):
        self.basis = self.basis.to(device)
        self.coeff_dist.base_dist.low = self.coeff_dist.base_dist.low.to(device)
        self.coeff_dist.base_dist.high = self.coeff_dist.base_dist.high.to(device)
        self.log_acceptance_rate = self.log_acceptance_rate.to(device)

    def _estimate_acceptance_rate(self, n_samples):
        """Estimate the acceptance rate for rejection sampling."""
        with torch.no_grad():
            # Sample many matrices from the unrestricted distribution
            coeffs = self.coeff_dist.sample((n_samples,))
            A = self._coeffs_to_matrix(coeffs)
            X = expm(A)

            # Check which ones satisfy the determinant constraint
            accepted = self._is_valid_det(X)

            # Estimate acceptance rate
            acceptance_rate = accepted.float().mean()

            # Add small epsilon to avoid log(0)
            eps = 1e-8
            self.log_acceptance_rate = torch.log(acceptance_rate + eps)

            print(f"Estimated acceptance rate: {acceptance_rate.item():.4f}")

    def _coeffs_to_matrix(self, coeffs):
        # coeffs: [..., 18, 2], returns [..., 9, 2] complex
        complex_coeffs = torch.complex(coeffs[..., 0], coeffs[..., 1])  # [..., 9]

        return torch.einsum("...i,ijk->...jk", complex_coeffs, self.basis)

    def _is_valid_det(self, X):
        """Check if matrices have determinant in the allowed range."""
        if self.det_range is None:
            return torch.abs(torch.linalg.det(X)) > 0

        det = torch.abs(torch.linalg.det(X))
        return (det >= self.det_range[0]) & (det <= self.det_range[1])

    def sample(self, sample_shape=torch.Size(), max_attempts=1000):
        """Sample from GL(3, C)"""
        if self.det_range is None:
            # No restriction on determinants, just sample from the Lie algebra
            coeffs = self.coeff_dist.sample(sample_shape)  # [*sample_shape, 4, 2]
            A = self._coeffs_to_matrix(coeffs)  # [*sample_shape, 2, 2]
            return expm(A)  # [*sample_shape, 2, 2]

        # Use rejection sampling to fit determinant range.
        total_shape = sample_shape + self.event_shape
        batch_size = torch.Size(sample_shape).numel()

        # Initialize output
        device = self.basis.device
        dtype = self.basis.dtype
        result = torch.empty(total_shape, device=device, dtype=dtype)

        # Track which samples still need to be filled
        remaining_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        flat_result = result.view(batch_size, *self.event_shape)

        attempts = 0
        while remaining_mask.any() and attempts < max_attempts:
            # Sample candidates for remaining positions
            n_remaining = remaining_mask.sum().item()
            coeffs = self.coeff_dist.sample((n_remaining,))
            A = self._coeffs_to_matrix(coeffs)
            X_candidates = expm(A)

            # Check which candidates are valid
            valid = self._is_valid_det(X_candidates)

            if valid.any():
                # Fill in valid samples
                remaining_indices = torch.where(remaining_mask)[0]
                valid_indices = torch.where(valid)[0]

                # Take as many valid samples as we can use
                n_to_fill = min(len(valid_indices), len(remaining_indices))
                flat_result[remaining_indices[:n_to_fill]] = X_candidates[
                    valid_indices[:n_to_fill]
                ]

                # Update remaining mask
                remaining_mask[remaining_indices[:n_to_fill]] = False

            attempts += 1

        if remaining_mask.any():
            print(
                "Warning: Could not generate all samples after"
                f" {max_attempts} attempts. {remaining_mask.sum().item()} samples"
                " remaining."
            )
            # Fill remaining with identity matrices
            I = torch.eye(3, device=device, dtype=dtype)
            flat_result[remaining_mask] = I

        return result


class ObjectTransformDistribution(td.Distribution):
    arg_constraints = {}
    support = td.constraints.real
    has_rsample = False

    def __init__(
        self,
        base_dist,
        base_object,
        validate_args=False,
    ):
        """
        Args:
            base_dist: a distribution represented as [K, D, D] transformation matrices
            : the object to transform, shape [..., D]
        """
        self.base_dist = base_dist
        self.base_object = torch.tensor(base_object.data)

        super().__init__(
            batch_shape=torch.Size(),
            event_shape=self.base_object.shape,
            validate_args=validate_args,
        )

    def to(self, device):
        self.base_object = self.base_object.to(device)
        self.base_dist.to(device)

    def sample(self, sample_shape=torch.Size(), return_transform=False):
        """
        Sample R ~ base_dist, return R @ x
        """

        R = self.base_dist.sample(sample_shape)  # [*sample_shape, D, D]
        x = self.base_object  # [N, D] or [...,D]
        x = self.base_object.expand(sample_shape + x.shape)  # [*sample_shape, N, D]
        if R.is_complex():
            x = torch.complex(x, torch.zeros_like(x))

        x = x.type_as(R)

        y = x @ R.adjoint()  # [*sample_shape, N, D]

        if return_transform:
            return y, R
        else:
            return y

    def log_prob(self, value: torch.Tensor):
        """
        Args:
            value: [B, N, D], transformed objects

        Returns:
            log_prob for each batch element, shape [B]
        """
        B = value.shape[0]
        # Flatten spatial dims of value[b] to [B, N, D]
        value_flat = value.reshape(B, -1, self.event_shape[-1])  # [B, N, D]

        # Flatten spatial dims of base_object to [N, D]
        x_flat = self.base_object.reshape(-1, self.event_shape[-1])
        x_flat = x_flat.unsqueeze(0).expand(B, -1, -1)  # [B, N, D]
        x_flat = x_flat.to(value_flat.device)

        # Solve x @ R^T = value
        R_hat = torch.linalg.lstsq(x_flat, value_flat).solution
        R_hat = R_hat.transpose(-2, -1)  # [B, D, D]

        # Compute log_prob batchwise
        return self.base_dist.log_prob(R_hat)  # [B]
