import warnings
from shutil import copyfile
import inspect
import pickle
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
import torchmetrics
from einops import rearrange
import pytorch_lightning as pl
from pytorch_lightning import Trainer, loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, DeviceStatsMonitor, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from omegaconf import OmegaConf, DictConfig, ListConfig
import os
import argparse
from einops import rearrange
from pytorch_lightning import Trainer, seed_everything
from earthformer.utils import layout_to_in_out_slice, get_parameter_names, SequentialLR, warmup_lambda, save_example_vis_results
from earthformer.cuboid_transformer import CuboidTransformerModel
from earthformer.nbody_mnist import NBodyMovingMNISTLightningDataModule
from earthformer.apex_ddp import ApexDDPPlugin


class CuboidNBodyPLModule(pl.LightningModule):

    def __init__(self,
                 total_num_steps: int,
                 oc_file: str = None,
                 save_dir: str = None):
        super(CuboidNBodyPLModule, self).__init__()
        self._max_train_iter = total_num_steps
        if oc_file is not None:
            oc_from_file = OmegaConf.load(open(oc_file, "r"))
        else:
            oc_from_file = None
        oc = self.get_base_config(oc_from_file=oc_from_file)
        model_cfg = OmegaConf.to_object(oc.model)
        num_blocks = len(model_cfg["enc_depth"])
        if isinstance(model_cfg["self_pattern"], str):
            enc_attn_patterns = [model_cfg["self_pattern"]] * num_blocks
        else:
            enc_attn_patterns = OmegaConf.to_container(model_cfg["self_pattern"])
        if isinstance(model_cfg["cross_self_pattern"], str):
            dec_self_attn_patterns = [model_cfg["cross_self_pattern"]] * num_blocks
        else:
            dec_self_attn_patterns = OmegaConf.to_container(model_cfg["cross_self_pattern"])
        if isinstance(model_cfg["cross_pattern"], str):
            dec_cross_attn_patterns = [model_cfg["cross_pattern"]] * num_blocks
        else:
            dec_cross_attn_patterns = OmegaConf.to_container(model_cfg["cross_pattern"])

        self.torch_nn_module = CuboidTransformerModel(
            input_shape=model_cfg["input_shape"],
            target_shape=model_cfg["target_shape"],
            base_units=model_cfg["base_units"],
            block_units=model_cfg["block_units"],
            scale_alpha=model_cfg["scale_alpha"],
            enc_depth=model_cfg["enc_depth"],
            dec_depth=model_cfg["dec_depth"],
            enc_use_inter_ffn=model_cfg["enc_use_inter_ffn"],
            dec_use_inter_ffn=model_cfg["dec_use_inter_ffn"],
            downsample=model_cfg["downsample"],
            downsample_type=model_cfg["downsample_type"],
            downsample_conv_cfg_list=model_cfg["downsample_conv_cfg_list"],
            downsample_conv_norm=model_cfg["downsample_conv_norm"],
            enc_attn_patterns=enc_attn_patterns,
            dec_self_attn_patterns=dec_self_attn_patterns,
            dec_cross_attn_patterns=dec_cross_attn_patterns,
            dec_cross_last_n_frames=model_cfg["cross_last_n_frames"],
            dec_use_first_self_attn=model_cfg["dec_use_first_self_attn"],
            num_heads=model_cfg["num_heads"],
            attn_drop=model_cfg["attn_drop"],
            proj_drop=model_cfg["proj_drop"],
            ffn_drop=model_cfg["ffn_drop"],
            upsample_type=model_cfg["upsample_type"],
            upsample_tconv_cfg_list=model_cfg["upsample_tconv_cfg_list"],
            activation=model_cfg["activation"],
            ffn_activation=model_cfg["ffn_activation"],
            gated_ffn=model_cfg["gated_ffn"],
            norm_layer=model_cfg["norm_layer"],
            # global vectors
            num_global_vectors=model_cfg["num_global_vectors"],
            use_dec_self_global=model_cfg["use_dec_self_global"],
            dec_self_update_global=model_cfg["dec_self_update_global"],
            use_dec_cross_global=model_cfg["use_dec_cross_global"],
            use_global_vector_ffn=model_cfg["use_global_vector_ffn"],
            use_global_self_attn=model_cfg["use_global_self_attn"],
            use_global_mode=model_cfg["use_global_mode"],
            separate_global_qkv=model_cfg["separate_global_qkv"],
            global_dim_ratio=model_cfg["global_dim_ratio"],
            # initial_downsample
            initial_downsample_scale=model_cfg["initial_downsample_scale"],
            initial_downsample_type=model_cfg["initial_downsample_type"],
            # initial_downsample_type=="conv"
            initial_downsample_conv_layers=model_cfg["initial_final_sample_num_conv"],
            final_upsample_conv_layers=model_cfg["initial_final_sample_num_conv"] - 1,
            # misc
            padding_type=model_cfg["padding_type"],
            z_init_method=model_cfg["z_init_method"],
            z_init_token_len=model_cfg["z_init_token_len"],
            checkpoint_level=model_cfg["checkpoint_level"],
            pos_embed_type=model_cfg["pos_embed_type"],
            use_relative_pos=model_cfg["use_relative_pos"],
            self_attn_use_final_proj=model_cfg["self_attn_use_final_proj"],
            # initialization
            attn_linear_init_mode=model_cfg["attn_linear_init_mode"],
            ffn_linear_init_mode=model_cfg["ffn_linear_init_mode"],
            conv_init_mode=model_cfg["conv_init_mode"],
            down_up_linear_init_mode=model_cfg["down_up_linear_init_mode"],
            norm_init_mode=model_cfg["norm_init_mode"],
        )

        self.total_num_steps = total_num_steps
        if oc_file is not None:
            oc_from_file = OmegaConf.load(open(oc_file, "r"))
        else:
            oc_from_file = None
        oc = self.get_base_config(oc_from_file=oc_from_file)
        self.save_hyperparameters(oc)
        self.oc = oc
        # layout
        self.in_len = oc.layout.in_len
        self.out_len = oc.layout.out_len
        self.layout = oc.layout.layout
        # optimization
        self.max_epochs = oc.optim.max_epochs
        self.optim_method = oc.optim.method
        self.lr = oc.optim.lr
        self.wd = oc.optim.wd
        # lr_scheduler
        self.total_num_steps = total_num_steps
        self.lr_scheduler_mode = oc.optim.lr_scheduler_mode
        self.warmup_percentage = oc.optim.warmup_percentage
        self.min_lr_ratio = oc.optim.min_lr_ratio
        # logging
        self.save_dir = save_dir
        self.logging_prefix = oc.logging.logging_prefix
        # visualization
        self.train_example_data_idx_list = list(oc.vis.train_example_data_idx_list)
        self.val_example_data_idx_list = list(oc.vis.val_example_data_idx_list)
        self.test_example_data_idx_list = list(oc.vis.test_example_data_idx_list)
        self.eval_example_only = oc.vis.eval_example_only

        self.configure_save(cfg_file_path=oc_file)

        self.valid_mse = torchmetrics.MeanSquaredError()
        self.valid_mae = torchmetrics.MeanAbsoluteError()
        self.valid_ssim = torchmetrics.StructuralSimilarityIndexMeasure()
        self.test_mse = torchmetrics.MeanSquaredError()
        self.test_mae = torchmetrics.MeanAbsoluteError()
        self.test_ssim = torchmetrics.StructuralSimilarityIndexMeasure()

    def configure_save(self, cfg_file_path=None):
        self.save_dir = os.path.join("experiments", self.save_dir)
        os.makedirs(self.save_dir, exist_ok=True)
        self.scores_dir = os.path.join(self.save_dir, 'scores')
        os.makedirs(self.scores_dir, exist_ok=True)
        if cfg_file_path is not None:
            cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml")
            if (not os.path.exists(cfg_file_target_path)) or \
                    (not os.path.samefile(cfg_file_path, cfg_file_target_path)):
                copyfile(cfg_file_path, cfg_file_target_path)
        self.example_save_dir = os.path.join(self.save_dir, "examples")
        os.makedirs(self.example_save_dir, exist_ok=True)

    def get_base_config(self, oc_from_file=None):
        oc = OmegaConf.create()
        oc.layout = self.get_layout_config()
        oc.optim = self.get_optim_config()
        oc.logging = self.get_logging_config()
        oc.trainer = self.get_trainer_config()
        oc.vis = self.get_vis_config()
        oc.model = self.get_model_config()
        if oc_from_file is not None:
            # oc = apply_omegaconf_overrides(oc, oc_from_file)
            oc = OmegaConf.merge(oc, oc_from_file)
        return oc

    @classmethod
    def get_model_config(cls):
        cfg = OmegaConf.create()
        height = 64
        width = 64
        in_len = 10
        out_len = 10
        data_channels = 1
        cfg.input_shape = (in_len, height, width, data_channels)
        cfg.target_shape = (out_len, height, width, data_channels)

        cfg.base_units = 64
        cfg.block_units = None # multiply by 2 when downsampling in each layer
        cfg.scale_alpha = 1.0

        cfg.enc_depth = [1, 1]
        cfg.dec_depth = [1, 1]
        cfg.enc_use_inter_ffn = True
        cfg.dec_use_inter_ffn = True
        cfg.dec_hierarchical_pos_embed = True

        cfg.downsample = 2
        cfg.downsample_type = "patch_merge"
        cfg.upsample_type = "upsample"
        cfg.downsample_conv_norm = None
        cfg.downsample_conv_cfg_list = None
        cfg.upsample_tconv_cfg_list = None

        cfg.num_global_vectors = 8
        cfg.use_dec_self_global = True
        cfg.dec_self_update_global = True
        cfg.use_dec_cross_global = True
        cfg.use_global_vector_ffn = True
        cfg.use_global_self_attn = False
        cfg.use_global_mode = "direct"
        cfg.separate_global_qkv = False
        cfg.global_dim_ratio = 1

        cfg.self_pattern = 'axial'
        cfg.cross_self_pattern = 'axial'
        cfg.cross_pattern = 'cross_1x1'
        cfg.cross_last_n_frames = None

        cfg.attn_drop = 0.1
        cfg.proj_drop = 0.1
        cfg.ffn_drop = 0.1
        cfg.num_heads = 4

        cfg.activation = 'leaky'
        cfg.ffn_activation = 'gelu'
        cfg.gated_ffn = False
        cfg.norm_layer = 'layer_norm'
        cfg.padding_type = 'zeros'
        cfg.pos_embed_type = "t+hw"
        cfg.use_relative_pos = True
        cfg.self_attn_use_final_proj = True
        cfg.dec_use_first_self_attn = False

        cfg.z_init_method = 'zeros'  # The method for initializing the first input of the decoder
        cfg.z_init_token_len = None
        cfg.initial_downsample_type = "conv"
        cfg.initial_downsample_scale = 2
        cfg.initial_final_sample_num_conv = 2
        cfg.checkpoint_level = 2
        # initialization
        cfg.attn_linear_init_mode = "0"
        cfg.ffn_linear_init_mode = "0"
        cfg.conv_init_mode = "0"
        cfg.down_up_linear_init_mode = "0"
        cfg.norm_init_mode = "0"
        return cfg

    @staticmethod
    def get_dataset_config():
        oc = OmegaConf.create()
        oc.dataset_name = "nbody_digits3_len20_size64_r0_train20k"
        oc.num_train_samples = 20000
        oc.num_val_samples = 1000
        oc.num_test_samples = 1000
        oc.digit_num = None
        oc.img_size = 64
        oc.raw_img_size = 128
        oc.seq_len = 20
        oc.raw_seq_len_multiplier = 5
        oc.distractor_num = None
        oc.distractor_size = 5
        oc.max_velocity_scale = 2.0
        oc.initial_velocity_range = (0.0, 2.0)
        oc.random_acceleration_range = (0.0, 0.0)
        oc.scale_variation_range = (1.0, 1.0)
        oc.rotation_angle_range = (-0, 0)
        oc.illumination_factor_range = (1.0, 1.0)
        oc.period = 5
        oc.global_rotation_prob = 0.5
        oc.index_range = (0, 40000)
        oc.mnist_data_path = None
        # N-Body params
        oc.nbody_acc_mode = "r0"
        oc.nbody_G = 0.05
        oc.nbody_softening_distance = 10.0
        oc.nbody_mass = None
        return oc

    @staticmethod
    def get_layout_config():
        oc = OmegaConf.create()
        oc.in_len = 10
        oc.out_len = 10
        oc.layout = "NTHWC"  # The layout of the data, not the model
        return oc

    @staticmethod
    def get_optim_config():
        oc = OmegaConf.create()
        oc.seed = None
        oc.total_batch_size = 32
        oc.micro_batch_size = 8

        oc.method = "adamw"
        oc.lr = 1E-3
        oc.wd = 1E-5
        oc.gradient_clip_val = 1.0
        oc.max_epochs = 50
        # scheduler
        oc.warmup_percentage = 0.2
        oc.lr_scheduler_mode = "cosine"  # Can be strings like 'linear', 'cosine', 'platue'
        oc.min_lr_ratio = 0.1
        oc.warmup_min_lr_ratio = 0.1
        # early stopping
        oc.early_stop = False
        oc.early_stop_mode = "min"
        oc.early_stop_patience = 5
        oc.save_top_k = 1
        return oc

    @staticmethod
    def get_logging_config():
        oc = OmegaConf.create()
        oc.logging_prefix = "NBody"
        oc.monitor_lr = True
        oc.monitor_device = False
        oc.track_grad_norm = -1
        return oc

    @staticmethod
    def get_trainer_config():
        oc = OmegaConf.create()
        oc.check_val_every_n_epoch = 1
        oc.log_step_ratio = 0.001  # Logging every 1% of the total training steps per epoch
        oc.precision = 32
        return oc

    @staticmethod
    def get_vis_config():
        oc = OmegaConf.create()
        oc.train_example_data_idx_list = [0, ]
        oc.val_example_data_idx_list = [0, ]
        oc.test_example_data_idx_list = [0, ]
        oc.eval_example_only = False
        return oc

    def forward(self, batch):
        seq = batch
        in_seq = seq[self.in_slice]
        out_seq = seq[self.out_slice]
        output = self.torch_nn_module(in_seq)
        loss = F.mse_loss(output, out_seq)
        return output, loss

    def configure_optimizers(self):
        # Configure the optimizer. Disable the weight decay for layer norm weights and all bias terms.
        decay_parameters = get_parameter_names(self.torch_nn_module, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [{
            'params': [p for n, p in self.torch_nn_module.named_parameters() if n in decay_parameters],
            'weight_decay': self.oc.optim.wd
        }, {
            'params': [p for n, p in self.torch_nn_module.named_parameters() if n not in decay_parameters],
            'weight_decay': 0.0
        }]

        if self.oc.optim.method == 'adamw':
            optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters,
                                          lr=self.oc.optim.lr,
                                          weight_decay=self.oc.optim.wd)
        else:
            raise NotImplementedError

        warmup_iter = int(np.round(self.oc.optim.warmup_percentage * self.total_num_steps))

        if self.oc.optim.lr_scheduler_mode == 'cosine':
            warmup_scheduler = LambdaLR(optimizer,
                                        lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
                                                                min_lr_ratio=self.oc.optim.warmup_min_lr_ratio))
            cosine_scheduler = CosineAnnealingLR(optimizer,
                                                 T_max=(self.total_num_steps - warmup_iter),
                                                 eta_min=self.oc.optim.min_lr_ratio * self.oc.optim.lr)
            lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
                                        milestones=[warmup_iter])
            lr_scheduler_config = {
                'scheduler': lr_scheduler,
                'interval': 'step',
                'frequency': 1,
            }
        else:
            raise NotImplementedError
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config}

    def set_trainer_kwargs(self, **kwargs):
        r"""
        Default kwargs used when initializing pl.Trainer
        """
        checkpoint_callback = ModelCheckpoint(
            monitor="valid_loss_epoch",
            dirpath=os.path.join(self.save_dir, "checkpoints"),
            filename="model-{epoch:03d}",
            save_top_k=self.oc.optim.save_top_k,
            save_last=True,
            mode="min",
        )
        callbacks = kwargs.pop("callbacks", [])
        assert isinstance(callbacks, list)
        for ele in callbacks:
            assert isinstance(ele, Callback)
        callbacks += [checkpoint_callback, ]
        if self.oc.logging.monitor_lr:
            callbacks += [LearningRateMonitor(logging_interval='step'), ]
        if self.oc.logging.monitor_device:
            callbacks += [DeviceStatsMonitor(), ]
        if self.oc.optim.early_stop:
            callbacks += [EarlyStopping(monitor="valid_loss_epoch",
                                        min_delta=0.0,
                                        patience=self.oc.optim.early_stop_patience,
                                        verbose=False,
                                        mode=self.oc.optim.early_stop_mode), ]

        logger = kwargs.pop("logger", [])
        tb_logger = pl_loggers.TensorBoardLogger(os.path.join(self.save_dir, 'lightning_logs'))
        csv_logger = pl_loggers.CSVLogger(os.path.join(self.save_dir, 'lightning_logs'))
        logger += [tb_logger, csv_logger]

        log_every_n_steps = int(self.oc.trainer.log_step_ratio * self.total_num_steps)
        trainer_init_keys = inspect.signature(Trainer).parameters.keys()
        ret = dict(
            callbacks=callbacks,
            # log
            logger=logger,
            log_every_n_steps=log_every_n_steps,
            flush_logs_every_n_steps=log_every_n_steps,
            track_grad_norm=self.oc.logging.track_grad_norm,
            # save
            default_root_dir=self.save_dir,
            # ddp
            accelerator="gpu",
            # strategy="ddp",
            strategy=ApexDDPPlugin(find_unused_parameters=False, delay_allreduce=True),
            # optimization
            max_epochs=self.oc.optim.max_epochs,
            check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch,
            gradient_clip_val=self.oc.optim.gradient_clip_val,
            # NVIDIA amp
            precision=self.oc.trainer.precision,
        )
        oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer)
        oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys}
        ret.update(oc_trainer_kwargs)
        ret.update(kwargs)
        return ret

    @classmethod
    def get_total_num_steps(
            cls,
            num_samples: int,
            total_batch_size: int,
            epoch: int = None):
        r"""
        Parameters
        ----------
        num_samples:    int
            The number of samples of the dataloader. `num_samples / micro_batch_size` is the number of steps per epoch.
        total_batch_size:   int
            `total_batch_size == micro_batch_size * world_size * grad_accum`
        """
        if epoch is None:
            epoch = cls.get_optim_config().max_epochs
        return int(epoch * num_samples / total_batch_size)

    @staticmethod
    def get_nbody_datamodule(dataset_oc,
                             micro_batch_size: int,
                             num_workers: int = 8):
        data_dir = dataset_oc["dataset_name"]
        if not os.path.exists(data_dir):
            raise ValueError(f"dataset in {data_dir} not exists!")
        load_dataset_cfg_path = os.path.join(data_dir, "nbody_dataset_cfg.yaml")
        load_dataset_cfg = OmegaConf.to_object(OmegaConf.load(open(load_dataset_cfg_path, "r")).dataset)
        for key, val in dataset_oc.items():
            assert val == load_dataset_cfg[key], \
                f"dataset config {key} mismatches!" \
                f"{val} specified, but {load_dataset_cfg[key]} loaded."
        dm = NBodyMovingMNISTLightningDataModule(
            data_dir=data_dir,
            force_regenerate=False,
            num_train_samples=dataset_oc["num_train_samples"],
            num_val_samples=dataset_oc["num_val_samples"],
            num_test_samples=dataset_oc["num_test_samples"],
            digit_num=dataset_oc["digit_num"],
            img_size=dataset_oc["img_size"],
            raw_img_size=dataset_oc["raw_img_size"],
            seq_len=dataset_oc["seq_len"],
            raw_seq_len_multiplier=dataset_oc["raw_seq_len_multiplier"],
            distractor_num=dataset_oc["distractor_num"],
            distractor_size=dataset_oc["distractor_size"],
            max_velocity_scale=dataset_oc["max_velocity_scale"],
            initial_velocity_range=dataset_oc["initial_velocity_range"],
            random_acceleration_range=dataset_oc["random_acceleration_range"],
            scale_variation_range=dataset_oc["scale_variation_range"],
            rotation_angle_range=dataset_oc["rotation_angle_range"],
            illumination_factor_range=dataset_oc["illumination_factor_range"],
            period=dataset_oc["period"],
            global_rotation_prob=dataset_oc["global_rotation_prob"],
            index_range=dataset_oc["index_range"],
            mnist_data_path=dataset_oc["mnist_data_path"],
            # N-Body params
            nbody_acc_mode=dataset_oc["nbody_acc_mode"],
            nbody_G=dataset_oc["nbody_G"],
            nbody_softening_distance=dataset_oc["nbody_softening_distance"],
            nbody_mass=dataset_oc["nbody_mass"],
            # datamodule_only
            batch_size=micro_batch_size,
            num_workers=num_workers,)
        return dm

    @property
    def in_slice(self):
        if not hasattr(self, "_in_slice"):
            in_slice, out_slice = layout_to_in_out_slice(layout=self.layout,
                                                         in_len=self.in_len,
                                                         out_len=self.out_len)
            self._in_slice = in_slice
            self._out_slice = out_slice
        return self._in_slice

    @property
    def out_slice(self):
        if not hasattr(self, "_out_slice"):
            in_slice, out_slice = layout_to_in_out_slice(layout=self.layout,
                                                         in_len=self.in_len,
                                                         out_len=self.out_len)
            self._in_slice = in_slice
            self._out_slice = out_slice
        return self._out_slice

    def training_step(self, batch, batch_idx):
        seq = batch
        x = seq[self.in_slice]
        y = seq[self.out_slice]
        y_hat, loss = self(batch)
        self.save_vis_step_end(
            batch_idx=batch_idx,
            in_seq=x, target_seq=y,
            pred_seq=y_hat,
            mode="train"
        )
        self.log('train_loss', loss, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        seq = batch
        x = seq[self.in_slice]
        y = seq[self.out_slice]
        B, T_out, H, W, C = y.shape
        if not self.eval_example_only or batch_idx in self.val_example_data_idx_list:
            y_hat, _ = self(batch)
            self.save_vis_step_end(
                batch_idx=batch_idx,
                in_seq=x, target_seq=y,
                pred_seq=y_hat,
                mode="val"
            )
            if self.precision == 16:
                y_hat = y_hat.float()
            step_mse = self.valid_mse(y_hat, y) * H * W
            step_mae = self.valid_mae(y_hat, y) * H * W
            y_hat = rearrange(y_hat,
                              "b t h w c -> (b t) c h w")
            y = rearrange(y,
                          "b t h w c -> (b t) c h w")
            step_ssim = self.valid_ssim(y_hat, y)

            self.log('valid_mse_step', step_mse, prog_bar=True, on_step=True, on_epoch=False)
            self.log('valid_mae_step', step_mae, prog_bar=True, on_step=True, on_epoch=False)
            self.log('valid_ssim_step', step_ssim, prog_bar=True, on_step=True, on_epoch=False)
        return H, W

    def validation_epoch_end(self, outputs):
        H, W = outputs[0]
        frame_mse = self.valid_mse.compute() * H * W
        frame_mae = self.valid_mae.compute() * H * W
        valid_loss = frame_mse
        epoch_ssim = self.valid_ssim.compute()

        self.log('valid_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log('valid_mse_epoch', frame_mse, prog_bar=True, on_step=False, on_epoch=True)
        self.log('valid_mae_epoch', frame_mae, prog_bar=True, on_step=False, on_epoch=True)
        self.log('valid_ssim_epoch', epoch_ssim, prog_bar=True, on_step=False, on_epoch=True)
        self.valid_mse.reset()
        self.valid_mae.reset()
        self.valid_ssim.reset()

    def test_step(self, batch, batch_idx):
        seq = batch
        x = seq[self.in_slice]
        y = seq[self.out_slice]
        B, T_out, H, W, C = y.shape
        if not self.eval_example_only or batch_idx in self.test_example_data_idx_list:
            y_hat, _ = self(batch)
            self.save_vis_step_end(
                batch_idx=batch_idx,
                in_seq=x, target_seq=y,
                pred_seq=y_hat,
                mode="test"
            )
            if self.precision == 16:
                y_hat = y_hat.float()
            step_mse = self.test_mse(y_hat, y) * H * W
            step_mae = self.test_mae(y_hat, y) * H * W
            y_hat = rearrange(y_hat,
                              "b t h w c -> (b t) c h w")
            y = rearrange(y,
                          "b t h w c -> (b t) c h w")
            step_ssim = self.test_ssim(y_hat, y)

            self.log('test_mse_step', step_mse, prog_bar=True, on_step=True, on_epoch=False)
            self.log('test_mae_step', step_mae, prog_bar=True, on_step=True, on_epoch=False)
            self.log('test_ssim_step', step_ssim, prog_bar=True, on_step=True, on_epoch=False)
        return H, W

    def test_epoch_end(self, outputs):
        H, W = outputs[0]
        frame_mse = self.test_mse.compute() * H * W
        frame_mae = self.test_mae.compute() * H * W
        epoch_ssim = self.test_ssim.compute()

        self.log('test_mse_epoch', frame_mse, prog_bar=True, on_step=False, on_epoch=True)
        self.log('test_mae_epoch', frame_mae, prog_bar=True, on_step=False, on_epoch=True)
        self.log('test_ssim_epoch', epoch_ssim, prog_bar=True, on_step=False, on_epoch=True)
        self.test_mse.reset()
        self.test_mae.reset()
        self.test_ssim.reset()

    def save_vis_step_end(
            self,
            batch_idx: int,
            in_seq: torch.Tensor, target_seq: torch.Tensor,
            pred_seq: torch.Tensor,
            mode: str = "train"):

        if self.local_rank == 0:
            if mode == "train":
                example_data_idx_list = self.train_example_data_idx_list
            elif mode == "val":
                example_data_idx_list = self.val_example_data_idx_list
            elif mode == "test":
                example_data_idx_list = self.test_example_data_idx_list
            else:
                raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].")
            if batch_idx in example_data_idx_list:
                micro_batch_size = in_seq.shape[self.layout.find("N")]
                data_idx = int(batch_idx * micro_batch_size)
                save_example_vis_results(
                    save_dir=self.example_save_dir,
                    save_prefix=f'{mode}_epoch_{self.current_epoch}_data_{data_idx}',
                    in_seq=in_seq.detach().float().cpu().numpy(),
                    target_seq=target_seq.detach().float().cpu().numpy(),
                    pred_seq=pred_seq.detach().float().cpu().numpy(),
                    layout=self.layout,
                    plot_stride=1,
                    label=self.oc.logging.logging_prefix)

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save', default='tmp_nbody', type=str)
    parser.add_argument('--gpus', default=1, type=int)
    parser.add_argument('--cfg', default=None, type=str)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--ckpt_name', default=None, type=str,
                        help='The model checkpoint trained on N-Body MovingMNIST.')
    return parser

