import jax
import jax.numpy as jnp
from omegaconf import DictConfig

from jadex.algorithms.vae.models import BaseVAEModel, create_vae
from jadex.base.base_state import BaseState
from jadex.base.base_trainer import ModelTrainer
from jadex.data.datasets import BaseDataset, create_dataset
from jadex.data.datasets.base_dataset import DSET
from jadex.global_configs import jadex_hydra_main
from jadex.networks.variational.constants import X
from jadex.networks.variational.variational_network import merge_nn_cfg
from jadex.utils.plotting import plot_prediction


class VAETrainer(ModelTrainer):

    @classmethod
    def create_trainer_kwargs_and_state(cls, cfg, ctx):
        train_dataset = create_dataset(cfg, "train", ctx)

        if cfg.get("test") is not None and cfg.job.validation_frequency_nsteps != 0:
            test_dataset = create_dataset(cfg, "test", ctx)
        else:
            assert cfg.job.validation_frequency_nsteps == 0, "test dataset needed for validation!"
            test_dataset = None

        model = create_vae(cfg)
        state: BaseState = model.init(jax.random.PRNGKey(cfg.train.seed))

        fid = None
        if test_dataset is not None and test_dataset.dset_type == DSET.IMAGE:
            # Lazy import required (flax import interferes with multiprocessing)
            from jadex.utils.fidjax import create_fid

            fid = create_fid(cfg)

        trainer_kwargs = dict(
            model=model,
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            fid=fid,
        )
        return trainer_kwargs, state

    def extract_features_from_batch(self, p_batch):
        features = {X: self.train_dataset.get_feature_from_batch(p_batch, X)}

        if self.keep_feature_idxs is not None and X in self.keep_feature_idxs.keys():
            features[X] = features[X][..., self.keep_feature_idxs[X]]

        return features

    def log_expensive(self, state, batch, metrics, val=False):
        expensive_metrics = {}
        if (val and metrics) or (not val):
            prefix = "val" if val else "train"
            if self.cfg.dataset.scaler_mode == "online":
                xs = batch[X]
                x_hats = self.model.apply_inverse_scaler(metrics[f"{prefix}_x_hats"], state.scaler_vars, X)
            elif self.cfg.dataset.scaler_mode == "data":
                xs = self.train_dataset.apply_inverse_scaler(batch[X])
                x_hats = self.train_dataset.apply_inverse_scaler(metrics[f"{prefix}_x_hats"])

            wandb_metrics = plot_prediction(x_hats, xs, self.cfg, prefix)
            expensive_metrics.update(wandb_metrics)

        return expensive_metrics

    def get_aux_metrics(self, state: BaseState, batch, metrics, val=False):
        aux_metrics = {}

        # val already computes l2 loss
        if not val:
            prefix = "train"
            x_hats = metrics[f"{prefix}_x_hats"]
            if self.cfg.dataset.scaler_mode == "online":
                x_hats = self.model.apply_inverse_scaler(x_hats, state.scaler_vars, X)
            elif self.cfg.dataset.scaler_mode == "data":
                x_hats = self.train_dataset.apply_inverse_scaler(x_hats)

            aux_metrics[f"{prefix}_l2_loss"] = self.train_dataset.compute_l2_loss(
                xs=batch[X], descaled_x_hats=x_hats
            )

        return aux_metrics

    def run_validation(self, state):
        assert self.test_dataloader is not None, "train config must be set!"

        def _get_predictions(state, val_batch, rng_key):
            val_x_hats, val_metrics = self.model.get_predictions(state, val_batch, rng_key)

            if self.fid is not None:
                if self.cfg.dataset.scaler_mode == "online":
                    images = self.model.apply_inverse_scaler(val_x_hats, state.scaler_vars, X)
                elif self.cfg.dataset.scaler_mode == "data":
                    images = self.train_dataset.apply_inverse_scaler(val_x_hats)
                images = jnp.clip(images, 0, 255).astype(jnp.uint8)
                val_metrics["fid_acts"] = self.fid.compute_acts(images)

            return val_x_hats, val_metrics

        p_val_step = jax.pmap(_get_predictions, in_axes=(None, 0, None), out_axes=(0, 0), axis_name="batch")

        # return a random batch to visualize different batches when logging
        # NOTE: validation is not shuffled, and last batch is not dropped
        ret_idx = jax.random.randint(
            state.rng_key, shape=(), minval=0, maxval=self.test_dataloader.num_batches - 1
        )

        # iterator will be exhausted after the first loop, since we aren't using the infinite sampler
        self.test_dataloader.reset()

        all_metrics = []
        all_fid_acts = []
        l2_losses = []
        for i, p_val_full_batch in enumerate(self.test_dataloader):
            p_val_batch = self.extract_features_from_batch(p_val_full_batch)
            p_val_x_hats, p_val_metrics = p_val_step(state, p_val_batch, jax.random.PRNGKey(i))
            val_batch = jax.tree.map(jnp.concatenate, p_val_batch)
            val_x_hats = jax.tree.map(jnp.concatenate, p_val_x_hats)

            if self.cfg.dataset.scaler_mode == "online":
                descaled_val_x_hats = self.model.apply_inverse_scaler(val_x_hats, state.scaler_vars, X)
            elif self.cfg.dataset.scaler_mode == "data":
                descaled_val_x_hats = self.train_dataset.apply_inverse_scaler(val_x_hats)

            l2_losses.append(
                self.train_dataset.compute_l2_loss(xs=val_batch[X], descaled_x_hats=descaled_val_x_hats)
            )

            if i == ret_idx:
                ret_batch = val_batch
                ret_x_hats = val_x_hats

            val_metrics = jax.tree.map(lambda x: jnp.concatenate(jnp.atleast_2d(x)), p_val_metrics)
            if "fid_acts" in val_metrics.keys():
                all_fid_acts.append(val_metrics.pop("fid_acts"))

            all_metrics.append(jax.tree.map(jnp.mean, val_metrics))

        combined_metrics = jax.tree.map(lambda *args: jnp.stack(args), *all_metrics)
        metrics = jax.tree.map(jnp.mean, combined_metrics)

        if self.fid is not None:
            stats = self.fid.compute_stats(all_fid_acts)
            fid_score = self.fid.compute_score(stats)
            metrics["fid_score"] = fid_score

        metrics = {f"val_{key}": value for key, value in metrics.items()}
        metrics["val_x_hats"] = ret_x_hats
        metrics["val_l2_loss"] = jnp.array(l2_losses).mean()

        return ret_batch, metrics

    @staticmethod
    def get_project_name(cfg: DictConfig):
        model_cls = BaseVAEModel.registered[cfg.model.name]
        dataset_cls: BaseDataset = BaseDataset.registered[cfg.dataset.name]
        project_name = f"{model_cls.get_abbrev(cfg)}_{dataset_cls.get_abbrev(cfg)}"
        return project_name


@jadex_hydra_main(config_name="vae_config", config_path="./configs")
def main(cfg: DictConfig):
    cfg = merge_nn_cfg(cfg)
    VAETrainer.submit(cfg)


if __name__ == "__main__":
    main()
