import torch
import torch.nn as nn
from tools import builder
from utils import misc, dist_utils
import time
from utils.logger import *
from utils.AverageMeter import AverageMeter

import numpy as np
from datasets import data_transforms
from pointnet2_ops import pointnet2_utils
from torchvision import transforms

#To calaulate param and flops
from thop import profile
# try:
#     from mmcv.cnn import get_model_complexity_info
# except ImportError:
#     raise ImportError('Please upgrade mmcv to >0.6.2')

import time
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table, flop_count_str

# train_transforms = transforms.Compose(
#     [
#          data_transforms.PointcloudScaleAndTranslate(),
#     ]
# )

# test_transforms = transforms.Compose(
#     [
#         data_transforms.PointcloudScaleAndTranslate(),
#     ]
# )

def rotate_y(points, angle):
    """绕Y轴旋转点云（角度单位：度）"""
    angle_rad = np.radians(angle)
    cos_ang = np.cos(angle_rad)
    sin_ang = np.sin(angle_rad)
    rot_matrix = np.array([
        [cos_ang, 0, sin_ang],
        [0, 1, 0],
        [-sin_ang, 0, cos_ang]
    ], dtype=np.float32)
    return points @ torch.tensor(rot_matrix, device=points.device).T  # 适配GPU

def translate_z(points, delta):
    """沿Z轴平移点云"""
    translated = points.clone()
    translated[..., 2] += delta
    return translated

def scale_points(points, scale_range):
    """按范围缩放点云（单次随机缩放因子）"""
    scale = np.random.uniform(scale_range[0], scale_range[1])
    return points * scale

def jitter_points(points, sigma):
    """添加高斯抖动"""
    jitter = torch.normal(0, sigma, size=points.shape, device=points.device)
    return points + jitter

def test_robustness(base_model, test_dataloader, args, config, logger=None):
    base_model.eval()
    npoints = config.npoints
    
    # 定义所有要测试的变换（类型：参数列表）
    robustness_tests = {
        "None": [],  # 无变换（基准）
        "Rotation (Y-axis)": [-90, 90, 180],  # 绕Y轴旋转角度
        "Translation (Z-axis)": [0.2, -0.2],  # 沿Z轴平移量
        "Scaling": [(0.5, 1.5), (0.6, 1.4), (0.7, 1.3)],  # 缩放范围
        "Jittering (σ)": [0.01, 0.02]  # 抖动标准差
    }
    
    # 存储所有结果：{变换类型: {参数: 准确率}}
    results = {}
    
    with torch.no_grad():
        for transform_type, params in robustness_tests.items():
            results[transform_type] = {}
            # 处理无变换的基准情况
            if transform_type == "None":
                acc = test_single_transform(base_model, test_dataloader, npoints, 
                                           transform_type=None, param=None)
                results[transform_type][""] = acc
                print_log(f"[Robustness] {transform_type}: acc = {acc:.4f}", logger=logger)
                continue
            
            # 处理每种变换的每个参数
            for param in params:
                acc = test_single_transform(base_model, test_dataloader, npoints,
                                           transform_type=transform_type, param=param)
                results[transform_type][param] = acc
                print_log(f"[Robustness] {transform_type} {param}: acc = {acc:.4f}", logger=logger)
    
    # 打印汇总表格（类似之前的实验表格）
    print_log("\n[Robustness Summary]", logger=logger)
    print_log("-"*80, logger=logger)
    # 打印表头
    header = "Method"
    for t in robustness_tests:
        header += f"\t{t}"
    print_log(header, logger=logger)
    # 打印结果行（假设模型名为当前模型）
    row = "Our Model"
    for t, ps in robustness_tests.items():
        if t == "None":
            row += f"\t{results[t]['']:.2f}"
        else:
            row += "\t" + "/".join([f"{results[t][p]:.2f}" for p in ps])
    print_log(row, logger=logger)
    print_log("-"*80, logger=logger)
    
    return results

def test_single_transform(base_model, test_dataloader, npoints, transform_type, param):
    """对单个变换参数计算准确率"""
    test_pred = []
    test_label = []
    
    for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
        points = data[0].cuda()  # (B, N, 3)
        label = data[1].cuda()
        
        # 采样点（保持与原有测试逻辑一致）
        points = misc.fps(points, npoints)  # 采样到npoints个点
        
        # 应用变换
        if transform_type == "Rotation (Y-axis)":
            points = rotate_y(points, angle=param)  # 绕Y轴旋转指定角度
        elif transform_type == "Translation (Z-axis)":
            points = translate_z(points, delta=param)  # 沿Z轴平移指定量
        elif transform_type == "Scaling":
            points = scale_points(points, scale_range=param)  # 缩放到指定范围
        elif transform_type == "Jittering (σ)":
            points = jitter_points(points, sigma=param)  # 添加指定σ的抖动
        # 无变换则不处理
        
        # 模型预测
        logits = base_model(points)
        pred = logits.argmax(-1).view(-1)
        target = label.view(-1)
        
        test_pred.append(pred.detach())
        test_label.append(target.detach())
    
    # 计算准确率
    test_pred = torch.cat(test_pred, dim=0)
    test_label = torch.cat(test_label, dim=0)
    acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
    return acc

