import sys
sys.path.append('../')
from functools import partial

from bayesflow.networks import InvertibleNetwork, TimeSeriesTransformer, SequenceNetwork
from bayesflow.amortizers import AmortizedPosterior
from bayesflow.trainers import Trainer


from settings import EMBEDDING_SETTINGS
from configuration import configurator
from architectures.hybrid import HybridSummaryNetwork, AmortizedHybrid


def get_amortized_setup(benchmark, mode, summary_type, num_total_summaries=18, budget=5000, ai_expert=False):
    """Helper function to create a trainer with corresponding settings."""

    # Determine summary dim
    if mode == 'direct_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = None

    elif mode == 'mmd_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = 'MMD'

    elif mode == 'generative_hybrid':
        summary_dim = num_total_summaries // 2
        summary_loss = None

    elif mode == 'learner':
        summary_dim = num_total_summaries
        summary_loss = None

    elif mode == 'expert':
        summary_dim = None
        summary_loss = None

    else:
        raise NotImplementedError(f'No mode {mode} known!')

    # Determine summary network
    if mode == 'expert':
        summary_net = None
    else:
        # Hardcoded, so num learned == num expert
        if summary_type == 'sequence':
            summary_net = SequenceNetwork(summary_dim=summary_dim)
        elif summary_type == 'transformer':
            # 2 time-series, 1 positional (temporal) encoding, 2 + 1 = 3
            summary_net = TimeSeriesTransformer(input_dim=3, summary_dim=summary_dim)
        else:
            raise NotImplementedError("No such summary network known...")

    if mode == 'mmd_hybrid':
        summary_net = HybridSummaryNetwork(
            num_expert_summaries=num_total_summaries//2, summary_net=summary_net, **EMBEDDING_SETTINGS
        )

    # Inference network setup
    inference_net = InvertibleNetwork(
        num_params=4,
        coupling_design='spline'
    )
    if mode == 'generative_hybrid':
        learner_net = InvertibleNetwork(num_params=summary_dim)
        amortizer = AmortizedHybrid(inference_net, summary_net, learner_net)
    else:
        amortizer = AmortizedPosterior(inference_net, summary_net, summary_loss_fun=summary_loss)

    # Ai expert or not
    if ai_expert:
        ckpt = f'checkpoints/{budget}/{mode}_{summary_type}_aiexp'
    else:
        ckpt = f'checkpoints/{budget}/{mode}_{summary_type}'

    trainer = Trainer(
        amortizer=amortizer,
        generative_model=benchmark.generative_model,
        skip_checks=True,
        configurator=partial(configurator, mode=mode, summary_type=summary_type, ai_expert=ai_expert),
        checkpoint_path=ckpt,
        max_to_keep=1,
    )
    return trainer
