from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np

from unpaired_iv.data import UnpairedIVData


def bounded_lognormal_scales(rng, m, base, logstd, min_mult=0.5, max_mult=2.0):
    """Compute lognormal scales for the given inputs."""
    s = base * rng.lognormal(mean=0.0, sigma=logstd, size=m)
    s = np.clip(s, base * min_mult, base * max_mult)
    s *= base / s.mean()
    return s


def dense_beta(
    d: int,
    rng: np.random.Generator,
    beta_min: float = 0.5,
    beta_max: float = 1.0,
    random_signs: bool = True,
) -> np.ndarray:
    """Generate a dense coefficient vector."""
    beta = rng.uniform(beta_min, beta_max, size=int(d))
    if random_signs:
        beta = beta * rng.choice([-1.0, 1.0], size=int(d))
    return beta


def make_dense_beta(
    d: int,
    *,
    beta_min: float = 0.5,
    beta_max: float = 1.0,
    random_signs: bool = True,
):
    """Create a callable that samples dense coefficients."""
    return lambda rng: dense_beta(
        d, rng, beta_min=beta_min, beta_max=beta_max, random_signs=random_signs
    )


def make_sparse_beta(dgp_cls, d: int, s_star: int):
    """Create a callable that samples sparse coefficients."""
    return lambda rng: dgp_cls.sparse_beta(d, s_star, rng=rng)


@dataclass(frozen=True)
class DiscreteEnvDGPConfig:
    """Configuration for DiscreteEnvDGP."""
    gamma_x: float = 1.0 / 5
    gamma_y: float = 1.0 / 5
    sigma_u: float = 1.0 / 5
    sigma_x: float = 1.0
    sigma_eps: float = 1.0 / 5
    env_sigma: float = 1.0

    eps_env_logstd: float = 0.5
    x_env_logstd: float = 0.5