def main():
    parser = get_parser()
    args = parser.parse_args()
    if args.cfg is not None:
        oc_from_file = OmegaConf.load(open(args.cfg, "r"))
        dataset_oc = OmegaConf.to_object(oc_from_file.dataset)
        total_batch_size = oc_from_file.optim.total_batch_size
        micro_batch_size = oc_from_file.optim.micro_batch_size
        max_epochs = oc_from_file.optim.max_epochs
        seed = oc_from_file.optim.seed
    else:
        dataset_oc = OmegaConf.to_object(CuboidNBodyPLModule.get_dataset_config())
        micro_batch_size = 1
        total_batch_size = int(micro_batch_size * args.gpus)
        max_epochs = None
        seed = 0
    seed_everything(seed, workers=True)
    dm = CuboidNBodyPLModule.get_nbody_datamodule(
        dataset_oc=dataset_oc,
        micro_batch_size=micro_batch_size,
        num_workers=8,)
    dm.prepare_data()
    dm.setup()
    accumulate_grad_batches = total_batch_size // (micro_batch_size * args.gpus)
    total_num_steps = CuboidNBodyPLModule.get_total_num_steps(
        epoch=max_epochs,
        num_samples=dm.num_train_samples,
        total_batch_size=total_batch_size,
    )
    pl_module = CuboidNBodyPLModule(
        total_num_steps=total_num_steps,
        save_dir=args.save,
        oc_file=args.cfg)
    trainer_kwargs = pl_module.set_trainer_kwargs(
        devices=args.gpus,
        accumulate_grad_batches=accumulate_grad_batches,
    )
    trainer = Trainer(**trainer_kwargs)
    if args.test:
        assert args.ckpt_name is not None, f"args.ckpt_name is required for test!"
        ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
        trainer.test(model=pl_module,
                     datamodule=dm,
                     ckpt_path=ckpt_path)
    else:
        if args.ckpt_name is not None:
            ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
            if not os.path.exists(ckpt_path):
                warnings.warn(f"ckpt {ckpt_path} not exists! Start training from epoch 0.")
                ckpt_path = None
        else:
            ckpt_path = None
        trainer.fit(model=pl_module,
                    datamodule=dm,
                    ckpt_path=ckpt_path)
        trainer.test(ckpt_path="best",
                     datamodule=dm)

if __name__ == "__main__":
    main()
