import functools

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from models import AutoEncoder, GenNet, VAE, EncoderOracle, DecoderOracle
from helpers.utils import to_torch, MedianHeuristicMMR, gen_params, compute_r2, get_trainer, choose_lambda

import seaborn as sns
import matplotlib.pyplot as plt


def gen_data(A_dim, Z_dim, params, net, sample_size, A_str):
    A = np.random.uniform(-1, 1, size=(sample_size, A_dim))
    V = np.random.multivariate_normal(mean=np.zeros(shape=(Z_dim,)), cov=params['cov_ez'],
                                      size=(sample_size,))
    Z = A @ params['M'] * A_str + V

    X = net(to_torch(Z)).detach().numpy()
    X = X / np.std(X, axis=0, keepdims=True)

    return X, A, Z


def get_reg_params(Z_dim, A_str, gen_net, params, sample_size):
    ret_dict = {}
    for a_s in A_str:
        X, A, Z = gen_data(Z_dim, Z_dim, params, gen_net, sample_size, a_s)
        model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0, accelerator='cpu')
        model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=None, accelerator='cpu')
        chosen_lmd, lmd_df = choose_lambda(model_MMR, model_baseline, n_iter, cut_off=0.2)
        ret_dict[a_s] = chosen_lmd

    return ret_dict


if __name__ == '__main__':

    X_dim = 10

    n_repeat = 20
    n_iter = 1000
    sample_size = 1000

    A_str = [0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4]
    Z_dims = [2, 4]
    np.random.seed(1)

    res_df = pd.DataFrame()
    for Z_dim in Z_dims:
        params = gen_params(Z_dim, Z_dim)

        torch.manual_seed(1)
        gen_net = GenNet(Z_dim, X_dim)
        gen_net.init_weights()

        reg_params = get_reg_params(Z_dim, A_str, gen_net, params, sample_size)

        for a_s in A_str:
            def compute_R2():
                X, A, Z = gen_data(Z_dim, Z_dim, params, gen_net, sample_size, a_s)

                model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0, accelerator='cpu')
                model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=reg_params[a_s], accelerator='cpu')
                model_vae = VAE(X, Z_dim, lr=5e-3)
                se_callback = MedianHeuristicMMR()

                trainer_MMR = get_trainer(max_epochs=n_iter, callback=se_callback, accelerator='cpu')
                trainer_baseline = get_trainer(max_epochs=n_iter, accelerator='cpu')
                trainer_VAE = get_trainer(max_epochs=n_iter, accelerator='cpu')

                trainer_baseline.fit(model_baseline)
                model_MMR.load_state_dict(model_baseline.state_dict())
                trainer_MMR.fit(model_MMR)
                trainer_VAE.fit(model_vae)

                # fit oracle baselines
                trainer_encoder = get_trainer(max_epochs=n_iter, accelerator='cpu')
                trainer_decoder = get_trainer(max_epochs=n_iter, accelerator='cpu')
                trainer_oracle = get_trainer(max_epochs=n_iter, callback=se_callback, accelerator='cpu')

                oracle_encoder = EncoderOracle(X, Z)

                trainer_encoder.fit(oracle_encoder)
                Z_pred = oracle_encoder.encode(to_torch(X)).detach().numpy()
                oracle_decoder = DecoderOracle(X, Z_pred)
                trainer_decoder.fit(oracle_decoder)
                # initialize oracle with the above models
                model_oracle = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=reg_params[a_s], accelerator='cpu')
                model_oracle.encoder.load_state_dict(oracle_encoder.encoder.state_dict())
                model_oracle.decoder.load_state_dict(oracle_decoder.decoder.state_dict())

                trainer_oracle.fit(model_oracle)

                X_new, A_new, Z_new = gen_data(Z_dim, Z_dim, params, gen_net, 10000, a_s)

                predZ_MMR = model_MMR.encode(to_torch(X_new)).detach().numpy()
                predZ_baseline = model_baseline.encode(to_torch(X_new)).detach().numpy()
                predZ_oracle = model_oracle.encode(to_torch(X_new)).detach().numpy()
                predZ_vae = model_vae.encode(to_torch(X_new), return_var=False).detach().numpy()

                r2_MMR = compute_r2(predZ_MMR, Z_new)
                r2_baseline = compute_r2(predZ_baseline, Z_new)
                r2_oracle = compute_r2(predZ_oracle, Z_new)
                r2_vae = compute_r2(predZ_vae, Z_new)

                ret_df = pd.DataFrame()

                ret_df['Method'] = np.array(['AE-Vanilla', 'VAE', 'AE-MMR', 'AE-MMR-Oracle'])
                ret_df['R2'] = np.array([r2_baseline, r2_vae, r2_MMR, r2_oracle])
                ret_df['A_Str'] = np.repeat(a_s, 4)
                ret_df['Z_dim'] = np.repeat(Z_dim, 4)
                ret_df['lmd'] = np.repeat(reg_params[a_s], 4)

                return ret_df


            inner_df = Parallel(n_jobs=-1)(
                delayed(compute_R2)() for _ in range(n_repeat)
            )

            inner_df = functools.reduce(lambda df1, df2: pd.concat([df1, df2], ignore_index=True), inner_df)

            res_df = pd.concat([res_df, inner_df], ignore_index=True)

res_df.to_csv('res_df.csv', index=False)

g = sns.relplot(
    data=res_df, x="A_Str", y="R2",
    col="Z_dim", hue="Method", style="Method",
    kind="line", markers=True, hue_order=['AE-MMR', 'AE-MMR-Oracle', 'AE-Vanilla', 'VAE'],
    style_order=['AE-MMR', 'AE-MMR-Oracle', 'AE-Vanilla', 'VAE'], aspect=1.4, height=2.75, ci=95,
    facet_kws={'sharey': False, 'sharex': True}
)
for ax in g.axes.flat:
    ax.set_xticks(ticks=A_str)

res_df.groupby(['Method', 'Z_dim', 'A_Str'])['lmd'].mean()

g.set_xlabels(r"Intervention Strength ($\alpha$)")
g.set_ylabels(r"R-Squared")
g.set_titles(col_template=r"The dimension of $Z$ = {col_name}")

plt.savefig('simu_R2.pdf', bbox_inches='tight')
