import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm

import tensorflow as tf

from models import get_model
from settings import MODES, MODEL_NAMES, PARAMETER_NAMES
from helpers import evaluate_contraction, evaluate_r2_score
from helpers import get_amortized_setup

SIMULATION_BUDGET = 1000
NUM_POST_SAMPLES = 1000
MODEL_TABLE_NAMES = [
    r'$\mathcal{M}_{1a}$', r'$\mathcal{M}_{1b}$', r'$\mathcal{M}_{1c}$', r'$\mathcal{M}_{2}$', r'$\mathcal{M}_{3}$',
    r'$\mathcal{M}_{4a}$', r'$\mathcal{M}_{4b}$', r'$\mathcal{M}_{5}$', r'$\mathcal{M}_{6}$', r'$\mathcal{M}_{7}$',
    r'$\mathcal{M}_{8}$', r'$\mathcal{M}_{9}$', r'$\mathcal{M}_{10}$', r'$\mathcal{M}_{11}$', r'$\mathcal{M}_{12}$'
]


def evaluate_model(model_name):
    """Helper function to evaluate performance for a single model."""

    model = get_model(model_name)
    test_sims = pickle.load(open(f'./simulations/test_{model_name}.pkl', 'rb+'))
    samples_per_method = []
    for mode in MODES:

        # Global clear session
        tf.keras.backend.clear_session()

        trainer = get_amortized_setup(model, mode, SIMULATION_BUDGET)
        conf_sims = trainer.configurator(test_sims)

        samples = trainer.amortizer.sample(conf_sims, NUM_POST_SAMPLES)
        samples_per_method.append(samples)

    contractions = evaluate_contraction(samples_per_method, model, PARAMETER_NAMES[model_name])
    r2_scores = evaluate_r2_score(samples_per_method, test_sims['prior_draws'], PARAMETER_NAMES[model_name])
    together = pd.concat((contractions, r2_scores), axis=1)
    together.columns = pd.MultiIndex.from_product((
        ('Posterior Contraction', r'$R^2$-Score'), list(contractions.columns)
    ))

    together.to_csv(f"./tables/csv/{model_name}_{SIMULATION_BUDGET}.csv")
    with open(f'./tables/tex/{model_name}_{SIMULATION_BUDGET}.tex', 'w') as file:
        file.write(together.to_latex(float_format="%.2f"))

    return contractions["Mean"], r2_scores["Mean"]


if __name__ == '__main__':

    # Init global tables
    mean_contractions = pd.DataFrame(
        np.zeros((len(MODES), len(MODEL_NAMES))),
        columns=MODEL_NAMES,
        index=MODES
    )
    mean_r2_scores = pd.DataFrame(
        np.zeros((len(MODES), len(MODEL_NAMES))),
        columns=MODEL_NAMES,
        index=MODES
    )

    # Evaluate for each model
    for model_name in tqdm(MODEL_NAMES):
        mean_contractions[model_name], mean_r2_scores[model_name] = evaluate_model(model_name)

    # Prepare multi-index tables
    mean_contractions.columns = pd.MultiIndex.from_product((
        MODEL_TABLE_NAMES, ['Mean posterior contraction']
    ))
    mean_r2_scores.columns = pd.MultiIndex.from_product((
        MODEL_TABLE_NAMES, [r'Mean $R^2$-score']
    ))

    # Store tables
    mean_contractions.to_csv(f"./tables/csv/contraction_{SIMULATION_BUDGET}.csv")
    mean_r2_scores.to_csv(f"./tables/csv/r2_scores_{SIMULATION_BUDGET}.csv")
    with open(f'./tables/tex/contraction_{SIMULATION_BUDGET}.tex', 'w') as file:
        file.write(mean_contractions.to_latex(float_format="%.2f"))
    with open(f'./tables/tex/r2_scores_{SIMULATION_BUDGET}.tex', 'w') as file:
        file.write(mean_r2_scores.to_latex(float_format="%.2f"))
