import functools

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from models import AutoEncoder, GenNet, AdditiveMLP, MLP, CFEstimator
from helpers.utils import to_torch, MedianHeuristicMMR, gen_params, compute_V, get_trainer, choose_lambda

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.lines as mlines


def gen_data(A_dim, Z_dim, params, x_net, sample_size, A_range, A_star=None):
    if A_star is None:
        A = np.random.uniform(-A_range, A_range, size=(sample_size, A_dim))
    else:
        A = np.repeat([A_star], sample_size, axis=0)

    V = np.random.multivariate_normal(mean=np.zeros(shape=(Z_dim,)), cov=params['cov_ez'],
                                      size=(sample_size,))
    U = 0.2 * V ** 3 + np.random.normal(size=V.shape)
    Z = A @ params['M'] * 4 + V
    Y = -2 * Z.flatten() + 10 * np.sin(Z).flatten() + U.flatten()

    X = x_net(to_torch(Z)).detach().numpy()

    return X, A, Z, Y


if __name__ == '__main__':
    X_dim = 2
    Z_dim = 1

    n_iter = 200
    n_repeat = 10
    sample_size = 10000

    torch.manual_seed(1)
    np.random.seed(1)
    params = gen_params(Z_dim, Z_dim)
    x_net = GenNet(Z_dim, X_dim)
    x_net.init_weights()

    A_range_test = 3.5
    A_lins = np.linspace(-A_range_test, A_range_test, 100)
    A_range_train = [0.2, 0.7, 1.2]

    reg_dict = {}
    for A_range in A_range_train:
        X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, sample_size, A_range=A_range)

        model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)
        model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=None)

        chosen_lmd, lmd_df = choose_lambda(model_MMR, model_baseline, n_iter, cut_off=0.2, accelerator='mps')
        reg_dict[A_range] = chosen_lmd


    def compute_predictions(A_range):

        X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, sample_size, A_range=A_range)

        model_baseline = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=0)
        model_MMR = AutoEncoder(X, A, Z_dim, lr=5e-3, lmd=reg_dict[A_range])
        se_callback = MedianHeuristicMMR()

        trainer_baseline = get_trainer(max_epochs=n_iter)
        trainer_MMR = get_trainer(se_callback, max_epochs=n_iter)

        trainer_baseline.fit(model_baseline)
        model_MMR.load_state_dict(model_baseline.state_dict())
        trainer_MMR.fit(model_MMR)

        predZ = model_MMR.encode(to_torch(X)).detach().numpy()
        predZ = (predZ - predZ.mean()) / predZ.std()

        V, lr_PredZ_A = compute_V(predZ, A)

        additive_NN = AdditiveMLP(Y, predZ, V, lr=5e-3)
        mlp = MLP(Y, A)

        trainer_additive = get_trainer(max_epochs=30)
        trainer_MLP = get_trainer(max_epochs=30)

        trainer_additive.fit(additive_NN)
        trainer_MLP.fit(mlp)

        cfEstimator = CFEstimator(additive_NN, lr_PredZ_A)
        cfEstimator.fit_bias(predZ, Y)

        ret_df = pd.DataFrame()

        pred_Y_CF = []
        pred_Y_MLP = []
        for a_star in A_lins:
            A_star = [a_star]

            pred_Y_CF += [cfEstimator.predict(A_star, V)]
            pred_Y_MLP += [mlp(to_torch(np.array(A_star).reshape(1, -1))).item()]

        ret_df['A'] = A_lins
        ret_df['Rep4Ex-CF'] = pred_Y_CF
        ret_df['MLP'] = pred_Y_MLP

        ret_df = pd.melt(ret_df, id_vars=['A'], value_vars=['Rep4Ex-CF', 'MLP'], value_name='Y', var_name='Model')
        ret_df['A_range'] = A_range

        return ret_df


    true_Y = []
    for a_star in A_lins:
        X_new, A_new, Z_new, Y_new = gen_data(Z_dim, Z_dim, params, x_net, 50000, A_range=A_range_test,
                                              A_star=[a_star])
        true_Y += [Y_new.mean()]

    true_Y_df = pd.DataFrame()
    true_Y_df['A'] = A_lins
    true_Y_df['Model'] = 'True conditional mean'
    true_Y_df['Y'] = true_Y

    final_df = pd.DataFrame()
    for A_range in A_range_train:
        prediction_df = Parallel(n_jobs=-1)(
            delayed(compute_predictions)(A_range=A_range) for _ in range(n_repeat)
        )
        prediction_df = functools.reduce(lambda df1, df2: pd.concat([df1, df2], ignore_index=True), prediction_df)
        true_Y_df['A_range'] = A_range

        final_df = pd.concat([final_df, prediction_df, true_Y_df], ignore_index=True)

    final_df.to_csv("simu_extrapolation_result.csv", index=False)
    final_df = pd.read_csv("simu_extrapolation_result.csv")

    cols = sns.color_palette('tab10')
    yellow = sns.color_palette("hls", 8)[1]
    g = sns.relplot(final_df, x='A', y='Y',
                    hue='Model', col='A_range',
                    palette=[cols[0], cols[2], cols[3]],
                    legend='brief', kind='line', height=3.5, aspect=0.9)

    sns.move_legend(g, "upper center", bbox_to_anchor=(0.32, 0.97), ncol=3,
                    frameon=True, title=None, prop={'size': 14})
    plt.tight_layout()
    g.figure.subplots_adjust(top=0.75)

    training_size = [500, 1000, 3000]
    for col_idx in range(g.axes.shape[1]):
        X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, training_size[col_idx], A_range=A_range_train[col_idx])
        _, A_test, _, Y_test = gen_data(Z_dim, Z_dim, params, x_net, 10000, A_range=A_range_test)
        sns.scatterplot(x=A_test.flatten(), s=3, y=Y_test.flatten(), alpha=.6, linewidth=0,
                        color='lightgrey', ax=g.axes[0, col_idx], legend=False, zorder=0)
        sns.scatterplot(x=A.flatten(), s=3, y=Y.flatten(), alpha=.8, linewidth=0,
                        color=yellow, ax=g.axes[0, col_idx], legend=False, zorder=0)

    yellow_point = mlines.Line2D([], [], color=yellow, marker='o', linestyle='None',
                                 markersize=6, alpha=1, label='Training data')
    grey_point = mlines.Line2D([], [], color='lightgrey', marker='o', linestyle='None',
                               markersize=6, alpha=.75, label='Test data')
    first_legend = plt.legend(handles=[yellow_point, grey_point],
                              bbox_to_anchor=(0.3, 1.375), loc='upper center', ncol=2, prop={'size': 14})
    plt.gca().add_artist(first_legend)

    g.set_titles(r"A$\sim$Unif$(-{col_name}, {col_name})$")
    g.set(ylim=(-16, 16))
    plt.savefig('simu_extrapolation.pdf', bbox_inches='tight')

    cols = sns.color_palette('tab10')
    yellow = sns.color_palette("hls", 8)[1]
    g = sns.relplot(final_df[final_df.A_range == A_range_train[-1]], x='A', y='Y',
                    hue='Model', palette=[cols[0], cols[2], cols[3]],
                    legend='brief', kind='line', height=5, aspect=0.95)

    sns.move_legend(g, "upper center", bbox_to_anchor=(0.42, 0.95), ncol=3,
                    frameon=True, title='', prop={'size': 11})
    g.figure.subplots_adjust(top=0.8, bottom=0.25)

    X, A, Z, Y = gen_data(Z_dim, Z_dim, params, x_net, 3000, A_range=A_range_train[-1])
    _, A_test, _, Y_test = gen_data(Z_dim, Z_dim, params, x_net, 10000, A_range=A_range_test)
    sns.scatterplot(x=A_test.flatten(), y=Y_test.flatten(), alpha=.6, linewidth=0,
                    color='lightgrey', ax=g.axes[0, 0], legend=False, zorder=0, s=6)
    sns.scatterplot(x=A.flatten(), y=Y.flatten(), alpha=.8, linewidth=0, edgecolor='yellow',
                    color=yellow, ax=g.axes[0, 0], legend=False, zorder=0, s=6)

    yellow_point = mlines.Line2D([], [], color=yellow, marker='o', linestyle='None',
                                 markersize=6, alpha=1, label='Training data')
    grey_point = mlines.Line2D([], [], color='lightgrey', marker='o', linestyle='None',
                               markersize=6, alpha=.75, label='Test data')
    first_legend = plt.legend(handles=[yellow_point, grey_point],
                              bbox_to_anchor=(0.5, -0.20), loc='upper center', ncol=2, prop={'size': 11})
    plt.gca().add_artist(first_legend)

    plt.title(None)
    plt.ylabel(r"Y")
    plt.xlabel(r"Intervention on $A$")
    g.set(ylim=(-16, 16))

    plt.savefig('intro_extrapolation.pdf')
