import os
import pdb
import sys
import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from data.MultiframeDataset import DHBDataset, NLDriveDataset
from model.NeuralPCI import NeuralPCI
from pytorch3d.loss import chamfer_distance
from utils.smooth_loss import compute_smooth
from utils.utils import *
sys.path.append('./utils/EMD')
sys.path.append('./utils/CD')
from emd import earth_mover_distance
import chamfer3D.dist_chamfer_3D
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
# import mayavi.mlab


class Runner:
    def __init__(self, args):
        self.args = args
        # Init Dataset
        if args.dataset == 'DHB':
            self.dataset = DHBDataset(data_root = args.dataset_path, 
                                      scene_list = args.scenes_list, 
                                      interval = args.interval) 
        elif args.dataset == 'NL_Drive':
            self.dataset = NLDriveDataset(data_root = args.dataset_path, 
                                          scene_list = args.scenes_list, 
                                          interval = args.interval,
                                          num_points = args.num_points, 
                                          num_frames = args.num_frames)

        # Init DataLoader
        self.data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=8)

        self.chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()

        self.model = NeuralPCI(args).cuda()
        print("Number of parameters in model is {:.3f}M".format(sum(tensor.numel() for tensor in self.model.parameters())/1e6))

        # Loss recorder
        self.recorder = self.get_recorder(args)

        # Define time stamp for input and interpolation frames 
        self.time_seq, self.time_intp = self.get_timestamp(args)
        # self.emd_weight = torch.nn.Parameter(torch.tensor(0.0))

    def weighted_chamfer_distance(self, pc_pred, pc_input, NL_flag, alpha=1.6, dist_threshold=5.0, z_min=-1.4, z_max=1.99, z_weight=1.4):
        dist1, dist2, _, _ = self.chamLoss(pc_pred, pc_input)
        dist1 = torch.clamp(dist1, max=dist_threshold)
        dist2 = torch.clamp(dist2, max=dist_threshold)
        if NL_flag:
            z1 = pc_pred[:, :, 2:]
            z2 = pc_input[:, :, 2:]
            z_mask1 = (z1 >= z_min) & (z1 <= z_max)
            z_mask2 = (z2 >= z_min) & (z2 <= z_max)

            z_dist1 = torch.where(z_mask1.squeeze(-1), z_weight * dist1, dist1)
            z_dist2 = torch.where(z_mask2.squeeze(-1), z_weight * dist2, dist2)
            # pdb.set_trace()
        else:
            z_dist1 = dist1
            z_dist2 = dist2

        weight1 = torch.tanh(alpha * z_dist1)  # (B, N)
        weight2 = torch.tanh(alpha * z_dist2)  # (B, M)

        return (weight1 * z_dist1 + weight2 * z_dist2).sum() #* 0.5
    
    def save_tensors_to_dict(self, input_list, gt_list, file_path):
        input_list = [tensor.detach().cpu().numpy() for tensor in input_list]
        gt_list = [tensor.detach().cpu().numpy() for tensor in gt_list]
        max_length = max(len(input_list), len(gt_list))
        input_list.extend([np.array([])] * (max_length - len(input_list)))
        gt_list.extend([np.array([])] * (max_length - len(gt_list)))
        zipped = zip(input_list, gt_list)

        data_dict = {}
        for i, (inp, gt) in enumerate(zipped):
            data_dict[f'input_{i}'] = inp
            data_dict[f'gt_{i}'] = gt
        np.save(file_path, data_dict)
    
    def load_tensors_from_dict(self, file_path):
        data_dict = np.load(file_path, allow_pickle=True).item()
        input_list = []
        gt_list = []
        for key, value in data_dict.items():
            if key.startswith('input'):
                input_list.append(torch.from_numpy(value))
            elif key.startswith('gt'):
                gt_list.append(torch.from_numpy(value))

        return input_list, gt_list
        

    def get_recorder(self, args):
        recorder = {}
        recorder['loss_all_CD'] = []
        recorder['loss_all_EMD'] = []

        for i in range(args.interval - 1):
            recorder['loss_frame{}_CD'.format(i+1)] = []
            recorder['loss_frame{}_EMD'.format(i+1)] = []

        return recorder
        

    def get_timestamp(self, args):
        time_seq = [t for t in np.linspace(args.t_begin, args.t_end, args.num_frames)]
        t_left = time_seq[args.num_frames//2 - 1]
        t_right = time_seq[args.num_frames//2]
        time_intp = [t for t in np.linspace(t_left, t_right, args.interval+1)]
        time_intp = time_intp[1:-1]
        
        return time_seq, time_intp


    def loop(self):
        args = self.args
        print("Optimization Start for {} samples!".format(len(self.data_loader)))
        # get one sample from DataLoader (for-loop)
        for idx, (input, gt) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print("\n[Sample: {}]".format(idx+1))
            # input point cloud frames (x,y,z)
            # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
            for i in range(len(input)):
                input[i] = input[i].squeeze(0).cuda().contiguous().float()
            for i in range(len(gt)):
                gt[i] = gt[i].squeeze(0).cuda().contiguous().float()

            # Definition
            model = NeuralPCI(args).cuda()
            if idx==0:
                input_list, gt_list = self.load_tensors_from_dict("./NeuroGauss4Dlog/visualizations/compare_pc/NL_PC_25.npy")
                input = [tensor.cuda().contiguous().float() for tensor in input_list]
                gt = [tensor.cuda().contiguous().float() for tensor in gt_list]
                best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input, model, args, gt, idx=idx)
                model.load_state_dict(best_model_weight)
                self.recorder_error(best_cd, best_emd, intp_cd_best, intp_emd_best)
                intp_cd_error, intp_emd_error = self.eval_NeuralPCI(input, gt, model, best_model_weight, args)
                print("intp_cd_error, intp_emd_error:",intp_cd_error, intp_emd_error)
            break
        # print overall average result
        self.print_result()


    def demo(self):
        args = self.args
        input_list, gt_list = self.load_tensors_from_dict("./NeuroGauss4Dlog/visualizations/compare_pc/NL_PC_42.npy")

        model = NeuralPCI(args).cuda()
        # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
        # input_path = args.input_path
        # input = np.load(input_path).tolist()
        input_list = [tensor.cuda().contiguous().float() for tensor in input_list]
        gt = [tensor.cuda().contiguous().float() for tensor in gt_list]

        self.time_seq, self.time_intp = self.get_timestamp(args)
        pdb.set_trace()
        best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input_list, model, args, gt, idx=0)

        pc_intp = self.infer_NeuralPCI(input_list, model, best_model_weight)
        pc_intp_save = [tensor.cpu().numpy() for tensor in pc_intp]
        np.save("./NeuroGauss4Dlog/visualizations/compare_pc/NL_42_ours_pred.npy", pc_intp_save)


    def get_optimizer(self, model, args):
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            # optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': [self.emd_weight], 'lr': 0.01}], lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgdm':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True)
        elif args.optimizer == 'radam':
            optimizer = torch.optim.RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'nadam':
            optimizer = torch.optim.NAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.99, momentum=0.9, weight_decay=args.weight_decay)
        return optimizer


    def get_scheduler(self, optimizer, args):
        # from warmup_scheduler import GradualWarmupScheduler
        if args.scheduler == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)
        elif args.scheduler == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.iters, eta_min=1e-6)
        # elif args.scheduler == 'cosine_warm':
        #     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.iters-1500, T_mult=1, eta_min=1e-6)
        #     scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=500, after_scheduler=scheduler)
        elif args.scheduler == 'poly':
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1-step/args.iters)**0.9 if step>1500 else (step/1500)*1.5)
        elif args.scheduler == 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-6)
        else:
            raise ValueError(f'Unknown scheduler: {args.scheduler}')
        return scheduler


    def compute_loss(self, pc_current, pc_pred, pc_input, iter):  
        args = self.args
        loss = 0

        if iter<500 or (iter>2500 and iter<3500):
            dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), pc_input.unsqueeze(0))
            chamfer_loss = (dist1 + dist2).sum() * 0.5
        else:
            chamfer_loss = self.weighted_chamfer_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), args.dataset == 'NL_Drive')
        loss = loss + chamfer_loss * args.factor_cd

        if args.dataset == 'DHB':
            dist = earth_mover_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), transpose=False)
            emd_loss = (dist / pc_pred.shape[0]).mean()
            loss = loss + emd_loss * args.factor_emd

        elif args.dataset == 'NL_Drive':
            flow = pc_pred - pc_current
            smooth_loss = compute_smooth(pc_current.unsqueeze(0), flow.unsqueeze(0), k=9)
            smooth_loss = smooth_loss.squeeze(0).sum()
            loss = loss + smooth_loss * args.factor_smooth

            # dist = earth_mover_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), transpose=False)
            # emd_loss = (dist / pc_pred.shape[0]).mean()
            # emd_weight = 10 * torch.sigmoid(self.emd_weight)
            # loss = loss + smooth_loss * args.factor_smooth + emd_weight * emd_loss

        return loss

    def optimize_NeuralPCI(self, input, model, args, gt, idx=None):
        model.train()
        
        optimizer = self.get_optimizer(model, args)
        
        if args.scheduler:
            scheduler = self.get_scheduler(optimizer, args)

        # record optimization best result
        # best_loss = np.inf
        best_cd = np.inf
        best_emd = np.inf
        best_iter = 0
        intp_emd_best=0
        intp_cd_best =0

        best_model_weight = model.state_dict()

        # early stopping parameters
        patience = args.patience
        counter = 0

        # # optimize xx iterations for one sample
        for iter in range(args.iters):
            optimizer.zero_grad()
            loss = 0
            # input current point cloud to predict point cloud at time_pred
            for i in range(len(input)//2 - 1, len(input)//2 + 1):
                pc_current = input[i]
                time_current = self.time_seq[i]
                for j in range(len(input)):
                    time_pred = self.time_seq[j]
                    flow = model(pc_current, time_current, time_pred)
                    pc_pred = pc_current + flow
                    # pc_pred = model(pc_current, time_current, time_pred)
                    loss = loss + self.compute_loss(pc_current, pc_pred, input[j], iter)
                    
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            if args.scheduler:
                if args.scheduler == 'plateau':
                    scheduler.step(loss.item())
                else:
                    scheduler.step()

            if (iter >500) and (iter % args.eval_freq) == 0:
                cd_error, emd_error, intp_cd_error, intp_emd_error = self.eval_Train(input, gt, model)
                # if cd_error<0.35:
                #     if cd_error<0.30:
                #         self.save_tensors_to_dict(input, gt, "./log/visualizations/compare_pc/NL_PC_{}".format(idx))
                #         # input_list, gt_list = self.load_tensors_from_dict("./NeuroGauss4Dlog/visualizations/compare_pc/NL_PC_0.npy")
                # else:
                #     intp_emd_best=0
                #     intp_cd_best =0
                #     break
                model.train()
                if (best_cd > cd_error.item()) or (best_emd > emd_error.item()):
                    best_model_weight = model.state_dict()
                    best_iter = iter
                    if best_cd > cd_error.item():
                        best_cd = cd_error.item()
                        intp_cd_best = intp_cd_error
                    if best_emd > emd_error.item():
                        best_emd = emd_error.item()
                        intp_emd_best = intp_emd_error
                    counter = 0
                    if (iter > 500) and (iter % 50 == 0):
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                else:
                    counter += 1
                    if counter >= patience:
                        print(f"Early stopping at iteration {iter}")
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                        break
        return best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best


    def eval_NeuralPCI(self, input, gt, model, weight, args):
        model.eval()
        with torch.no_grad():
            model.load_state_dict(weight)
            np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/input_{}".format(0),input[0].cpu().numpy())
            np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/input_{}".format(1),input[1].cpu().numpy())
            np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/input_{}".format(2),input[2].cpu().numpy())
            np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/input_{}".format(3),input[3].cpu().numpy())
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]
                ## [0.41666666666666663, 0.5, 0.5833333333333333]
                # [0.0, 0.3333333333333333, 0.6666666666666666, 1.0]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    flow = model(pc_current, time_current, time_pred, train=False)
                    pdb.set_trace()
                    pc_pred.append(pc_current + flow)
                    # pc_pred.append(model(pc_current, time_current, time_pred, train=False))

                # NN-Intp
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            
            for i in range(len(pc_intp)):
                pdb.set_trace()
                np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/pred_ours_{}".format(i),pc_intp[i].cpu().numpy())
                np.save("./NeuroGauss4Dlog/visualizations/NL_Drive/NL25_Neaurl/gt_{}_{}".format(j, i),gt[i].cpu().numpy())
                # self.save_pointcloud_as_ply(pc_intp[i], (220,0,0), "./NeuroGauss4Dlog/visualizations/NL_Drive/NL_f_pred_{}.ply".format(i))
                # self.save_pointcloud_as_ply(gt[i], (0,220,0), "./NeuroGauss4Dlog/visualizations/NL_Drive/NL_f_gt_{}.ply".format(i))

                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]

                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()

                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)

            print("Eval: Average Interpolation CD Error: {}".format(torch.mean(torch.tensor(intp_cd_error))))
            print("Eval: Average Interpolation EMD Error: {}".format(torch.mean(torch.tensor(intp_emd_error))))
            
            # self.recorder['loss_all_CD'].append(torch.mean(torch.tensor(intp_cd_error)).item())
            # self.recorder['loss_all_EMD'].append(torch.mean(torch.tensor(intp_emd_error)).item())
            # for i in range(args.interval - 1):
            #     self.recorder['loss_frame{}_CD'.format(i+1)].append(intp_cd_error[i].item())
            #     self.recorder['loss_frame{}_EMD'.format(i+1)].append(intp_emd_error[i].item())

            return intp_cd_error, intp_emd_error
        
    def recorder_error(self, best_cd, best_emd, intp_cd_best, intp_emd_best):
        print("Eval: Average Interpolation CD Error: {}".format(best_cd))
        print("Eval: Average Interpolation EMD Error: {}".format(best_emd))
        
        self.recorder['loss_all_CD'].append(torch.mean(torch.tensor(intp_cd_best)).item())
        self.recorder['loss_all_EMD'].append(torch.mean(torch.tensor(intp_emd_best)).item())
        for i in range(4 - 1):
            self.recorder['loss_frame{}_CD'.format(i+1)].append(intp_cd_best[i].item())
            self.recorder['loss_frame{}_EMD'.format(i+1)].append(intp_emd_best[i].item())
        
    def save_pointcloud_as_ply(self, tensor, rgb, save_path):
        points = tensor.cpu().numpy()
        ply_header = """ply
            format ascii 1.0
            element vertex {}
            property float x
            property float y
            property float z
            property uchar red
            property uchar green
            property uchar blue
            end_header
            """.format(points.shape[0])

        points_with_rgb = np.concatenate((points, np.tile(np.array(rgb, dtype=np.uint8), (points.shape[0], 1))), axis=1)

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'w') as ply_file:
            ply_file.write(ply_header)
            np.savetxt(ply_file, points_with_rgb, fmt="%f %f %f %d %d %d")
        
        print(f"Point cloud saved as {save_path}")

    def eval_Train(self, input, gt, model):
        model.eval()
        with torch.no_grad():
            
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    pc_pred.append(pred_time_pc)
                    
                    # pc_pred.append(model(pc_current, time_current, time_pred, train=False))

                # NN-Intp
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            for i in range(len(pc_intp)): 
                # self.save_pointcloud_as_ply(pc_intp[i], (220,0,0), "./NeuroGauss4Dlog/visualizations/NL_Drive/NL_f_pred_{}.ply".format(i))
                # self.save_pointcloud_as_ply(gt[i], (0,220,0), "./NeuroGauss4Dlog/visualizations/NL_Drive/NL_f_gt_{}.ply".format(i))
                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]

                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()

                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)

            return torch.mean(torch.tensor(intp_cd_error)), torch.mean(torch.tensor(intp_emd_error)), intp_cd_error, intp_emd_error
        
    
    def print_result(self):
        recorder = self.recorder
        args = self.args
        print("\n=======================================")
        print("Final Result CD Loss is:{}".format(np.mean(recorder['loss_all_CD'])))
        print("Final Result EMD Loss is:{}".format(np.mean(recorder['loss_all_EMD'])))
        for i in range(args.interval - 1):
            print("=======================================")
            print("Final Frame-{} Result CD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_CD'.format(i+1)])))
            print("Final Frame-{} Result EMD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_EMD'.format(i+1)])))
        print("=======================================")


    def infer_NeuralPCI(self, input, model, weight):
        with torch.no_grad():
            model.load_state_dict(weight)
            model.eval()
            pc_intp = []
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    # pdb.set_trace()
                    pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    pc_pred.append(pred_time_pc)
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
        return pc_intp

    # def viz_and_save(self, points, idx, save_dir="./log/visualizations/longdress/"):
    #     # Mayavi visualization and save the figure
    #     mayavi.mlab.figure(bgcolor=(1, 1, 1), size=(600, 800))
    #     mayavi.mlab.points3d(points[:, 0], points[:, 1], points[:, 2],
    #                          mode="point",
    #                          colormap='spectral',
    #                          scale_factor=0.04)

    #     mayavi.mlab.view(azimuth=0, elevation=0, distance=3.5)
    #     filename = f"{save_dir}/frame_{idx:04d}.png"
    #     mayavi.mlab.savefig(filename)
    #     mayavi.mlab.close()

    def viz_and_save(self, points, idx, save_dir="./log/visualizations/NL_Drive"):  # ./log/visualizations/swing/NeuralPCI/npy
        # Ensure the save directory exists, if not create it
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the points array as a .npy file
        filename = f"{save_dir}/frame_{idx:04d}.npy"
        np.save(filename, points)
    
    def visualization(self):
        args = self.args
        print("Optimization Start for {} samples!".format(len(self.data_loader)))
        for idx, (input, gt) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print("\n[Sample: {}]".format(idx + 1))
            input = [i.squeeze(0).cuda().contiguous().float() for i in input]
            gt = [g.squeeze(0).cuda().contiguous().float() for g in gt]

            # Definition
            model = NeuralPCI(args).cuda()
            best_model_weight, _, _, _, _ = self.optimize_NeuralPCI(input, model, self.args, gt)
            model.load_state_dict(best_model_weight)

            # Model inference and visualization
            pc_all = self.out_NeuralPCI(input, gt, model, best_model_weight, self.args)
            for pc_idx, pc in enumerate(pc_all):
                self.viz_and_save(pc, idx * len(pc_all) + pc_idx)
    def out_NeuralPCI(self, input, gt, model, weight, args):
        model.eval()
        with torch.no_grad():
            model.load_state_dict(weight)

            # The list to store both original and interpolated point clouds
            pc_all = []
            
            # Add the initial point cloud
            pc_all.append(input[0].cpu().numpy())

            # Predict the interpolation frame at time_intp
            for j in range(len(self.time_intp)):
                time_pred = self.time_intp[j]
                pc_pred = []

                # Interpolate between each pair of input frames
                for i in range(len(input) - 1):
                    pc_current = input[i]
                    pc_next = input[i + 1]
                    time_current = self.time_seq[i]
                    time_next = self.time_seq[i + 1]

                    # Assuming linear interpolation for time
                    if time_current <= time_pred < time_next:
                        # Interpolate and add to the list
                        interpolated_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                        pc_pred.append(interpolated_pc.cpu().numpy())
                        break  # Break the loop after finding the interval for time_pred

                # Add the interpolated point cloud to the list
                if pc_pred:
                    pc_all.extend(pc_pred)

                # Add the next original point cloud in sequence
                if j < len(self.time_intp) - 1:
                    next_time_pred = self.time_intp[j + 1]
                    # If the next interpolated time is greater than the next time in the sequence,
                    # add the next original point cloud
                    if next_time_pred >= time_next:
                        pc_all.append(pc_next.cpu().numpy())

            # Add the last point cloud if it wasn't added
            if not (len(self.time_intp) > 0 and self.time_intp[-1] >= self.time_seq[-1]):
                pc_all.append(input[-1].cpu().numpy())

        return pc_all