def test_robustness_single_param(base_model, test_dataloader, args, config, 
                                target_transform, target_param, logger=None):
    """单独测试一个特定的变换参数"""
    base_model.eval()
    npoints = config.npoints
    
    # 存储结果
    results = {}
    
    with torch.no_grad():
        # 只处理目标变换和参数
        results[target_transform] = {}
        # 计算该参数下的准确率
        acc = test_single_transform(
            base_model, test_dataloader, npoints,
            transform_type=target_transform, 
            param=target_param
        )
        results[target_transform][target_param] = acc
        
        # 打印结果
        print(f"\n[Single Robustness Test]")
        print(f"Transform: {target_transform}, Param: {target_param}")
        print(f"Accuracy: {acc:.4f}%")
    
    return results

class Acc_Metric:
    def __init__(self, acc = 0.):
        if type(acc).__name__ == 'dict':
            self.acc = acc['acc']
        elif type(acc).__name__ == 'Acc_Metric':
            self.acc = acc.acc
        else:
            self.acc = acc

    def better_than(self, other):
        if self.acc > other.acc:
            return True
        else:
            return False

    def state_dict(self):
        _dict = dict()
        _dict['acc'] = self.acc
        return _dict
    
# def cal_params_flops(model):
#     # model.eval()
#     input_shape = tuple([2048, 3])
#     flops, params = get_model_complexity_info(model, input_shape)

#     split_line = '=' * 30
#     print(f'{split_line}\nInput shape: {input_shape}\n' 
#         f'Flops: {flops}\nParams: {params}\n{split_line}')
#     print('!!!Please be cautious if you use the results in papers. '
#         'You may need to check if all ops are supported and verify that the '
#         'flops computation is correct.')
    # model.train()

def cal_flops(model):
    model.eval()
    # input_shape = tuple([2048, 3])
    flops = FlopCountAnalysis(model, torch.rand(1, 2048, 3).cuda())
    print(flop_count_table(flops, max_depth=1))
    cnt = flops.total()
    print("[#FLOPs] cnt: ", cnt)
    return cnt
    # model.train()

def calculate_total_param(base_model):
    params = list(base_model.parameters())
    k = 0
    for i in params:
        l = 1
        # print("layer structure: " + str(list(i.size())))
        for j in i.size():
            l *= j
        # print("layer param: " + str(l))
        k = k + l
    print("##################TOTAL PARAMETER NUMBER: " + str(k))

# def run_net(args, config, train_writer=None, val_writer=None):
#     if config.dataset.train._base_.NAME == "ModelNet": # ModelNet
#         train_transforms = transforms.Compose([
#             # data_transforms.PointcloudRotate(),
#             data_transforms.PointcloudScaleAndTranslate(),
#         ])
#     else:
#         train_transforms = transforms.Compose([
#             data_transforms.PointcloudRotate(),
#             # data_transforms.PointcloudScaleAndTranslate(),
#         ])
#     print("train_transforms: ", train_transforms)
#     # 【修改：移除logger，直接用print】
#     # logger = get_logger(args.log_name)
#     # build dataset
#     (train_sampler, train_dataloader), (_, test_dataloader),= builder.dataset_builder(args, config.dataset.train), \
#                                                             builder.dataset_builder(args, config.dataset.val)
#     # build model
#     base_model = builder.model_builder(config.model)
#     num_trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
#     print(f"Total Number of trainable parameters: {num_trainable_params}")
    
#     # model param
#     calculate_total_param(base_model)
    
#     # parameter setting
#     start_epoch = 0
#     best_metrics = Acc_Metric(0.)
#     best_metrics_vote = Acc_Metric(0.)
#     metrics = Acc_Metric(0.)

#     # resume ckpts
#     if args.resume:
#         start_epoch, best_metric = builder.resume_model(base_model, args)  # 【修改：移除logger】
#         best_metrics = Acc_Metric(best_metrics)
#     else:
#         if args.ckpts is not None:
#             base_model.load_model_from_ckpt(args.ckpts)
#         else:
#             print('Training from scratch')  # 【修改：用print替代print_log】

#     if args.use_gpu:    
#         base_model.to(args.local_rank)
#     # DDP
#     if args.distributed:
#         # Sync BN
#         if args.sync_bn:
#             base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(base_model)
#             print('Using Synchronized BatchNorm ...')  # 【修改：用print替代print_log】
#         base_model = nn.parallel.DistributedDataParallel(base_model, device_ids=[args.local_rank % torch.cuda.device_count()])
#         print('Using Distributed Data parallel ...')  # 【修改：用print替代print_log】
#     else:
#         print('Using Data parallel ...')  # 【修改：用print替代print_log】
#         base_model = nn.DataParallel(base_model).cuda()
#     # optimizer & scheduler
#     optimizer, scheduler = builder.build_opti_sche(base_model, config)
    
