import functools

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from models import AutoEncoder, GenNet, AdditiveMLP, MLP, CFEstimator, GenNetY
from helpers.utils import to_torch, MedianHeuristicMMR, gen_params, compute_V, get_trainer, choose_lambda

import seaborn as sns
import matplotlib.pyplot as plt


def gen_data(A_dim, Z_dim, params, x_net, h_net, l_net, sample_size, A_range, A_star=None):
    if A_star is None:
        A = np.random.uniform(-A_range, A_range, size=(sample_size, A_dim))
    else:
        A = np.repeat([A_star], sample_size, axis=0)

    V = np.random.multivariate_normal(mean=np.zeros(shape=(Z_dim,)), cov=params['cov_ez'],
                                      size=(sample_size,))
    U = h_net(to_torch(V)).detach().numpy() + 2 * np.random.normal(size=(V.shape[0], 1))
    Z = A @ params['M'] * 4 + 2 * V
    eYZ = l_net(to_torch(Z)).detach().numpy().flatten()
    Y = eYZ + U.flatten()

    X = x_net(to_torch(Z)).detach().numpy()

    return X, A, Z, Y


def get_fns(Z_dim, X_dim, seed=1):
    torch.manual_seed(seed)

    params = gen_params(Z_dim, Z_dim, seed)

    x_net = GenNet(Z_dim, X_dim)
    x_net.init_weights()

    h_net = GenNetY(Z_dim)
    h_net.init_weights()

    l_net = GenNetY(Z_dim)
    l_net.init_weights()

    return params, x_net, h_net, l_net


