from typing import Literal

import geoopt
import numpy as np
import torch
from torch import Tensor
from .constraints import ConstrainedSet

__all__ = ["GaugeMap"]


class GaugeMap:

    def __init__(
            self,
            constraint: ConstrainedSet,
            ref_point: Tensor,
            dest: Literal["cube", "ball"] = "cube",
            manifold: geoopt.Manifold | None = None,
    ):
        """
        Gauge mapping between a (compact) convex constrained set and the unit ball.

        :param constraint: Constrained set.
        :param ref_point: Interior point of the constrained set,
            its feasibility is not checked, please make sure it is strictly feasible.
        """
        self.constraint = constraint
        self.x0 = ref_point.view(1, -1)
        self.dest = dest
        self.manifold = manifold
        match dest:
            case "cube":
                self.p = torch.inf
            case "ball":
                self.p = 2

    def log_map(self, xs: Tensor, ys: Tensor) -> Tensor:
        return (self.manifold.logmap(xs, ys) if self.manifold is not None
                else ys - xs)

    def exp_map(self, xs: Tensor, ys: Tensor) -> Tensor:
        return (self.manifold.expmap(xs, ys) if self.manifold is not None
                else xs + ys)

    def boundary_dist(
            self,
            x0: Tensor,
            us: Tensor,
            tol,
            thresh,
            device,
            **kwargs,
    ) -> Tensor:
        ds = (
            self.constraint.eval_manifold_intersection_v(x0, us, self.manifold, tol, thresh, device, **kwargs)
            if self.manifold is not None else
            self.constraint.eval_intersection_v(x0, us, tol, thresh, device)
        )
        return ds.view(-1, 1)

    def to_disk(
            self,
            xs: Tensor,
            tol: float = 1e-6,
            thresh: float = 1e8,
            device=torch.get_default_device(),
            **kwargs,
    ) -> Tensor:
        """
        Maps points on manifold into (scaled) unit ball.

        If the manifold is provided, the unit ball is inside the tangent space of the reference point

        This method implements the inverse (geodesic) gauge map:

        .. math::
            \Phi^{-1}(x) &= \gamma_\mathcal{U}(x, x^\circ) v_x \\
                         &= r_x v_x / d_\mathcal{U}(x^\circ, v_x) \\
                         &= u_x / d_\mathcal{U}(x^\circ, v_x) \\

        Where
            - :math:`\mathcal{U}` is a bounded (geodesic) convex set (over a manifold).
            - :math:`\gamma_\mathcal{U}` is the gauge function.
            - :math:`U_x = r_x v_x = \log_{x^\circ}(x)`
            - :math:`d_\mathcal{U}` is the distance to boundary function

        :param xs: Points on manifold to map, assumed to be already projected onto the manifold.
        :param tol: Tolerance.
        :param thresh: Threshold.
        :param device: Torch device.
        :param kwargs: Arguments passing to: method:`GaugeMap.boundary_dist`, :see:method:`ConstrainedSet.eval_manifold_intersection_v`.
        :return: Tensor of mapped points inside the unit ball in the tangent space.
        """
        x0 = self.x0.expand(xs.shape[0], -1)
        # tangent vectors at x0 to xs
        us = self.log_map(x0, xs)
        # distances from x0 to xs
        rs = torch.linalg.vector_norm(us, dim=-1, keepdim=True, ord=self.p).to(device)
        ii = torch.flatten(rs >= tol)
        # normalize the tangent vectors
        vs = us[ii] / rs[ii]
        zs = torch.zeros_like(xs)
        # point-to-boundary distance: d_\mathcal{U}(x^\circ, v_x)
        ds = self.boundary_dist(x0[ii], vs, tol, thresh, device, **kwargs)
        zs[ii, :] = us / ds

        return zs

    def from_disk(
            self,
            zs: Tensor,
            tol: float = 1e-6,
            thresh: float = 1e8,
            device=torch.get_default_device(),
            **kwargs,
    ) -> Tensor:
        """
        Maps points, which are in the (scaled) unit ball (in the tangent space of the reference point),
        to the (geodesic) convex set (over the manifold).

        If the manifold is provided, the unit ball is inside the tangent space of the reference point,
        and the mapped points are on the manifold.

        This method implements the (geodesic) gauge map:
        .. math::
                \Phi(z) &= \exp_{x^\circ}(d_\mathcal{U}(x^\circ, z/||z||) \cdot z) \\
                        &= \exp_{x^\circ}(||z|| d_\mathcal{U}(x^\circ, z) \cdot z) \\

        where
                - :math:`d_\mathcal{U}` is the distance to boundary function.
                - :math:`\exp_{x^\circ}` is the exponential map.

        :param zs: Points in the unit ball in the tangent space, not checked.
        :param tol: Tolerance.
        :param thresh: Threshold.
        :param device: Torch device.
        :return: Points on the manifold.
        """
        # z = x / ( |x| * phi(x) ) => |z| = 1 / phi(x)

        x0 = self.x0.expand(zs.shape[0], -1).to(device)

        # rs
        rs = torch.linalg.vector_norm(zs, dim=-1, keepdim=True, ord=self.p).to(device)

        ii = torch.flatten(rs >= tol)
        xs = torch.zeros_like(zs, device=device)
        xs[~ii] = x0[~ii]

        ds = self.boundary_dist(x0[ii], zs[ii], tol, thresh, device, **kwargs)
        xs[ii, :] = self.exp_map(x0[ii], ds * rs[ii] * zs[ii])

        return xs
    def mirror_forward_polytope(self, x: torch.Tensor, kappa=0.3, eps=2e-6) -> torch.Tensor:
        A = self.constraint.A_get
        b = self.constraint.b_get

        N, D = x.shape
        K = A.shape[0]
        device = x.device

        Ax = x @ A.T    # (N, K)
        b_expanded = b.view(1, K).expand(N, K).to(device)

        diff = b_expanded - Ax + eps   # (N, K)
        if torch.any(diff <= 0):
            raise ValueError("Some b_i - a_i^T x ≤ 0 — input too close to constraint boundary")

        weights = diff.pow(-kappa)     # (N, K)

        grad = weights @ A             # (N, D)

        return grad + x                # add x for mirror-like effect
    
    def mirror_forward(self, x: torch.Tensor, kappa=0.3, eps=1e-6) -> torch.Tensor:
        R = self.constraint.r  # from BallConstraint
        norm_sq = torch.sum(x ** 2, dim=1, keepdim=True)  # (N, 1)
        inside = torch.clamp(R**2 - norm_sq, min=eps)  # (N, 1)
        
        scale = 1 + 2 * inside.pow(-kappa)  # (N, 1)
        return x * scale  # (N, D)

    def mirror_backward(self, z: torch.Tensor, max_iter=10000, lr=1e-3, tol=1e-2, kappa=0.3, momentum=0.9, device='cpu'):
        R = self.constraint.r
        
        z = z.to(device).float()
        if z.ndim == 1:
            z = z.unsqueeze(0)

        x = torch.zeros_like(z)
        v = torch.zeros_like(z)  # velocity

        for i in range(max_iter):
            if i > 500:
                lr = 1e-4
            if i > 1000:
                lr = 1e-5
            x_lookahead = x + momentum * v               # (B, D)
            norm_sq = torch.sum(x_lookahead ** 2, dim=1, keepdim=True)                  # (B, K)
            diff = torch.clamp(R ** 2 - norm_sq, min=1e-6)     # (B, K)

            grad = z - x_lookahead  * (2 * diff.pow(-kappa) + 1)  # (B, D)

            grad_norm = grad.norm(dim=1).mean().item()
            if grad_norm < tol:
                if i % 100 == 0:
                    print(f"Converged at iteration {i}, grad norm = {grad_norm:.2e}")
                break

            v = momentum * v + lr * grad
            x = x + v

        return x.detach()
    def mirror_backward_polytope(self, z: torch.Tensor, max_iter=10000, lr=1e-3, tol=1e-2, kappa=0.3, momentum=0.9, device='cpu'):
        A = self.constraint.A_get.to(device)             # (K, D)
        b = self.constraint.b_get.view(-1).to(device)    # (K,)
        
        z = z.to(device).float()
        if z.ndim == 1:
            z = z.unsqueeze(0)

        x = torch.zeros_like(z)
        v = torch.zeros_like(z)  # velocity

        for i in range(max_iter):
            if i > 500:
                lr = 1e-4
            if i > 1000:
                lr = 1e-5
            x_lookahead = x + momentum * v               # (B, D)
            dot_ax = x_lookahead @ A.T                   # (B, K)
            diff = torch.clamp(b - dot_ax, min=1e-6)     # (B, K)

            grad = z - x_lookahead - (diff.pow(-kappa) @ A)  # (B, D)

            grad_norm = grad.norm(dim=1).mean().item()
            if grad_norm < tol:
                if i % 100 == 0:
                    print(f"Converged at iteration {i}, grad norm = {grad_norm:.2e}")
                break

            v = momentum * v + lr * grad
            x = x + v

        return x.detach()