#     if args.resume:
#         builder.resume_optimizer(optimizer, args)  # 【修改：移除logger】
        
#     # model statistics
#     flops_cnt = cal_flops(base_model) / 1000000000.
#     print(f'[FLOPs: %.3f G]' % flops_cnt)  # 【修改：用print替代print_log】
    
#     # trainval
#     base_model.zero_grad()
#     for epoch in range(start_epoch, config.max_epoch + 1):
#         if args.distributed:
#             train_sampler.set_epoch(epoch)
#         base_model.train()

#         epoch_start_time = time.time()
#         batch_start_time = time.time()
#         batch_time = AverageMeter()
#         data_time = AverageMeter()
#         losses = AverageMeter(['loss', 'acc'])
#         num_iter = 0
#         base_model.train()
#         n_batches = len(train_dataloader)

#         npoints = config.npoints
#         for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):
#             num_iter += 1
#             n_itr = epoch * n_batches + idx
            
#             data_time.update(time.time() - batch_start_time)
            
#             points = data[0].cuda()
#             label = data[1].cuda()

#             if npoints == 1024:
#                 point_all = 1200
#             elif npoints == 2048:
#                 point_all = 2400
#             elif npoints == 4096:
#                 point_all = 4800
#             elif npoints == 8192:
#                 point_all = 8192
#             else:
#                 raise NotImplementedError()

#             if points.size(1) < point_all:
#                 point_all = points.size(1)

#             fps_idx = pointnet2_utils.furthest_point_sample(points, point_all)
#             fps_idx = fps_idx[:, np.random.choice(point_all, npoints, False)]
#             points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
#             points = train_transforms(points)
            
#             ret = base_model(points)

#             loss, acc = base_model.module.get_loss_acc(ret, label)

#             _loss = loss
#             _loss.backward()

#             if num_iter == config.step_per_update:
#                 if config.get('grad_norm_clip') is not None:
#                     torch.nn.utils.clip_grad_norm_(base_model.parameters(), config.grad_norm_clip, norm_type=2)
#                 num_iter = 0
#                 optimizer.step()
#                 base_model.zero_grad()

#             if args.distributed:
#                 loss = dist_utils.reduce_tensor(loss, args)
#                 acc = dist_utils.reduce_tensor(acc, args)
#                 losses.update([loss.item(), acc.item()])
#             else:
#                 losses.update([loss.item(), acc.item()])

#             if args.distributed:
#                 torch.cuda.synchronize()

#             if train_writer is not None:
#                 train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)
#                 train_writer.add_scalar('Loss/Batch/TrainAcc', acc.item(), n_itr)
#                 train_writer.add_scalar('Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)

#             batch_time.update(time.time() - batch_start_time)
#             batch_start_time = time.time()

#         if isinstance(scheduler, list):
#             for item in scheduler:
#                 item.step(epoch)
#         else:
#             scheduler.step(epoch)
#         epoch_end_time = time.time()

#         if train_writer is not None:
#             train_writer.add_scalar('Loss/Epoch/Loss', losses.avg(0), epoch)

#         # 【修改：用print替代print_log，仅主进程打印】
#         if not args.distributed or args.local_rank == 0:
#             print(f'[Training] EPOCH: {epoch} EpochTime = {epoch_end_time - epoch_start_time:.3f} (s) Losses = {["%.4f" % l for l in losses.avg()]} lr = {optimizer.param_groups[0]["lr"]:.6f}', flush=True)
        
#         if epoch % args.val_freq == 0 and epoch != 0:
#             # 【修改：调用validate时指定单个鲁棒性参数（按需切换）】
#             # 示例1：测试“绕Y轴旋转90°”
#             metrics = validate(
#                 base_model, test_dataloader, epoch, val_writer, args, config,
#                 target_transform="Rotation (Y-axis)",  # 变换类型
#                 target_param=180  # 具体参数
#             )
            
#             # 示例2：测试“沿Z轴平移+0.2”（取消注释即可切换）
#             # metrics = validate(
#             #     base_model, test_dataloader, epoch, val_writer, args, config,
#             #     target_transform="Translation (Z-axis)",
#             #     target_param=0.2
#             # )

#             better = metrics.better_than(best_metrics)
#             if better:
#                 best_metrics = metrics
#                 builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args)  # 【修改：移除logger】
#                 # 【修改：用print替代print_log】
#                 if not args.distributed or args.local_rank == 0:
#                     print("--------------------------------------------------------------------------------------------", flush=True)

