from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Tuple

import numpy as np
from sklearn.datasets import make_moons


@dataclass
class MoonsWarpConfig:
    n_source: int = 2000
    n_target: int = 2000
    noise: float = 0.1
    alpha: float = 0.0  # warp strength
    seed: int = 0


def radial_twist(xy: np.ndarray, alpha: float) -> np.ndarray:
    x, y = xy[:, 0], xy[:, 1]
    r = np.sqrt(x * x + y * y) + 1e-12
    theta = np.arctan2(y, x)
    theta_new = theta + alpha * np.exp(-r)
    x_new = r * np.cos(theta_new)
    y_new = r * np.sin(theta_new)
    return np.stack([x_new, y_new], axis=1)


def sample_moons_warp(config: MoonsWarpConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(config.seed)
    Xs, ys = make_moons(n_samples=config.n_source, noise=config.noise, random_state=config.seed)
    # Target: apply smooth invertible warp preserving labels
    Xt, yt = make_moons(n_samples=config.n_target, noise=config.noise, random_state=config.seed + 1)
    if config.alpha != 0.0:
        Xt = radial_twist(Xt.astype(float), config.alpha)
    return Xs.astype(float), ys.astype(int), Xt.astype(float), yt.astype(int)




