from functools import partial

import tensorflow as tf
import tensorflow_probability as tfp

import bayesflow as bf

tfd = tfp.distributions
tfpl = tfp.layers


class HeterokedasticAmortizer(tf.keras.Model):
    """Custom class for heteroskedastic amortizer interacting with BayesFlow."""

    def __init__(self, num_params, summary_net):
        super().__init__()

        self.inference_net = tf.keras.Sequential([
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.05),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.05),
            tf.keras.layers.Dense(tfpl.IndependentNormal.params_size(num_params)),
            tfpl.IndependentNormal(num_params)
        ])
        self.summary_net = summary_net
        self.latent_dim = num_params

    def call(self, x, **kwargs):
        out = self.summary_net(x, **kwargs)
        out = self.inference_net(out, **kwargs)
        return out

    def predict(self, x, **kwargs):
        out = self(x, training=False)
        return tf.concat((out.mean(), out.stddev()), axis=-1)

    def sample(self, x, n_samples):
        out = self(x, training=False)
        return tf.transpose(out.sample(n_samples), (1, 0, 2))

    def compute_loss(self, input_dict, **kwargs):
        pred = self(input_dict['summary_conditions'], **kwargs)
        loss = tf.reduce_mean(-pred.log_prob(input_dict['parameters']))
        return {'H.L': loss}


def get_pretrained(benchmark):
    summary_net = bf.networks.SequenceNetwork(32)
    hetero = HeterokedasticAmortizer(num_params=4, summary_net=summary_net)

    trainer = bf.trainers.Trainer(
        amortizer=hetero,
        generative_model=benchmark.generative_model,
        configurator=partial(benchmark.configurator, as_summary_condition=True),
        checkpoint_path='pretrained_expert'
    )
    return trainer.amortizer, trainer.configurator
