import math
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, List, Tuple


def safe_abs(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return torch.sqrt(x * x + eps)


def safe_pow(r: torch.Tensor, mu: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    return torch.exp(mu * torch.log(r.clamp_min(eps)))


def compute_bc_adaptive_bounds(omega: float, bc_type: str) -> Tuple[float, float]:
    bc_type = bc_type.upper()

    if bc_type in ["DD", "NN"]:
        mu_1 = math.pi / omega
        mu_2 = 2 * math.pi / omega
    else:
        mu_1 = 0.5 * math.pi / omega
        mu_2 = 1.5 * math.pi / omega

    mu_min = max(0.05, mu_1 * 0.3)

    mu_max = min(mu_1 * 2.5, (mu_1 + mu_2) / 2 * 1.2)

    mu_max = max(mu_max, mu_min + 0.5)

    return float(mu_min), float(mu_max)


def compute_init_exponents(
    K: int, omega: float, bc_type: str, include_fundamental: bool = True
) -> List[float]:
    bc_type = bc_type.upper()

    if bc_type in ["DD", "NN"]:
        mu_1 = math.pi / omega
    else:
        mu_1 = 0.5 * math.pi / omega

    mu_min, mu_max = compute_bc_adaptive_bounds(omega, bc_type)

    if include_fundamental:
        init_mus = [mu_1 * 0.98]

        remaining = K - 1
        if remaining > 0:
            below = [
                mu_min + (mu_1 - mu_min) * i / (remaining // 2 + 1)
                for i in range(1, remaining // 2 + 1)
            ]
            above = [
                mu_1 + (mu_max - mu_1) * i / (remaining - remaining // 2 + 1)
                for i in range(1, remaining - remaining // 2 + 1)
            ]
            init_mus.extend(below)
            init_mus.extend(above)
    else:
        init_mus = list(np.linspace(mu_min * 1.1, mu_max * 0.9, K))

    init_mus = [max(mu_min + 0.01, min(mu_max - 0.01, m)) for m in init_mus]

    init_mus = sorted(init_mus)[:K]

    return init_mus


class OrderedBoundedExponents(nn.Module):

    def __init__(
        self,
        K: int,
        mu_min: float = 0.1,
        mu_max: float = 3.0,
        init: Optional[List[float]] = None,
        softness: float = 4.0,
    ):
        super().__init__()
        self.K = int(K)
        self.mu_min = float(mu_min)
        self.mu_max = float(mu_max)
        self.softness = softness

        assert self.mu_max > self.mu_min, "mu_max must be greater than mu_min"

        self.raw = nn.Parameter(torch.zeros(self.K))

        if init is not None:
            init = np.array(init, dtype=np.float32)
            if len(init) < self.K:
                extra = np.linspace(
                    self.mu_min + 0.1, self.mu_max - 0.1, self.K - len(init)
                )
                init = np.concatenate([init, extra])
            init = init[: self.K]
            init_clipped = np.clip(init, self.mu_min + 0.01, self.mu_max - 0.01)
            init_norm = (init_clipped - self.mu_min) / (self.mu_max - self.mu_min)
            init_norm = np.clip(init_norm, 0.01, 0.99)
            raw_init = np.log(init_norm / (1 - init_norm)) / self.softness
            self.raw.data = torch.tensor(raw_init, dtype=torch.float32)
        else:
            default_mus = np.linspace(0.2, 0.8, self.K)
            raw_init = np.log(default_mus / (1 - default_mus)) / self.softness
            self.raw.data = torch.tensor(raw_init, dtype=torch.float32)

    def forward(self) -> torch.Tensor:
        normalized = torch.sigmoid(self.softness * self.raw)
        mu_unsorted = self.mu_min + (self.mu_max - self.mu_min) * normalized
        mu_sorted, _ = torch.sort(mu_unsorted)
        return mu_sorted


class MSNWedge2D(nn.Module):

    def __init__(
        self,
        K: int = 6,
        omega: float = 3 * math.pi / 2,
        bc_type: str = "DD",
        mu_min: Optional[float] = None,
        mu_max: Optional[float] = None,
        init_mus: Optional[List[float]] = None,
        use_adaptive_bounds: bool = True,
    ):
        super().__init__()
        self.K = int(K)
        self.omega = float(omega)
        self.bc_type = bc_type.upper()

        assert self.bc_type in [
            "DD",
            "NN",
            "DN",
            "ND",
        ], f"bc_type must be one of DD, NN, DN, ND, got {bc_type}"

        self.use_sin = self.bc_type in ["DD", "DN"]

        self.sin_constraint = self.bc_type in ["DD", "NN"]

        if use_adaptive_bounds and (mu_min is None or mu_max is None):
            adaptive_min, adaptive_max = compute_bc_adaptive_bounds(omega, bc_type)
            mu_min = mu_min if mu_min is not None else adaptive_min
            mu_max = mu_max if mu_max is not None else adaptive_max
        else:
            mu_min = mu_min if mu_min is not None else 0.1
            mu_max = mu_max if mu_max is not None else 4.0

        self.mu_min = mu_min
        self.mu_max = mu_max

        if init_mus is None:
            init_mus = compute_init_exponents(
                K, omega, bc_type, include_fundamental=True
            )

        self.exps = OrderedBoundedExponents(
            K=self.K, mu_min=mu_min, mu_max=mu_max, init=init_mus
        )

        self.coeffs = nn.Parameter(torch.zeros(self.K))

    def cart_to_polar(self, xy: torch.Tensor, eps: float = 1e-12):
        x, y = xy[:, 0:1], xy[:, 1:2]
        r = torch.sqrt(x * x + y * y + eps)
        theta = torch.atan2(y, x)
        theta = torch.where(theta < 0, theta + 2 * math.pi, theta)
        return r, theta

    def forward(self, xy: torch.Tensor) -> torch.Tensor:
        r, theta = self.cart_to_polar(xy)
        mu = self.exps()

        terms = []
        for k in range(self.K):
            muk = mu[k]
            radial = safe_pow(r, muk)
            if self.use_sin:
                angular = torch.sin(muk * theta)
            else:
                angular = torch.cos(muk * theta)
            term = (radial * angular).squeeze(-1)
            terms.append(term)

        basis = torch.stack(terms, dim=-1)
        return basis @ self.coeffs

    def compute_angular_derivative(self, xy: torch.Tensor) -> torch.Tensor:
        r, theta = self.cart_to_polar(xy)
        mu = self.exps()

        terms = []
        for k in range(self.K):
            muk = mu[k]
            radial = safe_pow(r, muk - 1)
            if self.use_sin:
                angular_deriv = muk * torch.cos(muk * theta)
            else:
                angular_deriv = -muk * torch.sin(muk * theta)
            term = (radial * angular_deriv).squeeze(-1)
            terms.append(term)

        basis = torch.stack(terms, dim=-1)
        return basis @ self.coeffs

    def constraint_loss(self) -> torch.Tensor:
        mu = self.exps()
        coeffs_abs = torch.abs(self.coeffs)
        weights = coeffs_abs / (coeffs_abs.sum() + 1e-8)

        if self.sin_constraint:
            penalty = torch.sin(mu * self.omega) ** 2
        else:
            penalty = torch.cos(mu * self.omega) ** 2

        return torch.sum(weights * penalty)

    def small_mu_preference_loss(self, strength: float = 0.01) -> torch.Tensor:
        mu = self.exps()
        coeffs_abs = torch.abs(self.coeffs)
        weights = coeffs_abs / (coeffs_abs.sum() + 1e-8)

        return strength * torch.sum(weights * torch.log(mu + 0.1))

    def get_dominant_exponent(self) -> float:
        mu = self.exps().detach().cpu().numpy()
        coeffs = self.coeffs.detach().cpu().numpy()
        dominant_idx = np.argmax(np.abs(coeffs))
        return float(mu[dominant_idx])

    def get_true_exponent(self, mode: int = 1) -> float:
        if self.sin_constraint:
            return mode * math.pi / self.omega
        else:
            return (mode - 0.5) * math.pi / self.omega

    @torch.no_grad()
    def get_exponents(self):
        return self.exps().cpu().numpy()

    @torch.no_grad()
    def get_coeffs(self):
        return self.coeffs.cpu().numpy()


def build_optimizers(model: MSNWedge2D, lr_w: float = 1e-2, lr_mu: float = 1e-4):
    mu_params = list(model.exps.parameters())
    mu_param_ids = {id(p) for p in mu_params}
    other_params = [p for p in model.parameters() if id(p) not in mu_param_ids]

    opt_w = torch.optim.Adam(other_params, lr=lr_w)
    opt_mu = torch.optim.Adam(mu_params, lr=lr_mu)

    return opt_w, opt_mu


if __name__ == "__main__":
    print("Testing MSNWedge2D with adaptive bounds...")

    for bc_type in ["DD", "NN", "DN", "ND"]:
        omega = 3 * np.pi / 2
        model = MSNWedge2D(K=6, omega=omega, bc_type=bc_type)

        print(f"\n{bc_type} (omega={np.degrees(omega):.1f}deg):")
        print(f"  True exponent (m=1): {model.get_true_exponent(1):.4f}")
        print(f"  Mu bounds: [{model.mu_min:.4f}, {model.mu_max:.4f}]")
        print(f"  Initial exponents: {model.get_exponents()}")
        print(f"  Initial coeffs: {model.get_coeffs()}")

        xy = torch.rand(10, 2)
        u = model(xy)
        print(f"  Forward pass output shape: {u.shape}")

        loss_con = model.constraint_loss()
        print(f"  Constraint loss: {loss_con.item():.4f}")

        du_dtheta = model.compute_angular_derivative(xy)
        print(f"  Angular derivative shape: {du_dtheta.shape}")
