from typing import Optional, Literal, overload

import numpy as np
import pandas as pd
from scipy import stats
from xgboost import XGBRegressor

from utils.funcs import logit


@overload
def DGP(
    N: int, /,
    rs: Optional[np.random.RandomState],
    latents: bool,
    return_X_y: Literal['True'],
    **intv
) -> tuple[pd.DataFrame, pd.Series]:
    ...


def DGP(
    N: int, /,
    rs: Optional[np.random.RandomState],
    latents: bool,
    return_X_y: Literal['False'],
    **intv
) -> pd.DataFrame:
    ...


def DGP(N, /, rs=None, latents=False, return_X_y=False, **intv):
    """Create N samples from the DGP, possibly subject to some interventions."""
    if rs is None:
        rs = np.random.RandomState()

    # Parse intv before we begin
    for k, v in list(intv.items()):
        if isinstance(v, (bool, int, float)):
            v = np.full((N,), v)

        if len(v) < N:
            assert not (N % len(v))
            v = np.repeat(v, N // len(v), axis=0)

        assert v.shape == (N,)
        intv[k] = v.copy()

    # Create variables in topological order. Only apply intv if given

    if (z := intv.get('z')) is None:
        z = rs.beta(a=2, b=5, size=N)

    if (u := intv.get('u')) is None:
        u = stats.chi2(df=10).rvs(N, rs)

    if (x := intv.get('x')) is None:
        x = np.abs(rs.normal(z * (u - 5), scale=.1))

    if (a := intv.get('a')) is None:
        a = np.abs(rs.exponential(size=N) + rs.normal(np.sqrt(x), scale=.1))

    if (b := intv.get('b')) is None:
        b = rs.normal(np.sin(a) * 5 - u / 10, scale=1)

    if (c := intv.get('c')) is None:
        c = rs.normal(np.log(b ** 2 + 1), scale=.5)

    if (y := intv.get('y')) is None:
        y = rs.normal(logit(z) + (x / 10) ** 2 + c, scale=.5)

    df = pd.DataFrame(dict(
        **(dict(u=u) if latents else dict()),
        z=z,
        x=x,
        a=a,
        b=b,
        c=c
    ))

    if return_X_y:
        return df, pd.Series(y)
    else:
        df['y'] = y
        return df
    

if __name__ == '__main__':
    V = list('zxabc')

    import matplotlib.pyplot as plt
    from utils.file import make_path

    rs = np.random.RandomState(seed=123)

    # Generate 1000 observational points
    df = DGP(1000, rs=rs, latents=False)
    pd.plotting.scatter_matrix(df[df.columns[df.dtypes == float]], alpha=.5)
    plt.savefig(make_path('../../data/synthetic_5vars_hard.png'), dpi=300)

    print('Correlation matrix:')
    print(df.corr())

    reg = XGBRegressor().fit(
        df.iloc[:-100][V], df.iloc[:-100].y
    )
    print('Predictor score:', 
        reg.score(
            df.iloc[-100:][V], df.iloc[-100:].y
        )
    )

    # Compute conditional and intv effects
    K = len(df.columns) - 1
    ylim = tuple(df.y.quantile([0.025, .975]))
    _, axes = plt.subplots(1, len(V), figsize=(6 * len(V), 4))
    for col, ax in zip(df.drop('y', axis=1).columns, axes):
        qs = df[col].quantile(np.linspace(0, 1, 22)[1:-1])
        marg, intv = [], []
        for q in qs:
            df2 = df.drop('y', axis=1).copy()
            df2[col] = q
            pred = reg.predict(df2)
            marg.append((pred.mean(), pred.std() / np.sqrt(len(pred))))

            pred = DGP(10000, rs=rs, **{col: q}).y
            intv.append((pred.mean(), pred.std() / np.sqrt(len(pred))))

        marg, intv = map(np.array, (marg, intv))
        ax.plot(qs, marg[:, 0], label='marginal', color='C0')
        ax.fill_between(qs, marg[:, 0] - 2 * marg[:, 1], marg[:, 0] + 2 * marg[:, 1], color='C0', alpha=.25)
        ax.plot(qs, intv[:, 0], label='intv', color='C1')
        ax.fill_between(qs, intv[:, 0] - 2 * intv[:, 1], intv[:, 0] + 2 * intv[:, 1], color='C1', alpha=.25)
        ax.set_xlabel(col)
        ax.set_ylabel('y')
        ax.legend()
    plt.show()

    rs = np.random.RandomState(seed=123)
    df = DGP(1000, rs=rs)
    df.to_csv(make_path('../../data/synthetic_5vars_hard_1000.csv'), index=False)

    rs = np.random.RandomState(seed=123)
    df = DGP(1000, rs=rs, latents=True)
    df.to_csv(make_path('../../data/synthetic_5vars_hard_latents_1000.csv'), index=False)