from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import numpy as np
from numpy.linalg import solve
from sklearn.linear_model import Lasso

from ..data import UnpairedIVData
from ..linalg import (
    center_cols,
    cross_cov_IX,
    cross_cov_IY,
    safe_cholesky_spd,
    safe_inv_spd,
)
from .base import Estimator


@dataclass(frozen=True)
class UPGMMConfig:
    """Configuration for UPGMM."""
    ridge: float = 1e-10
    use_optimal_weight: bool = True
    l1: bool = False
    lam_scale: float = (
        1.0  # implementing a data dependent regularization parameter would be nice
    )
    post_refit: bool = True
    split_B: int = 0


class UPGMM(Estimator):
    """Unpaired GMM estimator with optional L1 regularization."""
    name = "up_gmm"

    def __init__(self, cfg: UPGMMConfig = UPGMMConfig()):
        """Initialize UPGMM with configuration parameters."""
        self.cfg = cfg

    def _omega_hat(self, data: UnpairedIVData, beta_init: np.ndarray) -> np.ndarray:
        """ omega hat for the given inputs."""
        Iy = center_cols(data.I_y)
        Yc = data.Y - data.Y.mean()
        Ix = center_cols(data.I_x)
        Xb = data.X @ beta_init
        Xb = Xb - Xb.mean()

        n_y, n_x = data.n_y, data.n_x
        N = n_y + n_x
        tau = n_y / max(N, 1)
        ttau = n_x / max(N, 1)

        IY = Iy * Yc[:, None]
        IYc = IY - IY.mean(axis=0, keepdims=True)
        cov_IY = (IYc.T @ IYc) / max(n_y - 1, 1)

        ZX = Ix * Xb[:, None]
        ZXc = ZX - ZX.mean(axis=0, keepdims=True)
        cov_ZX = (ZXc.T @ ZXc) / max(n_x - 1, 1)

        Omega = (1.0 / max(tau, 1e-12)) * cov_IY + (1.0 / max(ttau, 1e-12)) * cov_ZX
        Omega = 0.5 * (Omega + Omega.T) + self.cfg.ridge * np.eye(data.m)
        return Omega

    def _split_avg_cxx(
        self, data: UnpairedIVData, rng: np.random.Generator
    ) -> np.ndarray:
        """ split avg cxx for the given inputs."""
        B = cross_cov_IX(data.I_x, data.X)
        d = data.d
        if self.cfg.split_B <= 0 or data.n_x < 4:
            return B.T @ B

        idx = np.arange(data.n_x)
        acc = np.zeros((d, d))
        cnt = 0
        for _ in range(self.cfg.split_B):
            perm = rng.permutation(idx)
            mid = perm.size // 2
            A, C = perm[:mid], perm[mid:]
            if A.size < 2 or C.size < 2:
                continue
            BA = cross_cov_IX(data.I_x[A], data.X[A])
            BC = cross_cov_IX(data.I_x[C], data.X[C])
            acc += 0.5 * (BA.T @ BC + BC.T @ BA)
            cnt += 1
        if cnt == 0:
            return B.T @ B
        return acc / cnt

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        rng = np.random.default_rng() if rng is None else rng
        a = cross_cov_IY(data.I_y, data.Y)  # (m,)
        B = cross_cov_IX(data.I_x, data.X)  # (m,d)
        d = data.d

        Cxx0 = self._split_avg_cxx(data, rng) + self.cfg.ridge * np.eye(d)
        beta0 = solve(Cxx0, B.T @ a)

        if self.cfg.use_optimal_weight:
            Omega = self._omega_hat(data, beta0)
            W = safe_inv_spd(Omega, ridge=self.cfg.ridge)
        else:
            W = np.eye(data.m)

        if not self.cfg.l1:
            A = B.T @ W @ B + self.cfg.ridge * np.eye(d)
            b = B.T @ W @ a
            return solve(A, b)

        # L1: solve lasso on (L B, L a)
        L = safe_cholesky_spd(W, ridge=self.cfg.ridge)
        Z = L.T @ B
        y = L.T @ a

        col_scales = np.linalg.norm(Z, axis=0)
        col_scales = np.maximum(col_scales, 1e-12)
        Zs = Z / col_scales[None, :]

        eff_n = float(data.m)
        n_samples = float(Zs.shape[0])
        alpha_max = np.max(np.abs(Zs.T @ y)) / max(n_samples, 1.0)
        frac = self.cfg.lam_scale * np.sqrt(np.log(d + 1.0) / max(eff_n, 1.0))
        alpha = min(frac * alpha_max, 0.99 * alpha_max)

        model = Lasso(alpha=alpha, fit_intercept=False, max_iter=50000, tol=1e-4)
        model.fit(Zs, y)

        beta_scaled = model.coef_.copy()
        beta_l1 = beta_scaled / col_scales

        if not self.cfg.post_refit:
            return beta_l1

        supp = np.where(np.abs(beta_l1) > 1e-8)[0]
        if supp.size == 0:
            return beta_l1

        Bs = B[:, supp]
        As = Bs.T @ W @ Bs + self.cfg.ridge * np.eye(supp.size)
        bs = Bs.T @ W @ a
        ref = solve(As, bs)

        out = np.zeros(d)
        out[supp] = ref
        return out
