from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

import numpy as np


@dataclass
class BlobsShiftConfig:
    dim: int = 2
    separation: float = 4.0
    cov_scale: float = 1.0
    rotation_deg: float = 0.0
    translation: float = 0.0  # shift magnitude
    n_source: int = 2000
    n_target: int = 2000
    seed: int = 0


def _rotation_matrix(theta: float) -> np.ndarray:
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[c, -s], [s, c]], dtype=float)


def sample_blobs_shift(config: BlobsShiftConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    rng = np.random.default_rng(config.seed)
    assert config.dim == 2, "BlobsShift implemented for 2D"

    centers = np.array([
        [-config.separation / 2, 0.0],
        [config.separation / 2, 0.0],
    ])

    cov = (config.cov_scale ** 2) * np.eye(2)
    n_per = config.n_source // 2
    Xs_list = []
    ys_list = []
    for c_id, c in enumerate(centers):
        Xs_list.append(rng.multivariate_normal(c, cov, size=n_per))
        ys_list.append(np.full(n_per, c_id))
    Xs = np.vstack(Xs_list)
    ys = np.concatenate(ys_list)

    # Target shift: rotate and translate marginal P_X; conditional P(y|x) kept same by re-using labels
    theta = np.deg2rad(config.rotation_deg)
    R = _rotation_matrix(theta)
    t = np.array([config.translation, 0.0])

    Xt = (Xs @ R.T) + t  # apply to features
    yt = ys.copy()
    return Xs, ys, Xt, yt




