import functools

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from models import AutoEncoder, GenNet
from helpers.utils import to_torch, MedianHeuristicMMR, gen_params, compute_r2, get_trainer, compute_mse

import seaborn as sns
import matplotlib.pyplot as plt


def gen_data(A_dim, Z_dim, params, net, sample_size, A_str):
    A = np.random.uniform(-1, 1, size=(sample_size, A_dim))
    V = np.random.multivariate_normal(mean=np.zeros(shape=(Z_dim,)), cov=params['cov_ez'],
                                      size=(sample_size,))
    Z = A @ params['M'] * A_str + V

    X = net(to_torch(Z)).detach().numpy()
    X = X / np.std(X, axis=0, keepdims=True)

    return X, A, Z


if __name__ == '__main__':

    X_dim = 10

    n_repeat = 20
    n_iter = 1000
    sample_size = 1000

    reg_params = [1e-1, 1, 1e+1, 1e+2, 1e+3, 1e+4, 1e+5]

    res_df = pd.DataFrame()
    for Z_dim in [4]:
        params = gen_params(Z_dim, Z_dim, seed=1)

        torch.manual_seed(1)
        gen_net = GenNet(Z_dim, X_dim)
        gen_net.init_weights()

        for reg in reg_params:
            def compute_R2():
                X, A, _ = gen_data(Z_dim, Z_dim, params, gen_net, sample_size, 4)

                model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)
                model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=reg)
                se_callback = MedianHeuristicMMR()

                trainer_MMR = get_trainer(max_epochs=n_iter, callback=se_callback)
                trainer_baseline1 = get_trainer(max_epochs=n_iter)
                trainer_baseline2 = get_trainer(max_epochs=n_iter)

                trainer_baseline1.fit(model_baseline)
                model_MMR.load_state_dict(model_baseline.state_dict())
                trainer_MMR.fit(model_MMR)
                trainer_baseline2.fit(model_baseline)

                X_new, A_new, Z_new = gen_data(Z_dim, Z_dim, params, gen_net, 10000, 4)

                predZ_MMR = model_MMR.encode(to_torch(X_new)).detach().numpy()
                predZ_baseline = model_baseline.encode(to_torch(X_new)).detach().numpy()

                r2_MMR = compute_r2(predZ_MMR, Z_new)
                r2_baseline = compute_r2(predZ_baseline, Z_new)

                mse_MMR = compute_mse(model_MMR, X)
                mse_baseline = compute_mse(model_baseline, X)

                ret_df = pd.DataFrame()

                ret_df['Method'] = np.array(['AE-Vanilla', 'AE-MMR'])
                ret_df['R2'] = np.array([r2_baseline, r2_MMR])
                ret_df['MSE'] = np.array([mse_baseline, mse_MMR])
                ret_df['Reg'] = np.repeat(reg, 2)
                ret_df['Z_dim'] = np.repeat(Z_dim, 2)

                return ret_df


            inner_df = Parallel(n_jobs=-1)(
                delayed(compute_R2)() for _ in range(n_repeat)
            )

            inner_df = functools.reduce(lambda df1, df2: pd.concat([df1, df2], ignore_index=True), inner_df)

            res_df = pd.concat([res_df, inner_df], ignore_index=True)

res_df.to_csv('res_df_reg.csv', index=False)

res_df.loc[res_df.Method == 'AE-Vanilla', 'MSE'] = res_df.loc[res_df.Method == 'AE-Vanilla', 'MSE'].mean()

g1 = sns.lineplot(
    data=res_df[res_df.Method == 'AE-MMR'], x="Reg", y="R2",
    hue="Method", style="Method",
    markers=['D'], errorbar=('ci', 95), palette=['black']
)
g1.legend_.set_title('R-Squared')
g1.set_xlabel(r'Regularization parameter ($\lambda$)')
g1.set_ylabel('R-Squared')
ax2 = plt.twinx()
g2 = sns.lineplot(
    data=res_df, x="Reg", y="MSE",
    hue="Method", style="Method",
    markers=True, hue_order=['AE-MMR', 'AE-Vanilla'],
    style_order=['AE-MMR', 'AE-Vanilla'], errorbar=('ci', 95), ax=ax2
)
g2.legend_.set_title('MSE')
g2.set_xlabel(r'Regularization parameter ($\lambda$)')
plt.xscale('log')

plt.savefig('simu_regularization.pdf', bbox_inches='tight')
