import torch
from torch import Tensor
import math
from typing import Dict, List, Optional,Union
import numpy as np

def constrained_to_unconstrained(
    theta: Tensor,
    a: Union[float, Tensor],
    b: Union[float, Tensor]
) -> Tensor:
    """
    Invert a sigmoid‐rescale mapping from ℝ to [a, b].

    Suppose a latent variable z ∈ ℝ is mapped to θ ∈ [a, b] via:
        θ = a + (b - a) * sigmoid(z).
    This function recovers z from θ using the logit transform.

    Parameters
    ----------
    theta : Tensor
        Value(s) in the open interval (a, b). Values at the endpoints
        are clamped slightly inward to avoid infinities.
    a : float or Tensor
        Lower bound of the interval.
    b : float or Tensor
        Upper bound of the interval.

    Returns
    -------
    z : Tensor
        Unconstrained latent value(s) in ℝ such that
            theta ≈ a + (b - a) * sigmoid(z).
    """
    # 1) Normalize theta into the unit interval: y ∈ [0, 1]
    y = (theta - a) / (b - a)

    # 2) Clamp to (ε, 1−ε) to avoid logit(0) or logit(1) = ±∞
    eps = 1e-6
    y = torch.clamp(y, min=eps, max=1 - eps)

    # 3) Apply logit: z = log(y / (1 − y))
    z = torch.log(y / (1 - y))
    return z


def get_grid_uniform(
    T: int,
    label_params: Dict[str, List[float]]
) -> torch.Tensor:
    """
    Generate a uniform grid of points over the 1D or 2D label range.

    This grid can be used for Riemann-sum approximations or for
    uniformly evaluating functions (e.g., GP posterior means).

    Args:
        T : int
            Desired number of grid points.
        label_params : dict
            Configuration dictionary with keys:
              - 'no_of_outputs': 1 or 2
              - 'dimension_1_range': [min, max] for first axis
              - 'dimension_2_range': [min, max] for second axis (if 2D)

    Returns:
        labels_uniform_grid : torch.Tensor
            Grid points in shape:
              - [T, 1] for 1D
              - [<=T, 2] for 2D (truncated to at most T points)
    """
    # --- 0) Build axis_range from label_params ---
    if label_params['no_of_outputs'] == 1:
        axis_range = {
            'y1': label_params['dimension_1_range']
        }
    else:
        axis_range = {
            'y1': label_params['dimension_1_range'],
            'y2': label_params['dimension_2_range']
        }

    # --- 1) Determine dimensionality from axis_range keys ---
    axes = list(axis_range.keys())  # ['y1'] or ['y1','y2']
    dims = len(axes)               # 1 or 2

    # --- 2) Construct uniform grid based on dims ---
    if dims == 1:
        # 1D: evenly spaced points along the single axis
        xmin, xmax = axis_range['y1']
        # torch.linspace includes both endpoints
        labels_uniform_grid = torch.linspace(xmin, xmax, steps=T).unsqueeze(1)
        # shape: [K, 1]

    elif dims == 2:
        # 2D: choose a spacing so that area is roughly K points
        xmin, xmax = axis_range['y1']
        ymin, ymax = axis_range['y2']
        span_x = xmax - xmin
        span_y = ymax - ymin
        # spacing s chosen so that (#steps_x * #steps_y) ≈ K
        s = math.sqrt((span_x * span_y) / T)

        # build coordinate vectors for each axis
        xv = torch.arange(xmin, xmax - s/2, step=s)
        yv = torch.arange(ymin, ymax - s/2, step=s)

        # full meshgrid, then flatten
        X, Y = torch.meshgrid(xv, yv, indexing='xy')
        grid = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=1)  # [n_pts, 2]

        # # if we overshoot K points, truncate down to K
        # if grid.size(0) > T:
        #     grid = grid[:T]

        labels_uniform_grid = grid  # shape [<=K, 2]

    else:
        # Shouldn't happen unless label_params is malformed
        raise ValueError(f"Unsupported number of dimensions: {dims} (expected 1 or 2)")

    return labels_uniform_grid

def get_circular_error(theta_true: torch.Tensor, theta_est: torch.Tensor) -> torch.Tensor:
    """
    Compute absolute angular error between true and estimated circular labels.

    Wraps differences to [-pi, pi] then takes absolute value.
    """
    diff = theta_true - theta_est
    # Wrap the differences to the interval [-pi, pi]:
    diff_wrapped = torch.remainder(diff + np.pi, 2 * np.pi) - np.pi
    # Compute the absolute differences (the angular distances)
    abs_diff = torch.abs(diff_wrapped)
    return abs_diff

def get_circular_error_np(theta_true, theta_est):
    """
    Compute the absolute angular error between true and estimated angles.
    Handles wrap-around at the 2π boundary by mapping differences into [-π, π].

    Parameters
    ----------
    theta_true : array_like
        Ground-truth angles, in radians.
    theta_est : array_like
        Estimated angles, in radians (same shape as theta_true).

    Returns
    -------
    abs_diff : ndarray
        Absolute angular distance between true and estimated angles,
        in radians, with values in [0, π].
    """
    # 1) Compute raw difference between true and estimated angles
    diff = theta_true - theta_est

    # 2) Wrap differences into the principal interval [-π, π]:
    #    - Add π to shift range to [0, 2π)
    #    - Apply modulo 2π to wrap around
    #    - Subtract π to shift back to [-π, π]
    diff_wrapped = (diff + np.pi) % (2 * np.pi) - np.pi

    # 3) Take absolute value to get the unsigned angular distance
    abs_diff = np.abs(diff_wrapped)

    return abs_diff
