from econml.dml import DML, LinearDML
from panel_dynamic_dml import DynamicPanelDML
from sklearn.linear_model import LinearRegression, LassoCV, Lasso, MultiTaskLasso, MultiTaskLassoCV, LogisticRegressionCV
from sklearn.model_selection import GroupKFold, cross_val_predict
from sklearn.base import clone
from dynamic_panel_dgp import DynamicPanelDGP
import warnings
import numpy as np
from dynamic_panel_dgp import SemiSynthetic
from copy import deepcopy
from econml.sklearn_extensions.linear_model import StatsModelsRLM
warnings.simplefilter('ignore')


def lasso_model(lr):
    if lr:
        return LinearRegression()
    return LassoCV(cv=3, n_alphas=10, max_iter=2000)


def mlasso_model(lr):
    if lr:
        return LinearRegression()
    return MultiTaskLassoCV(cv=3, n_alphas=10, max_iter=2000)


def experiment(n_periods, n_units, dgp, cv, random_seed, semi, lr):
    np.random.seed(random_seed)
    random_state = np.random.RandomState(random_seed)
    # Gen Data
    if semi:
        panelX, panelT, panelY, panelGroups, true_effect = dgp.gen_data(n_units, n_periods,
                                                                        random_state.uniform(0, 1, size=(
                                                                            dgp.n_proxies, dgp.n_treatments)),
                                                                        random_state.randint(0, np.iinfo('uint16').max))
    else:
        gamma = .0
        sigma_t = .5
        s_t = 10
        dgp = synthetic(n_units, n_periods, random_state, lr)
        Y, T, X, groups = dgp.observational_data(
            n_units, gamma, s_t, sigma_t, random_seed=random_state.randint(0, np.iinfo('uint16').max))
        true_effect = dgp.true_effect
        panelX = X.reshape(-1, n_periods, X.shape[1])
        panelT = T.reshape(-1, n_periods, T.shape[1])
        panelY = Y.reshape(-1, n_periods)
        panelGroups = groups.reshape(-1, n_periods)

    n_x = panelX.shape[2]
    n_treatments = dgp.n_treatments

    # Estimate Dynamic Effects
    ests = []
    for t in range(1, n_periods + 1):
        ests.append(DynamicPanelDML(model_t=mlasso_model(lr),
                                    model_y=lasso_model(lr),
                                    n_cfit_splits=cv).fit(panelY[:, :t].reshape(-1,),
                                                          panelT[:, :t,
                                                                 :].reshape(-1, n_treatments),
                                                          panelX[:, :t,
                                                                 :].reshape(-1, n_x),
                                                          panelGroups[:, :t].reshape(-1,)))

    # Create matrix of dynamic effects
    true_effect = true_effect.flatten()
    effect = np.empty_like(np.zeros((n_periods, n_periods, n_treatments)))
    effect[:, :, :] = np.nan
    true_eff = np.empty_like(np.zeros((n_periods, n_periods, n_treatments)))
    true_eff[:, :, :] = np.nan
    for p in range(n_periods):
        param_hat = ests[p].param
        for kappa in range(p + 1):
            for t in range(n_treatments):
                param_ind = kappa*n_treatments + t
                effect[p - kappa][p][t] = param_hat[param_ind]
                true_eff[p - kappa][p][t] = true_effect[param_ind]

    # Truth
    true_long_range_effects = np.nansum(true_eff, axis=1)[0]

    # Benchmark 1
    est_on_true = LinearDML(model_t=mlasso_model(lr),
                            model_y=lasso_model(lr),
                            linear_first_stages=False, cv=cv)
    est_on_true.fit(np.sum(panelY, axis=1),
                    panelT[:, 0], X=None, W=panelX[:, 0])
    est_on_true_effects = est_on_true.intercept__inference()

    # Benchmark 2
    XS = np.hstack([panelX[:, 1], panelY[:, :1]])
    TotalY = np.sum(panelY, axis=1)
    if lr:
        unadjusted_proxy_model = LinearRegression().fit(XS, TotalY)
    else:
        unadjusted_proxy_model = LassoCV().fit(XS, TotalY)
    sindex = unadjusted_proxy_model.predict(XS)
    est = LinearDML(model_t=mlasso_model(lr),
                    model_y=lasso_model(lr),
                    linear_first_stages=False, cv=cv).fit(sindex,
                                                          panelT[:, 0], W=panelX[:, 0])
    est_on_surrogate_effects = est.intercept__inference()

    # Benchmark 3
    panelYadj = panelY.copy()
    for i in range(n_periods):
        for j in range(i):
            panelYadj[:, i] -= panelT[:, i - j, :] @ effect[i - j, i]
    TotalYadj = np.sum(panelYadj, axis=1)
    est_on_true_adj = LinearDML(model_t=mlasso_model(lr),
                                model_y=lasso_model(lr),
                                linear_first_stages=False, cv=cv).fit(
        TotalYadj, panelT[:, 0], W=panelX[:, 0])
    est_on_true_adj_effects = est_on_true_adj.intercept__inference()

    # Proposed Method
    XS = np.hstack([panelX[:, 1], panelYadj[:, :1]])
    if lr:
        proxy_model = LinearRegression().fit(XS, TotalYadj)
    else:
        proxy_model = LassoCV().fit(XS, TotalYadj)
    sindex_adj = proxy_model.predict(XS)
    est = LinearDML(model_t=mlasso_model(lr),
                    model_y=lasso_model(lr),
                    linear_first_stages=False, cv=cv).fit(
        sindex_adj, panelT[:, 0], W=panelX[:, 0])
    est_on_adj_surr_effects = est.intercept__inference()

    # Proposed Method RLM
    est = DML(model_t=mlasso_model(lr),
              model_y=lasso_model(lr),
              linear_first_stages=False,
              model_final=StatsModelsRLM(fit_intercept=False),
              cv=cv)
    est.fit(sindex_adj, panelT[:, 0], W=panelX[:, 0])
    est_on_adj_surr_effects_rlm = est.intercept__inference()

    # Novel Treatment Analysis
    if semi:
        data = dgp.gen_data(n_units, n_periods,
                            random_state.uniform(0, 1, size=(
                                dgp.n_proxies, dgp.n_treatments)),
                            random_state.randint(0, np.iinfo('uint16').max))
        novelpanelX, novelpanelT, novelpanelY, novelpanelGroups, novel_true_effect = data
    else:
        novel_dgp = dgp.create_instance(dgp.s_x, dgp.sigma_x, dgp.sigma_y,
                                        dgp.conf_str,
                                        random_state.uniform(-1,
                                                             1, size=n_treatments),
                                        random_state.uniform(-1, 1,
                                                             size=(dgp.n_x, dgp.n_treatments)),
                                        dgp.hetero_strength, dgp.hetero_inds,
                                        dgp.autoreg, .25,
                                        random_seed=dgp.random_seed)
        novelY, novelT, novelX, novelgroups = novel_dgp.observational_data(
            n_units, gamma, s_t, sigma_t, random_seed=random_state.randint(0, np.iinfo('uint16').max))
        novel_true_effect = novel_dgp.true_effect.flatten()
        novelpanelX = novelX.reshape(-1, n_periods, n_x)
        novelpanelT = novelT.reshape(-1, n_periods, n_treatments)
        novelpanelY = novelY.reshape(-1, n_periods)
        novelpanelGroups = novelgroups.reshape(-1, n_periods)

    n_treatments = dgp.n_treatments
    n_x = novelpanelX.shape[2]
    novel_true_effect = novel_true_effect.flatten()

    novelXS = np.hstack([novelpanelX[:, 1], novelpanelY[:, :1]])
    novelsindex_adj = proxy_model.predict(novelXS)
    novelest = LinearDML(model_t=mlasso_model(lr),
                         model_y=lasso_model(lr),
                         linear_first_stages=False, cv=cv).fit(
        novelsindex_adj, novelpanelT[:, 0], W=novelpanelX[:, 0], cache_values=True)
    novel_effects = novelest.intercept__inference()

    debiased_novel_effects = debiased_inference(panelX, panelT, panelY, panelGroups, TotalYadj,
                                                novelpanelX, novelpanelY[:, :1],
                                                proxy_model, novelest, ests, lr)

    novel_true_eff = np.empty_like(
        np.zeros((n_periods, n_periods, n_treatments)))
    novel_true_eff[:, :, :] = np.nan
    for p in range(n_periods):
        for kappa in range(p + 1):
            for t in range(n_treatments):
                param_ind = kappa*n_treatments + t
                novel_true_eff[p - kappa][p][t] = novel_true_effect[param_ind]

    novel_true_long_range_effects = np.nansum(novel_true_eff, axis=1)[0]

    return ((true_long_range_effects, est_on_true_effects, est_on_surrogate_effects,
             est_on_true_adj_effects, est_on_adj_surr_effects, est_on_adj_surr_effects_rlm),
            (novel_true_long_range_effects, novel_effects, debiased_novel_effects))