class DiscreteEnvDGP:
    """Data-generating process for DiscreteEnvDGP."""

    def __init__(self, cfg: DiscreteEnvDGPConfig = DiscreteEnvDGPConfig()):
        """Initialize DiscreteEnvDGP with configuration parameters."""
        self.cfg = cfg

    @staticmethod
    def sparse_beta(
        d: int,
        s_star: int,
        beta_min: float = 0.5,
        beta_max: float = 1.0,
        rng: Optional[np.random.Generator] = None,
    ) -> np.ndarray:
        """Sample a sparse coefficient vector."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.zeros(d)
        supp = rng.choice(d, size=min(s_star, d), replace=False)
        beta[supp] = rng.uniform(beta_min, beta_max, size=supp.size)
        return beta

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)
        d = beta.size

        n_y, n_x = m * r_y, m * r_x
        I_y = np.zeros((n_y, m))
        I_x = np.zeros((n_x, m))

        Y = np.zeros(n_y)
        X = np.zeros((n_x, d))

        cfg = self.cfg

        env_means = rng.normal(0.0, cfg.env_sigma, size=(m, d))

        sigma_eps_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_eps, cfg.eps_env_logstd, min_mult=0.5, max_mult=3.0
        )
        sigma_x_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_x, cfg.x_env_logstd, min_mult=0.5, max_mult=2.0
        )

        idx = 0
        for e in range(m):
            for _ in range(r_y):
                I_y[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=d)
                eps_y = rng.normal(0.0, sigma_eps_env[e])
                x_lat = env_means[e] + cfg.gamma_x * u + eps_x
                Y[idx] = x_lat @ beta + cfg.gamma_y * u + eps_y
                idx += 1

        idx = 0
        for e in range(m):
            for _ in range(r_x):
                I_x[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=d)
                X[idx] = env_means[e] + cfg.gamma_x * u + eps_x
                idx += 1

        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)


@dataclass(frozen=True)
class ContinuousIVDGPConfig(DiscreteEnvDGPConfig):
    """Configuration for ContinuousIVDGP."""

    z_scale: float = 1.0
    pi_scale: float = 1.0

    eps_coord_logstd: float = 0.0
    x_coord_logstd: float = 0.0


class ContinuousIVDGP:
    """Data-generating process for ContinuousIVDGP."""

    def __init__(self, cfg: ContinuousIVDGPConfig = ContinuousIVDGPConfig()):
        """Initialize ContinuousIVDGP with configuration parameters."""
        self.cfg = cfg

    @staticmethod
    def sparse_beta(
        d: int,
        s_star: int,
        beta_min: float = 0.5,
        beta_max: float = 1.0,
        rng: Optional[np.random.Generator] = None,
    ) -> np.ndarray:
        """Sample a sparse coefficient vector."""
        return DiscreteEnvDGP.sparse_beta(
            d=d, s_star=s_star, beta_min=beta_min, beta_max=beta_max, rng=rng
        )

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)
        d = int(beta.size)
        m = int(m)
        cfg = self.cfg

        Pi = cfg.pi_scale * rng.normal(0.0, 1.0, size=(m, d))

        n_y, n_x = m * int(r_y), m * int(r_x)
        std = cfg.z_scale / np.sqrt(m)
        I_y = rng.normal(0.0, std, size=(n_y, m))
        I_x = rng.normal(0.0, std, size=(n_x, m))

        U_y = rng.normal(0.0, cfg.sigma_u, size=n_y)
        U_x = rng.normal(0.0, cfg.sigma_u, size=n_x)

        sigma_eps_coord = bounded_lognormal_scales(
            rng=rng, m=m, base=cfg.sigma_eps, logstd=cfg.eps_coord_logstd
        )
        sigma_x_coord = bounded_lognormal_scales(
            rng=rng, m=m, base=cfg.sigma_x, logstd=cfg.x_coord_logstd
        )

        coord_y = np.argmax(np.abs(I_y), axis=1)
        coord_x = np.argmax(np.abs(I_x), axis=1)

        eps_x_y = rng.normal(0.0, 1.0, size=(n_y, d)) * sigma_x_coord[coord_y][
            :, None
        ]
        eps_x_x = rng.normal(0.0, 1.0, size=(n_x, d)) * sigma_x_coord[coord_x][
            :, None
        ]

        eps_y = rng.normal(0.0, sigma_eps_coord[coord_y])

        X_lat_y = I_y @ Pi + cfg.gamma_x * U_y[:, None] + eps_x_y
        Y = X_lat_y @ beta + cfg.gamma_y * U_y + eps_y

        X = I_x @ Pi + cfg.gamma_x * U_x[:, None] + eps_x_x

        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)


@dataclass(frozen=True)
class LowRankEnvDGPConfig(DiscreteEnvDGPConfig):
    """Configuration for LowRankEnvDGP."""
    k: int = 60
    A_scale: float = 1.0


class LowRankEnvDGP(DiscreteEnvDGP):
    """Data-generating process for LowRankEnvDGP."""

    def __init__(self, cfg: LowRankEnvDGPConfig, d: int, rng: np.random.Generator):
        """Initialize LowRankEnvDGP with configuration parameters."""
        super().__init__(
            DiscreteEnvDGPConfig(
                gamma_x=cfg.gamma_x,
                gamma_y=cfg.gamma_y,
                sigma_u=cfg.sigma_u,
                sigma_x=cfg.sigma_x,
                sigma_eps=cfg.sigma_eps,
            )
        )
        self.lrcfg = cfg
        self.d = int(d)
        self.k = int(cfg.k)

        A = rng.normal(0.0, 1.0, size=(self.d, self.k))
        A = (cfg.A_scale / np.sqrt(self.k)) * A
        self.A = A  # (d,k)

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)

        Z = rng.normal(0.0, 1.0, size=(m, self.k))
        env_means = Z @ self.A.T

        n_y, n_x = m * r_y, m * r_x
        I_y = np.zeros((n_y, m))
        I_x = np.zeros((n_x, m))
        Y = np.zeros(n_y)
        X = np.zeros((n_x, self.d))

        cfg = self.cfg

        sigma_eps_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_eps, cfg.eps_env_logstd, min_mult=0.5, max_mult=3.0
        )
        sigma_x_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_x, cfg.x_env_logstd, min_mult=0.5, max_mult=2.0
        )

        idx = 0
        for e in range(m):
            for _ in range(r_y):
                I_y[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=self.d)
                eps_y = rng.normal(0.0, sigma_eps_env[e])
                x_lat = env_means[e] + cfg.gamma_x * u + eps_x
                Y[idx] = x_lat @ beta + cfg.gamma_y * u + eps_y
                idx += 1

        idx = 0
        for e in range(m):
            for _ in range(r_x):
                I_x[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=self.d)
                X[idx] = env_means[e] + cfg.gamma_x * u + eps_x
                idx += 1

        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)


@dataclass(frozen=True)
class LowRankContinuousIVDGPConfig(ContinuousIVDGPConfig):
    """Configuration for LowRankContinuousIVDGP."""
    k: int = 60
    A_scale: float = 1.0


class LowRankContinuousIVDGP(ContinuousIVDGP):
    """Data-generating process for LowRankContinuousIVDGP."""
    def __init__(
        self, cfg: LowRankContinuousIVDGPConfig, d: int, rng: np.random.Generator
    ):
        """Initialize LowRankContinuousIVDGP with configuration parameters."""
        super().__init__(
            ContinuousIVDGPConfig(
                gamma_x=cfg.gamma_x,
                gamma_y=cfg.gamma_y,
                sigma_u=cfg.sigma_u,
                sigma_x=cfg.sigma_x,
                sigma_eps=cfg.sigma_eps,
                z_scale=cfg.z_scale,
                pi_scale=cfg.pi_scale,
                eps_coord_logstd=cfg.eps_coord_logstd,
                x_coord_logstd=cfg.x_coord_logstd,
            )
        )
        self.lrcfg = cfg
        self.d = int(d)
        self.k = int(cfg.k)

        A = rng.normal(0.0, 1.0, size=(self.d, self.k))
        A = (cfg.A_scale / np.sqrt(self.k)) * A
        self.A = A  # (d,k)

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)
        m = int(m)

        cfg = self.cfg
        lrcfg = self.lrcfg

        Z = rng.normal(0.0, 1.0, size=(m, self.k))
        Pi = cfg.pi_scale * (Z @ self.A.T)

        n_y, n_x = m * int(r_y), m * int(r_x)
        std = cfg.z_scale / np.sqrt(m)

        I_y = rng.normal(0.0, std, size=(n_y, m))
        I_x = rng.normal(0.0, std, size=(n_x, m))

        U_y = rng.normal(0.0, cfg.sigma_u, size=n_y)
        U_x = rng.normal(0.0, cfg.sigma_u, size=n_x)

        sigma_eps_coord = bounded_lognormal_scales(
            rng=rng, m=m, base=cfg.sigma_eps, logstd=cfg.eps_coord_logstd
        )
        sigma_x_coord = bounded_lognormal_scales(
            rng=rng, m=m, base=cfg.sigma_x, logstd=cfg.x_coord_logstd
        )

        coord_y = np.argmax(np.abs(I_y), axis=1)
        coord_x = np.argmax(np.abs(I_x), axis=1)

        eps_x_y = (
            rng.normal(0.0, 1.0, size=(n_y, self.d)) * sigma_x_coord[coord_y][:, None]
        )
        eps_x_x = (
            rng.normal(0.0, 1.0, size=(n_x, self.d)) * sigma_x_coord[coord_x][:, None]
        )
        eps_y = rng.normal(0.0, sigma_eps_coord[coord_y])

        X_lat_y = I_y @ Pi + cfg.gamma_x * U_y[:, None] + eps_x_y
        Y = X_lat_y @ beta + cfg.gamma_y * U_y + eps_y

        X = I_x @ Pi + cfg.gamma_x * U_x[:, None] + eps_x_x

        _ = lrcfg
        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)


class DiscreteEnvDGPCorrelated:
    """Data-generating process for DiscreteEnvDGPCorrelated."""
    def __init__(self, cfg: DiscreteEnvDGPConfig = DiscreteEnvDGPConfig()):
        """Initialize DiscreteEnvDGPCorrelated with configuration parameters."""
        self.cfg = cfg

    @staticmethod
    def sparse_beta(
        d: int,
        s_star: int,
        beta_min: float = 0.5,
        beta_max: float = 1.0,
        rng: Optional[np.random.Generator] = None,
    ) -> np.ndarray:
        """Sample a sparse coefficient vector."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.zeros(d)
        supp = rng.choice(d, size=min(s_star, d), replace=False)
        beta[supp] = rng.uniform(beta_min, beta_max, size=supp.size)
        return beta

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)
        d = beta.size

        env_means = rng.normal(0.0, 1.0, size=(m, d))

        n_y, n_x = m * r_y, m * r_x
        I_y = np.zeros((n_y, m))
        I_x = np.zeros((n_x, m))

        Y = np.zeros(n_y)
        X = np.zeros((n_x, d))

        cfg = self.cfg

        sigma_eps_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_eps, cfg.eps_env_logstd, min_mult=0.5, max_mult=3.0
        )
        sigma_x_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_x, cfg.x_env_logstd, min_mult=0.5, max_mult=2.0
        )

        assert r_y == r_x

        idx = 0
        for e in range(m):
            for _ in range(r_y):
                I_y[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                h = rng.uniform(-1.0, 1.0)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=d)
                eps_y = rng.normal(0.0, sigma_eps_env[e])
                x_lat = env_means[e] + cfg.gamma_x * u + eps_x + h
                Y[idx] = x_lat @ beta + cfg.gamma_y * u + eps_y + h

                I_x[idx, e] = 1.0
                eps_x = rng.normal(0.0, sigma_x_env[e], size=d)
                X[idx] = env_means[e] + eps_x + h + rng.normal(0.0, 1.0)
                idx += 1

        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)


