from functools import partial

import numpy as np
import pandas as pd
from sklearn.metrics import r2_score

from bayesflow.networks import InvertibleNetwork, SetTransformer
from bayesflow.amortizers import AmortizedPosterior
from bayesflow.trainers import Trainer

from configuration import configurator
from settings import EMBEDDING_SETTINGS, MODES
import sys
sys.path.append("../architectures/")
from hybrid import HybridSummaryNetwork, AmortizedHybrid


def get_amortized_setup(model, mode, budget, num_total_summaries=30):
    """Creates BayesFlow setup and returns the trainer instance."""

    # Determine number of parameters
    num_params = model.prior(1)['prior_draws'].shape[1]

    # Determine summary dim
    if mode == 'direct_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = None

    elif mode == 'mmd_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = 'MMD'

    elif mode == 'generative_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = None

    elif mode == 'learner':
        summary_dim = num_total_summaries
        summary_loss = None

    else:
        summary_dim = None
        summary_loss = None

    # Summary network setup
    if mode in ['direct_hybrid', 'mmd_hybrid', 'generative_hybrid', 'learner']:
        summary_net = SetTransformer(input_dim=3, summary_dim=summary_dim, use_layer_norm=False)
    else:
        summary_net = None
    # If mmd hybrid, use embedder
    if mode == 'mmd_hybrid':
        summary_net = HybridSummaryNetwork(
            num_expert_summaries=num_total_summaries//2, summary_net=summary_net, **EMBEDDING_SETTINGS
        )

    # Inference network setup
    inference_net = InvertibleNetwork(num_params=num_params)
    if mode == 'generative_hybrid':
        learner_net = InvertibleNetwork(num_params=summary_dim)
        amortizer = AmortizedHybrid(inference_net, summary_net, learner_net)
    else:
        amortizer = AmortizedPosterior(inference_net, summary_net, summary_loss_fun=summary_loss)

    # Trainer
    trainer = Trainer(
        configurator=partial(configurator, mode=mode),
        amortizer=amortizer,
        generative_model=model,
        checkpoint_path=f'checkpoints/estimation/{model.name}_{mode}_{budget}',
        max_to_keep=1
    )
    return trainer


def evaluate_z_score(samples_per_method, ground_truths, parameter_names):
    num_params = len(parameter_names)
    z_scores_per_par = np.zeros((len(samples_per_method), num_params))
    for i, samples in enumerate(samples_per_method):
        post_mean = samples.mean(1)
        post_std = samples.std(1)
        z_score = np.abs((ground_truths - post_mean) / post_std)
        z_scores_per_par[i] = np.median(z_score, axis=0)
    z_scores_df = pd.DataFrame(
        np.round(z_scores_per_par, 3),
        columns=parameter_names,
        index=MODES
    )
    z_scores_df['Mean'] = np.round(z_scores_df.mean(axis=1), 2)
    return z_scores_df


def evaluate_contraction(samples_per_method, model, parameter_names):
    num_params = len(parameter_names)
    contractions_per_par = np.zeros((len(samples_per_method), num_params))
    prior_var = np.var(model.prior(50000)['prior_draws'], axis=0)
    for i, samples in enumerate(samples_per_method):
        post_var = samples.var(1)
        cont = 1 - (post_var / prior_var[np.newaxis, :])
        contractions_per_par[i] = np.median(cont, axis=0)
    contraction_df = pd.DataFrame(
        np.round(contractions_per_par, 3),
        columns=parameter_names,
        index=MODES
    )
    contraction_df['Mean'] = np.round(contraction_df.mean(axis=1), 2)
    return contraction_df


def evaluate_r2_score(samples_per_method, ground_truths, parameter_names):
    num_params = len(parameter_names)
    r2_score_per_par = np.zeros((len(samples_per_method), num_params))
    for i, samples in enumerate(samples_per_method):
        post_medians = np.median(samples, axis=1)
        for p in range(num_params):
            r2_score_per_par[i, p] = r2_score(ground_truths[:, p], post_medians[:, p])
    r2_score_df = pd.DataFrame(
        np.round(r2_score_per_par, 3),
        columns=parameter_names,
        index=MODES
    )
    r2_score_df['Mean'] = np.round(r2_score_df.mean(axis=1), 2)
    return r2_score_df