def debiased_inference(panelX, panelT, panelY, panelGroups, TotalYadj,
                       novelpanelX, novelpanelY,
                       proxy_model, novelest, ests, lr):
    ne = novelpanelX.shape[0]
    no = panelX.shape[0]
    SS = np.hstack([novelpanelX[:, 0], novelpanelX[:, 1], novelpanelY[:, :1]])
    pop = np.ones(ne)
    SS = np.vstack(
        [SS, np.hstack([panelX[:, 0], panelX[:, 1], panelY[:, :1]])])
    pop = np.concatenate((pop, np.zeros(no)))
    prop_model = LogisticRegressionCV(max_iter=1000).fit(SS, pop)
    # Log odds ratio of which environment
    prop_obs = prop_model.predict_proba(SS[ne:])[:, [1]]
    oods_obs = prop_obs / (1 - prop_obs)
    # Surrogate score
    Yres, Tres, _, _ = novelest.residuals_
    score_model = mlasso_model(lr).fit(novelpanelX[:, 1], Tres)
    surscore = score_model.predict(panelX[:, 1])
    # Observational residuals
    XS = np.hstack([panelX[:, 1], panelY[:, :1]])
    sindex_adj = cross_val_predict(
        clone(proxy_model, safe=False), XS, TotalYadj, cv=3)
    # sindex_adj = proxy_model.predict(XS)
    residual = (TotalYadj - sindex_adj).reshape(-1, 1)
    # Extra co-variance
    J = (Tres.T @ Tres) / Tres.shape[0]
    invJ = np.linalg.pinv(J)

    correction = (no / ne) * invJ @ np.mean(surscore *
                                            oods_obs * residual, axis=0)

    n_treatments = Tres.shape[1]
    n_periods = panelX.shape[1]
    moment = (surscore * oods_obs * residual)
    for K in range(1, n_periods):
        for tau in range(K):
            M = K - tau  # the period of the treatment T
            moment = np.hstack(
                [moment, ests[K].resT[M][M] * ests[K].res_epsilon[M].reshape(-1, 1)])
    Sigma = ((no / ne)**2) * (moment.T @ moment) / no

    Jac = np.zeros(Sigma.shape)
    Jac[:n_treatments, :n_treatments] = J
    start = n_treatments
    for K in range(1, n_periods):
        end = start + K * n_treatments
        flatT = panelT[:, 1:(K+1)][:, ::-1].reshape((-1, n_treatments*K))
        Jac[:n_treatments, start:end] = (
            no / ne) * ((surscore * oods_obs).T @ flatT) / no
        Jac[start:end, start:end] = (
            no / ne) * ests[K]._M[:-n_treatments, :-n_treatments]
        start = end

    invJac = np.linalg.pinv(Jac)
    Vobsfull = invJac @ Sigma @ invJac.T / no
    var_correction = np.sqrt(np.diag(Vobsfull))[:n_treatments]

    # Var = (no**2 / (ne**2)) * (surscore.T @ (surscore * (oods_obs**2) * (residual**2))) / surscore.shape[0]

    # Vobs =  invJ @ Var @ invJ / no
    # var_correction = np.sqrt(np.diag(Vobs))

    # offdiagSigma = (no / ne) * ((surscore * oods_obs * residual).T @
    #                         (ests[1].resT[1][1] * ests[1].res_epsilon[1].reshape(-1, 1)) / no)
    # offdiagM = (no / ne) * ((surscore * oods_obs).T @ panelT[:, 1]) / no

    # Sigma = ests[1]._Sigma.copy()
    # M = ests[1]._M.copy()
    # n_treatments = Tres.shape[1]
    # Sigma[-n_treatments:, -n_treatments:] = Var
    # Sigma[:n_treatments, -n_treatments:] = offdiagSigma.T
    # Sigma[-n_treatments:, :n_treatments] = offdiagSigma
    # M[-n_treatments:, -n_treatments:] = J
    # M[-n_treatments:, :n_treatments] = offdiagM
    # invM = np.linalg.pinv(M)
    # Vobsfull = invM @ Sigma @ invM.T / no
    # var_correction = np.sqrt(np.diag(Vobsfull))[-n_treatments:]

    # Sigma = ests[1]._Sigma.copy()
    # M = ests[1]._M.copy()
    # n_treatments = Tres.shape[1]
    # Sigma[-n_treatments:, -n_treatments:] = Var
    # M[-n_treatments:, -n_treatments:] = J
    # invM = np.linalg.pinv(M)
    # Vobsfull = invM @ Sigma @ invM.T / no
    # var_correction = np.sqrt(np.diag(Vobsfull))[-n_treatments:]

    inf = deepcopy(novelest.intercept__inference())
    inf.pred_stderr = np.sqrt(
        inf.pred_stderr ** 2 + var_correction**2)  # += var_correction
    inf.pred += correction
    return inf


def synthetic(n_units, n_periods, random_state, lr):
    n_treatments = 2
    if lr:
        n_x = 20
    else:
        n_x = 100
    s_x = 10
    sigma_x = .5
    sigma_y = .5
    autoreg = .1
    state_effect = .1
    conf_str = 6
    hetero_strength = 0
    hetero_inds = None

    # dgp_class = LongRangeDynamicPanelDGP
    dgp_class = DynamicPanelDGP
    dgp = dgp_class(n_periods, n_treatments, n_x).create_instance(s_x, sigma_x, sigma_y,
                                                                  conf_str,
                                                                  random_state.uniform(
                                                                      -1, 1, size=n_treatments),
                                                                  random_state.uniform(
                                                                      -1, 1, size=(n_x, n_treatments)),
                                                                  hetero_strength, hetero_inds,
                                                                  autoreg, state_effect,
                                                                  random_seed=random_state.randint(0, np.iinfo('uint16').max))
    return dgp
