from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional

import numpy as np
from numpy.linalg import solve
from sklearn.linear_model import Lasso

from ..data import UnpairedIVData
from ..instruments import is_one_hot_instrument
from ..linalg import center_cols, cross_cov_IX, cross_cov_IY, make_psd
from ..stats import default_lam
from .base import Estimator


@dataclass(frozen=True)
class UPGMMHDConfig:
    """Configuration for UPGMMHD."""

    ridge: float = 1e-10
    K: int = 2
    redraw_B: int = 5
    l1: bool = False
    lam_scale: float = (
        1.0  # implementing a data dependent regularization parameter would be nice
    )
    post_refit: bool = True
    lam_effective_n: str = "m"  # "m" or "N"


class UPGMMHD(Estimator):
    """High-dimensional unpaired GMM estimator using cross-fold moments."""

    name = "up_gmm_hd"

    def __init__(self, cfg: UPGMMHDConfig = UPGMMHDConfig()):
        """Initialize UPGMMHD with configuration parameters."""
        if cfg.K < 2:
            raise ValueError("UPGMMHD requires K>=2.")
        self.cfg = cfg

    def _compute_Cxx_crossfold_once_unstratified(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> np.ndarray:
        """compute cxx crossfold once unstratified for the given inputs."""
        idx = rng.permutation(data.n_x)
        folds = np.array_split(idx, self.cfg.K)
        covs = [cross_cov_IX(data.I_x[f], data.X[f]) for f in folds]

        d = data.d
        acc = np.zeros((d, d))
        cnt = 0
        K = len(covs)
        for k in range(K):
            for h in range(K):
                if k == h:
                    continue
                acc += covs[h].T @ covs[k]
                cnt += 1
        return data.m * (acc / max(cnt, 1))

    def _compute_Cxx_crossfold_once_stratified(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> np.ndarray:
        """compute cxx crossfold once stratified for the given inputs."""
        m = data.m
        idx_by_env = [np.where(data.I_x[:, e] > 0.5)[0] for e in range(m)]

        folds = [[] for _ in range(self.cfg.K)]
        for e in range(m):
            ix = idx_by_env[e]
            if ix.size == 0:
                continue
            perm = rng.permutation(ix)
            parts = np.array_split(perm, self.cfg.K)
            for k in range(self.cfg.K):
                if parts[k].size > 0:
                    folds[k].append(parts[k])

        folds = [np.concatenate(f) for f in folds if len(f) > 0]
        folds = [f for f in folds if f.size >= 2]
        if len(folds) < 2:
            return self._compute_Cxx_crossfold_once_unstratified(data, rng)

        covs = [cross_cov_IX(data.I_x[f], data.X[f]) for f in folds]

        d = data.d
        acc = np.zeros((d, d))
        cnt = 0
        K = len(covs)
        for k in range(K):
            for h in range(K):
                if k == h:
                    continue
                acc += covs[h].T @ covs[k]
                cnt += 1

        return data.m * (acc / max(cnt, 1))

    def _compute_Cxx_crossfold_once(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> np.ndarray:
        """compute cxx crossfold once for the given inputs."""
        if is_one_hot_instrument(data.I_x):
            return self._compute_Cxx_crossfold_once_stratified(data, rng)
        return self._compute_Cxx_crossfold_once_unstratified(data, rng)

    def _compute_Cxy_full(self, data: UnpairedIVData) -> np.ndarray:
        """compute cxy full for the given inputs."""
        covIX = cross_cov_IX(data.I_x, data.X)  # (m,d)
        covIY = cross_cov_IY(data.I_y, data.Y)  # (m,)
        return data.m * (covIX.T @ covIY)  # (d,)

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        rng = np.random.default_rng() if rng is None else rng
        d = data.d

        Cxy = self._compute_Cxy_full(data)
        Cxx_acc = np.zeros((d, d))
        used = 0
        for _ in range(max(1, self.cfg.redraw_B)):
            Cxx_b = self._compute_Cxx_crossfold_once(data, rng)
            Cxx_acc += Cxx_b
            used += 1
        Cxx = make_psd(Cxx_acc / max(used, 1), min_eig=self.cfg.ridge)

        if not self.cfg.l1:
            return solve(Cxx + self.cfg.ridge * np.eye(d), Cxy)

        eff_n = float(data.m) if self.cfg.lam_effective_n == "m" else float(data.N)
        lam = default_lam(d, effective_n=eff_n, lam_scale=self.cfg.lam_scale)
        alpha = lam / max(Cxx.shape[0], 1.0)

        D = np.sqrt(np.maximum(np.diag(Cxx), 1e-12))
        Xs = Cxx / D[None, :]

        model = Lasso(alpha=alpha, fit_intercept=False, max_iter=50000, tol=1e-4)
        model.fit(Xs, Cxy)
        beta_l1 = model.coef_.copy() / D

        if not self.cfg.post_refit:
            return beta_l1

        supp = np.where(np.abs(beta_l1) > 1e-8)[0]
        if supp.size == 0:
            return beta_l1

        Zs = Cxx[:, supp]
        G = Zs.T @ Zs + self.cfg.ridge * np.eye(supp.size)
        h = Zs.T @ Cxy
        ref = solve(G, h)

        out = np.zeros(d)
        out[supp] = ref
        return out


# -----------------------------
# Moment-lasso HD variant
# -----------------------------


@dataclass(frozen=True)
class UPGMMHDMomentConfig:
    """Configuration for UPGMMHDMoment."""

    ridge: float = 1e-10
    K: int = 2
    redraw_B: int = 5
    l1: bool = False
    lam_scale: float = 1.0
    post_refit: bool = True
    lam_effective_n: str = "m"  # "m","N","min_n","harmonic"


class UPGMMHD_Moment(Estimator):
    """UPGMMHD_Moment class."""

    name = "up_gmm_hd_moment"

    def __init__(self, cfg: UPGMMHDMomentConfig = UPGMMHDMomentConfig()):
        """Initialize UPGMMHD_Moment with configuration parameters."""
        if cfg.K < 2:
            raise ValueError("UPGMMHD_Moment requires K>=2.")
        self.cfg = cfg

    def _make_folds(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> List[np.ndarray]:
        """make folds for the given inputs."""
        idx = np.arange(data.n_x)
        if not is_one_hot_instrument(data.I_x):
            perm = rng.permutation(idx)
            return [f for f in np.array_split(perm, self.cfg.K) if f.size >= 2]

        m = data.m
        idx_by_env = [np.where(data.I_x[:, e] > 0.5)[0] for e in range(m)]

        buckets: List[List[np.ndarray]] = [[] for _ in range(self.cfg.K)]
        for e in range(m):
            ix = idx_by_env[e]
            if ix.size == 0:
                continue
            perm = rng.permutation(ix)
            parts = np.array_split(perm, self.cfg.K)
            for k in range(self.cfg.K):
                if parts[k].size > 0:
                    buckets[k].append(parts[k])

        out: List[np.ndarray] = []
        for k in range(self.cfg.K):
            if len(buckets[k]) == 0:
                continue
            f = np.concatenate(buckets[k])
            if f.size >= 2:
                out.append(f)

        if len(out) < 2:
            perm = rng.permutation(idx)
            out = [f for f in np.array_split(perm, self.cfg.K) if f.size >= 2]
        return out

    def _fold_covIX_once(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> List[np.ndarray]:
        """fold covix once for the given inputs."""
        folds = self._make_folds(data, rng)
        covs = [cross_cov_IX(data.I_x[f], data.X[f]) for f in folds]
        return [B for B in covs if np.all(np.isfinite(B))]

    def _compute_Cxy_full(self, data: UnpairedIVData) -> np.ndarray:
        """compute cxy full for the given inputs."""
        covIX = cross_cov_IX(data.I_x, data.X)
        covIY = cross_cov_IY(data.I_y, data.Y)
        return data.m * (covIX.T @ covIY)

    def _compute_Cxx_from_covs(
        self, data: UnpairedIVData, covs: List[np.ndarray]
    ) -> np.ndarray:
        """compute cxx from covs for the given inputs."""
        d = data.d
        K = len(covs)
        if K < 2:
            B = cross_cov_IX(data.I_x, data.X)
            return B.T @ B

        acc = np.zeros((d, d))
        cnt = 0
        for k in range(K):
            for h in range(K):
                if h == k:
                    continue
                acc += covs[h].T @ covs[k]
                cnt += 1
        return data.m * (acc / max(cnt, 1))

    def _effective_n(self, data: UnpairedIVData) -> float:
        """effective n for the given inputs."""
        if self.cfg.lam_effective_n == "m":
            return float(data.m)
        if self.cfg.lam_effective_n == "N":
            return float(data.N)
        if self.cfg.lam_effective_n == "min_n":
            return float(min(data.n_x, data.n_y))
        if self.cfg.lam_effective_n == "harmonic":
            return float(data.n_x * data.n_y) / max(float(data.n_x + data.n_y), 1.0)
        raise ValueError(f"Unknown lam_effective_n={self.cfg.lam_effective_n!r}")

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        rng = np.random.default_rng() if rng is None else rng
        d = data.d

        Cxy = self._compute_Cxy_full(data)

        Cxx_acc = np.zeros((d, d))
        used = 0
        covs_for_l1: List[np.ndarray] = []

        for _ in range(max(1, self.cfg.redraw_B)):
            covs = self._fold_covIX_once(data, rng)

            if len(covs) >= 2:
                Cxx_acc += self._compute_Cxx_from_covs(data, covs)
                used += 1

            if self.cfg.l1:
                covs_for_l1.extend(covs)

        if used == 0:
            B = cross_cov_IX(data.I_x, data.X)
            Cxx = make_psd(B.T @ B, min_eig=self.cfg.ridge)
        else:
            Cxx = make_psd(Cxx_acc / used, min_eig=self.cfg.ridge)

        if not self.cfg.l1:
            return solve(Cxx + self.cfg.ridge * np.eye(d), Cxy)

        a = cross_cov_IY(data.I_y, data.Y)  # (m,)
        if len(covs_for_l1) == 0:
            return np.zeros(d)

        Z = np.vstack(covs_for_l1)  # (m*Keff, d)
        y = np.tile(a, len(covs_for_l1))

        col_scales = np.linalg.norm(Z, axis=0)
        col_scales = np.maximum(col_scales, 1e-12)
        Zs = Z / col_scales[None, :]

        eff_n = self._effective_n(data)
        n_samples = float(Zs.shape[0])
        alpha_max = np.max(np.abs(Zs.T @ y)) / max(n_samples, 1.0)
        frac = self.cfg.lam_scale * np.sqrt(np.log(d + 1.0) / max(eff_n, 1.0))
        alpha = min(frac * alpha_max, 0.99 * alpha_max)

        model = Lasso(alpha=alpha, fit_intercept=False, max_iter=50000, tol=1e-4)
        model.fit(Zs, y)
        beta_scaled = model.coef_.copy()
        beta_l1 = beta_scaled / col_scales

        if not self.cfg.post_refit:
            return beta_l1

        supp = np.where(np.abs(beta_l1) > 1e-8)[0]
        if supp.size == 0:
            return beta_l1

        Cxx_ss = Cxx[np.ix_(supp, supp)] + self.cfg.ridge * np.eye(supp.size)
        ref = solve(Cxx_ss, Cxy[supp])
        out = np.zeros(d)
        out[supp] = ref
        return out


# -----------------------------
# Analytic U-statistic correction
# -----------------------------


@dataclass(frozen=True)
class UPGMMHDAnalyticConfig:
    """Configuration for UPGMMHDAnalytic."""

    ridge: float = 1e-10
    l1: bool = False
    lam_scale: float = 1.0
    post_refit: bool = True
    lam_effective_n: str = "m"  # "m" or "N"


class UPGMMHDAnalytic(Estimator):
    """Analytic high-dimensional unpaired GMM estimator (SplitUP)."""

    name = "up_gmm_hd_analytic"

    def __init__(self, cfg: UPGMMHDAnalyticConfig = UPGMMHDAnalyticConfig()):
        """Initialize UPGMMHDAnalytic with configuration parameters."""
        self.cfg = cfg

    def _Cxx_analytic_generic(self, data: UnpairedIVData) -> np.ndarray:
        """cxx analytic generic for the given inputs."""
        Ix = center_cols(data.I_x)
        Xc = center_cols(data.X)
        n = data.n_x
        d = data.d

        S = Ix.T @ Xc
        term1 = S.T @ S

        w = np.sum(Ix**2, axis=1)
        term2 = (Xc * w[:, None]).T @ Xc

        scale = 1.0 / max(n * (n - 1), 1)
        Cxx = data.m * scale * (term1 - term2)

        Cxx = ((data.m - 1) / data.m) * Cxx
        return make_psd(Cxx, min_eig=self.cfg.ridge) + self.cfg.ridge * np.eye(d)

    def _Cxx_analytic_categorical(self, data: UnpairedIVData) -> np.ndarray:
        """Analytic H→∞ limit for *within-environment stratified splits* (sample-weighted)."""
        Xc = center_cols(data.X)
        m = data.m
        d = data.d

        idx_by_env = [np.where(data.I_x[:, e] > 0.5)[0] for e in range(m)]

        kept = []
        for e in range(m):
            r = idx_by_env[e].size
            if r >= 2:
                kept.append(e)
        if len(kept) < 2:
            return self._Cxx_analytic_generic(data)

        r = np.array([idx_by_env[e].size for e in kept], dtype=float)
        n_eff = float(np.sum(r))
        w = r / n_eff

        mus = []
        Ms = []
        for e in kept:
            ix = idx_by_env[e]
            re = ix.size
            Xe = Xc[ix]

            S = Xe.sum(axis=0)
            Q = Xe.T @ Xe
            mu = S / re
            M = (np.outer(S, S) - Q) / (re * (re - 1))

            mus.append(mu)
            Ms.append(M)

        Mu = np.stack(mus, axis=0)
        Mstack = np.stack(Ms, axis=0)

        E_same = np.tensordot(w, Mstack, axes=(0, 0))

        sum_mu = (w[:, None] * Mu).sum(axis=0)
        sum_mumu_same = (w**2)[:, None, None] * (Mu[:, :, None] * Mu[:, None, :])
        sum_mumu_same = sum_mumu_same.sum(axis=0)
        cross_env = np.outer(sum_mu, sum_mu) - sum_mumu_same
        same_env = np.tensordot(w**2, Mstack, axes=(0, 0))
        E_random = cross_env + same_env

        Cxx = E_same - E_random

        return make_psd(Cxx, min_eig=self.cfg.ridge) + self.cfg.ridge * np.eye(d)

    def _Cxx_analytic(self, data: UnpairedIVData) -> np.ndarray:
        """cxx analytic for the given inputs."""
        if is_one_hot_instrument(data.I_x):
            return self._Cxx_analytic_categorical(data)
        return self._Cxx_analytic_generic(data)

    def _Cxy_full(self, data: UnpairedIVData) -> np.ndarray:
        """cxy full for the given inputs."""
        covIX = cross_cov_IX(data.I_x, data.X)
        covIY = cross_cov_IY(data.I_y, data.Y)
        return data.m * (covIX.T @ covIY)

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        d = data.d
        Cxx = self._Cxx_analytic(data)
        Cxy = self._Cxy_full(data)

        if not self.cfg.l1:
            return solve(Cxx + self.cfg.ridge * np.eye(d), Cxy)

        eff_n = float(data.m) if self.cfg.lam_effective_n == "m" else float(data.N)
        lam = default_lam(d, effective_n=eff_n, lam_scale=self.cfg.lam_scale)
        alpha = lam / max(Cxx.shape[0], 1.0)

        D = np.sqrt(np.maximum(np.diag(Cxx), 1e-12))
        Xs = Cxx / D[None, :]

        model = Lasso(alpha=alpha, fit_intercept=False, max_iter=50000, tol=1e-4)
        model.fit(Xs, Cxy)
        beta_l1 = model.coef_.copy() / D

        if not self.cfg.post_refit:
            return beta_l1

        supp = np.where(np.abs(beta_l1) > 1e-8)[0]
        if supp.size == 0:
            return beta_l1

        Zs = Cxx[:, supp]
        G = Zs.T @ Zs + self.cfg.ridge * np.eye(supp.size)
        h = Zs.T @ Cxy
        ref = solve(G, h)

        out = np.zeros(d)
        out[supp] = ref
        return out
