from typing import Dict, List

import tensorflow as tf
import numpy as np
from time import (
    time,
    process_time
)
from tensorflow_probability import distributions as tfd


class TimerCallback(tf.keras.callbacks.Callback):

    def __init__(
        self,
    ):
        self.time_reference = None
        self.process_time_reference = None

    def on_train_begin(
        self,
        logs = None
    ):
        logs["time"] = [0.]
        logs["process_time"] = [0.]

        self.time_reference = time()
        self.process_time_reference = process_time()

    def on_epoch_end(
        self,
        epoch,
        logs = None
    ):
        logs["time"] = (time() - self.time_reference)
        logs["process_time"] = (
            process_time() - self.process_time_reference
        )


class ELBOCallback(tf.keras.callbacks.Callback):

    def __init__(
        self,
        elbo_epochs: int,
        p_model: tfd.Distribution,
        observed_values: Dict[str, tf.Tensor],
        sample_size: int = 1_000,
        sample_kwargs: Dict = {},
        log_prob_kwargs: Dict = {}
    ):
        self.ELBO = None
        self.elbo_epochs = elbo_epochs
        self.p_model = p_model
        self.observed_values = observed_values
        self.sample_size = sample_size
        self.sample_kwargs = sample_kwargs
        self.log_prob_kwargs = log_prob_kwargs

    @tf.function
    def recompute_ELBO(
        self
    ) -> tf.Tensor:
        q_sample = self.model.sample(
            sample_shape=(self.sample_size,),
            observed_values=self.observed_values,
            return_observed_values=True,
            **self.sample_kwargs
        )
        q = self.model.log_prob(
            q_sample,
            **self.log_prob_kwargs
        )
        p = self.p_model.log_prob(
            q_sample
        )
        return tf.reduce_mean(p - q)

    def on_epoch_end(
        self,
        epoch,
        logs = None
    ):
        if epoch == 0 or epoch % self.elbo_epochs == self.elbo_epochs - 1:
            self.ELBO = self.recompute_ELBO().numpy()

        logs["ELBO"] = self.ELBO


class WeightsSaverCallback(tf.keras.callbacks.Callback):

    def __init__(
        self,
        save_epochs: int,
        weights_suffix: str,
        weights_dir: str
    ) -> None:
        super().__init__()
        self.save_epochs = save_epochs
        self.weights_suffix = weights_suffix
        self.weights_dir = weights_dir

    def on_epoch_end(
        self,
        epoch,
        logs=None
    ) -> None:
        if epoch % self.save_epochs == self.save_epochs - 1:
            self.model.save_weights_np(
                path=(
                    self.weights_dir
                    + f"epoch_{epoch + 1}_weights_"
                    + self.weights_suffix
                )
            )


class BatchELBOCallback(tf.keras.callbacks.Callback):

    def __init__(
        self,
        elbo_batches: int,
        p_model: tfd.Distribution,
        observed_values: Dict[str, tf.Tensor],
        sample_size: int = 1_000
    ):
        self.ELBO = None
        self.ELBOs = []
        self.elbo_batches = elbo_batches
        self.p_model = p_model
        self.observed_values = observed_values
        self.sample_size = sample_size

    @tf.function
    def recompute_ELBO(
        self
    ) -> tf.Tensor:
        encodings = self.model.encode_data(
            self.observed_values
        )
        q_sample, latent_values = self.model.sample(
            sample_shape=(self.sample_size,),
            observed_values=self.observed_values,
            hbm_type="full",
            encodings=encodings,
            return_observed_values=True,
            return_latent_values=True,
            repeat_observed_rvs_and_encodings=True
        )
        q = self.model.log_prob(
            values=q_sample,
            latent_values=latent_values,
            hbm_type="full",
            encodings=encodings,
            encodings_repeats=self.sample_size
        )
        p = self.p_model.log_prob(
            q_sample
        )
        return tf.reduce_mean(p - q)
    
    def on_train_batch_end(
        self,
        batch,
        logs=None
    ):
        if batch == 0 or batch % self.elbo_batches == self.elbo_batches - 1:
            self.ELBO = self.recompute_ELBO().numpy()

        self.ELBOs.append(self.ELBO)
        logs["ELBO"] = self.ELBO


class EMAResetCallback(tf.keras.callbacks.Callback):

    def __init__(self) -> None:
        super().__init__()

    def on_epoch_end(
        self,
        epoch,
        logs=None
    ) -> None:
        for encoding_level, memory in self.model.encodings_memory.items():
            self.model.encodings_memory[encoding_level].assign(
                np.nan * tf.ones_like(memory)
            )
