from copy import deepcopy

import jax
import numpy as np
from omegaconf import DictConfig, OmegaConf, open_dict

from jadex.algorithms.vae.models import create_vae
from jadex.base.base_state import BaseState
from jadex.base.base_trainer import ModelTrainer
from jadex.data.datasets import create_dataset
from jadex.data.datasets.base_dataset import BaseDataset
from jadex.downstream.traj_gpt.models import create_traj_gpt_model
from jadex.downstream.traj_gpt.models.traj_gpt_model import BaseTrajGptModel
from jadex.global_configs import jadex_hydra_main
from jadex.global_configs.constants import JADEX_CHECKPOINT_DIR
from jadex.networks.variational.constants import GOAL, TEXT, X
from jadex.utils.plotting import plot_prediction


class TrajGptTrainer(ModelTrainer):

    def extract_features_from_batch(self, p_batch):
        features = {
            X: self.train_dataset.get_feature_from_batch(p_batch, X),
            TEXT: self.train_dataset.get_feature_from_batch(p_batch, TEXT),
            GOAL: self.train_dataset.get_feature_from_batch(p_batch, GOAL),
        }

        if self.keep_feature_idxs is not None and X in self.keep_feature_idxs.keys():
            features[X] = features[X][..., self.keep_feature_idxs[X]]

        return features

    @classmethod
    def _load_datasets(cls, vae_cfg, cfg, ctx):
        # Have the dataloader return traj hist + cur taj (2 * input_len), then we will split it manually
        dataloader_cfg = deepcopy(vae_cfg)
        with open_dict(dataloader_cfg):
            dataloader_cfg.trajectory.input_len *= 2
            dataloader_cfg.spaces.text = cfg.spaces.text
            dataloader_cfg.spaces.goal = cfg.spaces.goal
            train_dataset = create_dataset(dataloader_cfg, "train", ctx)

        return train_dataset

    @classmethod
    def create_trainer_kwargs_and_state(cls, cfg, ctx):
        vae_cfg = BaseState.load_cfg(JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name)

        with open_dict(vae_cfg):
            vae_cfg.train = cfg.train

        train_dataset = cls._load_datasets(vae_cfg, cfg, ctx)

        # Add resolved VAE cfg to main cfg
        OmegaConf.resolve(vae_cfg)
        with open_dict(cfg):
            cfg.vae_cfg = vae_cfg
            cfg.dataset = vae_cfg.dataset

        with open_dict(cfg):
            cfg.dists.goal_dist.shape = train_dataset.get_feature_shape("goal")
            cfg.dists.goal_dist.param_shape = train_dataset.get_feature_shape("goal")

        vae_model = create_vae(vae_cfg)
        vae_state: BaseState = vae_model.init(jax.random.PRNGKey(cfg.train.seed))
        vae_state = vae_state.load_checkpoint(
            JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name,
            checkpoint_idx=cfg.model.vae_checkpoint_idx,
        )

        model = create_traj_gpt_model(cfg, vae_cfg, vae_model, vae_state)
        state = model.init(jax.random.PRNGKey(cfg.train.seed))

        ##### Only Keep some feature indices #####
        h1rel_feature_cond = (
            vae_cfg.environment.name == "UnitreeH1"
            and vae_cfg.spaces.input.name == "RelSites"
            and vae_cfg.spaces.input.params.use_positions
            and vae_cfg.spaces.input.params.use_rotations
            and not vae_cfg.spaces.input.params.use_velocities
        )

        ##### Create other feature_conds here (e.g. for the G1) #####
        if h1rel_feature_cond:
            discard_idxs = np.array([55, 56, 57, 58, 59, 60, 65, 66, 89, 90, 125, 126])
        else:
            discard_idxs = np.array([])
            # raise NotImplementedError

        x_num_features = train_dataset.get_feature_shape(X)[-1]
        x_keep_feature_idxs = np.setdiff1d(np.arange(x_num_features), discard_idxs)

        keep_feature_idxs = {X: x_keep_feature_idxs}
        ##### Only Keep some feature indices #####

        trainer_kwargs = dict(
            model=model,
            train_dataset=train_dataset,
            test_dataset=None,
            fid=None,
            keep_feature_idxs=keep_feature_idxs,
        )

        return trainer_kwargs, state

    def log_expensive(self, state, batch, metrics, val=False):
        expensive_metrics = {}
        if not val:
            model: BaseTrajGptModel = self.model
            predict_traj_fn = jax.jit(self.model.predict_traj)
            x_cur_hat_norm, x_cur_recon_norm, x_cur = predict_traj_fn(state, batch, state.rng_key)
            if self.cfg.dataset.scaler_mode == "online":
                x_cur_hat = model.vae_model.apply_inverse_scaler(
                    x_cur_hat_norm, model.vae_state.scaler_vars, X
                )
                x_cur_recon = model.vae_model.apply_inverse_scaler(
                    x_cur_recon_norm, model.vae_state.scaler_vars, X
                )
            elif self.cfg.dataset.scaler_mode == "data":
                raise ValueError

            wandb_metrics = plot_prediction(
                x_cur_hat, x_cur, self.cfg, "train", plt_kwargs=dict(x_recon=x_cur_recon)
            )
            expensive_metrics.update(wandb_metrics)

        return expensive_metrics

    @staticmethod
    def get_project_name(cfg: DictConfig):
        model_cls: BaseTrajGptModel = BaseTrajGptModel.registered[cfg.model.name]
        vae_cfg = BaseState.load_cfg(JADEX_CHECKPOINT_DIR / cfg.model.vae_checkpoint_name)
        dataset_cls: BaseDataset = BaseDataset.registered[vae_cfg.dataset.name]
        project_name = f"{model_cls.get_abbrev(cfg)}_{dataset_cls.get_abbrev(cfg)}"
        return project_name


@jadex_hydra_main(config_name="traj_gpt_config", config_path="configs")
def main(cfg: DictConfig):
    TrajGptTrainer.submit(cfg)


if __name__ == "__main__":
    main()
