import os
import pdb
import sys
import torch
from torch.utils.data import DataLoader
import numpy as np
import timeit
from tqdm import tqdm
from data.MultiframeDataset import DHBDataset, NLDriveDataset, KITTI_hplflownet, KITTI_flownet3d, KittiPointCloudDataset

from pytorch3d.loss import chamfer_distance
from utils.smooth_loss import compute_smooth
from utils.utils import *
sys.path.append('./utils/EMD')
sys.path.append('./utils/CD')
from emd import earth_mover_distance
import chamfer3D.dist_chamfer_3D
from torch.nn.utils import clip_grad_norm_
from tqdm import trange
import time

class Runner:
    def __init__(self, args):
        self.args = args
        # Init Dataset
        if args.dataset == 'DHB':
            self.dataset = DHBDataset(data_root = args.dataset_path, 
                                      scene_list = args.scenes_list, 
                                      interval = args.interval) 
        elif args.dataset == 'NL_Drive':
            self.dataset = NLDriveDataset(data_root = args.dataset_path, 
                                          scene_list = args.scenes_list, 
                                          interval = args.interval,
                                          num_points = args.num_points, 
                                          num_frames = args.num_frames)
        elif args.dataset == 'KITTI_s':
            self.dataset = KITTI_hplflownet(data_root = args.dataset_path, num_points = args.num_points)
        elif args.dataset == 'KITTI_o':
            self.dataset = KITTI_flownet3d(root = args.dataset_path, num_points = args.num_points)
        # Init DataLoader
        self.data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=4)

        self.chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()

        if args.PCI_Model == 'NeuroGauss4D':
            from model.NeuroGauss4D import NeuroGauss4D
            self.model = NeuroGauss4D(args).cuda()
        elif args.PCI_Model == 'NeuralPCI':
            from model.NeuralPCI import NeuralPCI
            self.model = NeuralPCI(args).cuda()
        print("Number of parameters in model is {:.3f}M".format(sum(tensor.numel() for tensor in self.model.parameters())/1e6))
        # self.best_cd_val = 100
        self.accumulate = 0

        # Loss recorder
        self.recorder = self.get_recorder(args)

        # Define time stamp for input and interpolation frames 
        self.time_seq, self.time_intp = self.get_timestamp(args)

    def weighted_chamfer_distance(self, pc_pred, pc_input, NL_flag, alpha=1.6, dist_threshold=10.0, z_min=-1.3, z_max=3.99, z_weight=1.9):
        dist1, dist2, _, _ = self.chamLoss(pc_pred, pc_input)
        dist1 = torch.clamp(dist1, max=dist_threshold)
        dist2 = torch.clamp(dist2, max=dist_threshold)
        if NL_flag:
            z1 = pc_pred[:, :, 2:]
            z2 = pc_input[:, :, 2:]
            z_mask1 = (z1 >= z_min) & (z1 <= z_max)
            z_mask2 = (z2 >= z_min) & (z2 <= z_max)
            z_dist1 = torch.where(z_mask1.squeeze(-1), z_weight * dist1, dist1)
            z_dist2 = torch.where(z_mask2.squeeze(-1), z_weight * dist2, dist2)
        else:
            z_dist1 = dist1
            z_dist2 = dist2

        z_dist1 = dist1
        z_dist2 = dist2
        weight1 = torch.tanh(alpha * z_dist1)  # (B, N)
        weight2 = torch.tanh(alpha * z_dist2)  # (B, M)

        return (weight1 * z_dist1 + weight2 * z_dist2).sum()
    
    def save_tensors_to_dict(self, input_list, gt_list, file_path):
        input_list = [tensor.detach().cpu().numpy() for tensor in input_list]
        gt_list = [tensor.detach().cpu().numpy() for tensor in gt_list]
        max_length = max(len(input_list), len(gt_list))
        input_list.extend([np.array([])] * (max_length - len(input_list)))
        gt_list.extend([np.array([])] * (max_length - len(gt_list)))
        zipped = zip(input_list, gt_list)

        data_dict = {}
        for i, (inp, gt) in enumerate(zipped):
            data_dict[f'input_{i}'] = inp
            data_dict[f'gt_{i}'] = gt
        np.save(file_path, data_dict)
    
    def load_tensors_from_dict(self, file_path):
        data_dict = np.load(file_path, allow_pickle=True).item()
        input_list = []
        gt_list = []
        for key, value in data_dict.items():
            print("key:", key)
            if key.startswith('input'):
                input_list.append(torch.from_numpy(value))
            elif key.startswith('gt'):
                gt_list.append(torch.from_numpy(value))

        return input_list, gt_list   

    def get_recorder(self, args):
        recorder = {}
        recorder['loss_all_CD'] = []
        recorder['loss_all_EMD'] = []

        for i in range(args.interval - 1):
            recorder['loss_frame{}_CD'.format(i+1)] = []
            recorder['loss_frame{}_EMD'.format(i+1)] = []

        return recorder
        

    def get_timestamp(self, args):
        time_seq = [t for t in np.linspace(args.t_begin, args.t_end, args.num_frames)]
        t_left = time_seq[args.num_frames//2 - 1]
        t_right = time_seq[args.num_frames//2]
        time_intp = [t for t in np.linspace(t_left, t_right, args.interval+1)]
        time_intp = time_intp[1:-1]
        
        return time_seq, time_intp

    def adjust_point_clouds(self, point_clouds):
        target_count = point_clouds[0].shape[0]
        
        adjusted_point_clouds = []
        for pc in point_clouds:
            current_count = pc.shape[0]
            if current_count > target_count:
                adjusted_pc = pc[:target_count]
            elif current_count < target_count:
                repeat_times = (target_count + current_count - 1) // current_count
                adjusted_pc = pc.repeat(repeat_times, 1)[:target_count]
            else:
                adjusted_pc = pc
            adjusted_point_clouds.append(adjusted_pc)

        return adjusted_point_clouds

    def loop(self):
        args = self.args
        print("Optimization Start for {} samples!".format(len(self.data_loader)))
        # get one sample from DataLoader (for-loop)
        for idx, (input, gt) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print("\n[Sample: {}]".format(idx+1))
            # input point cloud frames (x,y,z)
            # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
            # input, gt = self.load_single_seq("./NeuroGauss4D/log/visualizations/kitti/rgb00")
            # input = self.adjust_point_clouds(input)
            if idx == 2 :
                for i in range(len(input)):
                    input[i] = input[i].squeeze(0).cuda().contiguous().float()
                for i in range(len(gt)):
                    gt[i] = gt[i].squeeze(0).cuda().contiguous().float()
                best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input, self.model, args, gt, idx=idx)
                if args.PCI_Model == 'NeuroGauss4D':
                    from model.NeuroGauss4D import NeuroGauss4D
                    self.model = NeuroGauss4D(args).cuda()

                self.recorder_error(best_cd, best_emd, intp_cd_best, intp_emd_best)

            # if idx == 0:
            #     best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input, self.model, args, gt, idx=idx)
            #     model.load_state_dict(best_model_weight)
            #     self.recorder_error(best_cd, best_emd, intp_cd_best, intp_emd_best)
            # self.eval_NeuralPCI(input, gt, self.model, best_model_weight, args, idx)
            # break
        # print overall average result
        self.print_result()

    def load_single_seq(self, directory):
        input_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('input_')]
        gt_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('gt_')]
        input_files.sort()
        gt_files.sort()
        input = []
        gt = []

        for f in input_files:
            input.append(torch.from_numpy(np.load(f)))

        for f in gt_files:
            gt.append(torch.from_numpy(np.load(f)))
        return input, gt

    def demo(self):
        args = self.args
        input_list, gt_list = self.load_tensors_from_dict("./NeuroGauss4D/log/visualizations/dense_pointclouds/kitti02_2404_2408.npy")
        # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
        # input_path = args.input_path
        # input = np.load(input_path).tolist()

        input_list = [tensor.cuda().contiguous().float() for tensor in input_list]
        gt_list = [tensor.cuda().contiguous().float() for tensor in gt_list]
        # input_list = [t[torch.randperm(t.size(0))[:16384]] if t.size(0) > 16384 else t for t in input_list]
        # gt_list = [t[torch.randperm(t.size(0))[:16384]] if t.size(0) > 16384 else t for t in gt_list]

        # self.time_seq, self.time_intp = self.get_timestamp(args)

        # best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input_list, model, args, gt_list, idx=0)
        
        pc_intp = self.infer_NeuralPCI(input_list, self.model, best_model_weight)
        pc_intp_save = [tensor.cpu().numpy() for tensor in pc_intp]
        np.save("./NeuroGauss4D/log/visualizations/compare_pc/NL_42_ours_pred.npy", pc_intp_save)


    def get_optimizer(self, model, args):
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            # optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': [self.emd_weight], 'lr': 0.01}], lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgdm':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True)
        elif args.optimizer == 'radam':
            optimizer = torch.optim.RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'nadam':
            optimizer = torch.optim.NAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.99, momentum=0.9, weight_decay=args.weight_decay)
        return optimizer


    def get_scheduler(self, optimizer, args):
        # from warmup_scheduler import GradualWarmupScheduler
        if args.scheduler == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.95)
        elif args.scheduler == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.iters, eta_min=1e-6)
        # elif args.scheduler == 'cosine_warm':
        #     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.iters-1500, T_mult=1, eta_min=1e-6)
        #     scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=500, after_scheduler=scheduler)
        elif args.scheduler == 'poly':
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1-step/args.iters)**0.95 if step>700 else (step/700)*1.5)
        elif args.scheduler == 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-6)
        else:
            raise ValueError(f'Unknown scheduler: {args.scheduler}')
        return scheduler


    def compute_loss(self, pc_current, pc_pred, pc_input, iter):  
        args = self.args
        loss = 0
        # dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), pc_input.unsqueeze(0))
        # chamfer_loss = (dist1 + dist2).sum() * 0.5

        # if pc_current.shape[0]>16384:
        #     for i in range(0, pc_current.shape[0], 16384):
        #         end = min(i + 16384, pc_current.shape[0])
        #         dist1, dist2, _, _ = self.chamLoss(pc_pred[i:end].unsqueeze(0), pc_input[i:end].unsqueeze(0))
        #         chamfer_loss += (dist1 + dist2).sum() * 0.5
        if iter<500 or (iter>2500 and iter<3100):
            dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), pc_input.unsqueeze(0))
            chamfer_loss = (dist1 + dist2).sum() * 0.5
        else:
            chamfer_loss = self.weighted_chamfer_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), args.dataset == 'NL_Drive')
        loss = loss + chamfer_loss * args.factor_cd

        if args.dataset == 'DHB':
            dist = earth_mover_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), transpose=False)
            emd_loss = (dist / pc_pred.shape[0]).mean()
            loss = loss + emd_loss * args.factor_emd

        elif args.dataset == 'NL_Drive':
            flow = pc_pred - pc_current
            smooth_loss = compute_smooth(pc_current.unsqueeze(0), flow.unsqueeze(0), k=9)
            smooth_loss = smooth_loss.squeeze(0).sum()
            loss = loss + smooth_loss * args.factor_smooth
        return loss


    def optimize_NeuralPCI(self, input, model, args, gt, idx=None):
        model.train()
        optimizer = self.get_optimizer(model, args)
        if args.scheduler:
            scheduler = self.get_scheduler(optimizer, args)

        # record optimization best result
        best_cd = np.inf
        best_emd = np.inf
        best_iter = 0
        intp_emd_best=0
        intp_cd_best =0
        best_model_weight = model.state_dict()
        total_train_time = 0 
        train_iterations = 0
        
        counter = 0
        pc_pred_list = []
        for iter in range(args.iters):
        # for iter in trange(args.iters, desc='Iteration', unit='iter'):
            optimizer.zero_grad()
            loss = 0
            # torch.cuda.synchronize()
            start_train_time = time.time()
            # input current point cloud to predict point cloud at time_pred
            # self.time_seq = [0, 0.5, 1.0]
            for i in range(len(input)//2 - 1, len(input)//2 + 1):
                pc_current = input[i]
                time_current = self.time_seq[i]
                
                for j in range(len(input)):
                    time_pred = self.time_seq[j]
                    if args.Accumulate_LiDAR:
                        pc_pred = pc_current.unsqueeze(1).repeat(1, args.Number_Acc, 1).view(args.Number_Acc*pc_current.shape[0], 3) + model(pc_current, time_current, time_pred).view(args.Number_Acc*pc_current.shape[0], 3)
                        loss = loss + self.compute_loss(pc_current.repeat(1, args.Number_Acc, 1).view(args.Number_Acc*pc_current.shape[0], 3), pc_pred, input[j].repeat(1, args.Number_Acc, 1).view(args.Number_Acc*pc_current.shape[0], 3), iter)
                    else:
                        pc_pred = pc_current + model(pc_current, time_current, time_pred)
                        loss = loss + self.compute_loss(pc_current, pc_pred, input[j], iter)

            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if args.scheduler:
                if args.scheduler == 'plateau':
                    scheduler.step(loss.item())
                else:
                    scheduler.step()
            torch.cuda.synchronize()
            end_train_time = time.time()
            total_train_time += end_train_time - start_train_time
            train_iterations += 1
            
            if (iter >200) and (iter % args.eval_freq) == 0:
            # if iter >= 0:
                # pc_pred_list = self.accumulate_pc(input, gt, model, pc_pred_list, args)
                # continue
                cd_error, emd_error, intp_cd_error, intp_emd_error, average_eval_time = self.eval_Train(input, gt, model, args)
                if (best_cd > cd_error.item()) or (best_emd > emd_error.item()):
                    best_iter = iter
                    if best_cd > cd_error.item():
                        best_cd = cd_error.item()
                        intp_cd_best = intp_cd_error
                        best_model_weight = model.state_dict()
                    if best_emd > emd_error.item():
                        best_emd = emd_error.item()
                        intp_emd_best = intp_emd_error
                    counter = 0
                    if (iter > 200) and (iter % args.eval_freq == 0):
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                        print("average_eval_time:", average_eval_time)
                else:
                    counter += 1
                    if best_cd > 0.92:
                        patience = args.eval_freq*200
                    else:
                        patience = args.patience
                    if counter >= patience:
                        print(f"Early stopping at iteration {iter}")
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                        print("average_eval_time:", average_eval_time)

                        cd_error, _, _, _, _ = self.eval_Train(input, gt, model, args, idx = idx, Save_fag=True)
                        break
                if args.Accumulate_LiDAR and (cd_error<(best_cd+0.02)):
                # if cd_error<(best_cd+0.01):
                    pc_pred_list = self.accumulate_pc(input, gt, model, pc_pred_list, args)
                    # pdb.set_trace()
                    if pc_pred_list is None:
                        continue
                model.train()
        torch.cuda.empty_cache()
        average_train_time = total_train_time / train_iterations if train_iterations else 0
        print("average_train_time:", average_train_time)
        return best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best


    def eval_NeuralPCI(self, input, gt, model, weight, args, idx, save_if=False):
        model.eval()
        with torch.no_grad():
            model.load_state_dict(weight)
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]

                    pc_pred.append(pc_current + model(pc_current, time_current, time_pred))

                # NN-Intp
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            for i in range(len(pc_intp)): 
                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]

                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()

                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)

            print("CD CD Eval Eval Eval Eval Eval Eval: {}".format(torch.mean(torch.tensor(intp_cd_error))))
            print("EMD EMD Eval Eval Eval Eval Eval: {}".format(torch.mean(torch.tensor(intp_emd_error))))
            if save_if:
                save_path = "./NeuroGauss4D/log/visualizations/pred_other/NeuralPCI"
                file_name = os.path.basename(args.scenes_list)[:-4]
                save_dir = os.path.join(save_path, file_name, f"{idx:03d}")
                for pc, name in zip([gt, pc_intp], ["gt", "pc_intp"]):
                    for j, tensor in enumerate(pc):
                        save_file = os.path.join(save_dir, f"{name}_{j}.npy")
                        self.save_pointcloud(tensor, save_file)
                for pc, name in zip([input], ["input"]):
                    for j, tensor in enumerate(pc):
                        save_file = os.path.join(save_dir, f"{name}_{j}.npy")
                        self.save_pointcloud(tensor, save_file)
        
    
    def print_result(self):
        recorder = self.recorder
        args = self.args
        print("\n=======================================")
        print("Final Result CD Loss is:{}".format(np.mean(recorder['loss_all_CD'])))
        print("Final Result EMD Loss is:{}".format(np.mean(recorder['loss_all_EMD'])))
        for i in range(args.interval - 1):
            print("=======================================")
            print("Final Frame-{} Result CD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_CD'.format(i+1)])))
            print("Final Frame-{} Result EMD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_EMD'.format(i+1)])))
        print("=======================================")


    def eval_Train(self, input, gt, model, args, idx=100, Save_fag=False):
        model.eval()
        eval_iterations = 0
        total_eval_time = 0
        with torch.no_grad():
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    start_eval_time = time.time()
                    eval_iterations += 1
                    if args.Accumulate_LiDAR:
                        pred_time_pc = pc_current.unsqueeze(1).repeat(1, args.Number_Acc, 1).view(args.Number_Acc*pc_current.shape[0], 3) + model(pc_current, time_current, time_pred).view(args.Number_Acc*pc_current.shape[0], 3)
                    else:
                        pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    torch.cuda.synchronize()
                    end_eval_time = time.time()
                    pc_pred.append(pred_time_pc)
                    total_eval_time += end_eval_time - start_eval_time
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            for i in range(len(pc_intp)):
                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]
                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()
                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)
            # if self.best_cd_val > torch.mean(torch.tensor(intp_cd_error)):
            #     self.best_cd_val = torch.mean(torch.tensor(intp_cd_error))
            #     print("self.best_cd_val:",self.best_cd_val)
            #     save_dir = "./NeuroGauss4D/log/visualizations/pred_other/NeuralPCI/soldier/035"
            #     for pc, name in zip([pc_intp], ["pc_intp"]):
            #         for j, tensor in enumerate(pc):
            #             save_file = os.path.join(save_dir, f"{name}_{j}.npy")
            #             self.save_pointcloud(tensor, save_file)
            if Save_fag:
                # 
                self.save_pointcloud(pc_intp[0], os.path.join("./NeuroGauss4D/log/visualizations/vis/NL_Drive_test4_6/ours_womlp", f"interpc1_{idx}.npy"))
                self.save_pointcloud(gt[0], os.path.join("./NeuroGauss4D/log/visualizations/vis/NL_Drive_test4_6/ours_womlp", f"gtpc1_{idx}.npy"))
            average_eval_time = total_eval_time / eval_iterations if eval_iterations else 0
            return torch.mean(torch.tensor(intp_cd_error)), torch.mean(torch.tensor(intp_emd_error)), intp_cd_error, intp_emd_error, average_eval_time

    def accumulate_pc(self, input, gt, model, pc_pred, args):
        model.eval()
        with torch.no_grad():
            pc_intp = []
            
            # pred_time_pc = input[0] + model(input[0], 0, 0.5, train=False)
            # self.save_pointcloud(pred_time_pc, './NeuroGauss4D/log/visualizations/kitti/rgb00/pred_time_pc2_v1.npy')
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred_ = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    if args.Accumulate_LiDAR:
                        pred_time_pc = pc_current.unsqueeze(1).repeat(1, args.Number_Acc, 1).view(args.Number_Acc*pc_current.shape[0], 3) + model(pc_current, time_current, time_pred).view(args.Number_Acc*pc_current.shape[0], 3)
                    else:
                        pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    pc_pred_.append(pred_time_pc)
                    # # Apply small noise perturbations to pc_current
                    # pc_current_perturbed_1 = pc_current + torch.randn_like(pc_current) * 0.01
                    # pc_current_perturbed_2 = pc_current + torch.randn_like(pc_current) * 0.01

                    # pred_time_pc_1 = pc_current_perturbed_1 + model(pc_current_perturbed_1, time_current, time_pred, train=False)
                    # pred_time_pc_2 = pc_current_perturbed_2 + model(pc_current_perturbed_2, time_current, time_pred, train=False)
                    # pred_time_pc = torch.cat([pred_time_pc_1, pred_time_pc_2], dim=0)
                    pc_pred.append(pred_time_pc)
                if j < 0.5 * len(self.time_intp):
                    pc_pred.append(pc_pred_[0].cuda())
                else:
                    pc_pred.append(pc_pred_[1].cuda())
            self.accumulate += 1
            print("Accumulate point cloud:", self.accumulate)
            if self.accumulate >= 40:
                dense_pc = torch.cat(pc_pred, dim=0) 
                save_path = os.path.join('./NeuroGauss4D/log/visualizations/dense_pointclouds', 'Test_dense.ply')
                # pdb.set_trace()
                self.save_pointcloud_as_ply(dense_pc, (123, 102, 140), save_path)
                self.save_pointcloud(dense_pc, save_path[:-4]+".npy")
                self.save_pointcloud_as_ply(input[1], (100, 20, 0), save_path[:-4]+"_pc1.ply")
                self.save_pointcloud_as_ply(input[2], (100, 20, 0), save_path[:-4]+"_pc2.ply")
                self.save_pointcloud_as_ply(gt[0], (0, 200, 0), save_path[:-4]+"_gt.ply")
                pc_pred = None
                self.accumulate = 0
                pdb.set_trace()
            return pc_pred
        
    def recorder_error(self, best_cd, best_emd, intp_cd_best, intp_emd_best):
        print("Eval: Average Interpolation CD Error: {}".format(best_cd))
        print("Eval: Average Interpolation EMD Error: {}".format(best_emd))
        self.recorder['loss_all_CD'].append(torch.mean(torch.tensor(intp_cd_best)).item())
        self.recorder['loss_all_EMD'].append(torch.mean(torch.tensor(intp_emd_best)).item())
        for i in range(len(intp_cd_best)):
            self.recorder['loss_frame{}_CD'.format(i+1)].append(intp_cd_best[i].item())
            self.recorder['loss_frame{}_EMD'.format(i+1)].append(intp_emd_best[i].item())
        
    def save_pointcloud_as_ply(self, tensor, rgb, save_path):
        points = tensor.cpu().numpy()
        ply_header = """ply
            format ascii 1.0
            element vertex {}
            property float x
            property float y
            property float z
            property uchar red
            property uchar green
            property uchar blue
            end_header
            """.format(points.shape[0])
        points_with_rgb = np.concatenate((points, np.tile(np.array(rgb, dtype=np.uint8), (points.shape[0], 1))), axis=1)

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'w') as ply_file:
            ply_file.write(ply_header)
            np.savetxt(ply_file, points_with_rgb, fmt="%f %f %f %d %d %d")
        
        print(f"Point cloud saved as {save_path}")

    def save_pointcloud(self, tensor, save_path):
        tensor = tensor.cpu().numpy()
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        np.save(save_path, tensor)
        # print(f"Point cloud saved as {save_path}")