"""
def to_dual(A, x, b, c, eps=1e-4, kappa=0.1, random_proj=True):
    assert A.shape[0] == x.shape[1], "Mismatch in dimensions: A should be (D, N), x should be (B, D)"
    (D, N), B = A.shape, x.shape[0]
    device = x.device

    # Ensure b and c are tensors on the correct device
    if not torch.is_tensor(b):
        b = torch.tensor(b, dtype=torch.float32, device=device)
    if not torch.is_tensor(c):
        c = torch.tensor(c, dtype=torch.float32, device=device)

    # Dot product: (B, D) x (D, N) = (B, N)
    dot_ax = x @ A                      # shape: (B, N)
    dot_ax_proj = project(dot_ax, b - eps, c + eps, random_proj=random_proj)
    
    diff = b - dot_ax_proj                  # shape: (B, N)
    diff2 = dot_ax_proj - c 

    # Ensure we're inside the feasible set (interior)
    if torch.any(diff + diff2 <= eps):
        raise ValueError("x too close to boundary or outside constraint set (b - A^T x ≤ eps)")
    weights = diff.pow(-kappa) - diff2.pow(-kappa)         # shape: (B, N)
    grad = weights @ A.T               # shape: (B, D)
    y = x + grad                       # shape: (B, D)
    return y


def to_primal(A, y, b, c, kappa=0.1, max_iter=100000, lr=1e-5, tol=1e-2, momentum=0.9):
    assert A.shape[0] == y.shape[1], "A should be (D, N), y should be (B, D)"
    B, D = y.shape
    N = A.shape[1]

    device = y.device
    x = torch.zeros_like(y)
    v = torch.zeros_like(y)

    for i in range(max_iter):
        x_lookahead = x + momentum * v  # lookahead point
        dot_ax = x_lookahead @ A        # shape: (B, N)
        diff = torch.clamp(b - dot_ax, min=1e-6)
        diff2 = torch.clamp(dot_ax - c, min=1e-6)
        # Gradient of objective: z - ∇ψ(x) = y - x - ∑ a_i * (b_i - a_i^T x)^(-kappa)
        grad = y - x_lookahead - (diff ** (-kappa)) @ A.T + (diff2 ** (-kappa)) @ A.T  # shape: (B, D)

        grad_norm = grad.norm(dim=1).mean().item()
        if grad_norm < tol:
            print(f"[to_primal_mirror] Converged at iter {i}, grad norm = {grad_norm:.2e}")
            break
        if i % 1000 == 0:
            print(f"Step {i} | grad_norm: {grad_norm:.6f}")
        v = momentum * v + lr * grad
        x = x + v

    return x  # shape: (B, D)
"""