#             if args.vote:
#                 if metrics.acc > 92.5 or (better and metrics.acc > 92):
#                     metrics_vote = validate_vote(base_model, test_dataloader, epoch, val_writer, args, config)  # 【修改：移除logger】
#                     if metrics_vote.better_than(best_metrics_vote):
#                         best_metrics_vote = metrics_vote
#                         if not args.distributed or args.local_rank == 0:
#                             print("****************************************************************************************", flush=True)
#                             builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics_vote, 'ckpt-best_vote', args)  # 【修改：移除logger】

#         builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-last', args)  # 【修改：移除logger】
#         # 【修改：用print替代print_log】
#         if not args.distributed or args.local_rank == 0:
#             print(f'[BEST MODEL] acc = {best_metrics.acc:.6f}', flush=True)  
#             GB = 1024. * 1024. * 1024.
#             gpu_memory = torch.cuda.max_memory_allocated()/GB
#             res_gpu_memory = torch.cuda.max_memory_reserved()/GB
#             print(f'[GPU Mem] MEM = {gpu_memory:.3f} GB | Reserved MEM = {res_gpu_memory:.3f} GB', flush=True) 
    
#     calculate_total_param(base_model)
    
#     if train_writer is not None:
#         train_writer.close()
#     if val_writer is not None:
#         val_writer.close()

# # 【修改：validate函数支持指定单个鲁棒性参数】
# def validate(
#     base_model, test_dataloader, epoch, val_writer, args, config, 
#     target_transform=None,  # 【新增：指定变换类型】
#     target_param=None,      # 【新增：指定变换参数】
#     logger=None
#     ):
#     base_model.eval()
#     npoints = config.npoints
    
#     # 【新增：定义所有可能的变换】
#     all_transforms = {
#         "None": [],  # 基准（必跑）
#         "Rotation (Y-axis)": [-90, 90, 180],
#         "Translation (Z-axis)": [0.2, -0.2],
#         "Scaling": [(0.5, 1.5), (0.6, 1.4), (0.7, 1.3)],
#         "Jittering (σ)": [0.01, 0.02]
#     }
    
#     # 【新增：确定本次要运行的变换（基准+指定参数）】
#     runs = [("None", "")]  # 始终跑基准
#     if target_transform is not None and target_param is not None:
#         if target_transform not in all_transforms:
#             raise ValueError(f"不支持的变换类型：{target_transform}，可选：{list(all_transforms.keys())}")
#         if target_param not in all_transforms[target_transform] and target_transform != "None":
#             raise ValueError(f"不支持的参数：{target_param}，{target_transform}可选：{all_transforms[target_transform]}")
#         runs.append((target_transform, target_param))  # 追加指定变换
    
#     results = {}
    
#     with torch.no_grad():
#         for transform_type, param in runs:
#             test_pred = []
#             test_label = []
#             val_time = []
            
#             for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
#                 start_time = time.time()
#                 points = data[0].cuda()
#                 label = data[1].cuda()
#                 points = misc.fps(points, npoints)
                
#                 # 【新增：应用指定变换】
#                 if transform_type == "Rotation (Y-axis)":
#                     points = rotate_y(points, angle=param)
#                 elif transform_type == "Translation (Z-axis)":
#                     points = translate_z(points, delta=param)
#                 elif transform_type == "Scaling":
#                     points = scale_points(points, scale_range=param)
#                 elif transform_type == "Jittering (σ)":
#                     points = jitter_points(points, sigma=param)
#                 # 基准模式无变换
                
#                 logits = base_model(points)
#                 pred = logits.argmax(-1).view(-1)
#                 target = label.view(-1)
                
#                 test_pred.append(pred.detach())
#                 test_label.append(target.detach())
#                 val_time.append(time.time() - start_time)
            
#             # 计算准确率
#             test_pred = torch.cat(test_pred, dim=0)
#             test_label = torch.cat(test_label, dim=0)
#             if args.distributed:
#                 test_pred = dist_utils.gather_tensor(test_pred, args)
#                 test_label = dist_utils.gather_tensor(test_label, args)
#             acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
#             results[(transform_type, param)] = acc
            
#             # 【修改：用print输出结果，仅主进程打印】
#             if not args.distributed or args.local_rank == 0:
#                 if transform_type == "None":
#                     print(f'[Validation] EPOCH: {epoch} | 基准准确率 = {acc:.4f}% | 耗时 = {np.sum(val_time):.4f}s', flush=True)
#                 else:
#                     print(f'[Validation] EPOCH: {epoch} | {transform_type} {param} | 准确率 = {acc:.4f}% | 耗时 = {np.sum(val_time):.4f}s', flush=True)
    
#     if val_writer is not None:
#         val_writer.add_scalar('Metric/ACC', results[("None", "")], epoch)
    
#     return Acc_Metric(results[("None", "")])