################################ ################################ ################################ ################################ ################################ 
    
    def optimize_4DGS_SF(self, pc_current, gt_pc, model, args, gt, idx, ii):
        model.train()
        optimizer = self.get_optimizer(model, args)
        if args.scheduler:
            scheduler = self.get_scheduler(optimizer, args)
        best_epe, best_outlier = 10000, 10000
        patience = args.patience
        sf_metrics_best = {'EPE3d': 1, '5cm': 100, '10cm': 100, 'outlier': 100, 'time': 0}
        best_model_weight = model.state_dict()
        for iter in range(args.iters):
        # for iter in trange(args.iters, desc='Iteration', unit='iter'):
            optimizer.zero_grad()
            flow = model(pc_current, 0, 1)
            pc_pred = pc_current + flow
            loss = self.compute_sf_loss(pc_current, pc_pred, gt_pc, flow, iter, args)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            # if (iter >700):
            #     with torch.no_grad():
            #         self.save_pointcloud(pc_pred, "./NeuroGauss4D/log/visualizations/dense_pointclouds/02_dense_sf/pc_{}_{}.npy".format(idx, ii))
            #         break
            # continue
            
            if (iter >200) and (iter % args.eval_freq) == 0:
            # if iter >= 0:
                sf_metrics = self.eval_sf_Train(pc_current, gt, model)
                if (best_epe>sf_metrics['EPE3d']) or (best_outlier > sf_metrics['outlier']):
                    best_iter = iter
                    if (best_epe > sf_metrics['EPE3d']):
                        best_epe = sf_metrics['EPE3d']
                        best_model_weight = model.state_dict()
                        sf_metrics_best = sf_metrics
                    if (best_outlier > sf_metrics['outlier']):
                        best_outlier = sf_metrics['outlier']
                    counter = 0
                    if (iter > 200) and (iter % 2 == 0):
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best EPE: {best_epe:.5f}] [Best Outlier: {best_outlier:.5f}]")
                else:
                    if (best_epe > 0.04) or (best_outlier>30.0):
                        patience = args.patience*1.5
                    counter += 1
                    if counter >= patience:
                        print(f"Early stopping at iteration {iter}")
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best EPE: {best_epe:.5f}] [Best best_outlier: {best_outlier:.5f}]")
                        break
                model.train()
                if best_epe>1 or loss.item()>1000:
                    self.weights_init(model)
                sf_metrics_best['outlier'] = best_outlier
        return best_model_weight, sf_metrics_best

    def eval_sf_Train(self, pc_current, flow_gt, model):
        # Set the model to evaluation mode
        model.eval()
        
        # Start timing the evaluation
        start = timeit.default_timer()
        torch.cuda.synchronize()
        flow_pred = model(pc_current, 0, 1)
        torch.cuda.synchronize()
        stop = timeit.default_timer()
        
        # Calculate the 3D end-point error (EPE)
        epe3d_map = torch.sqrt(torch.sum((flow_pred - flow_gt) ** 2, dim=1))
        
        if flow_gt.shape[1] > 3:
            flow_3d_mask = flow_gt[:, 3] > 0
        else:
            flow_3d_mask = torch.ones(flow_gt.shape[0], dtype=torch.bool, device=flow_pred.device)
        
        # Filter out invalid points and NaN values in the EPE map
        flow_3d_mask = torch.logical_and(flow_3d_mask, torch.logical_not(torch.isnan(epe3d_map)))
        
        # Compute the norm of the ground truth flow vectors only for valid points
        flow_3d_target = flow_gt[:, :3]
        flow_3d_norm = torch.sqrt(torch.sum(flow_3d_target ** 2, dim=1))
        
        # Calculate metrics only for valid flow points
        valid_epe3d = epe3d_map[flow_3d_mask]
        valid_flow_3d_norm = flow_3d_norm[flow_3d_mask]

        metrics = {
            'EPE3d': valid_epe3d.mean().item(),
            '5cm': torch.count_nonzero(torch.logical_or((valid_epe3d < 0.05), (valid_epe3d / valid_flow_3d_norm < 0.05))).item() / flow_3d_mask.sum().item() * 100.0,
            '10cm': torch.count_nonzero(torch.logical_or((valid_epe3d < 0.1), (valid_epe3d / valid_flow_3d_norm < 0.1))).item() / flow_3d_mask.sum().item() * 100.0,
            'outlier': torch.count_nonzero(torch.logical_or((valid_epe3d > 0.3), (valid_epe3d / valid_flow_3d_norm > 0.1))).item() / flow_3d_mask.sum().item() * 100.0,
            'time': stop - start,
            }
        return metrics


    def loop_sf(self):
        args = self.args
        print(f"Optimization Start for {len(self.data_loader)} samples!")
        cumulative_metrics = {'EPE3d': 0, '5cm': 0, '10cm': 0, 'outlier': 0, 'time': 0}
        sample_count = 0
        for idx, (pc1, pc2, flow_3d) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print(f"\n[Sample: {idx + 1}]")
            # Perform optimization and get metrics
            best_model_weight, sf_metrics_best = self.optimize_4DGS_SF(pc1.squeeze(0).cuda().contiguous().float(), pc2.squeeze(0).cuda().contiguous().float(), self.model, args, flow_3d.squeeze(0).cuda().contiguous().float(), idx, 1)
            self.model.load_state_dict(best_model_weight)
            # Accumulate metrics
            for key in sf_metrics_best:
                cumulative_metrics[key] += sf_metrics_best[key]
            sample_count += 1
            print("Validation KITTI EPE: %.4f, 5cm: %.4f, 10cm: %.4f, outlier: %.4f" % (sf_metrics_best['EPE3d'], sf_metrics_best['5cm'], sf_metrics_best['10cm'], sf_metrics_best['outlier']))
        # Calculate average metrics
        if sample_count > 0:
            for key in cumulative_metrics:
                cumulative_metrics[key] /= sample_count

        print("#### Average Metrics Across All Samples ####")
        print(f"EPE: {cumulative_metrics['EPE3d']:.4f}, 5cm: {cumulative_metrics['5cm']:.4f}%, 10cm: {cumulative_metrics['10cm']:.4f}%, outlier: {cumulative_metrics['outlier']:.4f}%")
    
    def Accumul_sf(self):
        data_path = './NeuroGauss4D/log/visualizations/dense_pointclouds/02_dense_sf/bin/'
        dataset = KittiPointCloudDataset(data_path)
        args = self.args
        print(f"Optimization Start for {len(self.data_loader)} samples!")

        for idx, (pc1, pc2) in enumerate(dataset):
            pc1_chunks = torch.split(torch.from_numpy(pc1), 16384, dim=0)[:-1]
            pc2_chunks = torch.split(torch.from_numpy(pc2), 16384, dim=0)[:-1]
            # self.save_pointcloud(pc2_chunks[1], "./NeuroGauss4D/log/visualizations/dense_pointclouds/02_dense_sf/GT_000792_.npy")
            # pdb.set_trace()
            for i, pc1_chunk in enumerate(pc1_chunks):
                if i<len(pc2_chunks):
                    pc2_chunk = pc2_chunks[i]
                    flow_3d = pc2_chunk - pc1_chunk
                    self.optimize_4DGS_SF(pc1_chunk.float().cuda(), pc2_chunk.float().cuda(), self.model, args, flow_3d.float().cuda(), idx, i)

    def compute_sf_loss(self, pc_current, pc_pred, gt_pc, flow, iter, args):
        loss = 0
        if iter<500 or (iter>2500 and iter<3100):
            dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), gt_pc.unsqueeze(0))
            chamfer_loss = (dist1 + dist2).sum() * 0.5
        else:
            chamfer_loss = self.weighted_chamfer_distance(pc_pred.unsqueeze(0), gt_pc.unsqueeze(0), args.dataset == 'NL_Drive')
        loss = loss + chamfer_loss * args.factor_cd
        smooth_loss = compute_smooth(pc_current.unsqueeze(0), flow.unsqueeze(0), k=9)
        smooth_loss = smooth_loss.squeeze(0).sum()
        loss = loss + smooth_loss * args.factor_smooth
        return loss

    def weights_init(self, m):
        import torch.nn as nn
        if isinstance(m, nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            torch.nn.init.constant_(m.weight, 1)
            torch.nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            torch.nn.init.kaiming_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0)