from pathlib import Path

import jax
import matplotlib.pyplot as plt
import numpy as np
import wandb
from omegaconf import DictConfig, open_dict

from jadex.algorithms.vae.models import create_vae
from jadex.base.base_state import BaseState
from jadex.base.base_trainer import ModelTrainer
from jadex.data.datasets import create_dataset
from jadex.data.datasets.base_dataset import BaseDataset
from jadex.downstream.image_label.models import create_image_label_df_model
from jadex.downstream.image_label.models.image_label_df_model import BaseImageLabelDiscreteFlowModel
from jadex.global_configs import jadex_hydra_main
from jadex.global_configs.constants import JADEX_CHECKPOINT_DIR
from jadex.networks.variational.constants import LABEL, X
from jadex.utils import mplfig_to_npimage
from jadex.utils.plotting import use_backend
from jadex.utils.printing import print_green

RESULTS_DIR = Path(__file__).parent / "results"
NUM_PLOTS_PER_LABEL = 10


class ImageLabelDfTrainer(ModelTrainer):

    @classmethod
    def _load_datasets(cls, cfg, ctx):
        ##### Use the dataset from trained VAE #####
        train_dataset = create_dataset(cfg, "train", ctx)
        if cfg.get("test") is not None and cfg.job.validation_frequency_nsteps != 0:
            test_dataset = create_dataset(cfg, "test", ctx)
        else:
            assert cfg.job.validation_frequency_nsteps == 0, "test dataset needed for validation!"
            test_dataset = None

        dataset_kwargs = dict(train_dataset=train_dataset, test_dataset=test_dataset)

        return dataset_kwargs

    @classmethod
    def create_trainer_kwargs_and_state(cls, cfg, ctx):
        vae_cfg = BaseState.load_cfg(JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name)

        with open_dict(vae_cfg):
            vae_cfg.train = cfg.train

        with open_dict(cfg):
            cfg.dataset = vae_cfg.dataset
            cfg.dataset.include_labels = True

        dataset_kwargs = cls._load_datasets(cfg, ctx)

        vae_model = create_vae(vae_cfg)
        vae_state: BaseState = vae_model.init(jax.random.PRNGKey(cfg.train.seed))
        vae_state = vae_state.load_checkpoint(
            JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name,
            checkpoint_idx=cfg.model.vae_checkpoint_idx,
        )

        model = create_image_label_df_model(cfg, vae_cfg, vae_model, vae_state)
        state = model.init(jax.random.PRNGKey(cfg.train.seed))
        trainer_kwargs = dict(model=model, fid=None, **dataset_kwargs)
        return trainer_kwargs, state

    def plot_xmat(self, xmat, state: BaseState):
        num_rows, num_cols = xmat.shape[:2]
        fig, axs = plt.subplots(num_cols, num_rows, figsize=(5, 5), dpi=100)
        axs = axs.reshape(num_rows, num_cols)

        dset_name = self.cfg.vae_cfg.dataset.name

        if self.cfg.dataset.scaler_mode == "online":
            xmat = self.model.apply_inverse_scaler(xmat, state.scaler_vars, X)
        elif self.cfg.dataset.scaler_mode == "data":
            xmat = self.train_dataset.apply_inverse_scaler(xmat)

        if self.cfg.job.get("export_data", False):
            fname_prefix = f"xmat_{state.step:09d}"
            data_dir = RESULTS_DIR / f"{self.cfg.vae_cfg.model.id}_{self.cfg.vae_cfg.dataset.id}"
            data_dir.mkdir(exist_ok=True, parents=True)
            np.savez(data_dir / f"{fname_prefix}.npz", data=np.array(xmat))

        for i in range(num_rows):
            for j in range(num_cols):
                if dset_name in ("MNISTDataset", "MNISTContinuousDataset"):
                    axs[i, j].imshow(xmat[i, j], cmap="gray")
                else:
                    image = np.clip(xmat[i, j], 0, 255).astype(np.uint8)
                    axs[i, j].imshow(image)
                axs[i, j].axis("off")

        fig.subplots_adjust(wspace=0.02, hspace=0.02, left=0.01, right=0.99, top=0.99, bottom=0.01)
        img = mplfig_to_npimage(fig)

        if self.cfg.job.get("export_data", False):
            plt.savefig(data_dir / f"{fname_prefix}.pdf")
            plt.savefig(data_dir / f"{fname_prefix}.png")

        plt.close(fig)
        return img

    def log_expensive(self, state, batch, metrics, val=False):
        expensive_metrics = {}
        if not val:
            generate_from_labels_fn = jax.jit(self.model.generate_from_labels)

            num_classes = self.cfg.vae_cfg.dataset.num_classes
            if num_classes > 10:
                k_vals = [2, 96, 250, 440, 445, 527, 624, 643, 657, 724]
            else:
                k_vals = np.arange(num_classes)

            xmat = np.zeros((len(k_vals), NUM_PLOTS_PER_LABEL, *self.model.x_dist.shape))

            for i, k in enumerate(k_vals):
                xmat[i] = generate_from_labels_fn(
                    state, {LABEL: np.full((NUM_PLOTS_PER_LABEL,), k)}, jax.random.PRNGKey(k)
                )

            with use_backend("agg"):
                np_img = self.plot_xmat(xmat, state)

            expensive_metrics["label_generations"] = wandb.Image(np_img)

        return expensive_metrics

    @staticmethod
    def get_project_name(cfg: DictConfig):
        model_cls = BaseImageLabelDiscreteFlowModel.registered[cfg.model.name]
        vae_cfg = BaseState.load_cfg(JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name)
        dataset_cls: BaseDataset = BaseDataset.registered[vae_cfg.dataset.name]
        project_name = f"{model_cls.get_abbrev(cfg)}_{dataset_cls.get_abbrev(cfg)}"
        return project_name


@jadex_hydra_main(config_name="image_label_df_config", config_path="./configs")
def main(cfg: DictConfig):
    ImageLabelDfTrainer.submit(cfg)


if __name__ == "__main__":
    main()
