"""
The whole framework
"""

import os
import numpy as np
import os.path as osp
from tqdm import tqdm
from time import time
import random
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from .basetrainer import BaseTrainer
from models.renderer import RenderNet
from models.transmodel import ParticleNet
from models.encoder import MyParticleNetwork, GaussianGRU
from datasets.dataset import BlenderDataset
from utils.particles_utils import record2obj
from utils.point_eval import FluidErrors
from utils.lr_schedulers import ExponentialLR

from torch import distributions as torchd
from pytorch3d.ops import ball_query
import utils.utils as utils
from pytorch3d.loss import chamfer_distance


img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]).cuda())
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)

class Trainer(BaseTrainer):
    def init_fn(self):
        self.start_step = 0
        self.eval_count = 0
        self.encoder_dim = self.options['encoder']['stoch']
        latent_dim = self.encoder_dim * 2 if 'sample' in self.options.TRAIN.get_feat else self.encoder_dim
        grid_res = self.options.TRAIN.grid_res
        if self.options.TRAIN.get_feat == 'grid_sample_multi':
            latent_dim = [grid_res**3, latent_dim] # if use one distribution each grid, the latent is grid_res**3 times larger
        if self.options.TRAIN.get_feat == 'particle_sample_multi':
            latent_dim = [self.options.TRAIN.particle_res, latent_dim]
        self.latent = torch.randn(latent_dim)
        self.adapt = self.options.TRAIN.adapt
        self.build_dataloader()
        if self.options.TRAIN.LR.latent_lr != 0:
            self.current_stage = 1
        elif self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.current_stage = 2
        self.build_model()
        if self.options.TRAIN.LR.latent_lr != 0:
            self.latent = nn.Parameter(self.latent)
        if self.options.TRAIN.LR.latent_lr != 0:
            self.build_latent_optimizer()
            print('\033[1;35m Current stage: stage B\033[0m')
            self.current_stage = 1
        elif self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.build_encoder_optimizer()
            print('\033[1;35m Current stage: stage C\033[0m')
            self.current_stage = 2
        self.set_RGB_criterion()
        self.set_L1_criterion()
        self.save_interval = self.options.TRAIN.save_interval
        self.log_interval = self.options.TRAIN.log_interval
        if self.options.TRAIN.get_feat in ['grid_sample', 'grid_sample_multi']:
            self.feat_fn = lambda pos: self.get_feat_from_grid(pos=pos, grid_res=grid_res)
        elif self.options.TRAIN.get_feat == 'particle_sample':
            self.feat_fn = self.get_feat
        elif self.options.TRAIN.get_feat == 'particle_deter':
            self.feat_fn = self.get_feat_deter
        elif self.options.TRAIN.get_feat == 'particle_sample_multi':
            self.feat_fn = self.get_feat_multi
        else:
            raise ValueError

        self._discrete = False

        self.transition_model.requires_grad_(False)
        # self.renderer.requires_grad_(False)
        # if self.adapt:
        #     for name, m in self.transition_model.named_parameters():
        #         if 'bypass' in name:
        #             m.requires_grad_(True)
        if self.options.TRAIN.LR.renderer_lr != 0. or self.options.TRAIN.LR.trans_lr != 0.:
            self.build_optimizer()

        init_particle_path = self.options.TRAIN.init_particle_path
        if init_particle_path:
            print('---> Initial position', init_particle_path)
            try:
                self.init_pos = torch.Tensor(np.load(init_particle_path)['particles']).to(self.device)
            except:
                self.init_pos = torch.Tensor(np.load(init_particle_path)['pos']).to(self.device)
            if self.options.TRAIN.particle_res < self.init_pos.shape[0]:
                rand_idx = np.random.permutation(self.init_pos.shape[0])[:self.options.TRAIN.particle_res]
                self.init_pos = self.init_pos[rand_idx]
        else:
            self.init_pos = None

        init_target_particle_path = self.options.TRAIN.target_init_particle_path
        if init_target_particle_path:
            print('---> Target Initial position', init_target_particle_path)
            self.target_init_pos = torch.Tensor(np.load(init_target_particle_path)['pos']).to(self.device)
        else:
            self.target_init_pos = None

        self.best_gt2pred = np.inf
        self.best_true_pred2gt = np.inf


    def build_dataloader(self):
        self.train_view_names = self.options['train'].views.dynamic
        self.test_viewnames = self.options['test'].views
        self.dataset = BlenderDataset(self.options.train.path, self.options,
                                            imgW=self.options.TRAIN.imgW, imgH=self.options.TRAIN.imgH,
                                            imgscale=self.options.TRAIN.scale, viewnames=self.train_view_names, split='train')
        self.dataset_length = len(self.dataset)
        self.test_dataset = BlenderDataset(self.options.train.path, self.options,
                                            imgW=self.options.TEST.imgW, imgH=self.options.TEST.imgH,
                                            imgscale=self.options.TEST.scale, viewnames=self.test_viewnames, split='train')
        self.target_dataset = BlenderDataset(self.options.target.path, self.options,
                                            imgW=self.options.TEST.imgW, imgH=self.options.TEST.imgH,
                                            imgscale=self.options.TEST.scale, viewnames=['view_0'], split='train')
        self.test_dataset_length = len(self.test_dataset)
        print('---> dataloader has been build')


    def build_model(self):
        # build model
        gravity = self.options.gravity
        print('---> set gravity', gravity)
        self.transition_model = ParticleNet(gravity=gravity, other_feats_channels=self.encoder_dim, adapt=self.adapt).to(self.device)
        self.renderer = RenderNet(self.options.RENDERER, near=self.options.near, far=self.options.far).to(self.device)

        # load pretrained checkpoints
        if self.options.TRAIN.pretrained_transition_model != '':
            self.load_pretained_transition_model(self.options.TRAIN.pretrained_transition_model)
            # self.load_pretained_encoder_model(self.options.TRAIN.pretrained_transition_model)
            print('\033[1;35m load: \033[0m', self.options.TRAIN.pretrained_transition_model)
        if self.options.TRAIN.pretained_renderer != '':
            self.load_pretained_renderer_model(self.options.TRAIN.pretained_renderer, partial_load=self.options.TRAIN.partial_load)
            print('\033[1;35m load: \033[0m', self.options.TRAIN.pretained_renderer)

        if self.options.TRAIN.use_encoder:
            if self.options['encoder']['input_last_latent']:
                encoder_dim = self.options['encoder']['stoch']
                if self.options['encoder'].get('use_std', False):
                    encoder_dim *= 2
                self.encoder = MyParticleNetwork(other_feats_channels=encoder_dim).to(self.device)
            else:
                self.encoder = MyParticleNetwork().to(self.device)
            self.prior_gru = GaussianGRU(output_size=self.options['encoder']['stoch'], mean_act=self.options['encoder']['mean_act']).to(self.device)
            checkpoint = torch.load(self.options.TRAIN.pretrained_transition_model)
            self.encoder.load_state_dict(checkpoint['encoder'], strict=True)
            print('\n load encoder encoder:')
            self.prior_gru.load_state_dict(checkpoint['prior_gru'], strict=True)
            print('\n load encoder gru:')
            if 'latent' in checkpoint.keys():
                self.latent = checkpoint['latent']
                print('\n loaded latent')

        if self.options.TRAIN.pretrained_latent != '':
            checkpoint = torch.load(self.options.TRAIN.pretrained_latent)
            self.latent = checkpoint['latent']
            print('\n loaded latent')
        else:
            if self.current_stage > 1:
                try:
                    latent_path = os.path.join(self.exppath, '../../stage_b/latent1e-4/models/98000.pt')
                    checkpoint = torch.load(latent_path)
                    self.latent = checkpoint['latent']
                    print('\n loaded latent')
                    print('loaded stage 1 last checkpoint:\n --->', latent_path)
                except:
                    # pass
                    raise NotImplementedError


    def build_optimizer(self):
        if self.options.TRAIN.loss_weight['encoder_KL_loss'] != 0.:
            self.encoder.requires_grad_(False)
        renderer_lr = self.options.TRAIN.LR.renderer_lr
        transition_lr = self.options.TRAIN.LR.trans_lr
        seperate_render_transition = self.options.TRAIN.seperate_render_transition
        if seperate_render_transition:
            self.optimizer = torch.optim.Adam([
                {'params': self.renderer.parameters(), 'lr': renderer_lr},
            ])
            if self.adapt:
                self.transition_optimizer = torch.optim.AdamW([
                    {'params': [param for name, param in self.transition_model.named_parameters() if 'bypass' in name], 'lr': transition_lr,
                    'weight_decay': self.options.TRAIN.LR.trans_weight_decay
                    },
                ])
                for param_group in self.transition_optimizer.param_groups:
                    print(param_group)
            else:
                self.transition_optimizer = torch.optim.Adam([
                    {'params': self.transition_model.parameters(), 'lr': transition_lr},
                ])
        else:
            raise NotImplementedError('adapt bypass parameters not implemented')
            self.optimizer = torch.optim.Adam([
                {'params': self.renderer.parameters(), 'lr': renderer_lr},
                {'params': self.transition_model.parameters(), 'lr': transition_lr}
                ])
        if self.options.TRAIN.LR.use_scheduler:
            boundaries = [
                50000,  # 10k
                100000,  # 75k
                200000,  # 150k
            ]
            lr_values = [
                1.0,
                0.5,
                0.25,
                0.125,
            ]

            def lrfactor_fn(x):
                factor = lr_values[0]
                for b, v in zip(boundaries, lr_values[1:]):
                    if x > b:
                        factor = v
                    else:
                        break
                return factor

            self.optim_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lrfactor_fn)

            if seperate_render_transition:
                boundaries_trans = [
                    80000, # 10k
                    120000,
                    160000,
                    200000,
                    300000,
                ]
                lr_values_trans = [
                    1.0,
                    0.5,
                    0.25,
                    0.125,
                    0.5 * 0.125,
                    0.25 * 0.125,
                    0.125 * 0.125,
                ]

                def lrfactor_fn_transition(x):
                    factor = lr_values[0]
                    for b, v in zip(boundaries_trans, lr_values_trans[1:]):
                        if x > b:
                            factor = v
                        else:
                            break
                    return factor

                self.optim_lr_scheduler_transition = torch.optim.lr_scheduler.LambdaLR(self.transition_optimizer, lrfactor_fn_transition)


    def build_latent_optimizer(self):
        latent_lr = self.options.TRAIN.LR.latent_lr
        self.latent_optimizer = torch.optim.Adam([
            {'params': self.latent, 'lr': latent_lr}
        ])
        if self.options.TRAIN.LR.use_scheduler_latent:
            self.optim_lr_scheduler_latent = torch.optim.lr_scheduler.CosineAnnealingLR(self.latent_optimizer, T_max=5000, eta_min=latent_lr*0.1)

    def build_encoder_optimizer(self):
        encoder_lr = self.options.TRAIN.LR.encoder_lr
        self.encoder_optimizer = torch.optim.Adam([
            {'params': self.encoder.parameters(), 'lr': encoder_lr},
            {'params': self.prior_gru.parameters(), 'lr': encoder_lr}
        ])
        if self.options.TRAIN.LR.use_scheduler_encoder:
            self.optim_lr_scheduler_encoder = torch.optim.lr_scheduler.CosineAnnealingLR(self.encoder_optimizer, T_max=5000, eta_min=encoder_lr*0.1)

    def resume(self, ckpt_file):
        checkpoint = torch.load(ckpt_file)
        self.start_step = checkpoint['step']
        self.renderer.load_state_dict(checkpoint['renderer_state_dict'], strict=True)
        self.transition_model.load_state_dict(checkpoint['transition_model_state_dict'], strict=True)
        self.latent = nn.Parameter(checkpoint['latent'])
        latent_lr = self.options.TRAIN.LR.latent_lr
        self.latent_optimizer = torch.optim.Adam([
            {'params': self.latent, 'lr': latent_lr}
        ])

        if self.options.TRAIN.LR.use_scheduler_latent:
            self.optim_lr_scheduler_latent = torch.optim.lr_scheduler.CosineAnnealingLR(self.latent_optimizer, T_max=3000, eta_min=1e-4)

            print('\n \033[1;35m-----!!!Resume and reload latent!!!------\033[0m')
            print('start step:', self.start_step, end='\n')

    def save_checkpoint(self, global_step, is_best=False):
        if self.options.TRAIN.LR.latent_lr != 0:
            model_dicts = {'step':global_step,
                            'renderer_state_dict':self.renderer.state_dict(),
                            'transition_model_state_dict':self.transition_model.state_dict(),
                            'latent': self.latent,
                            'latent_optimizer_state_dict': self.latent_optimizer.state_dict()
                            }
        # elif self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
        else:
            model_dicts = {'step':global_step,
                            'renderer_state_dict':self.renderer.state_dict(),
                            'transition_model_state_dict':self.transition_model.state_dict(),
                            'latent': self.latent,
                            'encoder': self.encoder.state_dict(),
                            'prior_gru': self.prior_gru.state_dict(),
                            }
        if self.adapt:
            model_dicts['transition_optimizer_state_dict'] = self.transition_optimizer.state_dict()
        torch.save(model_dicts,
                    osp.join(self.exppath, 'models', f'{global_step}.pt'))
        if is_best:
            torch.save(model_dicts,
                    osp.join(self.exppath, 'models', f'best.pt'))


    def train(self,):
        # prepare training
        global_step = self.start_step
        if self.options.TRAIN.epochs != 0:
            if self.current_stage > 1:
                self.eval_target(global_step)
                # exit()
            self.eval(global_step)
        view_num = len(self.train_view_names)
        imgW, imgH = self.options.TRAIN.imgW, self.options.TRAIN.imgH
        img_scale = self.options.TRAIN.scale
        H = int(imgH // img_scale)
        W = int(imgW // img_scale)

        # self.transition_model.eval()
        self.renderer.eval()

        for epoch_idx in tqdm(range(int(self.start_step / 49), self.options.TRAIN.epochs), total=self.options.TRAIN.epochs, desc='Epoch:'):
            self.tmp_fluid_error = FluidErrors(log_emd=False)
            for data_idx in range(self.dataset_length):
                data = self.dataset[data_idx]
                keys = ['particles_vel', 'particles_pos_1', 'cw_1', 'rgb_1', 'rays_1', 'focal', 'box', 'box_normals']
                data = {k: data[k].to(self.device) if isinstance(data[k], torch.Tensor) else data[k] for k in keys}

                # data = self.dataset[data_idx]
                # data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in data.items()}
                # training
                loss = self.train_step(data, data_idx, view_num, H, W, global_step)
                if self.options.TRAIN.LR.latent_lr != 0:
                    if self.options.TRAIN.LR.renderer_lr == 0. and self.options.TRAIN.LR.trans_lr == 0.:
                        self.update_step_latents(loss, global_step)
                        if global_step == 0:
                            print('latent optimizing')
                    elif self.options.TRAIN.LR.renderer_lr != 0. or self.options.TRAIN.LR.trans_lr != 0.:
                        self.update_step_latents_models(loss, global_step)
                        if global_step == 0:
                            print('latent and model optimizing')
                elif self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
                    if self.options.TRAIN.LR.renderer_lr == 0. and self.options.TRAIN.LR.trans_lr == 0.:
                        self.update_encoder(loss, global_step) # train encoder only
                        if global_step == 0:
                            print('encoder optimizing')
                    elif self.options.TRAIN.LR.renderer_lr != 0. or self.options.TRAIN.LR.trans_lr != 0.:
                        # train encoder, transition, renderer together
                        self.update_encoder_and_models(loss, global_step)
                        if global_step == 0:
                            print('encoder and model optimizing')
                elif self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr == 0:
                    self.update_step_models(loss, global_step)
                    if global_step == 0:
                        print('model optimizing')
                global_step += 1

                # evaluation
                if global_step != 0 and global_step % self.save_interval == 0:
                    self.eval(global_step)
                    self.save_checkpoint(global_step)

            if self.current_stage > 1 and epoch_idx % 10 == 0:
                self.eval_target(global_step)
        self.save_checkpoint(global_step)
        # self.eval(global_step)
        self.eval_end2end()

    def second_trainsition_step_for_training(self, data, data_idx):
        box = data['box']
        box_normals = data['box_normals']

        if data_idx == 0:
            self.vel_for_next_step = data['particles_vel']
            # load initial particles extracted by static DVGO
            if self.init_pos is not None:
                self.pos_for_next_step = self.init_pos
            else:
                self.pos_for_next_step = data['particles_pos']
            self.prior_gru.init_hidden(self.options.TRAIN.particle_res)
        self.prior_gru.stop_gradient()
        if self.options['encoder']['input_last_latent']:
            if data_idx == 0:
                if self.options['encoder'].get('use_std', False):
                    particle_feat = torch.zeros([self.options.TRAIN.particle_res, 2 * self.encoder_dim]).to(box.device)
                else:
                    particle_feat = torch.zeros([self.options.TRAIN.particle_res, self.encoder_dim]).to(box.device)
            else:
                if self.options['encoder']['use_mean']:
                    particle_feat = self.prior_feat_mean
                    if self.options['encoder'].get('use_std', False):
                        particle_feat = torch.cat([self.prior_feat_mean, self.prior_feat_std], dim=-1)
        else:
            particle_feat = None
        if particle_feat is not None:
            particle_feat = particle_feat.detach().clone()
            particle_feat.requires_grad = False
        input_prior = [self.pos_for_next_step, self.vel_for_next_step, particle_feat, box, box_normals]
        h = self.encoder(input_prior)
        particle_feat, prior_stat = self.prior_gru(h)
        self.prior_feat_mean = prior_stat['mean']
        self.prior_feat_std = prior_stat['std']
        pred_pos, pred_vel, num_fluid_nn = self.transition_model(self.pos_for_next_step, self.vel_for_next_step, box, box_normals, feats=particle_feat)

        in_mask_proportion = 1
        if self.options.TRAIN.get('outside_clip', False):
            in_mask = (pred_pos > self.pos_min).all(dim=-1) & (pred_pos < self.pos_max).all(dim=-1)
            pred_pos = torch.where(pred_pos > self.pos_max, self.pos_max, pred_pos)
            pred_pos = torch.where(pred_pos < self.pos_min, self.pos_min, pred_pos)
            pred_vel = (pred_pos - self.pos_for_next_step) / self.transition_model.time_step
            in_mask_num = in_mask.sum(dim=-1)
            in_mask_proportion = in_mask_num / pred_pos.shape[0]

        kl_loss = 0.
        if self.options.TRAIN.loss_weight['encoder_KL_loss'] != 0.:
            dist = lambda x: self.prior_gru.get_dist(x)
            kl_loss = torchd.kl.kl_divergence(dist(prior_stat)._dist, self.get_dist(self.latent)._dist)
            if 'grid' in self.options.TRAIN.get_feat:
                raise NotImplementedError('adapt bypass parameters not implemented')
                kl_loss = (mask * kl_loss).mean()  # Add mask for points for particles with no neighbors
            else:
                kl_loss = kl_loss.mean()

        self.pos_for_next_step, self.vel_for_next_step = pred_pos.clone().detach(),pred_vel.clone().detach()
        self.pos_for_next_step.requires_grad = False
        self.vel_for_next_step.requires_grad = False
        return pred_pos, kl_loss, in_mask_proportion

    def first_trainsition_step_for_training(self, data, data_idx):
        box = data['box']
        box_normals = data['box_normals']
        if data_idx == 0:
            self.vel_for_next_step = data['particles_vel']
            # load initial particles extracted by static DVGO
            if self.init_pos is not None:
                self.pos_for_next_step = self.init_pos
                if self.adapt:
                    assert self.vel_for_next_step.abs().mean() < 1e-5
                    self.vel_for_next_step = torch.zeros_like(self.init_pos)
                if self.options.TRAIN.particle_res != self.vel_for_next_step.shape[0]:
                    self.vel_for_next_step = torch.zeros_like(self.pos_for_next_step)
            else:
                self.pos_for_next_step = data['particles_pos']
            self.pos0 = self.pos_for_next_step.clone().detach()
            if self.options.TRAIN.get_feat == 'grid_sample_multi':
                self.fluid_feats = self.feat_fn(self.pos_for_next_step)
            if self.options.TRAIN.get_feat == 'grid_sample_multi':
                self.fluid_feats = self.feat_fn(self.pos_for_next_step)
        if self.options.TRAIN.get_feat != 'grid_sample_multi':
            fluid_feats = self.feat_fn(self.pos_for_next_step)
            pred_pos, pred_vel, num_fluid_nn = self.transition_model(self.pos_for_next_step, self.vel_for_next_step, box, box_normals, feats=fluid_feats)
        else:
            fluid_feats = self.feat_fn(self.pos0)
            pred_pos, pred_vel, num_fluid_nn = self.transition_model(self.pos_for_next_step, self.vel_for_next_step, box, box_normals, feats=fluid_feats)

        in_mask_proportion = 1
        if self.options.TRAIN.get('outside_clip', False):
            in_mask = (pred_pos > self.pos_min).all(dim=-1) & (pred_pos < self.pos_max).all(dim=-1)
            pred_pos = torch.where(pred_pos > self.pos_max, self.pos_max, pred_pos)
            pred_pos = torch.where(pred_pos < self.pos_min, self.pos_min, pred_pos)
            pred_vel = (pred_pos - self.pos_for_next_step) / self.transition_model.time_step
            in_mask_num = in_mask.sum(dim=-1)
            in_mask_proportion = in_mask_num / pred_pos.shape[0]

        # if data_idx == 0:
        # pos_min = torch.cat([self.pos_for_next_step, pred_pos], axis=0).min(axis=0)[0].cuda()
        # pos_max = torch.cat([self.pos_for_next_step, pred_pos], axis=0).max(axis=0)[0].cuda()
        kl_loss = 0.
        if self.options.TRAIN.loss_weight['encoder_KL_loss'] != 0.:
            if self.options.TRAIN.get_feat == 'grid_sample':
                pos_min = self.pos_for_next_step.min(axis=0).values
                pos_max = self.pos_for_next_step.max(axis=0).values
            elif self.options.TRAIN.get_feat == 'grid_sample_multi':
                pos_min = torch.Tensor([-1, -1, -1]).to(self.device)
                pos_max = torch.Tensor([1, 1, 2.4552]).to(self.device)

            if 'grid' in self.options.TRAIN.get_feat:
                grid_res = self.options.TRAIN.grid_res # cfg['grid_res']
                grid_xyz = torch.stack(
                    torch.meshgrid(
                        torch.linspace(pos_min[0], pos_max[0], grid_res),
                        torch.linspace(pos_min[1], pos_max[1], grid_res),
                        torch.linspace(pos_min[2], pos_max[2], grid_res),
                    ), -1).to(pos_min.device)
                grid_input = grid_xyz.reshape(-1, 3)
                search_radius = self.encoder.filter_extent * 4  # filter_extent * kernel size, kernel size is hard coded as 4
                all_particles = [self.pos_for_next_step, pred_pos]
                all_particles = torch.stack(all_particles)
                dists, indices, neighbors = ball_query(p1=grid_input.unsqueeze(0).repeat(2, 1, 1),
                                                    p2=all_particles,
                                                    radius=search_radius,
                                                    K=1)
                mask = torch.all(dists != 0, dim=0).squeeze(-1).detach()

            input_prior = []
            input_prior.append([self.pos_for_next_step, self.vel_for_next_step,  None, box, box_normals, grid_xyz])
            input_prior.append([pred_pos, pred_vel,  None, box, box_normals, grid_xyz])

            prior_feat, prior_stat = self.encoder(input_prior)
            dist = lambda x: self.encoder.get_dist(x)
            kl_loss = torchd.kl.kl_divergence(dist(prior_stat)._dist, self.get_dist(self.latent)._dist)
            if 'grid' in self.options.TRAIN.get_feat:
                kl_loss = (mask * kl_loss).mean()  # Add mask for points for particles with no neighbors
            else:
                kl_loss = kl_loss.mean()

        self.pos_for_next_step, self.vel_for_next_step = pred_pos.clone().detach(),pred_vel.clone().detach()
        self.pos_for_next_step.requires_grad = False
        self.vel_for_next_step.requires_grad = False
        return pred_pos, kl_loss, in_mask_proportion


    def train_step(self, data, data_idx, view_num, H, W, global_step):
        # -----
        # particle transition
        # -----
        if self.options.TRAIN.use_latent:
            pred_pos, kl_loss, in_mask_proportion = self.first_trainsition_step_for_training(data, data_idx)
        elif not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
            pred_pos, kl_loss, in_mask_proportion = self.second_trainsition_step_for_training(data, data_idx)
        if global_step % self.log_interval == 0 and global_step != 0:
            pos_t1 = data['particles_pos_1']
            # dist_pred2gt, dist_emd = self.tmp_fluid_error.cal_errors(pred_pos, pos_t1, data_idx+1)
            dist_gt2pred = self.tmp_fluid_error.cal_errors(pred_pos.detach().cpu().numpy(), pos_t1.detach().cpu().numpy(), data_idx+1)
            dist_pred2gt = self.tmp_fluid_error.cal_errors(pos_t1.detach().cpu().numpy(), pred_pos.detach().cpu().numpy(), data_idx+1)
            self.summary_writer.add_scalar(f'Train/pred2gt_distance', dist_gt2pred, global_step)
            self.summary_writer.add_scalar(f'Train/true_pred2gt_distance', dist_pred2gt, global_step)
            if self.options.TRAIN.get('outside_clip', False):
               self.summary_writer.add_scalar(f'Train/in_mask_proportion', in_mask_proportion, global_step)
            # self.summary_writer.add_scalar(f'Train/emd_distance', dist_emd, global_step)

        # -----
        # rendering
        # -----
        ray_chunk = self.options.RENDERER.ray.ray_chunk
        N_importance = self.options.RENDERER.ray.N_importance
        total_loss = 0.
        # for view_idx in range(view_num):
        view_idx = random.choice(list(range(view_num)))
        # -------
        # render by a nerf model, and then calculate mse loss
        # -------
        view_name = self.train_view_names[view_idx]
        cw_t1 = data['cw_1'][view_idx]
        rgbs_t1 = data['rgb_1'][view_idx]
        focal_length = data['focal'][view_idx]
        rays_t1 = data['rays_1'][view_idx]
        # randomly sample pixel
        coords = self.random_sample_coords(H,W,global_step)
        coords = torch.reshape(coords, [-1,2])
        select_inds = np.random.choice(coords.shape[0], size=[ray_chunk], replace=False)
        select_coords = coords[select_inds].long()
        rays_t1 = rays_t1[select_coords[:, 0], select_coords[:, 1]]
        rgbs_t1 = rgbs_t1.view(H, W, -1)[select_coords[:, 0], select_coords[:, 1]]
        ro_t1 = self.renderer.set_ro(cw_t1)
        render_ret = self.render_image(pred_pos, ray_chunk, ro_t1, rays_t1, focal_length, cw_t1)
        # calculate mse loss
        rgbloss_0 = self.rgb_criterion(render_ret['pred_rgbs_0'], rgbs_t1[:ray_chunk])
        if N_importance>0:
            rgbloss_1 = self.rgb_criterion(render_ret['pred_rgbs_1'], rgbs_t1[:ray_chunk])
            rgbloss = rgbloss_0 + rgbloss_1
        else:
            rgbloss = rgbloss_0
        total_loss = total_loss+rgbloss

        # log
        if global_step % self.log_interval == 0 and global_step != 0:
            self.summary_writer.add_scalar(f'{view_name}/rgbloss_0', rgbloss_0.item(), global_step)
            self.summary_writer.add_scalar(f'{view_name}/rgbloss', rgbloss.item(), global_step)
            self.summary_writer.add_histogram(f'{view_name}/num_neighbors_0', render_ret['num_nn_0'], global_step)
            if N_importance>0:
                self.summary_writer.add_scalar(f'{view_name}/rgbloss_1', rgbloss_1.item(), global_step)
                self.summary_writer.add_histogram(f'{view_name}/num_neighbors_1', render_ret['num_nn_1'], global_step)

        if self.options.TRAIN.loss_weight['boundary_loss'] != 0.:
            bd_loss = self.cal_boundary_loss(pred_pos)
            total_loss = total_loss + bd_loss * self.options.TRAIN.loss_weight['boundary_loss']
            if (global_step+1) % self.log_interval == 0:
                self.summary_writer.add_scalar(f'boudary_loss', bd_loss.item(), global_step)
        if self.options.TRAIN.loss_weight['encoder_KL_loss'] != 0.:
            total_loss = total_loss + kl_loss * self.options.TRAIN.loss_weight['encoder_KL_loss']
            if (global_step+1) % self.log_interval == 0:
                self.summary_writer.add_scalar(f'encoder_kl_loss', kl_loss.item(), global_step)
        return total_loss


    def update_step_latents_models(self, loss, global_step):
        grad_clip_value = self.options.TRAIN.grad_clip_value
        seperate_render_transition = self.options.TRAIN.seperate_render_transition

        self.latent_optimizer.zero_grad()
        self.optimizer.zero_grad()
        if seperate_render_transition:
            self.transition_optimizer.zero_grad()

        loss.backward()
        if grad_clip_value != 0:
            torch.nn.utils.clip_grad_norm_(self.latent, grad_clip_value)
            self.summary_writer.add_histogram('latent_grad', self.latent.grad.norm(), global_step)
            torch.nn.utils.clip_grad_norm_(self.renderer.parameters(), grad_clip_value)
            torch.nn.utils.clip_grad_norm_(self.transition_model.parameters(), grad_clip_value)

        self.latent_optimizer.step()
        if self.options.TRAIN.LR.use_scheduler_latent:
            self.optim_lr_scheduler_latent.step()
        self.optimizer.step()
        if seperate_render_transition:
            self.transition_optimizer.step()
        if self.options.TRAIN.LR.use_scheduler:
            self.optim_lr_scheduler.step()
            if seperate_render_transition:
                self.optim_lr_scheduler_transition.step()

        if global_step != 0 and global_step % self.log_interval == 0:
            lrs = self.get_learning_rate(self.latent_optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_latent_{i}', lr, global_step)
            lrs = self.get_learning_rate(self.optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_{i}', lr, global_step)
            if seperate_render_transition:
                lrs = self.get_learning_rate(self.transition_optimizer)
                for i,lr in enumerate(lrs):
                    self.summary_writer.add_scalar(f'learning_rate/lr_transition_{i}', lr, global_step)

    def update_step_latents(self,loss, global_step):
        grad_clip_value = self.options.TRAIN.grad_clip_value
        seperate_render_transition = self.options.TRAIN.seperate_render_transition

        self.latent_optimizer.zero_grad()
        if self.adapt:
            self.optimizer.zero_grad()
            self.transition_optimizer.zero_grad()
        loss.backward()
        if grad_clip_value != 0:
            torch.nn.utils.clip_grad_norm_(self.latent, grad_clip_value)
            self.summary_writer.add_histogram('latent_grad', self.latent.grad.norm(), global_step)
        self.latent_optimizer.step()
        if self.options.TRAIN.LR.use_scheduler_latent:
            self.optim_lr_scheduler_latent.step()
        if self.adapt:
            self.optimizer.step()
            if seperate_render_transition:
                self.transition_optimizer.step()
            if self.options.TRAIN.LR.use_scheduler:
                self.optim_lr_scheduler.step()
                if seperate_render_transition:
                    self.optim_lr_scheduler_transition.step()

        if global_step != 0 and global_step % self.log_interval == 0:
            lrs = self.get_learning_rate(self.latent_optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_latent_{i}', lr, global_step)
            if self.adapt:
                lrs = self.get_learning_rate(self.optimizer)
                for i,lr in enumerate(lrs):
                    self.summary_writer.add_scalar(f'learning_rate/lr_{i}', lr, global_step)
                if seperate_render_transition:
                    lrs = self.get_learning_rate(self.transition_optimizer)
                    for i,lr in enumerate(lrs):
                        self.summary_writer.add_scalar(f'learning_rate/lr_transition_{i}', lr, global_step)


    def update_encoder(self,loss, global_step):
        grad_clip_value = self.options.TRAIN.grad_clip_value
        seperate_render_transition = self.options.TRAIN.seperate_render_transition

        # encoder_grad = self.cal_grad_norm(self.encoder)
        # self.summary_writer.add_histogram('encoder_grad/encoder_grad_before', encoder_grad, global_step)

        self.encoder_optimizer.zero_grad()
        if self.adapt:
            self.optimizer.zero_grad()
            self.transition_optimizer.zero_grad()
        loss.backward()

        if grad_clip_value != 0:
            torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), grad_clip_value)
            encoder_grad = self.cal_grad_norm(self.encoder)
            self.summary_writer.add_histogram('encoder_grad', encoder_grad, global_step)
        self.encoder_optimizer.step()
        if self.options.TRAIN.LR.use_scheduler_encoder:
            self.optim_lr_scheduler_encoder.step()
        if self.adapt:
            self.optimizer.step()
            if seperate_render_transition:
                self.transition_optimizer.step()
            if self.options.TRAIN.LR.use_scheduler:
                self.optim_lr_scheduler.step()
                if seperate_render_transition:
                    self.optim_lr_scheduler_transition.step()

        if global_step != 0 and global_step % self.log_interval == 0:
            lrs = self.get_learning_rate(self.encoder_optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_encoder_{i}', lr, global_step)
            if self.adapt:
                lrs = self.get_learning_rate(self.optimizer)
                for i,lr in enumerate(lrs):
                    self.summary_writer.add_scalar(f'learning_rate/lr_{i}', lr, global_step)
                if seperate_render_transition:
                    lrs = self.get_learning_rate(self.transition_optimizer)
                    for i,lr in enumerate(lrs):
                        self.summary_writer.add_scalar(f'learning_rate/lr_transition_{i}', lr, global_step)


    def update_encoder_and_models(self, loss, global_step):
        grad_clip_value = self.options.TRAIN.grad_clip_value
        seperate_render_transition = self.options.TRAIN.seperate_render_transition

        self.optimizer.zero_grad()
        self.encoder_optimizer.zero_grad()
        if seperate_render_transition:
            self.transition_optimizer.zero_grad()
        loss.backward()
        if grad_clip_value != 0:
            torch.nn.utils.clip_grad_norm_(self.renderer.parameters(), grad_clip_value)
            torch.nn.utils.clip_grad_norm_(self.transition_model.parameters(), grad_clip_value)
        self.optimizer.step()
        self.encoder_optimizer.step()

        if seperate_render_transition:
            self.transition_optimizer.step()
        if self.options.TRAIN.LR.use_scheduler:
            self.optim_lr_scheduler.step()
            if seperate_render_transition:
                self.optim_lr_scheduler_transition.step()

        if self.options.TRAIN.LR.use_scheduler_encoder:
            self.optim_lr_scheduler_encoder.step()

        if global_step != 0 and global_step % self.log_interval == 0:
            lrs = self.get_learning_rate(self.optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_{i}', lr, global_step)
            if seperate_render_transition:
                lrs = self.get_learning_rate(self.transition_optimizer)
                for i,lr in enumerate(lrs):
                    self.summary_writer.add_scalar(f'learning_rate/lr_transition_{i}', lr, global_step)

            lrs = self.get_learning_rate(self.encoder_optimizer)
            for i,lr in enumerate(lrs):
                self.summary_writer.add_scalar(f'learning_rate/lr_encoder_{i}', lr, global_step)


    def eval(self, step_idx):
        """
        visulize the point cloud resutls, and the image
        """
        # print('\nStep {} Eval:'.format(step_idx))
        self.eval_count += 1
        self.transition_model.eval()
        self.renderer.eval()
        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.eval()
            self.prior_gru.eval()
        view_num = len(self.test_viewnames)
        N_importance = self.options.RENDERER.ray.N_importance
        with torch.no_grad():
            dist_pred2gt_all = []
            dist_emd_all = []
            dist_chamfer_all = []
            dist_true_pred2gt_all = []
            fluid_error = FluidErrors(log_emd=False)
            new_fluid_error = FluidErrors(log_emd=False)
            for data_idx in tqdm(range(self.test_dataset_length), desc='Step {} Eval:'.format(step_idx)):
                data = self.test_dataset[data_idx]
                keys = ['box', 'box_normals', 'particles_pos', 'particles_vel', 'particles_pos_1', 'cw_1', 'rgb_1', 'rays_1', 'focal']
                data = {k: data[k].to(self.device) if isinstance(data[k], torch.Tensor) else data[k] for k in keys}

                # data = self.test_dataset[data_idx]
                # data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in data.items()}

                box = data['box']
                box_normals = data['box_normals']
                if data_idx ==0:
                    if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                        self.prior_gru.init_hidden(self.options.TRAIN.particle_res)
                    pos_for_next_step, vel_for_next_step = data['particles_pos'],data['particles_vel']
                    # dist_pred2gt, dist_emd = fluid_error.cal_errors(self.init_pos, pos_for_next_step, data_idx+1)
                    dist_pred2gt = fluid_error.cal_errors(self.init_pos.detach().cpu().numpy(), pos_for_next_step.detach().cpu().numpy(), data_idx+1)
                    chamfer = 1000 * chamfer_distance(self.init_pos.unsqueeze(0), pos_for_next_step.unsqueeze(0))[0].item()
                    dist_pred2gt_all.append(dist_pred2gt)
                    dist_chamfer_all.append(chamfer)
                    dist_true_pred2gt = new_fluid_error.cal_errors(pos_for_next_step.detach().cpu().numpy(), self.init_pos.detach().cpu().numpy(), data_idx+1)
                    dist_true_pred2gt_all.append(dist_true_pred2gt)
                    # print(dist_pred2gt)
                    # exit()
                    # load initial particles extracted by static DVGO
                    if self.init_pos is not None:
                        pos_for_next_step = self.init_pos
                        if self.adapt:
                            assert vel_for_next_step.abs().mean() < 1e-5
                            vel_for_next_step = torch.zeros_like(pos_for_next_step)
                        # here vel equals 0
                        if self.options.TRAIN.particle_res != vel_for_next_step.shape[0]:
                            vel_for_next_step = torch.zeros_like(pos_for_next_step)
                    # fluid_feats = self.latent.repeat(pos_for_next_step.shape[0], 1)
                    if self.options.TRAIN.use_latent:
                        particle_feat = self.feat_fn(pos_for_next_step)
                if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                    if self.options['encoder']['input_last_latent']:
                        if data_idx == 0:
                            if self.options['encoder'].get('use_std', False):
                                particle_feat = torch.zeros([self.options.TRAIN.particle_res, 2 * self.encoder_dim]).to(box.device)
                            else:
                                particle_feat = torch.zeros([self.options.TRAIN.particle_res, self.encoder_dim]).to(box.device)
                        else:
                            if self.options['encoder']['use_mean']:
                                particle_feat = self.prior_feat_mean
                                if self.options['encoder'].get('use_std', False):
                                    particle_feat = torch.cat([self.prior_feat_mean, self.prior_feat_std], dim=-1)
                    else:
                        particle_feat = None
                    input_prior = [pos_for_next_step, vel_for_next_step, particle_feat, box, box_normals]
                    h = self.encoder(input_prior)
                    particle_feat, prior_stat = self.prior_gru(h)
                    self.prior_feat_mean = prior_stat['mean']
                    self.prior_feat_std = prior_stat['std']

                pred_pos, pred_vel, num_fluid_nn = self.transition_model(pos_for_next_step, vel_for_next_step, box, box_normals, feats=particle_feat)
                in_mask_proportion = 1
                if self.options.TRAIN.get('outside_clip', False):
                    in_mask = (pred_pos > self.pos_min).all(dim=-1) & (pred_pos < self.pos_max).all(dim=-1)
                    pred_pos = torch.where(pred_pos > self.pos_max, self.pos_max, pred_pos)
                    pred_pos = torch.where(pred_pos < self.pos_min, self.pos_min, pred_pos)
                    pred_vel = (pred_pos - pos_for_next_step) / self.transition_model.time_step
                    in_mask_num = in_mask.sum(dim=-1)
                    in_mask_proportion = in_mask_num / pred_pos.shape[0]

                pos_for_next_step, vel_for_next_step = pred_pos.clone(), pred_vel.clone()

                # evaluate transition model
                pos_t1 = data['particles_pos_1']
                # eval pred2gt distance
                # dist_pred2gt, dist_emd = fluid_error.cal_errors(pred_pos, pos_t1, data_idx+1)
                dist_pred2gt = fluid_error.cal_errors(pred_pos.detach().cpu().numpy(), pos_t1.detach().cpu().numpy(), data_idx+1)
                dist_pred2gt_all.append(dist_pred2gt)
                chamfer = 1000 * chamfer_distance(pred_pos.unsqueeze(0), pos_t1.unsqueeze(0))[0].item()
                dist_chamfer_all.append(chamfer)
                dist_true_pred2gt = new_fluid_error.cal_errors(pos_t1.detach().cpu().numpy(), pred_pos.detach().cpu().numpy(), data_idx+1)
                dist_true_pred2gt_all.append(dist_true_pred2gt)
                # dist_emd_all.append(dist_emd)
                self.summary_writer.add_scalar(f'pred2gt_distance', dist_pred2gt, self.eval_count*self.test_dataset_length+data_idx+1)
                self.summary_writer.add_scalar(f'true_pred2gt_distance', dist_true_pred2gt, self.eval_count*self.test_dataset_length+data_idx+1)
                self.summary_writer.add_scalar(f'chamfer_distance', chamfer, self.eval_count*self.test_dataset_length+data_idx+1)
                # save to obj
                if (step_idx / self.save_interval) % 5 == 0:
                    if not osp.exists(osp.join(self.particlepath, f'{step_idx}')):
                        os.makedirs(osp.join(self.particlepath, f'{step_idx}'))
                    particle_name = osp.join(self.particlepath, f'{step_idx}/pred_{data_idx+1}.obj')
                    with open(particle_name, 'w') as fp:
                        record2obj(pred_pos, fp, color=[255, 0, 0]) # red
                    particle_name = osp.join(self.particlepath, f'{step_idx}/gt_{data_idx+1}.obj')
                    with open(particle_name, 'w') as fp:
                        record2obj(pos_t1, fp, color=[3, 168, 158])

                # rendering results
                # to save time, we only render several frames
                if (step_idx / self.save_interval) % 20 == 0:
                # if False:
                    if data_idx in [20,30]:
                        for view_idx in range(view_num):
                            view_name = self.test_viewnames[view_idx]
                            cw = data['cw_1'][view_idx]
                            ro = self.renderer.set_ro(cw)
                            focal_length = data['focal'][view_idx]
                            rgbs = data['rgb_1'][view_idx]
                            rays = data['rays_1'][view_idx].view(-1, 6)
                            render_ret = self.render_image(pred_pos, rays.shape[0], ro, rays, focal_length, cw, iseval=True)
                            pred_rgbs_0 = render_ret['pred_rgbs_0']
                            mask_0 = render_ret['mask_0']
                            psnr_0 = mse2psnr(img2mse(pred_rgbs_0, rgbs.detach().cpu()))
                            self.summary_writer.add_scalar(f'{view_name}/psnr_{data_idx}_0', psnr_0.item(), step_idx)
                            self.visualization(pred_rgbs_0, rgbs, step_idx, mask=mask_0, prefix=f'coarse_{data_idx}_{view_name}')
                            if N_importance>0:
                                pred_rgbs_1 = render_ret['pred_rgbs_1']
                                mask_1 = render_ret['mask_1']
                                psnr_1 = mse2psnr(img2mse(pred_rgbs_1, rgbs.detach().cpu()))
                                self.summary_writer.add_scalar(f'{view_name}/psnr_{data_idx}_1', psnr_1.item(), step_idx)
                                self.visualization(pred_rgbs_1, rgbs, step_idx, mask=mask_1, prefix=f'fine_{data_idx}_{view_name}')
            fluid_error.save(osp.join(self.particlepath, f'res_{step_idx}.json'))
            path = osp.join(self.exppath, f'avg_pred2gt.json')
            mean_pred2gt = np.mean(dist_pred2gt_all)
            mean_chamfer = np.mean(dist_chamfer_all)
            mean_true_pred2gt = np.mean(dist_true_pred2gt_all)
            self.update_json(path, step_idx, mean_pred2gt, mean_chamfer)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avg_pred2gt_distance', mean_pred2gt, step_idx)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avg_pred2gt_distance_0-49', np.mean(dist_pred2gt_all[:49]), step_idx)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avg_pred2gt_distance_49', np.mean(dist_pred2gt_all[-1]), step_idx)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avg_true_pred2gt_distance', mean_true_pred2gt, step_idx)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avg_true_pred2gt_distance_49', dist_true_pred2gt_all[-1], step_idx)
            self.summary_writer.add_scalar('avg_pred2gt_distance/avfg_chamfer', mean_chamfer, step_idx)
            # self.summary_writer.add_scalar('pred2gt/avg_pred2gt_distance_50-59', np.mean(dist_pred2gt_all), step_idx)
            # self.summary_writer.add_scalar('pred2gt/avg_pred2gt_distance_59', np.mean(dist_pred2gt_all), step_idx)

            if self.current_stage == 1:
                if mean_pred2gt< self.best_gt2pred:
                    self.best_gt2pred = mean_pred2gt
                    self.save_checkpoint(step_idx, is_best=True)

                if mean_true_pred2gt < self.best_true_pred2gt:
                    self.best_true_pred2gt = mean_true_pred2gt
                    self.save_checkpoint(step_idx-1)

        self.transition_model.train()
        self.renderer.train()
        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.train()
            self.prior_gru.train()

    def eval_target(self, step_idx):
        """
        visulize the point cloud resutls, and the image
        """
        # print('\nStep {} Eval:'.format(step_idx))
        self.eval_count += 1
        self.transition_model.eval()
        self.renderer.eval()
        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.eval()
            self.prior_gru.eval()
        view_num = len(self.test_viewnames)
        N_importance = self.options.RENDERER.ray.N_importance
        with torch.no_grad():
            dist_pred2gt_all = []
            dist_emd_all = []
            dist_chamfer_all = []
            dist_true_pred2gt_all = []
            fluid_error = FluidErrors(log_emd=False)
            new_fluid_error = FluidErrors(log_emd=False)
            for data_idx in tqdm(range(self.test_dataset_length), desc='Step {} Eval:'.format(step_idx)):
                data = self.target_dataset[data_idx]
                keys = ['box', 'box_normals', 'particles_pos', 'particles_vel', 'particles_pos_1', 'cw_1', 'rgb_1', 'rays_1', 'focal']
                data = {k: data[k].to(self.device) if isinstance(data[k], torch.Tensor) else data[k] for k in keys}

                # data = self.test_dataset[data_idx]
                # data = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in data.items()}

                box = data['box']
                box_normals = data['box_normals']
                if data_idx ==0:
                    num_particles = self.target_init_pos.shape[0]
                    if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                        self.prior_gru.init_hidden(num_particles)
                    pos_for_next_step, vel_for_next_step = data['particles_pos'],data['particles_vel']
                    # dist_pred2gt, dist_emd = fluid_error.cal_errors(self.init_pos, pos_for_next_step, data_idx+1)
                    dist_pred2gt = fluid_error.cal_errors(self.target_init_pos.detach().cpu().numpy(), pos_for_next_step.detach().cpu().numpy(), data_idx+1)
                    chamfer = 1000 * chamfer_distance(self.target_init_pos.unsqueeze(0), pos_for_next_step.unsqueeze(0))[0].item()
                    dist_pred2gt_all.append(dist_pred2gt)
                    dist_chamfer_all.append(chamfer)
                    dist_true_pred2gt = new_fluid_error.cal_errors(pos_for_next_step.detach().cpu().numpy(), self.target_init_pos.detach().cpu().numpy(), data_idx+1)
                    dist_true_pred2gt_all.append(dist_true_pred2gt)
                    # print(dist_pred2gt)
                    # exit()
                    # load initial particles extracted by static DVGO
                    if self.target_init_pos is not None:
                        pos_for_next_step = self.target_init_pos
                        if self.adapt:
                            assert vel_for_next_step.abs().mean() < 1e-5
                            vel_for_next_step = torch.zeros_like(pos_for_next_step)
                        # here vel equals 0
                        if num_particles != vel_for_next_step.shape[0]:
                            vel_for_next_step = torch.zeros_like(pos_for_next_step)
                    # fluid_feats = self.latent.repeat(pos_for_next_step.shape[0], 1)
                    if self.options.TRAIN.use_latent:
                        particle_feat = self.feat_fn(pos_for_next_step)
                if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                    if self.options['encoder']['input_last_latent']:
                        if data_idx == 0:
                            if self.options['encoder'].get('use_std', False):
                                particle_feat = torch.zeros([num_particles, 2 * self.encoder_dim]).to(box.device)
                            else:
                                particle_feat = torch.zeros([num_particles, self.encoder_dim]).to(box.device)
                        else:
                            if self.options['encoder']['use_mean']:
                                particle_feat = prior_feat_mean
                                if self.options['encoder'].get('use_std', False):
                                    particle_feat = torch.cat([prior_feat_mean, prior_feat_std], dim=-1)
                    else:
                        particle_feat = None
                    input_prior = [pos_for_next_step, vel_for_next_step, particle_feat, box, box_normals]
                    h = self.encoder(input_prior)
                    particle_feat, prior_stat = self.prior_gru(h)
                    prior_feat_mean = prior_stat['mean']
                    prior_feat_std = prior_stat['std']

                pred_pos, pred_vel, num_fluid_nn = self.transition_model(pos_for_next_step, vel_for_next_step, box, box_normals, feats=particle_feat)
                in_mask_proportion = 1
                if self.options.TRAIN.get('outside_clip', False):
                    in_mask = (pred_pos > self.pos_min).all(dim=-1) & (pred_pos < self.pos_max).all(dim=-1)
                    pred_pos = torch.where(pred_pos > self.pos_max, self.pos_max, pred_pos)
                    pred_pos = torch.where(pred_pos < self.pos_min, self.pos_min, pred_pos)
                    pred_vel = (pred_pos - pos_for_next_step) / self.transition_model.time_step
                    in_mask_num = in_mask.sum(dim=-1)
                    in_mask_proportion = in_mask_num / pred_pos.shape[0]

                pos_for_next_step, vel_for_next_step = pred_pos.clone(), pred_vel.clone()

                # evaluate transition model
                pos_t1 = data['particles_pos_1']
                # eval pred2gt distance
                # dist_pred2gt, dist_emd = fluid_error.cal_errors(pred_pos, pos_t1, data_idx+1)
                dist_pred2gt = fluid_error.cal_errors(pred_pos.detach().cpu().numpy(), pos_t1.detach().cpu().numpy(), data_idx+1)
                dist_pred2gt_all.append(dist_pred2gt)
                chamfer = 1000 * chamfer_distance(pred_pos.unsqueeze(0), pos_t1.unsqueeze(0))[0].item()
                dist_chamfer_all.append(chamfer)
                dist_true_pred2gt = new_fluid_error.cal_errors(pos_t1.detach().cpu().numpy(), pred_pos.detach().cpu().numpy(), data_idx+1)
                dist_true_pred2gt_all.append(dist_true_pred2gt)
                # dist_emd_all.append(dist_emd)
                self.summary_writer.add_scalar(f'target_pred2gt_distance', dist_pred2gt, self.eval_count*self.test_dataset_length+data_idx+1)
                self.summary_writer.add_scalar(f'target_true_pred2gt_distance', dist_true_pred2gt, self.eval_count*self.test_dataset_length+data_idx+1)
                self.summary_writer.add_scalar(f'target_chamfer_distance', chamfer, self.eval_count*self.test_dataset_length+data_idx+1)
                # save to obj
                if (step_idx / self.save_interval) % 5 == 0:
                    if not osp.exists(osp.join(self.target_particlepath, f'{step_idx}')):
                        os.makedirs(osp.join(self.target_particlepath, f'{step_idx}'))
                    particle_name = osp.join(self.target_particlepath, f'{step_idx}/pred_{data_idx+1}.obj')
                    with open(particle_name, 'w') as fp:
                        record2obj(pred_pos, fp, color=[255, 0, 0]) # red
                    particle_name = osp.join(self.target_particlepath, f'{step_idx}/gt_{data_idx+1}.obj')
                    with open(particle_name, 'w') as fp:
                        record2obj(pos_t1, fp, color=[3, 168, 158])

                    np.savez(
                        os.path.join(self.target_particlepath, f'{step_idx}', 'fluid_%04d.npz' % (data_idx+1)),
                        pos=pred_pos.detach().cpu().numpy(),
                        vel=pred_vel.detach().cpu().numpy()
                    )


            if not osp.exists(self.target_particlepath):
                os.makedirs(self.target_particlepath)
            fluid_error.save(osp.join(self.target_particlepath, f'res_{step_idx}.json'))
            path = osp.join(self.exppath, f'avg_target_pred2gt.json')
            mean_pred2gt = np.mean(dist_pred2gt_all)
            mean_chamfer = np.mean(dist_chamfer_all)
            mean_true_pred2gt = np.mean(dist_true_pred2gt_all)
            self.update_json(path, step_idx, mean_pred2gt, mean_chamfer)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avg_pred2gt_distance', mean_pred2gt, step_idx)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avg_pred2gt_distance_0-49', np.mean(dist_pred2gt_all[:49]), step_idx)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avg_pred2gt_distance_49', np.mean(dist_pred2gt_all[-1]), step_idx)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avg_true_pred2gt_distance', mean_true_pred2gt, step_idx)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avg_true_pred2gt_distance_49', dist_true_pred2gt_all[-1], step_idx)
            self.summary_writer.add_scalar('avg_target_pred2gt_distance/avfg_chamfer', mean_chamfer, step_idx)

            if mean_pred2gt< self.best_gt2pred:
                self.best_gt2pred = mean_pred2gt
                self.save_checkpoint(step_idx, is_best=True)

            if mean_chamfer < self.best_true_pred2gt:
                self.best_true_pred2gt = mean_chamfer
                self.save_checkpoint(step_idx-1)

        self.transition_model.train()
        self.renderer.train()
        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.train()
            self.prior_gru.train()


    def get_feat(self, pos):
        num_particles = pos.shape[0]
        stoch_latent = self.get_dist(self.latent).sample([num_particles])
        return stoch_latent

    def get_feat_multi(self, pos):
        num_particles = pos.shape[0]
        stoch_latent = self.get_dist(self.latent).sample().reshape(num_particles, -1)
        return stoch_latent

    def get_feat_from_grid(self, pos, grid_res=10):
        if self.options.TRAIN.get_feat == 'grid_sample':
            stoch_latent = self.get_dist(self.latent).sample([grid_res, grid_res, grid_res]).unsqueeze(0)
        else:
            stoch_latent = self.get_dist(self.latent).sample().reshape([grid_res, grid_res, grid_res, -1]).unsqueeze(0)
        stoch_latent = stoch_latent.permute(0, 4, 1, 2, 3)
        if self.latent.dim() < 3:
            xyz_min = pos.min(axis=0).values
            xyz_max = pos.max(axis=0).values
        else:
            xyz_min = torch.Tensor([-1, -1, -1]).to(self.device)
            xyz_max = torch.Tensor([1, 1, 2.4552]).to(self.device)
        particle_stoch_latent = utils.interpolation(grid=stoch_latent, xyz=pos, xyz_min=xyz_min, xyz_max=xyz_max)
        return particle_stoch_latent

    def get_feat_deter(self, pos):
        num_particles = pos.shape[0]
        latent = self.latent.repeat(num_particles, 1)
        return latent

    def get_dist(self, latent, dtype=None):
        if self._discrete:
            logit = latent
            dist = torchd.independent.Independent(utils.OneHotDist(logit), 1)
        else:
            mean, std = latent.chunk(2, dim=-1)
            mean = {
                'none': lambda: mean,
                'tanh5': lambda: 5.0 * torch.tanh(mean / 5.0),
                'tanh': lambda: 5.0 * torch.tanh(mean),
                'tanh1': lambda: 1.0 * torch.tanh(mean),
            }[self.options['encoder']['mean_act']]()
            std_act = lambda std: 2 * torch.sigmoid(std / 2)
            std = std_act(std=std)
            dist = utils.ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1))
        return dist

    def update_json(self, path, step_idx, pred2gt_dist, chamfer_dist):
        new_res = {step_idx: {'pred2gt_dist': pred2gt_dist,
                              'chamfer_dist': chamfer_dist}}
        print(new_res)
        try:
            with open(path) as f:
                content = json.load(f)
                content.update(new_res)
                json.dump(content, open(path, "w"), indent = 4)
        except:
            json.dump(new_res, open(path, "w"), indent = 4)

    def eval_end2end(self):
        from utils import eval_utils
        import imageio
        import joblib
        self.options.defrost()
        self.options.end_index = 60
        self.test_viewnames = [self.options.TEST.test_view]
        self.test_dataset = BlenderDataset(self.options.train.path, self.options,
                                            imgW=self.options.TEST.imgW, imgH=self.options.TEST.imgH,
                                            imgscale=self.options.TEST.scale, viewnames=self.test_viewnames, split='train')
        print('eval data length:', len(self.test_dataset))
        print('dataset:', self.options.dataset)
        self.test_dataset_length = len(self.test_dataset)
        os.makedirs(os.path.join(self.exppath, 'eval_end2end'), exist_ok=True)
        def visualization(pred_rgbs, gt_rgbs,prefix=None, data_idx=0):
            pred_image = vis_rgbs(pred_rgbs)

            if not os.path.exists(osp.join(self.exppath, 'eval_end2end', 'images')):
                os.makedirs(osp.join(self.exppath, 'eval_end2end', 'images'))

            # save res
            # gt_rgb8 = to8b(gt_image)
            # filename = '{}/{}/GT/{:05d}.png'.format(self.imgpath, prefix, data_idx)
            # imageio.imwrite(filename, gt_rgb8)

            pred_rgb8 = to8b(pred_image)
            filename = '{:05d}.png'.format(data_idx)
            imageio.imwrite(f'{self.exppath}/eval_end2end/images/{filename}', pred_rgb8)
            return pred_image, pred_rgb8

        def vis_rgbs(rgbs, channel=3):
            imgW = int(self.options.TEST.imgW // self.options.TEST.scale)
            imgH = int(self.options.TEST.imgH // self.options.TEST.scale)
            image = rgbs.reshape(imgH, imgW, channel).detach().cpu().numpy()
            return image

        self.transition_model.eval()
        self.renderer.eval()
        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.eval()
            self.prior_gru.eval()
        view_num = len(self.test_viewnames)
        with torch.no_grad():
            dist_pred2gt_all = []
            dist_true_pred2gt_all = []
            dist_chamfer_all = []
            rgbs = []
            psnrs = []
            ssims = []
            lpips_vgg = []
            rgbs_8 = []
            self.fluid_error = FluidErrors()
            fluid_error = FluidErrors()
            for data_idx in tqdm(range(self.test_dataset_length), total=self.test_dataset_length):
                data = self.test_dataset[data_idx]
                keys = ['box', 'box_normals', 'particles_pos', 'particles_vel', 'particles_pos_1', 'cw_1', 'rgb_1', 'rays_1', 'focal']
                data = {k: data[k].to(self.device) if isinstance(data[k], torch.Tensor) else data[k] for k in keys}
                box = data['box']
                box_normals = data['box_normals']
                if data_idx ==0:
                    if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                        self.prior_gru.init_hidden(self.options.TRAIN.particle_res)
                    pos_for_next_step, vel_for_next_step = data['particles_pos'],data['particles_vel']
                    if self.init_pos is not None:
                        pos_for_next_step = self.init_pos
                    if self.options.TRAIN.use_latent:
                        particle_feat = self.feat_fn(pos_for_next_step)

                    if self.options.TRAIN.particle_res != vel_for_next_step.shape[0]:
                        vel_for_next_step = torch.zeros_like(pos_for_next_step)

                if not self.options.TRAIN.use_latent and self.options.TRAIN.use_encoder:
                    if self.options['encoder']['input_last_latent']:
                        if data_idx == 0:
                            if self.options['encoder'].get('use_std', False):
                                particle_feat = torch.zeros([self.options.TRAIN.particle_res, 2 * self.encoder_dim]).to(box.device)
                            else:
                                particle_feat = torch.zeros([self.options.TRAIN.particle_res, self.encoder_dim]).to(box.device)
                        else:
                            if self.options['encoder']['use_mean']:
                                particle_feat = self.prior_feat_mean
                                if self.options['encoder'].get('use_std', False):
                                    particle_feat = torch.cat([self.prior_feat_mean, self.prior_feat_std], dim=-1)
                    else:
                        particle_feat = None
                    input_prior = [pos_for_next_step, vel_for_next_step, particle_feat, box, box_normals]
                    h = self.encoder(input_prior)
                    particle_feat, prior_stat = self.prior_gru(h)
                    self.prior_feat_mean = prior_stat['mean']
                    self.prior_feat_std = prior_stat['std']

                pred_pos, pred_vel, num_fluid_nn = self.transition_model(pos_for_next_step, vel_for_next_step, box, box_normals, feats=particle_feat)
                in_mask_proportion = 1
                if self.options.TRAIN.get('outside_clip', False):
                    in_mask = (pred_pos > self.pos_min).all(dim=-1) & (pred_pos < self.pos_max).all(dim=-1)
                    pred_pos = torch.where(pred_pos > self.pos_max, self.pos_max, pred_pos)
                    pred_pos = torch.where(pred_pos < self.pos_min, self.pos_min, pred_pos)
                    pred_vel = (pred_pos - pos_for_next_step) / self.transition_model.time_step
                    in_mask_num = in_mask.sum(dim=-1)
                    in_mask_proportion = in_mask_num / pred_pos.shape[0]
                pos_for_next_step, vel_for_next_step = pred_pos.clone(), pred_vel.clone()

                # --------
                # evaluate transition model
                # --------
                pos_t1 = data['particles_pos_1']
                # eval pred2gt distance
                dist_pred2gt = self.fluid_error.cal_errors(pred_pos.cpu().numpy(), pos_t1.cpu().numpy(), data_idx+1)
                dist_pred2gt_all.append(dist_pred2gt)
                dist_true_pred2gt = fluid_error.cal_errors(pos_t1.detach().cpu().numpy(), pred_pos.detach().cpu().numpy(), data_idx+1)
                dist_true_pred2gt_all.append(dist_true_pred2gt)
                # save to obj
                if not osp.exists(os.path.join(self.exppath, 'eval_end2end', 'particles')):
                    os.makedirs(os.path.join(self.exppath, 'eval_end2end', 'particles'))
                particle_name = osp.join(self.exppath, 'eval_end2end', 'particles', f'pred_{data_idx+1}.obj')
                with open(particle_name, 'w') as fp:
                    record2obj(pred_pos, fp, color=[255, 0, 0]) # red
                particle_name = osp.join(self.exppath, 'eval_end2end', 'particles', f'gt_{data_idx+1}.obj')
                with open(particle_name, 'w') as fp:
                    record2obj(pos_t1, fp, color=[3, 168, 158])
            print('----------------- trained 50 steps ------------------------')
            print('Pred2GT:', np.mean(dist_pred2gt_all[0:49]))
            print('Pred2GT-10:', np.mean(dist_pred2gt_all[:10]))
            print('Pred2GT-end:', dist_pred2gt_all[48])
            print('true Pred2GT:', np.mean(dist_true_pred2gt_all[0:49]))
            print('true Pred2GT-10:', np.mean(dist_true_pred2gt_all[:10]))
            print('true Pred2GT-end:', dist_true_pred2gt_all[48])
            # print('avg_psnrs:', np.mean(psnrs[0:49]))
            # print('avg_ssims',  np.mean(ssims[0:49]))
            # print('avg_lpips (vgg)', np.mean(lpips_vgg[0:49]))

            print('\n----------------- rollout 10 steps ------------------------')
            print('Pred2GT:', np.mean(dist_pred2gt_all[-10:]))
            print('Pred2GT-5:', np.mean(dist_pred2gt_all[-5]))
            print('Pred2GT-end:', dist_pred2gt_all[-1])
            # print('rollout-avg_psnrs', np.mean(psnrs[-10:]))
            # print('rollout-avg_ssims', np.mean(ssims[-10:]))
            # print('rollout-avg_lpips (vgg)', np.mean(lpips_vgg[-10:]))


            joblib.dump({'dist': dist_pred2gt_all}, osp.join(self.exppath, 'eval_end2end', 'pred2gt.pt'))

            with open(os.path.join(self.exppath, 'eval_end2end', 'mean.json'), 'w') as f:
                info = {}
                info['Pred2GT'] = np.mean(dist_pred2gt_all[0:49])
                info['Pred2GT-10'] = np.mean(dist_pred2gt_all[:10])
                info['Pred2GT-end'] = dist_pred2gt_all[48]
                info['true Pred2GT'] = np.mean(dist_true_pred2gt_all[0:49])
                info['true Pred2GT-10'] = np.mean(dist_true_pred2gt_all[:10])
                info['true Pred2GT-end'] = dist_true_pred2gt_all[48]

                info['rollout-Pred2GT'] = np.mean(dist_pred2gt_all[-10:])
                info['rollout-Pred2GT-5'] = np.mean(dist_pred2gt_all[-5])
                info['rollout-Pred2GT-end'] = dist_pred2gt_all[-1]

                info['Pred2GT_all'] = dist_pred2gt_all
                info['true Pred2GT_all'] = dist_true_pred2gt_all
                json.dump(info, f, indent=4)
        self.transition_model.train()
        self.renderer.train()

        if self.options.TRAIN.LR.latent_lr == 0 and self.options.TRAIN.LR.encoder_lr != 0:
            self.encoder.train()
            self.prior_gru.train()