import math
import torch
import torch.nn as nn
import numpy as np
from typing import List, Optional, Tuple


class OrderedBoundedExponents(nn.Module):

    def __init__(
        self,
        K: int,
        mu_min: float = 0.1,
        mu_max: float = 3.0,
        init_mus: Optional[List[float]] = None,
        softness: float = 4.0,
    ):
        super().__init__()
        self.K = int(K)
        self.mu_min = float(mu_min)
        self.mu_max = float(mu_max)
        self.softness = softness
        assert self.mu_max > self.mu_min

        self.raw = nn.Parameter(torch.zeros(self.K))

        if init_mus is not None:
            init_mus = np.array(init_mus, dtype=np.float32)
            assert len(init_mus) == self.K
            init_clipped = np.clip(init_mus, self.mu_min + 0.01, self.mu_max - 0.01)
            init_norm = (init_clipped - self.mu_min) / (self.mu_max - self.mu_min)
            init_norm = np.clip(init_norm, 0.01, 0.99)
            raw_init = np.log(init_norm / (1 - init_norm)) / self.softness
            self.raw.data = torch.tensor(raw_init, dtype=torch.float32)
        else:
            default_mus = np.linspace(0.2, 0.8, self.K)
            raw_init = np.log(default_mus / (1 - default_mus)) / self.softness
            self.raw.data = torch.tensor(raw_init, dtype=torch.float32)

    def forward(self) -> torch.Tensor:
        normalized = torch.sigmoid(self.softness * self.raw)
        mu_unsorted = self.mu_min + (self.mu_max - self.mu_min) * normalized
        mu_sorted, _ = torch.sort(mu_unsorted)
        return mu_sorted


def safe_pow(r: torch.Tensor, mu: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    return torch.exp(mu * torch.log(r.clamp_min(eps)))


class MSNWedge2D(nn.Module):

    def __init__(
        self,
        K: int = 6,
        mu_min: float = 0.1,
        mu_max: float = 3.0,
        bc_type: str = "DD",
        init_mus: Optional[List[float]] = None,
    ):
        super().__init__()
        self.K = K
        self.bc_type = bc_type.upper()
        assert self.bc_type in ["DD", "NN", "DN", "ND"]

        self.exps = OrderedBoundedExponents(
            K=K, mu_min=mu_min, mu_max=mu_max, init_mus=init_mus
        )
        self.coeffs = nn.Parameter(torch.zeros(K))

        self.use_sin_basis = self.bc_type in ["DD", "DN"]
        self.use_sin_constraint = self.bc_type in ["DD", "NN"]
        self.second_edge_dirichlet = self.bc_type in ["DD", "ND"]

    def forward(self, r: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
        mu = self.exps()

        r_powers = torch.stack([safe_pow(r, mu_k) for mu_k in mu], dim=-1)

        if self.use_sin_basis:
            angular = torch.stack([torch.sin(mu_k * theta) for mu_k in mu], dim=-1)
        else:
            angular = torch.stack([torch.cos(mu_k * theta) for mu_k in mu], dim=-1)

        basis = r_powers * angular
        u = basis @ self.coeffs
        return u

    def forward_xy(self, xy: torch.Tensor) -> torch.Tensor:
        r, theta = self.cart_to_polar(xy)
        return self.forward(r.squeeze(-1), theta.squeeze(-1))

    @staticmethod
    def cart_to_polar(
        xy: torch.Tensor, eps: float = 1e-12
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, y = xy[:, 0:1], xy[:, 1:2]
        r = torch.sqrt(x * x + y * y + eps)
        theta = torch.atan2(y, x)
        return r, theta

    def constraint_loss(self, omega: float) -> torch.Tensor:
        mu = self.exps()
        coeffs_abs = torch.abs(self.coeffs)
        weights = coeffs_abs / (coeffs_abs.sum() + 1e-8)

        if self.use_sin_constraint:
            penalty = torch.sin(mu * omega) ** 2
        else:
            penalty = torch.cos(mu * omega) ** 2

        return torch.sum(weights * penalty)

    def small_mu_preference_loss(self) -> torch.Tensor:
        mu = self.exps()
        coeffs_abs = torch.abs(self.coeffs)
        weights = coeffs_abs / (coeffs_abs.sum() + 1e-8)

        return torch.sum(weights * mu)

    def edge_bc_loss_theta0(self, r: torch.Tensor) -> torch.Tensor:
        return torch.tensor(0.0, device=r.device)

    def edge_bc_loss_theta_omega(
        self, r: torch.Tensor, omega: float, target: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        mu = self.exps()
        theta_omega = torch.full_like(r, omega)

        r_powers = torch.stack([safe_pow(r, mu_k) for mu_k in mu], dim=-1)
        if self.use_sin_basis:
            angular = torch.stack([torch.sin(mu_k * omega) for mu_k in mu], dim=-1)
            angular_deriv = torch.stack(
                [mu_k * torch.cos(mu_k * omega) for mu_k in mu], dim=-1
            )
        else:
            angular = torch.stack([torch.cos(mu_k * omega) for mu_k in mu], dim=-1)
            angular_deriv = torch.stack(
                [-mu_k * torch.sin(mu_k * omega) for mu_k in mu], dim=-1
            )

        u_omega = (r_powers * angular) @ self.coeffs
        du_dtheta_omega = (r_powers * angular_deriv) @ self.coeffs

        dirichlet_loss = torch.mean(u_omega**2)

        neumann_loss = torch.mean(du_dtheta_omega**2)

        return dirichlet_loss, neumann_loss

    def arc_bc_loss(
        self,
        theta: torch.Tensor,
        target: torch.Tensor,
        omega: float,
    ) -> torch.Tensor:
        r_one = torch.ones_like(theta)
        u_pred = self.forward(r_one, theta)
        return torch.mean((u_pred - target) ** 2)

    @torch.no_grad()
    def get_exponents(self) -> np.ndarray:
        return self.exps().detach().cpu().numpy()

    @torch.no_grad()
    def get_coeffs(self) -> np.ndarray:
        return self.coeffs.detach().cpu().numpy()

    @torch.no_grad()
    def get_dominant_mu(self) -> float:
        coeffs = np.abs(self.get_coeffs())
        mus = self.get_exponents()
        return float(mus[np.argmax(coeffs)])


class NaiveMSNWedge2D(MSNWedge2D):

    def constraint_loss(self, omega: float) -> torch.Tensor:
        return torch.tensor(0.0, device=self.coeffs.device)

    def small_mu_preference_loss(self) -> torch.Tensor:
        return torch.tensor(0.0, device=self.coeffs.device)


def create_model(config) -> MSNWedge2D:
    init_mus = config.get_init_mus()

    if config.method == "naive":
        model = NaiveMSNWedge2D(
            K=config.K,
            mu_min=config.mu_min,
            mu_max=config.mu_max,
            bc_type=config.bc_type,
            init_mus=init_mus,
        )
    else:
        model = MSNWedge2D(
            K=config.K,
            mu_min=config.mu_min,
            mu_max=config.mu_max,
            bc_type=config.bc_type,
            init_mus=init_mus,
        )

    return model.to(config.device)