# def validate_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger = None, times = 10):
#     if config.dataset.train._base_.NAME == "ModelNet":
#         test_transforms = transforms.Compose([
#             # data_transforms.PointcloudRotate(),
#             data_transforms.PointcloudScaleAndTranslate(),
#         ])
#     else:
#         test_transforms = transforms.Compose([
#             data_transforms.PointcloudRotate(),
#             # data_transforms.PointcloudScaleAndTranslate(),
#         ])
#     print("val_vote test_transforms: ", test_transforms)  # 【修改：用print替代print_log】
#     base_model.eval()

#     test_pred  = []
#     test_label = []
#     npoints = config.npoints
#     with torch.no_grad():
#         for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
#             points_raw = data[0].cuda()
#             label = data[1].cuda()
#             if npoints == 1024:
#                 point_all = 1200
#             elif npoints == 2048:
#                 point_all = 2400
#             elif npoints == 4096:
#                 point_all = 4800
#             elif npoints == 8192:
#                 point_all = 8192
#             else:
#                 raise NotImplementedError()
                
#             if points_raw.size(1) < point_all:
#                 point_all = points_raw.size(1)

#             fps_idx_raw = pointnet2_utils.furthest_point_sample(points_raw, point_all)
#             local_pred = []

#             for kk in range(times):
#                 fps_idx = fps_idx_raw[:, np.random.choice(point_all, npoints, False)]
#                 points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(), 
#                                                         fps_idx).transpose(1, 2).contiguous()
#                 points = test_transforms(points)
#                 logits = base_model(points)
#                 target = label.view(-1)
#                 local_pred.append(logits.detach().unsqueeze(0))

#             pred = torch.cat(local_pred, dim=0).mean(0)
#             _, pred_choice = torch.max(pred, -1)

#             test_pred.append(pred_choice)
#             test_label.append(target.detach())

#         test_pred = torch.cat(test_pred, dim=0)
#         test_label = torch.cat(test_label, dim=0)

#         if args.distributed:
#             test_pred = dist_utils.gather_tensor(test_pred, args)
#             test_label = dist_utils.gather_tensor(test_label, args)

#         acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
#         # 【修改：用print替代print_log】
#         if not args.distributed or args.local_rank == 0:
#             print(f'[Validation_vote] EPOCH: {epoch}  acc_vote = {acc:.4f}', flush=True)
#             GB = 1024. * 1024. * 1024.
#             val_gpu_memory = torch.cuda.max_memory_allocated() / GB
#             print(f'[Val GPU Mem] MEM = {val_gpu_memory:.3f} GB', flush=True)

#         if args.distributed:
#             torch.cuda.synchronize()

#     if val_writer is not None:
#         val_writer.add_scalar('Metric/ACC_vote', acc, epoch)

#     return Acc_Metric(acc)
            