class LowRankEnvDGPCorrelated(DiscreteEnvDGPCorrelated):
    """Data-generating process for LowRankEnvDGPCorrelated."""

    def __init__(self, cfg: LowRankEnvDGPConfig, d: int, rng: np.random.Generator):
        """Initialize LowRankEnvDGPCorrelated with configuration parameters."""
        super().__init__(
            DiscreteEnvDGPConfig(
                gamma_x=cfg.gamma_x,
                gamma_y=cfg.gamma_y,
                sigma_u=cfg.sigma_u,
                sigma_x=cfg.sigma_x,
                sigma_eps=cfg.sigma_eps,
            )
        )
        self.lrcfg = cfg
        self.d = int(d)
        self.k = int(cfg.k)

        A = rng.normal(0.0, 1.0, size=(self.d, self.k))
        A = (cfg.A_scale / np.sqrt(self.k)) * A
        self.A = A  # (d,k)

    def sample_unpaired(
        self,
        m: int,
        r_y: int,
        r_x: int,
        beta: np.ndarray,
        rng: Optional[np.random.Generator] = None,
    ) -> UnpairedIVData:
        """Sample unpaired X/Y datasets from the model."""
        rng = np.random.default_rng() if rng is None else rng
        beta = np.asarray(beta)

        Z = rng.normal(0.0, 1.0, size=(m, self.k))
        env_means = Z @ self.A.T

        n_y, n_x = m * r_y, m * r_x
        I_y = np.zeros((n_y, m))
        I_x = np.zeros((n_x, m))
        Y = np.zeros(n_y)
        X = np.zeros((n_x, self.d))

        cfg = self.cfg

        sigma_eps_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_eps, cfg.eps_env_logstd, min_mult=0.5, max_mult=3.0
        )
        sigma_x_env = bounded_lognormal_scales(
            rng, m, cfg.sigma_x, cfg.x_env_logstd, min_mult=0.5, max_mult=2.0
        )

        assert r_y == r_x

        idx = 0
        for e in range(m):
            for _ in range(r_y):
                I_y[idx, e] = 1.0
                u = rng.normal(0.0, cfg.sigma_u)
                h = rng.uniform(-1.0, 1.0)
                eps_x = rng.normal(0.0, sigma_x_env[e], size=self.d)
                eps_y = rng.normal(0.0, sigma_eps_env[e])
                x_lat = env_means[e] + cfg.gamma_x * u + eps_x + h
                Y[idx] = x_lat @ beta + cfg.gamma_y * u + eps_y + h

                I_x[idx, e] = 1.0
                eps_x = rng.normal(0.0, sigma_x_env[e], size=self.d)
                X[idx] = env_means[e] + eps_x + h + rng.normal(0.0, 1.0)
                idx += 1

        return UnpairedIVData(I_y=I_y, Y=Y, I_x=I_x, X=X)

