from __future__ import annotations

from typing import Optional

import numpy as np
from numpy.linalg import pinv, solve

from ..data import UnpairedIVData
from ..linalg import center_cols, cross_cov_IX, cross_cov_IY
from .base import Estimator


class NaiveOLS(Estimator):
    """Naive OLS on randomly paired X/Y subsamples (baseline)."""
    name = "naive_ols"

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        rng = np.random.default_rng() if rng is None else rng
        n = min(data.n_x, data.n_y)
        ix = rng.choice(data.n_x, size=n, replace=False)
        iy = rng.choice(data.n_y, size=n, replace=False)

        Xp = data.X[ix]
        Yp = data.Y[iy]

        Xc = center_cols(Xp)
        Yc = Yp - Yp.mean()

        XtX = Xc.T @ Xc
        XtY = Xc.T @ Yc
        return pinv(XtX) @ XtY


class EnvMeansOLS(Estimator):
    """OLS on environment means of X and Y (baseline)."""
    name = "env_means_ols"

    def __init__(self, ridge: float = 1e-10):
        """Initialize EnvMeansOLS with configuration parameters."""
        self.ridge = float(ridge)

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        env_x = np.argmax(data.I_x, axis=1)
        env_y = np.argmax(data.I_y, axis=1)

        m, d = data.m, data.d
        muX = np.full((m, d), np.nan, dtype=float)
        muY = np.full(m, np.nan, dtype=float)

        for e in range(m):
            ix = np.where(env_x == e)[0]
            if ix.size > 0:
                muX[e] = data.X[ix].mean(axis=0)
            iy = np.where(env_y == e)[0]
            if iy.size > 0:
                muY[e] = data.Y[iy].mean()

        ok = np.isfinite(muY) & np.all(np.isfinite(muX), axis=1)
        if int(ok.sum()) < 2:
            return np.zeros(d)

        Xbar = muX[ok]
        Ybar = muY[ok]

        XtX = Xbar.T @ Xbar + self.ridge * np.eye(d)
        XtY = Xbar.T @ Ybar
        try:
            return solve(XtX, XtY)
        except Exception:
            return pinv(XtX) @ XtY


class TS2SLS(Estimator):
    """Two-sample 2SLS baseline using instrument moments."""
    name = "ts_2sls"

    def __init__(self, ridge: float = 1e-10):
        """Initialize TS2SLS with configuration parameters."""
        self.ridge = float(ridge)

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        Ix = center_cols(data.I_x)
        Xc = center_cols(data.X)
        Iy = center_cols(data.I_y)
        Yc = data.Y - data.Y.mean()

        m, d = data.m, data.d
        G = Ix.T @ Ix + self.ridge * np.eye(m)
        H = Ix.T @ Xc
        Gamma = solve(G, H)  # (m,d)

        Xhat_y = Iy @ Gamma  # (n_y,d)

        XtX = Xhat_y.T @ Xhat_y + self.ridge * np.eye(d)
        XtY = Xhat_y.T @ Yc
        return solve(XtX, XtY)


class TSIV(Estimator):
    """Two-sample IV baseline using cross-covariances."""
    name = "ts_iv"

    def __init__(self, ridge: float = 1e-10):
        """Initialize TSIV with configuration parameters."""
        self.ridge = float(ridge)

    def fit(
        self, data: UnpairedIVData, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """Fit estimator on unpaired data and return coefficient estimates."""
        a = cross_cov_IY(data.I_y, data.Y)  # (m,)
        B = cross_cov_IX(data.I_x, data.X)  # (m,d)
        d = data.d
        A = B.T @ B + self.ridge * np.eye(d)
        b = B.T @ a
        return solve(A, b)