def run_net(args, config, train_writer=None, val_writer=None):
    if config.dataset.train._base_.NAME == "ModelNet": # ModelNet
        train_transforms = transforms.Compose([
            # data_transforms.PointcloudRotate(),
            data_transforms.PointcloudScaleAndTranslate(),
        ])
    else:
        train_transforms = transforms.Compose([
            data_transforms.PointcloudRotate(),
            # data_transforms.PointcloudScaleAndTranslate(),
        ])
    print("train_transforms: ", train_transforms)
    logger = get_logger(args.log_name)
    # build dataset
    (train_sampler, train_dataloader), (_, test_dataloader),= builder.dataset_builder(args, config.dataset.train), \
                                                            builder.dataset_builder(args, config.dataset.val)
    # build model
    base_model = builder.model_builder(config.model)
    num_trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
    print(f"Total Number of trainable parameters: {num_trainable_params}")
    
    #model param
    calculate_total_param(base_model)
    
    # parameter setting
    start_epoch = 0
    best_metrics = Acc_Metric(0.)
    best_metrics_vote = Acc_Metric(0.)
    metrics = Acc_Metric(0.)

    # resume ckpts
    if args.resume:
        start_epoch, best_metric = builder.resume_model(base_model, args, logger = logger)
        best_metrics = Acc_Metric(best_metrics)
    else:
        if args.ckpts is not None:
            base_model.load_model_from_ckpt(args.ckpts)
        else:
            print_log('Training from scratch', logger = logger)

    if args.use_gpu:    
        base_model.to(args.local_rank)
    # DDP
    if args.distributed:
        # Sync BN
        if args.sync_bn:
            base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(base_model)
            print_log('Using Synchronized BatchNorm ...', logger = logger)
        base_model = nn.parallel.DistributedDataParallel(base_model, device_ids=[args.local_rank % torch.cuda.device_count()])
        print_log('Using Distributed Data parallel ...' , logger = logger)
    else:
        print_log('Using Data parallel ...' , logger = logger)
        base_model = nn.DataParallel(base_model).cuda()
    # optimizer & scheduler
    optimizer, scheduler = builder.build_opti_sche(base_model, config)
    
    if args.resume:
        builder.resume_optimizer(optimizer, args, logger = logger)
        
    # model statistics
    # cal_params_flops(base_model)
    flops_cnt = cal_flops(base_model) / 1000000000.
    print_log('[FLOPs: %.3f G]' % flops_cnt, logger = logger)
    
    

    # trainval
    # training
    base_model.zero_grad()
    for epoch in range(start_epoch, config.max_epoch + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        base_model.train()

        epoch_start_time = time.time()
        batch_start_time = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['loss', 'acc'])
        num_iter = 0
        base_model.train()  # set model to training mode
        n_batches = len(train_dataloader)

        npoints = config.npoints
        for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):
            num_iter += 1
            n_itr = epoch * n_batches + idx
            
            data_time.update(time.time() - batch_start_time)
            
            points = data[0].cuda()
            label = data[1].cuda()

            if npoints == 1024:
                point_all = 1200
            elif npoints == 2048:
                point_all = 2400
            elif npoints == 4096:
                point_all = 4800
            elif npoints == 8192:
                point_all = 8192
            else:
                raise NotImplementedError()

            if points.size(1) < point_all:
                point_all = points.size(1)

            fps_idx = pointnet2_utils.furthest_point_sample(points, point_all)  # (B, npoint)
            fps_idx = fps_idx[:, np.random.choice(point_all, npoints, False)]
            points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
            # import pdb; pdb.set_trace()
            points = train_transforms(points)
            # print(points.shape)
            
            ret = base_model(points)

            loss, acc = base_model.module.get_loss_acc(ret, label)

            _loss = loss

            _loss.backward()

            # forward
            if num_iter == config.step_per_update:
                if config.get('grad_norm_clip') is not None:
                    torch.nn.utils.clip_grad_norm_(base_model.parameters(), config.grad_norm_clip, norm_type=2)
                num_iter = 0
                optimizer.step()
                base_model.zero_grad()

            if args.distributed:
                loss = dist_utils.reduce_tensor(loss, args)
                acc = dist_utils.reduce_tensor(acc, args)
                losses.update([loss.item(), acc.item()])
            else:
                losses.update([loss.item(), acc.item()])


            if args.distributed:
                torch.cuda.synchronize()


            if train_writer is not None:
                train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)
                train_writer.add_scalar('Loss/Batch/TrainAcc', acc.item(), n_itr)
                train_writer.add_scalar('Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)


            batch_time.update(time.time() - batch_start_time)
            batch_start_time = time.time()

            # if idx % 10 == 0:
            #     print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Loss+Acc = %s lr = %.6f' %
            #                 (epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
            #                 ['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger = logger)
        if isinstance(scheduler, list):
            for item in scheduler:
                item.step(epoch)
        else:
            scheduler.step(epoch)
        epoch_end_time = time.time()

        if train_writer is not None:
            train_writer.add_scalar('Loss/Epoch/Loss', losses.avg(0), epoch)

        print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
            (epoch,  epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],optimizer.param_groups[0]['lr']), logger = logger)
        if epoch % args.val_freq == 0 and epoch != 0:
            # Validate the current model
            metrics = validate(base_model, test_dataloader, epoch, val_writer, args, config, logger=logger)

            better = metrics.better_than(best_metrics)
            # Save ckeckpoints
            if better:
                best_metrics = metrics
                builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args, logger = logger)
                print_log("--------------------------------------------------------------------------------------------", logger=logger)
                # if metrics.acc > 92:
                #     builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-nice-%.6f' % (metrics.acc), args, logger = logger)
                #     print_log("---------------------------------***********----------------------------------------", logger=logger)
            if args.vote:
                if metrics.acc > 92.5 or (better and metrics.acc > 92):
                    metrics_vote = validate_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger=logger)
                    if metrics_vote.better_than(best_metrics_vote):
                        best_metrics_vote = metrics_vote
                        print_log(
                            "****************************************************************************************",
                            logger=logger)
                        builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics_vote, 'ckpt-best_vote', args, logger = logger)

        builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-last', args, logger = logger)
        print_log('[BEST MODEL] acc = %.6f' % (best_metrics.acc), logger=logger)  
        GB = 1024. * 1024. * 1024.
        gpu_memory = torch.cuda.max_memory_allocated()/GB
        res_gpu_memory = torch.cuda.max_memory_reserved()/GB
        print_log('[GPU Mem] MEM = %.3f GB | Reserved MEM = %.3f GB' % (gpu_memory, res_gpu_memory), logger=logger) 
        # if (config.max_epoch - epoch) < 10:
        #     builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args, logger = logger)
    #model param
    calculate_total_param(base_model)
    
    if train_writer is not None:
        train_writer.close()
    if val_writer is not None:
        val_writer.close()