if __name__ == '__main__':
    X_dim = 10

    n_iter_s1 = 500
    n_iter_s2 = 1000
    n_repeat = 10
    sample_size = 10000
    np.random.seed(1)
    Z_dim_list = [2, 4, 10]

    reg_dict = {}
    for Z_dim in Z_dim_list:
        params, x_net, h_net, l_net = get_fns(Z_dim, X_dim, seed=1)

        X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, h_net, l_net, sample_size, A_range=1)

        model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)
        model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=None)

        chosen_lmd, lmd_df = choose_lambda(model_MMR, model_baseline, n_iter_s1, cut_off=0.2, accelerator='mps')
        reg_dict[Z_dim] = chosen_lmd


    def compute_predictions(iter_i):
        final_df = pd.DataFrame()
        for Z_dim in Z_dim_list:
            reg_param = reg_dict[Z_dim]

            torch.manual_seed(1)

            params, x_net, h_net, l_net = get_fns(Z_dim, X_dim, seed=1)

            A_test = np.random.uniform(-3, -1, size=(100, Z_dim))
            A_range_train = 1

            X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, h_net, l_net, sample_size, A_range=A_range_train)

            model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)
            model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=reg_param)
            se_callback = MedianHeuristicMMR()

            trainer_baseline = get_trainer(max_epochs=n_iter_s1)
            trainer_MMR = get_trainer(max_epochs=n_iter_s1, callback=se_callback)

            trainer_baseline.fit(model_baseline)
            model_MMR.load_state_dict(model_baseline.state_dict())
            trainer_MMR.fit(model_MMR)

            predZ = model_MMR.encode(to_torch(X)).detach().numpy()
            predZ = (predZ - predZ.mean(axis=0)) / predZ.std(axis=0)

            V, lr_PredZ_A = compute_V(predZ, A)
            V_oracle, lr_Z_A = compute_V(Z, A)

            additive_NN = AdditiveMLP(Y, predZ, V, lr=5e-3)
            additive_oracle = AdditiveMLP(Y, Z, V_oracle, lr=5e-3)

            trainer_additive = get_trainer(max_epochs=n_iter_s2, accelerator='cpu')
            trainer_additive_oracle = get_trainer(max_epochs=n_iter_s2, accelerator='cpu')

            trainer_additive.fit(additive_NN)
            trainer_additive_oracle.fit(additive_oracle)

            mlp = MLP(Y, A, lr=5e-3)

            trainer_MLP = get_trainer(max_epochs=n_iter_s2, accelerator='cpu')
            trainer_MLP.fit(mlp)

            cfEstimator = CFEstimator(additive_NN, lr_PredZ_A)
            cfEstimator_oracle = CFEstimator(additive_oracle, lr_Z_A)
            cfEstimator.fit_bias(predZ, Y)
            cfEstimator_oracle.fit_bias(Z, Y)

            ret_df = pd.DataFrame()

            pred_Y_CF = []
            pred_Y_CF_oracle = []
            pred_Y_MLP = []
            for a_star in A_test:
                A_star = [a_star]

                pred_Y_CF += [cfEstimator.predict(A_star, V)]
                pred_Y_CF_oracle += [cfEstimator_oracle.predict(A_star, V_oracle)]
                pred_Y_MLP += [mlp(to_torch(np.array(A_star).reshape(1, -1))).item()]

            ret_df['idx'] = np.arange(A_test.shape[0])
            ret_df['CF'] = pred_Y_CF
            ret_df['CF_oracle'] = pred_Y_CF_oracle
            ret_df['MLP'] = pred_Y_MLP

            true_Y = []
            for a_star in A_test:
                X_new, A_new, Z_new, Y_new = gen_data(Z_dim, Z_dim, params, x_net, h_net, l_net, 50000,
                                                      A_range=1,
                                                      A_star=a_star)
                true_Y += [Y_new.mean()]

            true_Y_df = pd.DataFrame()
            true_Y_df['idx'] = np.arange(A_test.shape[0])
            true_Y_df['Y'] = true_Y

            inner_df = pd.merge(ret_df, true_Y_df, how='inner', on='idx')
            inner_df['Z_dim'] = Z_dim
            inner_df['iter_i'] = iter_i

            final_df = pd.concat([final_df, inner_df], ignore_index=True)

        return final_df


    prediction_df = Parallel(n_jobs=n_repeat)(
        delayed(compute_predictions)(iter_i=i) for i in range(n_repeat)
    )
    prediction_df = functools.reduce(lambda df1, df2: pd.concat([df1, df2], ignore_index=True), prediction_df)

    prediction_df.to_csv("prediction_df_multi_dim.csv", index=False)

    df_all = pd.DataFrame()
    for Z_dim in Z_dim_list:
        df = prediction_df[prediction_df.Z_dim == Z_dim]
        for method in ['CF', 'CF_oracle', 'MLP']:
            MSE = df.groupby("iter_i", group_keys=False).apply(lambda x: np.mean((x[method] - x['Y']) ** 2))
            inner_df = pd.DataFrame(MSE, columns=['MSE'])
            inner_df['method'] = method
            inner_df['Z_dim'] = Z_dim
            df_all = pd.concat([inner_df, df_all], ignore_index=True)

    df_all = df_all.replace({'CF': 'Rep4Ex-CF', 'CF_oracle': 'Rep4Ex-CF-Oracle'})

    g = sns.catplot(
        data=df_all, x='method', y='MSE',
        col='Z_dim', kind='box', sharey=False, hue='method',
        estimator=False, height=2.5, aspect=1.4, legend='full',
        hue_order=['MLP', 'Rep4Ex-CF', 'Rep4Ex-CF-Oracle'], showfliers=False,
        order=['MLP', 'Rep4Ex-CF', 'Rep4Ex-CF-Oracle'], dodge=False
    )

    g.set(yscale="log")
    g.set_xlabels("")
    g.set_ylabels("Mean squared error")
    g.set_xticklabels("")
    g.set_titles('The dimension of A = {col_name}')
    plt.legend(ncols=3, loc='lower center', bbox_to_anchor=(-0.9, -0.32), fontsize=12)

    plt.savefig('simu_extrapolation_multi.pdf', bbox_inches='tight')
