import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple
from scipy.special import expit
from typing import Union
import torch

@dataclass(frozen=True)
class CausalPCAMParams:
    d_x: int = 5

    # e(X): treatment propensity A|X
    gamma0: float = 0.0
    gamma: Optional[np.ndarray] = None  # (d_x,)

    # pi(X): baseline metastasis S|X
    pi0: float = 0.0
    pi_beta: Optional[np.ndarray] = None  # (d_x,)

    # q(X): removal success among diseased treated
    q0: float = 0.0
    q_beta: Optional[np.ndarray] = None  # (d_x,)


def _check_and_prepare_params(p: CausalPCAMParams) -> CausalPCAMParams:
    if p.d_x <= 0:
        raise ValueError(f"d_x must be positive. Got d_x={p.d_x}")

    def _vec(v, name: str) -> np.ndarray:
        arr = np.zeros((p.d_x,), dtype=float) if v is None else np.asarray(v, dtype=float)
        if arr.shape != (p.d_x,):
            raise ValueError(f"{name} must have shape (d_x,)={(p.d_x,)}. Got {arr.shape}")
        return arr

    return CausalPCAMParams(
        d_x=p.d_x,
        gamma0=float(p.gamma0),
        gamma=_vec(p.gamma, "gamma"),
        pi0=float(p.pi0),
        pi_beta=_vec(p.pi_beta, "pi_beta"),
        q0=float(p.q0),
        q_beta=_vec(p.q_beta, "q_beta"),
    )

def simulate_causal_pcam_observed_tensors(
    n: int,
    dataset_0,  # PCAM subset: no metastasis
    dataset_1,  # PCAM subset: metastasis
    params: CausalPCAMParams = CausalPCAMParams(),
    x_loc: float = 0.0,
    x_scale: float = 1.0,
    clip_propensity: Optional[Tuple[float, float]] = None,
    rng: Optional[np.random.Generator] = None,
    return_propensity: bool = False,
    device: Optional[torch.device] = None,
    dtype_x: torch.dtype = torch.float32,
) -> Union[
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
    """
    Observed-data causal simulation with images as outcomes, returning torch tensors.

    X ~ N(x_loc, x_scale^2 I)
    A|X ~ Bernoulli(e(X)) with e(X)=expit(gamma0 + X @ gamma)
    S|X ~ Bernoulli(pi(X)) with pi(X)=expit(pi0 + X @ pi_beta)
    q(X)=P(removal success | S=1,A=1,X) with q(X)=expit(q0 + X @ q_beta)

    Outcome:
      if S=0              -> sample Y from dataset0
      if S=1 and A=0      -> sample Y from dataset1
      if S=1 and A=1      -> dataset0 w.p. q(X), else dataset1

    Returns:
      X: (n, d_x) torch.float32
      A: (n,) torch.int64 in {0,1}
      Y: (n, C, H, W) torch.float32 (or whatever your transform returns)
      optionally e, pi, q as torch.float32 (n,)
    """

    if n <= 0:
        raise ValueError(f"n must be positive. Got n={n}")
    if len(dataset_0) == 0:
        raise ValueError("dataset0 is empty.")
    if len(dataset_1) == 0:
        raise ValueError("dataset1 is empty.")

    p = _check_and_prepare_params(params)
    rng = np.random.default_rng() if rng is None else rng

    # 1) Sample X in numpy, compute logits with numpy for simplicity/consistency
    X_np = rng.normal(loc=x_loc, scale=x_scale, size=(n, p.d_x)).astype(float)

    # 2) e(X), A|X
    e_np = expit(p.gamma0 + X_np @ p.gamma)
    if clip_propensity is not None:
        lo, hi = clip_propensity
        if not (0.0 < lo < hi < 1.0):
            raise ValueError(f"clip_propensity must satisfy 0 < lo < hi < 1. Got {clip_propensity}")
        e_np = np.clip(e_np, lo, hi)
    A_np = rng.binomial(n=1, p=e_np, size=n).astype(np.int64)

    # 3) pi(X), S|X
    pi_np = expit(p.pi0 + X_np @ p.pi_beta)
    S_np = rng.binomial(n=1, p=pi_np, size=n).astype(np.int64)

    # 4) q(X)
    q_np = expit(p.q0 + X_np @ p.q_beta)

    # 5) Choose source dataset per i
    # source=0 -> dataset0, source=1 -> dataset1
    source = np.zeros(n, dtype=np.int64)
    source[(S_np == 1) & (A_np == 0)] = 1

    treated_diseased = (S_np == 1) & (A_np == 1)
    if treated_diseased.any():
        success = rng.binomial(n=1, p=q_np[treated_diseased], size=int(treated_diseased.sum())).astype(np.int64)
        source[treated_diseased] = 1 - success  # success=>0, fail=>1

    # 6) Sample images
    Y_list = [None] * n

    n0 = int((source == 0).sum())
    n1 = int((source == 1).sum())
    idx0 = rng.integers(0, len(dataset_0), size=n0) if n0 > 0 else np.array([], dtype=int)
    idx1 = rng.integers(0, len(dataset_1), size=n1) if n1 > 0 else np.array([], dtype=int)

    pos0 = np.where(source == 0)[0]
    for j, i in enumerate(pos0):
        img, _ = dataset_0[int(idx0[j])]
        if not torch.is_tensor(img):
            raise TypeError("dataset0 returned a non-tensor image. Add transforms.ToTensor() to your dataset transform.")
        Y_list[i] = img

    pos1 = np.where(source == 1)[0]
    for j, i in enumerate(pos1):
        img, _ = dataset_1[int(idx1[j])]
        if not torch.is_tensor(img):
            raise TypeError("dataset1 returned a non-tensor image. Add transforms.ToTensor() to your dataset transform.")
        Y_list[i] = img

    # Stack into (n, C, H, W) (or your transform's shape)
    Y = torch.stack(Y_list, dim=0)

    X = torch.as_tensor(X_np, dtype=dtype_x)
    A = torch.as_tensor(A_np, dtype=torch.int64)
    S = torch.as_tensor(S_np, dtype=torch.int64)

    if device is not None:
        X = X.to(device)
        A = A.to(device)
        Y = Y.to(device)

    if return_propensity:
        e = torch.as_tensor(e_np, dtype=torch.float32)
        pi = torch.as_tensor(pi_np, dtype=torch.float32)
        q = torch.as_tensor(q_np, dtype=torch.float32)
        if device is not None:
            e = e.to(device); pi = pi.to(device); q = q.to(device)
        return X, A, Y, source, e, pi, q

    return X, A, Y