def validate(base_model, test_dataloader, epoch, val_writer, args, config, logger = None):
    # print_log(f"[VALIDATION] Start validating epoch {epoch}", logger = logger)
    base_model.eval()  # set model to eval mode

    test_pred  = []
    test_label = []
    npoints = config.npoints
    with torch.no_grad():
        val_time = []
        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
            val_start_time = time.time()
            points = data[0].cuda()
            label = data[1].cuda()

            points = misc.fps(points, npoints) # 64 2048 3

            # points = test_transforms(points) 

            logits = base_model(points)
            # print("Val input points shape: ",points.shape)
            
            
            target = label.view(-1)

            pred = logits.argmax(-1).view(-1)

            test_pred.append(pred.detach())
            test_label.append(target.detach())
            val_end_time = time.time()
            val_time.append(val_end_time-val_start_time)

        test_pred = torch.cat(test_pred, dim=0)
        test_label = torch.cat(test_label, dim=0)

        if args.distributed:
            test_pred = dist_utils.gather_tensor(test_pred, args)
            test_label = dist_utils.gather_tensor(test_label, args)

        acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
        print_log('[Validation] EPOCH: %d  acc = %.4f' % (epoch, acc), logger=logger)
        print_log('[Validation] EPOCH: %d  total time = %.4f  batch avg time = %.4f' % (epoch, np.sum(val_time), np.mean(val_time)), logger=logger)

        if args.distributed:
            torch.cuda.synchronize()

        GB = 1024. * 1024. * 1024.
        val_gpu_memory = torch.cuda.max_memory_allocated() / GB
        res_gpu_memory = torch.cuda.max_memory_reserved() / GB
        print_log('[Val GPU Mem] MEM = %.3f GB | Reserved MEM = %.3f GB' % (val_gpu_memory, res_gpu_memory), logger=logger) 
        
        
        # dump_input = torch.ones(64, 2048, 3).cuda()
        # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=False) as prof:
        #     outputs = base_model(dump_input)
        # print(prof.table())
        # prof.export_chrome_trace('./mamba_profile.json')

    # Add testing results to TensorBoard
    if val_writer is not None:
        val_writer.add_scalar('Metric/ACC', acc, epoch)

    return Acc_Metric(acc)


def validate_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger = None, times = 10):
    if config.dataset.train._base_.NAME == "ModelNet": # ModelNet
        test_transforms = transforms.Compose([
            # data_transforms.PointcloudRotate(),
            data_transforms.PointcloudScaleAndTranslate(),
        ])
    else:
        test_transforms = transforms.Compose([
            data_transforms.PointcloudRotate(),
            # data_transforms.PointcloudScaleAndTranslate(),
        ])
    print("val_vote test_transforms: ", test_transforms)
    print_log(f"[VALIDATION_VOTE] epoch {epoch}", logger = logger)
    base_model.eval()  # set model to eval mode

    test_pred  = []
    test_label = []
    npoints = config.npoints
    with torch.no_grad():
        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
            points_raw = data[0].cuda()
            label = data[1].cuda()
            if npoints == 1024:
                point_all = 1200
            elif npoints == 2048:
                point_all = 2400
            elif npoints == 4096:
                point_all = 4800
            elif npoints == 8192:
                point_all = 8192
            else:
                raise NotImplementedError()
                
            if points_raw.size(1) < point_all:
                point_all = points_raw.size(1)

            fps_idx_raw = pointnet2_utils.furthest_point_sample(points_raw, point_all)  # (B, npoint)
            local_pred = []

            for kk in range(times):
                fps_idx = fps_idx_raw[:, np.random.choice(point_all, npoints, False)]
                points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(), 
                                                        fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)

                points = test_transforms(points)

                logits = base_model(points)
                target = label.view(-1)

                local_pred.append(logits.detach().unsqueeze(0))

            pred = torch.cat(local_pred, dim=0).mean(0)
            _, pred_choice = torch.max(pred, -1)


            test_pred.append(pred_choice)
            test_label.append(target.detach())

        test_pred = torch.cat(test_pred, dim=0)
        test_label = torch.cat(test_label, dim=0)

        if args.distributed:
            test_pred = dist_utils.gather_tensor(test_pred, args)
            test_label = dist_utils.gather_tensor(test_label, args)

        acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
        print_log('[Validation_vote] EPOCH: %d  acc_vote = %.4f' % (epoch, acc), logger=logger)

        GB = 1024. * 1024. * 1024.
        val_gpu_memory = torch.cuda.max_memory_allocated() / GB
        print_log('[Val GPU Mem] MEM = %.3f GB' % val_gpu_memory, logger=logger) 

        if args.distributed:
            torch.cuda.synchronize()

    # Add testing results to TensorBoard
    if val_writer is not None:
        val_writer.add_scalar('Metric/ACC_vote', acc, epoch)

    return Acc_Metric(acc)



