import os
import pdb
import sys
import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from data.MultiframeDataset import DHBDataset, NLDriveDataset, KITTI_hplflownet, KITTI_flownet3d
from model.NeuralPCI import NeuralPCI
from pytorch3d.loss import chamfer_distance
from utils.smooth_loss import compute_smooth
from utils.utils import *
sys.path.append('./utils/EMD')
sys.path.append('./utils/CD')
from emd import earth_mover_distance
import chamfer3D.dist_chamfer_3D
from torch.nn.utils import clip_grad_norm_
import timeit


class Runner:
    def __init__(self, args):
        self.args = args
        # Init Dataset
        if args.dataset == 'DHB':
            self.dataset = DHBDataset(data_root = args.dataset_path, 
                                      scene_list = args.scenes_list, 
                                      interval = args.interval) 
        elif args.dataset == 'NL_Drive':
            self.dataset = NLDriveDataset(data_root = args.dataset_path, 
                                          scene_list = args.scenes_list, 
                                          interval = args.interval,
                                          num_points = args.num_points, 
                                          num_frames = args.num_frames)
        elif args.dataset == 'KITTI_s':
            self.dataset = KITTI_hplflownet(data_root = args.dataset_path, num_points = args.num_points)
        elif args.dataset == 'KITTI_o':
            self.dataset = KITTI_flownet3d(root = args.dataset_path, num_points = args.num_points)

        # Init DataLoader
        self.data_loader = DataLoader(self.dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=4)

        self.chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist()

        self.model = NeuralPCI(args).cuda()
        print("Number of parameters in model is {:.3f}M".format(sum(tensor.numel() for tensor in self.model.parameters())/1e6))
        # self.best_cd_val = 100
        self.accumulate = 0

        # Loss recorder
        self.recorder = self.get_recorder(args)

        # Define time stamp for input and interpolation frames 
        self.time_seq, self.time_intp = self.get_timestamp(args)

    def weighted_chamfer_distance(self, pc_pred, pc_input, NL_flag, alpha=1.6, dist_threshold=10.0, z_min=-1.3, z_max=3.99, z_weight=1.9):
        dist1, dist2, _, _ = self.chamLoss(pc_pred, pc_input)
        dist1 = torch.clamp(dist1, max=dist_threshold)
        dist2 = torch.clamp(dist2, max=dist_threshold)
        # if NL_flag:
        #     z1 = pc_pred[:, :, 2:]
        #     z2 = pc_input[:, :, 2:]
        #     z_mask1 = (z1 >= z_min) & (z1 <= z_max)
        #     z_mask2 = (z2 >= z_min) & (z2 <= z_max)
        #     z_dist1 = torch.where(z_mask1.squeeze(-1), z_weight * dist1, dist1)
        #     z_dist2 = torch.where(z_mask2.squeeze(-1), z_weight * dist2, dist2)
        # else:
        #     z_dist1 = dist1
        #     z_dist2 = dist2

        z_dist1 = dist1
        z_dist2 = dist2
        weight1 = torch.tanh(alpha * z_dist1)  # (B, N)
        weight2 = torch.tanh(alpha * z_dist2)  # (B, M)

        return (weight1 * z_dist1 + weight2 * z_dist2).sum()
    
    def save_tensors_to_dict(self, input_list, gt_list, file_path):
        input_list = [tensor.detach().cpu().numpy() for tensor in input_list]
        gt_list = [tensor.detach().cpu().numpy() for tensor in gt_list]
        max_length = max(len(input_list), len(gt_list))
        input_list.extend([np.array([])] * (max_length - len(input_list)))
        gt_list.extend([np.array([])] * (max_length - len(gt_list)))
        zipped = zip(input_list, gt_list)

        data_dict = {}
        for i, (inp, gt) in enumerate(zipped):
            data_dict[f'input_{i}'] = inp
            data_dict[f'gt_{i}'] = gt
        np.save(file_path, data_dict)
    
    def load_tensors_from_dict(self, file_path):
        data_dict = np.load(file_path, allow_pickle=True).item()
        input_list = []
        gt_list = []
        for key, value in data_dict.items():
            if key.startswith('input'):
                input_list.append(torch.from_numpy(value))
            elif key.startswith('gt'):
                gt_list.append(torch.from_numpy(value))

        return input_list, gt_list   

    def get_recorder(self, args):
        recorder = {}
        recorder['loss_all_CD'] = []
        recorder['loss_all_EMD'] = []

        for i in range(args.interval - 1):
            recorder['loss_frame{}_CD'.format(i+1)] = []
            recorder['loss_frame{}_EMD'.format(i+1)] = []

        return recorder
        

    def get_timestamp(self, args):
        time_seq = [t for t in np.linspace(args.t_begin, args.t_end, args.num_frames)]
        t_left = time_seq[args.num_frames//2 - 1]
        t_right = time_seq[args.num_frames//2]
        time_intp = [t for t in np.linspace(t_left, t_right, args.interval+1)]
        time_intp = time_intp[1:-1]
        
        return time_seq, time_intp


    def loop(self):
        args = self.args
        print("Optimization Start for {} samples!".format(len(self.data_loader)))
        # get one sample from DataLoader (for-loop)
        for idx, (input, gt) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print("\n[Sample: {}]".format(idx+1))
            # input point cloud frames (x,y,z)
            # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
            # input, gt = self.load_single_seq("./NeuroGauss4Dlog/visualizations/pred_our/soldier/035")
            for i in range(len(input)):
                input[i] = input[i].squeeze(0).cuda().contiguous().float()
            for i in range(len(gt)):
                gt[i] = gt[i].squeeze(0).cuda().contiguous().float()
            model = NeuralPCI(args).cuda()
            best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input, model, args, gt, idx=idx)
            self.recorder_error(best_cd, best_emd, intp_cd_best, intp_emd_best)
            # if idx == 0:
            #     best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input, model, args, gt, idx=idx)
            #     model.load_state_dict(best_model_weight)
            #     self.recorder_error(best_cd, best_emd, intp_cd_best, intp_emd_best)
            # # self.eval_NeuralPCI(input, gt, model, best_model_weight, args, idx)
            # break
        # print overall average result
        self.print_result()

    def load_single_seq(self, directory):
        input_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('input_')]
        gt_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith('gt_')]
        input_files.sort()
        gt_files.sort()
        input = []
        gt = []

        for f in input_files:
            input.append(torch.from_numpy(np.load(f)))

        for f in gt_files:
            gt.append(torch.from_numpy(np.load(f)))
        return input, gt

    def demo(self):
        args = self.args
        input_list, gt_list = self.load_tensors_from_dict("./NeuroGauss4Dlog/visualizations/compare_pc/NL_PC_42.npy")

        model = NeuralPCI(args).cuda()
        # input = [pc1, pc2, pc3, pc4], a list of numpy arrays, each point cloud shape is (N, 3)
        # input_path = args.input_path
        # input = np.load(input_path).tolist()
        input_list = [tensor.cuda().contiguous().float() for tensor in input_list]
        gt = [tensor.cuda().contiguous().float() for tensor in gt_list]

        self.time_seq, self.time_intp = self.get_timestamp(args)
        # pdb.set_trace()
        best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best = self.optimize_NeuralPCI(input_list, model, args, gt, idx=0)

        pc_intp = self.infer_NeuralPCI(input_list, model, best_model_weight)
        pc_intp_save = [tensor.cpu().numpy() for tensor in pc_intp]
        np.save("./NeuroGauss4Dlog/visualizations/compare_pc/NL_42_ours_pred.npy", pc_intp_save)


    def get_optimizer(self, model, args):
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            # optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': [self.emd_weight], 'lr': 0.01}], lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'adamw':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'sgdm':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=True)
        elif args.optimizer == 'radam':
            optimizer = torch.optim.RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'nadam':
            optimizer = torch.optim.NAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.optimizer == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, alpha=0.99, momentum=0.9, weight_decay=args.weight_decay)
        return optimizer


    def get_scheduler(self, optimizer, args):
        # from warmup_scheduler import GradualWarmupScheduler
        if args.scheduler == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)
        elif args.scheduler == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.iters, eta_min=1e-6)
        # elif args.scheduler == 'cosine_warm':
        #     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.iters-1500, T_mult=1, eta_min=1e-6)
        #     scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=500, after_scheduler=scheduler)
        elif args.scheduler == 'poly':
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1-step/args.iters)**0.9 if step>1500 else (step/1500)*1.5)
        elif args.scheduler == 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-6)
        else:
            raise ValueError(f'Unknown scheduler: {args.scheduler}')
        return scheduler


    def compute_loss(self, pc_current, pc_pred, pc_input, iter):  
        args = self.args
        loss = 0
        # dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), pc_input.unsqueeze(0))
        # chamfer_loss = (dist1 + dist2).sum() * 0.5
        if iter<500 or (iter>2500 and iter<3100):
            dist1, dist2, _, _ = self.chamLoss(pc_pred.unsqueeze(0), pc_input.unsqueeze(0))
            chamfer_loss = (dist1 + dist2).sum() * 0.5
        else:
            chamfer_loss = self.weighted_chamfer_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), args.dataset == 'NL_Drive')
        loss = loss + chamfer_loss * args.factor_cd

        if args.dataset == 'DHB':
            dist = earth_mover_distance(pc_pred.unsqueeze(0), pc_input.unsqueeze(0), transpose=False)
            emd_loss = (dist / pc_pred.shape[0]).mean()
            loss = loss + emd_loss * args.factor_emd

        elif args.dataset == 'NL_Drive':
            flow = pc_pred - pc_current
            smooth_loss = compute_smooth(pc_current.unsqueeze(0), flow.unsqueeze(0), k=9)
            smooth_loss = smooth_loss.squeeze(0).sum()
            loss = loss + smooth_loss * args.factor_smooth
        return loss


    def optimize_NeuralPCI(self, input, model, args, gt, idx=None):
        model.train()
        optimizer = self.get_optimizer(model, args)
        if args.scheduler:
            scheduler = self.get_scheduler(optimizer, args)

        # record optimization best result
        best_cd = np.inf
        best_emd = np.inf
        best_iter = 0
        intp_emd_best=0
        intp_cd_best =0
        best_model_weight = model.state_dict()
        patience = args.patience
        counter = 0
        pc_pred_list = []
        for iter in range(args.iters):
            optimizer.zero_grad()
            loss = 0
            # input current point cloud to predict point cloud at time_pred
            for i in range(len(input)//2 - 1, len(input)//2 + 1):
                pc_current = input[i]
                time_current = self.time_seq[i]
                
                for j in range(len(input)):
                    time_pred = self.time_seq[j]
                    pc_pred = pc_current + model(pc_current, time_current, time_pred)
                    loss = loss + self.compute_loss(pc_current, pc_pred, input[j], iter)

            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if args.scheduler:
                if args.scheduler == 'plateau':
                    scheduler.step(loss.item())
                else:
                    scheduler.step()
            if (iter >1800) and (iter % args.eval_freq) == 0:
            # if iter >= 0:
                cd_error, emd_error, intp_cd_error, intp_emd_error = self.eval_Train(input, gt, model)
                if (best_cd > cd_error.item()) or (best_emd > emd_error.item()):
                    best_iter = iter
                    if best_cd > cd_error.item():
                        best_cd = cd_error.item()
                        intp_cd_best = intp_cd_error
                        best_model_weight = model.state_dict()
                    if best_emd > emd_error.item():
                        best_emd = emd_error.item()
                        intp_emd_best = intp_emd_error
                    counter = 0
                    if (iter > 500) and (iter % 50 == 0):
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                else:
                    counter += 1
                    if counter >= patience:
                        print(f"Early stopping at iteration {iter}")
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_cd:.5f}] [Best EMD: {best_emd:.5f}]")
                        break
                if args.Accumulate_LiDAR and (cd_error<(best_cd+0.006)):
                    pc_pred_list = self.accumulate_pc(input, gt, model, pc_pred_list)
                    if pc_pred_list is None:
                        continue
                model.train()
        return best_model_weight, best_cd, best_emd, intp_cd_best, intp_emd_best


    def eval_NeuralPCI(self, input, gt, model, weight, args, idx, save_if=False):
        model.eval()
        with torch.no_grad():
            model.load_state_dict(weight)
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]

                    pc_pred.append(pc_current + model(pc_current, time_current, time_pred))

                # NN-Intp
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            for i in range(len(pc_intp)): 
                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]

                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()

                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)

            print("CD CD Eval Eval Eval Eval Eval Eval: {}".format(torch.mean(torch.tensor(intp_cd_error))))
            print("EMD EMD Eval Eval Eval Eval Eval: {}".format(torch.mean(torch.tensor(intp_emd_error))))
            if save_if:
                save_path = "./NeuroGauss4Dlog/visualizations/pred_other/NeuralPCI"
                file_name = os.path.basename(args.scenes_list)[:-4]
                save_dir = os.path.join(save_path, file_name, f"{idx:03d}")
                for pc, name in zip([gt, pc_intp], ["gt", "pc_intp"]):
                    for j, tensor in enumerate(pc):
                        save_file = os.path.join(save_dir, f"{name}_{j}.npy")
                        self.save_pointcloud(tensor, save_file)
                for pc, name in zip([input], ["input"]):
                    for j, tensor in enumerate(pc):
                        save_file = os.path.join(save_dir, f"{name}_{j}.npy")
                        self.save_pointcloud(tensor, save_file)
        
    
    def print_result(self):
        recorder = self.recorder
        args = self.args
        print("\n=======================================")
        print("Final Result CD Loss is:{}".format(np.mean(recorder['loss_all_CD'])))
        print("Final Result EMD Loss is:{}".format(np.mean(recorder['loss_all_EMD'])))
        for i in range(args.interval - 1):
            print("=======================================")
            print("Final Frame-{} Result CD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_CD'.format(i+1)])))
            print("Final Frame-{} Result EMD Loss is:{}".format(i+1, np.mean(recorder['loss_frame{}_EMD'.format(i+1)])))
        print("=======================================")


    def eval_Train(self, input, gt, model):
        model.eval()
        with torch.no_grad():
            
            pc_intp = []
            # Predict the inpolation frame at time_intp
            # pdb.set_trace()
            for j in range(len(self.time_intp)):
                pc_pred = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    pc_pred.append(pred_time_pc)
                if j < 0.5 * len(self.time_intp):
                    pc_intp.append(pc_pred[0].cuda())
                else:
                    pc_intp.append(pc_pred[1].cuda())
            
            intp_cd_error = []
            intp_emd_error = []
            for i in range(len(pc_intp)):
                cd_error = chamfer_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0))[0]
                dist = earth_mover_distance(pc_intp[i].unsqueeze(0), gt[i].unsqueeze(0), transpose=False)
                emd_error = (dist / pc_intp[i].shape[0]).mean()
                intp_cd_error.append(cd_error)
                intp_emd_error.append(emd_error)
            # if self.best_cd_val > torch.mean(torch.tensor(intp_cd_error)):
            #     self.best_cd_val = torch.mean(torch.tensor(intp_cd_error))
            #     print("self.best_cd_val:",self.best_cd_val)
            #     save_dir = "./NeuroGauss4Dlog/visualizations/pred_other/NeuralPCI/soldier/035"
            #     for pc, name in zip([pc_intp], ["pc_intp"]):
            #         for j, tensor in enumerate(pc):
            #             save_file = os.path.join(save_dir, f"{name}_{j}.npy")
            #             self.save_pointcloud(tensor, save_file)

            return torch.mean(torch.tensor(intp_cd_error)), torch.mean(torch.tensor(intp_emd_error)), intp_cd_error, intp_emd_error

    def accumulate_pc(self, input, gt, model, pc_pred):
        model.eval()
        with torch.no_grad():
            pc_intp = []
            # Predict the inpolation frame at time_intp
            for j in range(len(self.time_intp)):
                pc_pred_ = []
                time_pred = self.time_intp[j]

                # Use nearest input frame as the reference
                for i in range(len(input)//2 - 1, len(input)//2 + 1):
                    pc_current = input[i]
                    time_current = self.time_seq[i]
                    pred_time_pc = pc_current + model(pc_current, time_current, time_pred, train=False)
                    pc_pred_.append(pred_time_pc)
                if j < 0.5 * len(self.time_intp):
                    pc_pred.append(pc_pred_[0].cuda())
                else:
                    pc_pred.append(pc_pred_[1].cuda())

                    # Apply small noise perturbations to pc_current
                    # pc_current_perturbed_1 = pc_current + torch.randn_like(pc_current) * 0.01
                    # pc_current_perturbed_2 = pc_current + torch.randn_like(pc_current) * 0.01

                    # pred_time_pc_1 = pc_current_perturbed_1 + model(pc_current_perturbed_1, time_current, time_pred, train=False)
                    # pred_time_pc_2 = pc_current_perturbed_2 + model(pc_current_perturbed_2, time_current, time_pred, train=False)
                    # pred_time_pc = torch.cat([pred_time_pc_1, pred_time_pc_2], dim=0)
                    # pc_pred.append(pred_time_pc)
            self.accumulate += 1
            print("Accumulate point cloud:", self.accumulate)

            if self.accumulate >= 80:
                dense_pc = torch.cat(pc_pred, dim=0) 
                save_path = os.path.join('./NeuroGauss4Dlog/visualizations/dense_pointclouds', 'Test_dense.ply')
                # pdb.set_trace()
                self.save_pointcloud_as_ply(dense_pc, (123, 102, 140), save_path)
                self.save_pointcloud(dense_pc, save_path[:-4]+".npy")
                pc_pred = None
                self.accumulate = 0
            return pc_pred
        
    def recorder_error(self, best_cd, best_emd, intp_cd_best, intp_emd_best):
        print("Eval: Average Interpolation CD Error: {}".format(best_cd))
        print("Eval: Average Interpolation EMD Error: {}".format(best_emd))
        self.recorder['loss_all_CD'].append(torch.mean(torch.tensor(intp_cd_best)).item())
        self.recorder['loss_all_EMD'].append(torch.mean(torch.tensor(intp_emd_best)).item())
        for i in range(len(intp_cd_best)):
            self.recorder['loss_frame{}_CD'.format(i+1)].append(intp_cd_best[i].item())
            self.recorder['loss_frame{}_EMD'.format(i+1)].append(intp_emd_best[i].item())
        
    def save_pointcloud_as_ply(self, tensor, rgb, save_path):
        points = tensor.cpu().numpy()
        ply_header = """ply
            format ascii 1.0
            element vertex {}
            property float x
            property float y
            property float z
            property uchar red
            property uchar green
            property uchar blue
            end_header
            """.format(points.shape[0])
        points_with_rgb = np.concatenate((points, np.tile(np.array(rgb, dtype=np.uint8), (points.shape[0], 1))), axis=1)

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'w') as ply_file:
            ply_file.write(ply_header)
            np.savetxt(ply_file, points_with_rgb, fmt="%f %f %f %d %d %d")
        
        print(f"Point cloud saved as {save_path}")

    def save_pointcloud(self, tensor, save_path):
        tensor = tensor.cpu().numpy()
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        np.save(save_path, tensor)
        # print(f"Point cloud saved as {save_path}")




    def optimize_4DGS_SF(self, pc_current, gt_pc, model, args, gt):
        model.train()
        optimizer = self.get_optimizer(model, args)
        if args.scheduler:
            scheduler = self.get_scheduler(optimizer, args)
        best_epe, best_outlier = 10000000, 10000000
        for iter in range(args.iters):
            optimizer.zero_grad()
            loss = 0
            flow = model(pc_current, 0, 1)
            pc_pred = pc_current + flow
            loss = loss + self.compute_sf_loss(pc_pred, gt_pc, iter)
            loss.backward()
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            if (iter >400) and (iter % args.eval_freq) == 0:
            # if iter >= 0:
                sf_metrics = self.eval_sf_Train(pc_current, gt, model)
                if (best_epe > metrics['EPE3d']) or (best_outlier > metrics['outlier']):
                    best_iter = iter
                    if best_epe > metrics['EPE3d']:
                        best_epe = metrics['EPE3d']
                        best_model_weight = model.state_dict()
                    if best_outlier > metrics['outlier']:
                        best_outlier = metrics['outlier']
                    counter = 0
                    if (iter > 500) and (iter % 50 == 0):
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_epe:.5f}] [Best EMD: {best_outlier:.5f}]")
                else:
                    counter += 1
                    if counter >= patience:
                        print(f"Early stopping at iteration {iter}")
                        current_lr = optimizer.param_groups[0]['lr']
                        print(f"[Iter: {iter}] [Loss: {loss.item():.5f}] [LR: {current_lr:.5f}] [Best Iter: {best_iter}] [Best CD: {best_epe:.5f}] [Best EMD: {best_outlier:.5f}]")
                        break
                model.train()
        return best_model_weight, sf_metrics

    def eval_sf_Train(self, pc_current, flow_gt, model):
        # Set the model to evaluation mode
        model.eval()
        
        # Start timing the evaluation
        start = timeit.default_timer()
        flow_pred = model(pc_current, 0, 1)
        stop = timeit.default_timer()
        if np.random.rand(1) < 0.2:
            print(f"Eval time: {(stop - start):.4f} seconds")
        
        # Calculate the 3D end-point error (EPE)
        epe3d_map = torch.sqrt(torch.sum((flow_pred - flow_gt) ** 2, dim=1))
        
        # Create a mask for valid points if the ground truth contains a validity mask
        if flow_gt.shape[1] > 3:
            valid_mask = flow_gt[:, 3] > 0
            flow_gt = flow_gt[:, :3]
        else:
            valid_mask = torch.ones(flow_gt.shape[0], dtype=torch.bool, device=flow_pred.device)
        
        # Filter out invalid points and NaN values in the EPE map
        valid_mask = torch.logical_and(valid_mask, torch.logical_not(torch.isnan(epe3d_map)))
        
        # Calculate metrics only for valid flow points
        valid_epe3d = epe3d_map[valid_mask]
        metrics = {
            'EPE3d': valid_epe3d.mean().item(),
            '5cm': (valid_epe3d < 0.05).float().mean().item() * 100.0,
            '10cm': (valid_epe3d < 0.1).float().mean().item() * 100.0,
            'outlier': (valid_epe3d > 0.3).float().mean().item() * 100.0
        }
        return metrics
        ## print(f"EPE: {metrics['EPE3d']:.4f}, 5cm: {metrics['5cm']:.4f}%, 10cm: {metrics['10cm']:.4f}%, outlier: {metrics['outlier']:.4f}%")

    def loop_sf(self):
        args = self.args
        print(f"Optimization Start for {len(self.data_loader)} samples!")
        model = NeuralPCI(args).cuda()
        cumulative_metrics = {'EPE3d': 0, '5cm': 0, '10cm': 0, 'outlier': 0}
        sample_count = 0
        for idx, (pc1, pc2, flow_3d) in tqdm(enumerate(self.data_loader), total=len(self.data_loader), file=sys.stdout):
            print(f"\n[Sample: {idx + 1}]")
            try:
                # Perform optimization and get metrics
                best_model_weight, sf_metrics = optimize_4DGS_SF(pc1, pc2, model, args, flow_3d)

                # Accumulate metrics
                for key in sf_metrics:
                    cumulative_metrics[key] += sf_metrics[key]
                sample_count += 1
            except Exception as e:
                print(f"Error processing sample {idx + 1}: {e}")

        # Calculate average metrics
        if sample_count > 0:
            for key in cumulative_metrics:
                cumulative_metrics[key] /= sample_count

        print("#### Average Metrics Across All Samples ####")
        print(f"EPE: {cumulative_metrics['EPE3d']:.4f}, 5cm: {cumulative_metrics['5cm']:.4f}%, 10cm: {cumulative_metrics['10cm']:.4f}%, outlier: {cumulative_metrics['outlier']:.4f}%")
