import warnings
from typing import Sequence
from shutil import copyfile
import inspect
from collections import OrderedDict
import multiprocessing
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything, loggers as pl_loggers
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, DeviceStatsMonitor, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from omegaconf import OmegaConf
import os
import argparse
from einops import rearrange

from prediff.datasets.nbody.nbody_mnist_torch_wrap import NBodyMovingMNISTLightningDataModule
from prediff.datasets.nbody.nbody_mnist import default_datasets_dir
from prediff.datasets.nbody.visualization import vis_nbody_seq
from prediff.utils.checkpoint import pl_load, pl_ckpt_to_pytorch_state_dict
from prediff.utils.optim import SequentialLR, warmup_lambda
from prediff.taming.vae import AutoencoderKL
from prediff.distributions import DiagonalGaussianDistribution
from prediff.taming.loss import LPIPSWithDiscriminator


pytorch_state_dict_name = "vae_nbody.pt"
pytorch_loss_state_dict_name = "vae_loss_nbody.pt"
exps_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "experiments"))

class VAENbodyPLModule(pl.LightningModule):

    def __init__(self,
                 total_num_steps: int,
                 oc_file: str = None,
                 save_dir: str = None):
        super(VAENbodyPLModule, self).__init__()
        if oc_file is not None:
            oc_from_file = OmegaConf.load(open(oc_file, "r"))
        else:
            oc_from_file = None
        oc = self.get_base_config(oc_from_file=oc_from_file)
        model_cfg = OmegaConf.to_object(oc.model)

        self.torch_nn_module = AutoencoderKL(
            down_block_types=model_cfg["down_block_types"],
            in_channels=model_cfg["in_channels"],
            sample_size=model_cfg["sample_size"],  # not used
            block_out_channels=model_cfg["block_out_channels"],
            act_fn=model_cfg["act_fn"],
            latent_channels=model_cfg["latent_channels"],
            up_block_types=model_cfg["up_block_types"],
            norm_num_groups=model_cfg["norm_num_groups"],
            layers_per_block=model_cfg["layers_per_block"],
            out_channels=model_cfg["out_channels"], )
        loss_cfg = model_cfg["loss"]
        self.loss = LPIPSWithDiscriminator(
            disc_start=loss_cfg["disc_start"],
            kl_weight=loss_cfg["kl_weight"],
            disc_weight=loss_cfg["disc_weight"],
            perceptual_weight=loss_cfg["perceptual_weight"],
            disc_in_channels=loss_cfg["disc_in_channels"],)

        self.total_num_steps = total_num_steps
        if oc_file is not None:
            oc_from_file = OmegaConf.load(open(oc_file, "r"))
        else:
            oc_from_file = None
        oc = self.get_base_config(oc_from_file=oc_from_file)
        self.save_hyperparameters(oc)
        self.oc = oc
        # layout
        self.layout = oc.layout.layout
        self.channel_axis = self.layout.find("C")
        self.batch_axis = self.layout.find("N")
        self.t_axis = self.layout.find("T")
        self.h_axis = self.layout.find("H")
        self.w_axis = self.layout.find("W")
        self.channels = model_cfg["data_channels"]
        # optimization
        self.max_epochs = oc.optim.max_epochs
        self.optim_method = oc.optim.method
        self.lr = oc.optim.lr
        self.wd = oc.optim.wd
        # lr_scheduler
        self.total_num_steps = total_num_steps
        # logging
        self.save_dir = save_dir
        self.logging_prefix = oc.logging.logging_prefix
        # visualization
        self.train_example_data_idx_list = list(oc.vis.train_example_data_idx_list)
        self.val_example_data_idx_list = list(oc.vis.val_example_data_idx_list)
        self.test_example_data_idx_list = list(oc.vis.test_example_data_idx_list)
        self.eval_example_only = oc.vis.eval_example_only

        self.valid_mse = torchmetrics.MeanSquaredError()
        self.valid_mae = torchmetrics.MeanAbsoluteError()
        self.valid_ssim = torchmetrics.StructuralSimilarityIndexMeasure()
        self.test_mse = torchmetrics.MeanSquaredError()
        self.test_mae = torchmetrics.MeanAbsoluteError()
        self.test_ssim = torchmetrics.StructuralSimilarityIndexMeasure()

        self.configure_save(cfg_file_path=oc_file)
        # # Load pretrained torch.pt
        # # Notice that previously saved loss .pt gets wrong keys, so we need to fix it
        # state_dict = torch.load(os.path.join(self.save_dir, "checkpoints", pytorch_state_dict_name),
        #                         map_location=torch.device("cpu"))
        # self.torch_nn_module.load_state_dict(state_dict)
        # state_dict = torch.load(os.path.join(self.save_dir, "checkpoints", pytorch_loss_state_dict_name),
        #                         map_location=torch.device("cpu"))
        # loss_state_dict = OrderedDict()
        # for key, val in state_dict.items():
        #     loss_state_dict[key.replace("perceptual_", "perceptual_loss.")] = val
        # self.loss.load_state_dict(loss_state_dict)

    def configure_save(self, cfg_file_path=None):
        self.save_dir = os.path.join(exps_dir, self.save_dir)
        os.makedirs(self.save_dir, exist_ok=True)
        self.scores_dir = os.path.join(self.save_dir, 'scores')
        os.makedirs(self.scores_dir, exist_ok=True)
        if cfg_file_path is not None:
            cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml")
            if (not os.path.exists(cfg_file_target_path)) or \
                    (not os.path.samefile(cfg_file_path, cfg_file_target_path)):
                copyfile(cfg_file_path, cfg_file_target_path)
        self.example_save_dir = os.path.join(self.save_dir, "examples")
        os.makedirs(self.example_save_dir, exist_ok=True)

    def get_base_config(self, oc_from_file=None):
        oc = OmegaConf.create()
        oc.layout = self.get_layout_config()
        oc.optim = self.get_optim_config()
        oc.logging = self.get_logging_config()
        oc.trainer = self.get_trainer_config()
        oc.vis = self.get_vis_config()
        oc.model = self.get_model_config()
        oc.dataset = self.get_dataset_config()
        if oc_from_file is not None:
            # oc = apply_omegaconf_overrides(oc, oc_from_file)
            oc = OmegaConf.merge(oc, oc_from_file)
        return oc

    @staticmethod
    def get_layout_config():
        cfg = OmegaConf.create()
        cfg.img_height = 128
        cfg.img_width = 128
        cfg.layout = "NHWC"
        return cfg

    @classmethod
    def get_model_config(cls):
        cfg = OmegaConf.create()
        layout_cfg = cls.get_layout_config()
        cfg.data_channels = 4
        # from stable-diffusion-v1-5
        cfg.down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
        cfg.in_channels = cfg.data_channels
        cfg.sample_size = 512  # not used
        cfg.block_out_channels = [128, 256, 512, 512]
        cfg.act_fn = 'silu'
        cfg.latent_channels = 4
        cfg.up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
        cfg.norm_num_groups = 32
        cfg.layers_per_block = 2
        cfg.out_channels = cfg.data_channels

        cfg.loss = OmegaConf.create()
        cfg.loss.disc_start = 50001
        cfg.loss.kl_weight = 1e-6
        cfg.loss.disc_weight = 0.5
        cfg.loss.perceptual_weight = 1.0
        cfg.loss.disc_in_channels = cfg.data_channels
        return cfg

    @classmethod
    def get_dataset_config(cls):
        cfg = OmegaConf.create()
        cfg.dataset_name = "nbody20k_digits3_len20_size64"
        cfg.num_train_samples = 20000
        cfg.num_val_samples = 1000
        cfg.num_test_samples = 1000
        cfg.digit_num = None
        cfg.img_size = 64
        cfg.raw_img_size = 128
        cfg.seq_len = 20
        cfg.raw_seq_len_multiplier = 5
        cfg.distractor_num = None
        cfg.distractor_size = 5
        cfg.max_velocity_scale = 2.0
        cfg.initial_velocity_range = [0.0, 2.0]
        cfg.random_acceleration_range = [0.0, 0.0]
        cfg.scale_variation_range = [1.0, 1.0]
        cfg.rotation_angle_range = [-0, 0]
        cfg.illumination_factor_range = [1.0, 1.0]
        cfg.period = 5
        cfg.global_rotation_prob = 0.5
        cfg.index_range = [0, 40000]
        cfg.mnist_data_path = None
        cfg.aug_mode = "0"
        cfg.ret_contiguous = False
        # N-body params
        cfg.nbody_acc_mode = "r0"
        cfg.nbody_G = 0.035
        cfg.nbody_softening_distance = 0.01
        cfg.nbody_mass = None
        return cfg

    @staticmethod
    def get_optim_config():
        cfg = OmegaConf.create()
        cfg.seed = None
        cfg.total_batch_size = 32
        cfg.micro_batch_size = 8
        cfg.float32_matmul_precision = "high"

        cfg.method = "adam"
        cfg.lr = 1E-3
        cfg.wd = 1E-5
        cfg.betas = (0.5, 0.9)
        cfg.gradient_clip_val = 1.0
        cfg.max_epochs = 50
        # scheduler
        cfg.warmup_percentage = 0.2
        cfg.lr_scheduler_mode = "cosine"  # Can be strings like 'linear', 'cosine', 'platue'
        cfg.min_lr_ratio = 1.0E-3
        cfg.warmup_min_lr_ratio = 0.0
        # early stopping
        cfg.monitor = "val/total_loss"
        cfg.early_stop = False
        cfg.early_stop_mode = "min"
        cfg.early_stop_patience = 5
        cfg.save_top_k = 1
        return cfg

    @staticmethod
    def get_logging_config():
        cfg = OmegaConf.create()
        cfg.logging_prefix = "Nbody"
        cfg.monitor_lr = True
        cfg.monitor_device = False
        cfg.track_grad_norm = -1
        cfg.use_wandb = False
        return cfg

    @staticmethod
    def get_trainer_config():
        cfg = OmegaConf.create()
        cfg.check_val_every_n_epoch = 1
        cfg.log_step_ratio = 0.001  # Logging every 1% of the total training steps per epoch
        cfg.precision = 32
        cfg.find_unused_parameters = True
        cfg.num_sanity_val_steps = 2
        return cfg

    @staticmethod
    def get_vis_config():
        cfg = OmegaConf.create()
        cfg.train_example_data_idx_list = [0, ]
        cfg.val_example_data_idx_list = [0, ]
        cfg.test_example_data_idx_list = [0, ]
        cfg.eval_example_only = False
        cfg.num_vis = 10
        return cfg

    def configure_optimizers(self):
        optim_cfg = self.oc.optim
        lr = optim_cfg.lr
        betas = optim_cfg.betas
        opt_ae = torch.optim.Adam(list(self.torch_nn_module.encoder.parameters())+
                                  list(self.torch_nn_module.decoder.parameters())+
                                  list(self.torch_nn_module.quant_conv.parameters())+
                                  list(self.torch_nn_module.post_quant_conv.parameters()),
                                  lr=lr, betas=betas)
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=betas)

        warmup_iter = int(np.round(optim_cfg.warmup_percentage * self.total_num_steps))
        if optim_cfg.lr_scheduler_mode == 'none':
            return [{"optimizer": opt_ae}, {"optimizer": opt_disc}]
        else:
            if optim_cfg.lr_scheduler_mode == 'cosine':
                # generator
                warmup_scheduler_ae = LambdaLR(
                    opt_ae,
                    lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
                                            min_lr_ratio=optim_cfg.warmup_min_lr_ratio))
                cosine_scheduler_ae = CosineAnnealingLR(
                    opt_ae,
                    T_max=(self.total_num_steps - warmup_iter),
                    eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr)
                lr_scheduler_ae = SequentialLR(
                    opt_ae,
                    schedulers=[warmup_scheduler_ae, cosine_scheduler_ae],
                    milestones=[warmup_iter])
                lr_scheduler_config_ae = {
                    'scheduler': lr_scheduler_ae,
                    'interval': 'step',
                    'frequency': 1,}
                # discriminator
                warmup_scheduler_disc = LambdaLR(
                    opt_disc,
                    lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
                                            min_lr_ratio=optim_cfg.warmup_min_lr_ratio))
                cosine_scheduler_disc = CosineAnnealingLR(
                    opt_disc,
                    T_max=(self.total_num_steps - warmup_iter),
                    eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr)
                lr_scheduler_disc = SequentialLR(
                    opt_disc,
                    schedulers=[warmup_scheduler_disc, cosine_scheduler_disc],
                    milestones=[warmup_iter])
                lr_scheduler_config_disc = {
                    'scheduler': lr_scheduler_disc,
                    'interval': 'step',
                    'frequency': 1, }
            else:
                raise NotImplementedError
            return [
                {"optimizer": opt_ae, "lr_scheduler": lr_scheduler_config_ae},
                {"optimizer": opt_disc, "lr_scheduler": lr_scheduler_config_disc},
            ]

    def set_trainer_kwargs(self, **kwargs):
        r"""
        Default kwargs used when initializing pl.Trainer
        """
        checkpoint_callback = ModelCheckpoint(
            monitor=self.oc.optim.monitor,
            dirpath=os.path.join(self.save_dir, "checkpoints"),
            filename="{epoch:03d}",
            auto_insert_metric_name=False,
            save_top_k=self.oc.optim.save_top_k,
            save_last=True,
            mode="min",
        )
        callbacks = kwargs.pop("callbacks", [])
        assert isinstance(callbacks, list)
        for ele in callbacks:
            assert isinstance(ele, Callback)
        callbacks += [checkpoint_callback, ]
        if self.oc.logging.monitor_lr:
            callbacks += [LearningRateMonitor(logging_interval='step'), ]
        if self.oc.logging.monitor_device:
            callbacks += [DeviceStatsMonitor(), ]
        if self.oc.optim.early_stop:
            callbacks += [EarlyStopping(monitor="valid_loss_epoch",
                                        min_delta=0.0,
                                        patience=self.oc.optim.early_stop_patience,
                                        verbose=False,
                                        mode=self.oc.optim.early_stop_mode), ]

        logger = kwargs.pop("logger", [])
        tb_logger = pl_loggers.TensorBoardLogger(save_dir=self.save_dir)
        csv_logger = pl_loggers.CSVLogger(save_dir=self.save_dir)
        logger += [tb_logger, csv_logger]
        if self.oc.logging.use_wandb:
            wandb_logger = pl_loggers.WandbLogger(project=self.oc.logging.logging_prefix,
                                                  save_dir=self.save_dir)
            logger += [wandb_logger, ]

        log_every_n_steps = max(1, int(self.oc.trainer.log_step_ratio * self.total_num_steps))
        trainer_init_keys = inspect.signature(Trainer).parameters.keys()
        ret = dict(
            num_sanity_val_steps=self.oc.trainer.num_sanity_val_steps,
            callbacks=callbacks,
            # log
            logger=logger,
            log_every_n_steps=log_every_n_steps,
            track_grad_norm=self.oc.logging.track_grad_norm,
            # save
            default_root_dir=self.save_dir,
            # ddp
            accelerator="gpu",
            # strategy="ddp",
            strategy=DDPStrategy(find_unused_parameters=self.oc.trainer.find_unused_parameters),
            # optimization
            max_epochs=self.oc.optim.max_epochs,
            check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch,
            gradient_clip_val=self.oc.optim.gradient_clip_val,
            # NVIDIA amp
            precision=self.oc.trainer.precision,
        )
        oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer)
        oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys}
        ret.update(oc_trainer_kwargs)
        ret.update(kwargs)
        return ret

    @classmethod
    def get_total_num_steps(
            cls,
            num_samples: int,
            total_batch_size: int,
            epoch: int = None):
        r"""
        Parameters
        ----------
        num_samples:    int
            The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch.
        total_batch_size:   int
            `total_batch_size == micro_batch_size * world_size * grad_accum`
        """
        if epoch is None:
            epoch = cls.get_optim_config().max_epochs
        return int(epoch * num_samples / total_batch_size)

    @staticmethod
    def get_nbody_datamodule(dataset_oc,
                             load_dir: str = None,
                             micro_batch_size: int = 1,
                             num_workers: int = 8):
        if load_dir is None:
            load_dir = os.path.join(default_datasets_dir, "nbody")
        data_dir = os.path.join(load_dir, dataset_oc["dataset_name"])
        if not os.path.exists(data_dir):
            raise ValueError(f"dataset in {data_dir} not exists!")
        load_dataset_cfg_path = os.path.join(data_dir, "nbody_dataset_cfg.yaml")
        load_dataset_cfg = OmegaConf.to_object(OmegaConf.load(open(load_dataset_cfg_path, "r")).dataset)
        for key, val in load_dataset_cfg.items():
            if key in ["aug_mode", "ret_contiguous"]:
                continue  # exclude keys that can be different
            assert val == dataset_oc[key], \
                f"dataset config {key} mismatches!" \
                f"{dataset_oc[key]} specified, but {val} loaded."
        dm = NBodyMovingMNISTLightningDataModule(
            data_dir=data_dir,
            force_regenerate=False,
            num_train_samples=dataset_oc["num_train_samples"],
            num_val_samples=dataset_oc["num_val_samples"],
            num_test_samples=dataset_oc["num_test_samples"],
            digit_num=dataset_oc["digit_num"],
            img_size=dataset_oc["img_size"],
            raw_img_size=dataset_oc["raw_img_size"],
            seq_len=dataset_oc["seq_len"],
            raw_seq_len_multiplier=dataset_oc["raw_seq_len_multiplier"],
            distractor_num=dataset_oc["distractor_num"],
            distractor_size=dataset_oc["distractor_size"],
            max_velocity_scale=dataset_oc["max_velocity_scale"],
            initial_velocity_range=dataset_oc["initial_velocity_range"],
            random_acceleration_range=dataset_oc["random_acceleration_range"],
            scale_variation_range=dataset_oc["scale_variation_range"],
            rotation_angle_range=dataset_oc["rotation_angle_range"],
            illumination_factor_range=dataset_oc["illumination_factor_range"],
            period=dataset_oc["period"],
            global_rotation_prob=dataset_oc["global_rotation_prob"],
            index_range=dataset_oc["index_range"],
            mnist_data_path=dataset_oc["mnist_data_path"],
            aug_mode=dataset_oc["aug_mode"],
            ret_contiguous=dataset_oc["ret_contiguous"],
            # N-Body params
            nbody_acc_mode=dataset_oc["nbody_acc_mode"],
            nbody_G=dataset_oc["nbody_G"],
            nbody_softening_distance=dataset_oc["nbody_softening_distance"],
            nbody_mass=dataset_oc["nbody_mass"],
            # datamodule_only
            batch_size=micro_batch_size,
            num_workers=num_workers, )
        return dm

    def encode(self, x):
        h = self.torch_nn_module.encoder(x)
        moments = self.torch_nn_module.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.torch_nn_module.post_quant_conv(z)
        dec = self.torch_nn_module.decoder(z)
        return dec

    def get_last_layer(self):
        return self.torch_nn_module.decoder.conv_out.weight

    def get_input(self, batch):
        t_idx = np.random.randint(low=0, high=self.oc.dataset.seq_len)
        batch = batch[:, t_idx, :, :, :]
        target_bchw = rearrange(batch, "b h w c -> b c h w")
        mask = None
        return target_bchw, mask

    def forward(self, target_bchw, sample_posterior=True):
        posterior = self.encode(target_bchw)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        pred_bchw = self.decode(z)
        return pred_bchw, posterior

    def training_step(self, batch, batch_idx, optimizer_idx):
        target_bchw, _ = self.get_input(batch=batch)
        pred_bchw, posterior = self(target_bchw)
        micro_batch_size = batch.shape[self.batch_axis]
        data_idx = int(batch_idx * micro_batch_size)
        if self.current_epoch % self.oc.trainer.check_val_every_n_epoch == 0 \
                and self.local_rank == 0:
            self.save_vis_step_end(
                data_idx=data_idx,
                target=target_bchw.detach().float().cpu().numpy(),
                pred=pred_bchw.detach().float().cpu().numpy(),
                mode="train", )

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(target_bchw, pred_bchw, posterior, optimizer_idx, self.global_step,
                                            mask=None, last_layer=self.get_last_layer(), split="train")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=False)
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(target_bchw, pred_bchw, posterior, optimizer_idx, self.global_step,
                                                mask=None, last_layer=self.get_last_layer(), split="train")
            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=False)
            return discloss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        micro_batch_size = batch.shape[self.batch_axis]
        H = batch.shape[self.h_axis]
        W = batch.shape[self.w_axis]
        data_idx = int(batch_idx * micro_batch_size)
        if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
            target_bchw, _ = self.get_input(batch=batch)
            pred_bchw, posterior = self(target_bchw)
            if self.local_rank == 0:
                self.save_vis_step_end(
                    data_idx=data_idx,
                    target=target_bchw.detach().float().cpu().numpy(),
                    pred=pred_bchw.detach().float().cpu().numpy(),
                    mode="val", )
            aeloss, log_dict_ae = self.loss(target_bchw, pred_bchw, posterior, 0, self.global_step,
                                            mask=None, last_layer=self.get_last_layer(), split="val")
            discloss, log_dict_disc = self.loss(target_bchw, pred_bchw, posterior, 1, self.global_step,
                                                mask=None, last_layer=self.get_last_layer(), split="val")
            self.log("val/rec_loss", log_dict_ae["val/rec_loss"], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.valid_mse(pred_bchw, target_bchw)
            self.valid_mae(pred_bchw, target_bchw)
            self.valid_ssim(pred_bchw, target_bchw)
            # return log_dict_ae, log_dict_disc
        return H, W

    def validation_epoch_end(self, outputs):
        H, W = outputs[0]
        valid_mse = self.valid_mse.compute() * H * W
        valid_mae = self.valid_mae.compute() * H * W
        valid_ssim = self.valid_ssim.compute()
        # valid_loss = valid_mse

        # self.log('valid_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('valid_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log('valid_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log('valid_ssim_epoch', valid_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.valid_mse.reset()
        self.valid_mae.reset()
        self.valid_ssim.reset()

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        micro_batch_size = batch.shape[self.batch_axis]
        H = batch.shape[self.h_axis]
        W = batch.shape[self.w_axis]
        data_idx = int(batch_idx * micro_batch_size)
        if not self.eval_example_only or data_idx in self.test_example_data_idx_list:
            target_bchw, _ = self.get_input(batch=batch)
            pred_bchw, posterior = self(target_bchw)
            if self.local_rank == 0:
                self.save_vis_step_end(
                    data_idx=data_idx,
                    target=target_bchw.detach().float().cpu().numpy(),
                    pred=pred_bchw.detach().float().cpu().numpy(),
                    mode="test", )
            aeloss, log_dict_ae = self.loss(target_bchw, pred_bchw, posterior, 0, self.global_step,
                                            mask=None, last_layer=self.get_last_layer(), split="test")
            discloss, log_dict_disc = self.loss(target_bchw, pred_bchw, posterior, 1, self.global_step,
                                                mask=None, last_layer=self.get_last_layer(), split="test")
            self.log("test/rec_loss", log_dict_ae["test/rec_loss"], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
            self.test_mse(pred_bchw, target_bchw)
            self.test_mae(pred_bchw, target_bchw)
            self.test_ssim(pred_bchw, target_bchw)
            # return log_dict_ae, log_dict_disc
        return H, W

    def test_epoch_end(self, outputs):
        H, W = outputs[0]
        test_mse = self.test_mse.compute() * H * W
        test_mae = self.test_mae.compute() * H * W
        test_ssim = self.test_ssim.compute()

        self.log('test_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log('test_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log('test_ssim_epoch', test_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
        self.test_mse.reset()
        self.test_mae.reset()
        self.test_ssim.reset()

    def save_vis_step_end(
            self,
            data_idx: int,
            target: np.ndarray,
            pred: np.ndarray,
            mode: str = "train",
            prefix: str = ""):
        r"""
        Parameters
        ----------
        data_idx
        target, pred:   np.ndarray
            Shape = (N, C, H, W), actually (T, 1, H, W)
        mode:   str
        """
        if self.local_rank == 0:
            if mode == "train":
                example_data_idx_list = self.train_example_data_idx_list
            elif mode == "val":
                example_data_idx_list = self.val_example_data_idx_list
            elif mode == "test":
                example_data_idx_list = self.test_example_data_idx_list
            else:
                raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].")
            if data_idx in example_data_idx_list:
                save_name = f"{prefix}{mode}_epoch_{self.current_epoch}_data_{data_idx}.png"
                num_vis = min(target.shape[0], self.oc.vis.num_vis)
                vis_nbody_seq(
                    save_path=os.path.join(self.example_save_dir, save_name),
                    in_seq=target[:num_vis].squeeze(1),
                    pred_seq=pred[:num_vis].squeeze(1),
                    pred_label=self.oc.logging.logging_prefix,
                    plot_stride=1, fs=10, norm="none",)

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save', default='tmp_nbody', type=str)
    parser.add_argument('--gpus', default=1, type=int)
    parser.add_argument('--cfg', default=None, type=str)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--ckpt_name', default=None, type=str,
                        help='The model checkpoint trained on N-body MNIST.')
    return parser

def main():
    parser = get_parser()
    args = parser.parse_args()
    if args.cfg is not None:
        oc_from_file = OmegaConf.load(open(args.cfg, "r"))
        dataset_cfg = OmegaConf.to_object(oc_from_file.dataset)
        total_batch_size = oc_from_file.optim.total_batch_size
        micro_batch_size = oc_from_file.optim.micro_batch_size
        max_epochs = oc_from_file.optim.max_epochs
        seed = oc_from_file.optim.seed
        float32_matmul_precision = oc_from_file.optim.float32_matmul_precision
    else:
        dataset_cfg = OmegaConf.to_object(VAENbodyPLModule.get_dataset_config())
        micro_batch_size = 1
        total_batch_size = int(micro_batch_size * args.gpus)
        max_epochs = None
        seed = 0
        float32_matmul_precision = "high"
    torch.set_float32_matmul_precision(float32_matmul_precision)
    seed_everything(seed, workers=True)
    dm = VAENbodyPLModule.get_nbody_datamodule(
        dataset_oc=dataset_cfg,
        micro_batch_size=micro_batch_size,
        num_workers=8,)
    dm.prepare_data()
    dm.setup()
    accumulate_grad_batches = total_batch_size // (micro_batch_size * args.gpus)
    total_num_steps = VAENbodyPLModule.get_total_num_steps(
        epoch=max_epochs,
        num_samples=dm.num_train_samples,
        total_batch_size=total_batch_size,
    )
    pl_module = VAENbodyPLModule(
        total_num_steps=total_num_steps,
        save_dir=args.save,
        oc_file=args.cfg)
    trainer_kwargs = pl_module.set_trainer_kwargs(
        devices=args.gpus,
        accumulate_grad_batches=accumulate_grad_batches,
    )
    trainer = Trainer(**trainer_kwargs)
    if args.test:
        if args.ckpt_name is not None:
            ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
        else:
            ckpt_path = None
        trainer.test(model=pl_module,
                     datamodule=dm,
                     ckpt_path=ckpt_path)
    else:
        if args.ckpt_name is not None:
            ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
            if not os.path.exists(ckpt_path):
                warnings.warn(f"ckpt {ckpt_path} not exists! Start training from epoch 0.")
                ckpt_path = None
        else:
            ckpt_path = None
        trainer.fit(model=pl_module,
                    datamodule=dm,
                    ckpt_path=ckpt_path)
        # save state_dict of VAE and discriminator
        pl_ckpt = pl_load(path_or_url=trainer.checkpoint_callback.best_model_path,
                          map_location=torch.device("cpu"))
        state_dict = pl_ckpt["state_dict"]
        vae_key = "torch_nn_module."
        vae_state_dict = OrderedDict()
        loss_key = "loss."
        loss_state_dict = OrderedDict()
        unexpected_dict = OrderedDict()
        for key, val in state_dict.items():
            if key.startswith(vae_key):
                vae_state_dict[key[len(vae_key):]] = val
            elif key.startswith(loss_key):
                loss_state_dict[key[len(loss_key):]] = val
            else:
                unexpected_dict[key] = val
        torch.save(vae_state_dict, os.path.join(pl_module.save_dir, "checkpoints", pytorch_state_dict_name))
        torch.save(loss_state_dict, os.path.join(pl_module.save_dir, "checkpoints", pytorch_loss_state_dict_name))
        # test
        trainer.test(ckpt_path="best",
                     datamodule=dm)

if __name__ == "__main__":
    main()