def test_net(args, config):
    logger = get_logger(args.log_name)
    print_log('Tester start ... ', logger = logger)
    _, test_dataloader = builder.dataset_builder(args, config.dataset.test)
    base_model = builder.model_builder(config.model)
    # load checkpoints
    builder.load_model(base_model, args.ckpts, logger = logger) # for finetuned transformer
    # base_model.load_model_from_ckpt(args.ckpts) # for BERT
    if args.use_gpu:
        base_model.to(args.local_rank)

    #  DDP    
    if args.distributed:
        raise NotImplementedError()
     
    test(base_model, test_dataloader, args, config, logger=logger)

    # 新增鲁棒性测试
    # print_log("\nStarting robustness tests...", logger=logger)
    # test_robustness(base_model, test_dataloader, args, config, logger=logger)
    test_robustness_single_param(
        base_model, test_dataloader, args, config,
        target_transform="Rotation (Y-axis)",  # 指定变换类型
        target_param=90,  # 指定参数
        logger=logger
    )
    
def test(base_model, test_dataloader, args, config, logger = None):

    base_model.eval()  # set model to eval mode

    test_pred  = []
    test_label = []
    npoints = config.npoints

    with torch.no_grad():
        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
            points = data[0].cuda()
            label = data[1].cuda()

            points = misc.fps(points, npoints)

            logits = base_model(points)
            target = label.view(-1)

            pred = logits.argmax(-1).view(-1)

            test_pred.append(pred.detach())
            test_label.append(target.detach())


        test_pred = torch.cat(test_pred, dim=0)
        test_label = torch.cat(test_label, dim=0)

        if args.distributed:
            test_pred = dist_utils.gather_tensor(test_pred, args)
            test_label = dist_utils.gather_tensor(test_label, args)

        acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.
        print_log('[TEST] acc = %.4f' % acc, logger=logger)
        
        GB = 1024. * 1024. * 1024.
        gpu_memory = torch.cuda.max_memory_allocated() / GB
        print_log('[GPU Mem] MEM = %.3f GB' % gpu_memory, logger=logger) 

        if args.distributed:
            torch.cuda.synchronize()

        print_log(f"[TEST_VOTE]", logger = logger)
        acc = 0.
        for time in range(1, 300): # 300
            this_acc = test_vote(base_model, test_dataloader, 1, None, args, config, logger=logger, times=10)
            if acc < this_acc:
                acc = this_acc
            print_log('[TEST_VOTE_time %d]  acc = %.4f, best acc = %.4f' % (time, this_acc, acc), logger=logger)
        print_log('[TEST_VOTE] acc = %.4f' % acc, logger=logger)

def test_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger = None, times = 10):
    if config.dataset.train._base_.NAME == "ModelNet": # ModelNet
        test_transforms = transforms.Compose([
            # data_transforms.PointcloudRotate(),
            data_transforms.PointcloudScaleAndTranslate(),
        ])
    else:
        test_transforms = transforms.Compose([
            data_transforms.PointcloudRotate(),
            # data_transforms.PointcloudScaleAndTranslate(),
        ])
    print("test_vote test_transforms: ", test_transforms)
    base_model.eval()  # set model to eval mode

    test_pred  = []
    test_label = []
    npoints = config.npoints
    with torch.no_grad():
        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
            points_raw = data[0].cuda()
            label = data[1].cuda()
            if npoints == 1024:
                point_all = 1200
            elif npoints == 2048:
                point_all = 2400
            elif npoints == 4096:
                point_all = 4800
            elif npoints == 8192:
                point_all = 8192
            else:
                raise NotImplementedError()
                
            if points_raw.size(1) < point_all:
                point_all = points_raw.size(1)

            fps_idx_raw = pointnet2_utils.furthest_point_sample(points_raw, point_all)  # (B, npoint)
            local_pred = []

            for kk in range(times):
                fps_idx = fps_idx_raw[:, np.random.choice(point_all, npoints, False)]
                points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(), 
                                                        fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)

                points = test_transforms(points)

                logits = base_model(points)
                target = label.view(-1)

                local_pred.append(logits.detach().unsqueeze(0))

            pred = torch.cat(local_pred, dim=0).mean(0)
            _, pred_choice = torch.max(pred, -1)


            test_pred.append(pred_choice)
            test_label.append(target.detach())

        test_pred = torch.cat(test_pred, dim=0)
        test_label = torch.cat(test_label, dim=0)

        if args.distributed:
            test_pred = dist_utils.gather_tensor(test_pred, args)
            test_label = dist_utils.gather_tensor(test_label, args)

        acc = (test_pred == test_label).sum() / float(test_label.size(0)) * 100.

        if args.distributed:
            torch.cuda.synchronize()

        GB = 1024. * 1024. * 1024.
        gpu_memory_vote = torch.cuda.max_memory_allocated() / GB
        print_log('[Vote GPU Mem] MEM = %.3f GB' % gpu_memory_vote, logger=logger) 

    # Add testing results to TensorBoard
    if val_writer is not None:
        val_writer.add_scalar('Metric/ACC_vote', acc, epoch)
    # print_log('[TEST] acc = %.4f' % acc, logger=logger)
    
    return acc
