import numpy as np
import torch
from typing import Tuple
import math


def sample_arc_points(
    n: int,
    omega: float,
    seed: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(seed)
    theta = rng.random(n).astype(np.float32) * omega
    r = np.ones(n, dtype=np.float32)
    return torch.tensor(r), torch.tensor(theta)


def sample_edge_points(
    n: int,
    omega: float,
    r_min: float = 1e-4,
    r_power: float = 2.0,
    seed: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(seed)

    t = rng.random(n).astype(np.float32)
    r = r_min + (1.0 - r_min) * np.power(t, r_power)

    r_e0 = torch.tensor(r.copy())
    theta_e0 = torch.zeros(n)

    r_ew = torch.tensor(r.copy())
    theta_ew = torch.full((n,), omega, dtype=torch.float32)

    return r_e0, theta_e0, r_ew, theta_ew


def get_arc_target(
    theta: torch.Tensor,
    true_mu: float,
    bc_type: str,
) -> torch.Tensor:
    if bc_type in ["DD", "DN"]:
        return torch.sin(true_mu * theta)
    else:
        return torch.cos(true_mu * theta)


def sample_interior_points(
    n: int,
    omega: float,
    r_min: float = 1e-4,
    r_power: float = 2.0,
    seed: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(seed)

    t = rng.random(n).astype(np.float32)
    r = r_min + (1.0 - r_min) * np.power(t, r_power)

    theta = rng.random(n).astype(np.float32) * omega

    return torch.tensor(r), torch.tensor(theta)


def sample_evaluation_grid(
    n_r: int = 50,
    n_theta: int = 50,
    omega: float = 1.5 * math.pi,
    r_min: float = 1e-3,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    r_grid = np.logspace(np.log10(r_min), 0, n_r)
    theta_grid = np.linspace(0, omega, n_theta)

    R, TH = np.meshgrid(r_grid, theta_grid)
    X = R * np.cos(TH)
    Y = R * np.sin(TH)

    return R, TH, X, Y


def get_true_solution(
    r: np.ndarray,
    theta: np.ndarray,
    true_mu: float,
    bc_type: str,
) -> np.ndarray:
    r_power = np.power(np.abs(r) + 1e-12, true_mu)

    if bc_type in ["DD", "DN"]:
        angular = np.sin(true_mu * theta)
    else:
        angular = np.cos(true_mu * theta)

    return r_power * angular


class WedgeDataLoader:

    def __init__(
        self,
        omega: float,
        true_mu: float,
        bc_type: str,
        n_arc: int = 512,
        n_edge: int = 512,
        r_min: float = 1e-4,
        r_power: float = 2.0,
        base_seed: int = 0,
        device: str = "cuda",
    ):
        self.omega = omega
        self.true_mu = true_mu
        self.bc_type = bc_type
        self.n_arc = n_arc
        self.n_edge = n_edge
        self.r_min = r_min
        self.r_power = r_power
        self.base_seed = base_seed
        self.device = device
        self.step = 0

    def get_batch(self) -> dict:
        seed = self.base_seed + self.step
        self.step += 1

        r_arc, theta_arc = sample_arc_points(self.n_arc, self.omega, seed)
        target_arc = get_arc_target(theta_arc, self.true_mu, self.bc_type)

        r_e0, theta_e0, r_ew, theta_ew = sample_edge_points(
            self.n_edge, self.omega, self.r_min, self.r_power, seed + 1000
        )

        return {
            "r_arc": r_arc.to(self.device),
            "theta_arc": theta_arc.to(self.device),
            "target_arc": target_arc.to(self.device),
            "r_e0": r_e0.to(self.device),
            "theta_e0": theta_e0.to(self.device),
            "r_ew": r_ew.to(self.device),
            "theta_ew": theta_ew.to(self.device),
        }

    def reset(self):
        self.step = 0
