import numpy as np
import os
import zipfile
from pointnet2_ops import pointnet2_utils
import mcubes
import trimesh
import torch
import gc
from tqdm import tqdm
from scipy.spatial import cKDTree as KDTree
from models.misc import MetricLogger
from typing import List, Tuple
import random
import traceback
import math


# transform npz to ply (point cloud) without color
def npz_to_ply(npz_file, ply_file):
    if not os.path.exists(npz_file):
        raise FileNotFoundError(f"NPZ file not found: {npz_file}")

    if not zipfile.is_zipfile(npz_file):
        raise ValueError(f"File is not a valid zip file: {npz_file}")
    data = np.load(npz_file)
    points = data['points'].astype(np.float32)
    with open(ply_file, 'w') as f:
        f.write('ply\n')
        f.write('format ascii 1.0\n')
        f.write('element vertex %d\n' % points.shape[0])
        f.write('property float x\n')
        f.write('property float y\n')
        f.write('property float z\n')
        f.write('end_header\n')
        for point in points:
            f.write('%f %f %f\n' % (point[0], point[1], point[2]))

def npz_to_pointcloud(npz_file: str):
    data = np.load(npz_file)

    if 'points' not in data:
        print(f"Warning: 'points' key not found in {npz_file}, using default values")
    if 'normals' not in data:
        # print(f"Warning: 'normals' key not found in {npz_file}, using default values")
        pass

    points = data.get('points', np.zeros((1024, 3), dtype=np.float32))  
    normals = data.get('normals', np.zeros((1024, 3), dtype=np.float32))  

    points = np.asarray(points, dtype=np.float32)
    normals = np.asarray(normals, dtype=np.float32)

    points = torch.from_numpy(points).float().cuda()
    normals = torch.from_numpy(normals)

    return points, normals




def sample_pointcloud(points: torch.Tensor, num_samples: int, sampling_method: str = 'fps'):
    """
    sample the point cloud with the given method (fps or uniform)
    """
    if sampling_method == 'fps':
        ind = pointnet2_utils.furthest_point_sample(points.unsqueeze(0), num_samples)
        points = points[ind.squeeze(0)]
    elif sampling_method == 'uniform':
        ind = np.random.default_rng().choice(points.shape[0], num_samples, replace=False)
        points = points[ind]
    else:
        raise ValueError(f"Invalid sampling method: {sampling_method}")

    return points

def preprocess_pointcloud(points: torch.Tensor):
    # normalize the point cloud to [-0.5, 0.5]
    shape_scale = torch.max(torch.tensor([
        torch.max(points[:, 0]) - torch.min(points[:, 0]),
        torch.max(points[:, 1]) - torch.min(points[:, 1]),
        torch.max(points[:, 2]) - torch.min(points[:, 2])
    ]))
    shape_center = torch.tensor([
        (torch.max(points[:, 0]) + torch.min(points[:, 0])) / 2,
        (torch.max(points[:, 1]) + torch.min(points[:, 1])) / 2,
        (torch.max(points[:, 2]) + torch.min(points[:, 2])) / 2
    ])
    points = (points.cpu() - shape_center) / shape_scale
    return points

def npz_to_pointcloud_single(npz_file, num_samples=2048, sampling_method='fps', add_noise=False, noise_std=0.01):
    """
    npz_file: the path of the npz file
    num_samples: the number of samples to sample from the point cloud
    sampling_method: the method to sample the point cloud, default is fps (farthest point sampling)
    or uniform sampling
    """
    points = npz_to_pointcloud(npz_file)

    points = sample_pointcloud(points, num_samples, sampling_method)

    if add_noise:
        noise = torch.randn_like(points) * noise_std
        points = points + noise

    points = preprocess_pointcloud(points)

    return points

# laplace noise
def laplace_noise(shape, target_std):
    b = target_std / math.sqrt(2.0)
    u = torch.rand(shape) - 0.5          # U(-0.5,0.5)
    return -b * torch.sign(u) * torch.log1p(-2 * u.abs())


def discrete_noise(points, epsilon=0.01):
    choices = torch.tensor([-epsilon, 0.0, epsilon], device=points.device)
    indices = torch.randint(0, 3, points.shape, device=points.device)
    return choices[indices]


# dont need to normalize the point cloud maybe
# cause the point cloud is already normalized to [-0.5 + padding, 0.5 - padding]^3
def npz_to_pointcloud_noise2noise(npz_file: str, num_samples: int = 2048, sampling_method: str = 'fps', noise_std: float = 0.005, scale: float = 1.0, noise_type: str = "gaussian", noise_mean: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    This is for the noise2noise dataset getting 2 noise pair of one gt point cloud
    npz_file: the path of the npz file
    num_samples: the number of samples to sample from the point cloud
    sampling_method: the method to sample the point cloud, default is fps (farthest point sampling)
    or uniform sampling
    noise_std: the standard deviation of the noise
    noise_mean: the mean of the noise
    """
    # load the gt point cloud
    surface, normals = npz_to_pointcloud(npz_file)
    # scale to [-1, 1]^3
    surface = surface * scale


    # sample the point cloud
    points: torch.Tensor = sample_pointcloud(surface, num_samples, sampling_method).cpu()
    # Create mean bias tensor
    mean_bias = torch.full_like(points, noise_mean)

    # the noise is added in the normlization space [-0.5 + padding, 0.5 - padding]^3
    if noise_type == "gaussian":
        noise1: torch.Tensor = torch.randn_like(points) * noise_std + mean_bias
        noise2: torch.Tensor = torch.randn_like(points) * noise_std + mean_bias
    # 修改后的均匀噪声实现（正确标准差）
    elif noise_type == "uniform":
        a = noise_std * math.sqrt(3.0)                  # 使 std = noise_std
        gen = lambda: (torch.rand_like(points) - 0.5) * 2 * a + mean_bias
        noise1 = gen()
        noise2 = gen()
    elif noise_type == "discrete":
        epsilon = noise_std * (3 ** 0.5 / 2 ** 0.5)  # 调整ε使标准差=noise_std
        choices = torch.tensor([-epsilon, 0.0, epsilon], device=points.device)
        indices = torch.randint(0, 3, points.shape, device=points.device)
        noise1 = choices[indices] + mean_bias
        noise2 = choices[indices] + mean_bias
    elif noise_type == "laplace":
        noise1 = laplace_noise(points.shape, noise_std) + mean_bias
        noise2 = laplace_noise(points.shape, noise_std) + mean_bias
    else:
        raise ValueError(f"Invalid noise type: {noise_type}")


    # the original points are already normalized to [-1 - padding, 1 + padding]^3
    gt_points: torch.Tensor = points
    noise1_points: torch.Tensor = gt_points + noise1
    noise2_points: torch.Tensor = gt_points + noise2
    surface = surface

    # the shape of the return is (N, 3)
    return gt_points.cpu(), noise1_points.cpu(), noise2_points.cpu(), surface.cpu(), normals.cpu()


def get_grid_aligned(density):
    
    x = np.linspace(-1, 1, density + 1)
    y = np.linspace(-1, 1, density + 1)
    z = np.linspace(-1, 1, density + 1)
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None]
    return grid

def matching_cube_batch_aligned(model: torch.nn.Module,
                       device: torch.device,
                       surface: torch.Tensor,
                       density: int = 256,
                       batch_size: int = 256*256*5) -> trimesh.Trimesh:
    
    model.eval()  

    grid = get_grid_aligned(density)
    gap = 2.0 / density

    with torch.no_grad():
        
        surface = surface.to(device, non_blocking=True)
        surface = surface.reshape(1, -1, 3)
        
        grid_np = grid.cpu().numpy().reshape(-1, 3)  # [N, 3]
        total_points = grid_np.shape[0]  # N = (density + 1)^3
        
        output_logits = np.zeros(total_points, dtype=np.float32)

        
        for start_idx in tqdm(range(0, total_points, batch_size), desc="Processing Batches"):
            end_idx = min(start_idx + batch_size, total_points)
            
            sub_grid = grid_np[start_idx:end_idx]  # [current_batch_size, 3]
            sub_grid_tensor = torch.from_numpy(sub_grid.astype(np.float32)).view(1, -1, 3).to(device)
            output = model(surface, sub_grid_tensor)['logits']  # [1, current_batch_size]
            output_np = output.cpu().numpy().reshape(-1)  # [current_batch_size]
            output_logits[start_idx:end_idx] = output_np

            del sub_grid_tensor, output, output_np
            torch.cuda.empty_cache()
            gc.collect()

        full_logits = output_logits.reshape(density + 1, density + 1, density + 1)
        full_logits = full_logits.transpose(1, 0, 2)

        verts, faces = mcubes.marching_cubes(full_logits, 0)

        verts *= gap
        verts -= 1  

        mesh = trimesh.Trimesh(vertices=verts, faces=faces)

    return mesh



def get_grid(density):
    """
    This is for the grid query points to matching_cubes
    """
    x = np.linspace(-1, 1, density + 1)
    y = np.linspace(-1, 1, density + 1)
    z = np.linspace(-1, 1, density + 1)
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None]

    return grid



def matching_cube_batch(model: torch.nn.Module,
                       device: torch.device,
                       surface: torch.Tensor,
                       density: int = 128,
                       batch_size: int = 128*128*128) -> trimesh.Trimesh:
    model.eval()  


    # generate the grid
    grid = get_grid(density)
    gap = 2.0 / density


    with torch.no_grad():
        surface = surface.to(device, non_blocking=True)
        surface = surface.reshape(1, -1, 3)

        grid_np = grid.cpu().numpy().reshape(-1, 3)  # [N, 3]
        total_points = grid_np.shape[0]  # N = (density + 1)^3

        full_logits = np.zeros((density + 1, density + 1, density + 1), dtype=np.float32)

        for start_idx in range(0, total_points, batch_size):
            end_idx = min(start_idx + batch_size, total_points)

            sub_grid = grid_np[start_idx:end_idx]  # [current_batch_size, 3]

            sub_grid_tensor = torch.from_numpy(sub_grid.astype(np.float32)).view(1, -1, 3).to(device)

            output = model(surface, sub_grid_tensor)['logits']  # [1, current_batch_size]

            output_np = output.cpu().numpy().reshape(-1)  # [current_batch_size]

          
            grid_indices = np.arange(start_idx, end_idx)
            z = grid_indices // ((density + 1) * (density + 1))
            remainder = grid_indices % ((density + 1) * (density + 1))
            y = remainder // (density + 1)
            x = remainder % (density + 1)

            full_logits[z, y, x] = output_np

            del sub_grid_tensor, output, output_np
            torch.cuda.empty_cache()
            gc.collect()

        
        # print("Applying Marching Cubes...")
        verts, faces = mcubes.marching_cubes(full_logits, 0)

        verts *= gap
        verts -= 1.0  

        mesh = trimesh.Trimesh(vertices=verts, faces=faces)

    return mesh



def validate_tensors_require_grad(*tensors):
        for i, t in enumerate(tensors):
            if not t.requires_grad:
                raise RuntimeError(f"shape={t.shape}, dtype={t.dtype}, device={t.device}")
            if t.grad_fn is None:
                pass


def evaluate_model_v2(fix_model: torch.nn.Module,
                   fine_tune_model: torch.nn.Module,
                   device: torch.device,
                   dataset: torch.utils.data.Dataset,
                   save_path: str = None,
                   use_train_data: bool = False,
                   use_one_shape_only: bool = False,
                   num_samples: int = None):
   

    original_state = random.getstate()
    original_np_state = np.random.get_state()
    original_torch_state = torch.get_rng_state()

    eval_seed = 42
    torch.manual_seed(eval_seed)
    torch.cuda.manual_seed_all(eval_seed)
    np.random.seed(eval_seed)
    random.seed(eval_seed)
    torch.backends.cudnn.deterministic = True

    fix_model.eval()
    fine_tune_model.eval()

    original_fix_requires_grad = [p.requires_grad for p in fix_model.parameters()]
    original_finetune_requires_grad = [p.requires_grad for p in fine_tune_model.parameters()]

    for p in fix_model.parameters():
        p.requires_grad = True
    for p in fine_tune_model.parameters():
        p.requires_grad = True




    metric_logger = MetricLogger()  
    noise_metrics = MetricLogger()  
    gt_metrics = MetricLogger()  

    val_model_list = dataset.val_model_list if not use_train_data else dataset.train_model_list
    if use_one_shape_only:
        val_model_list= val_model_list[:1]

    indices = list(range(len(val_model_list)))
    if num_samples is not None and 0 < num_samples < len(indices):
        indices = random.sample(indices, num_samples)

    mesh_results = []

    mesh_indices = random.sample(indices, 10)
    cd_list = []
    for i in tqdm(indices, desc='processing'):
        try:
            gt_points, noise1_points, original_queries, surface, labels, normals = dataset.get_eval_item(i, use_train_data, use_one_shape_only)


            gt_points = gt_points.unsqueeze(0).to(device)
            noise1_points = noise1_points.unsqueeze(0).to(device)

            with torch.no_grad():
                queries_no_grad = original_queries.clone().unsqueeze(0).to(device)

                gt_logits = fix_model(gt_points, queries_no_grad)['logits']
                gt_logits_2 = fix_model(gt_points, queries_no_grad)['logits']
                noise_logits = fix_model(noise1_points, queries_no_grad)['logits']
                noise1_logits = fine_tune_model(noise1_points, queries_no_grad)['logits']

                gt_mse = torch.mean((gt_logits - gt_logits_2) ** 2).item()
                gt_mae = torch.mean(torch.abs(gt_logits - gt_logits_2)).item()
                noise_mse = torch.mean((noise_logits - gt_logits) ** 2).item()
                noise_mae = torch.mean(torch.abs(noise_logits - gt_logits)).item()
                denoised_mse = torch.mean((noise1_logits - gt_logits) ** 2).item()
                denoised_mae = torch.mean(torch.abs(noise1_logits - gt_logits)).item()


                labels = labels.to(device)

                noise_mask = (noise_logits >= 0.0).float().squeeze(0)
                denoised_mask = (noise1_logits >= 0.0).float().squeeze(0)
                gt_mask = (gt_logits >= 0.0).float().squeeze(0)

                noise_intersection = (noise_mask * labels).sum()
                noise_union = (noise_mask + labels).gt(0).sum()
                noise_iou = noise_intersection * 1.0 / noise_union

                denoised_intersection = (denoised_mask * labels).sum()
                denoised_union = (denoised_mask + labels).gt(0).sum()
                denoised_iou = denoised_intersection * 1.0 / denoised_union

                gt_intersection = (gt_mask * labels).sum()
                gt_union = (gt_mask + labels).gt(0).sum()
                gt_iou = gt_intersection * 1.0 / gt_union


                if i in mesh_indices:

                    gt_mesh = matching_cube_batch(fix_model, device, gt_points, density=128, batch_size=128*128*50)
                    denoised_mesh = matching_cube_batch(fine_tune_model, device, noise1_points, density=128, batch_size=128*128*50)
                    noise_mesh = matching_cube_batch(fix_model, device, noise1_points, density=128, batch_size=128*128*50)

                    mesh_results.append({
                        'index': i,
                        'gt_mesh': gt_mesh,
                        'denoised_mesh': denoised_mesh,
                        'noise_mesh': noise_mesh
                    })

                    pred = denoised_mesh.sample(100000)
                    tree = KDTree(pred)
                    dist, _ = tree.query(surface.cpu().numpy())
                    d1 = dist
                    gt_to_gen_chamfer = np.mean(dist)

                    tree = KDTree(surface.cpu().numpy())
                    dist, _ = tree.query(pred)
                    d2 = dist
                    gen_to_gt_chamfer = np.mean(dist)

                    cd = gt_to_gen_chamfer + gen_to_gt_chamfer
                    cd_list.append(cd)




           
            try:
                with torch.enable_grad():
                    queries_gt_compare = surface.clone().requires_grad_(True).unsqueeze(0).to(device)
                    validate_tensors_require_grad(queries_gt_compare)

                    gt_output_compare = fix_model(gt_points, queries_gt_compare )['logits']

                    if not gt_output_compare.requires_grad:
                        gt_output_compare = gt_output_compare.detach().clone().requires_grad_(True)

                    gt_grad_compare = torch.autograd.grad(
                        gt_output_compare.sum(),
                        queries_gt_compare,
                        create_graph=False,
                        retain_graph=False
                    )[0]

                    gt_normals_compare = gt_grad_compare / (gt_grad_compare.norm(dim=-1, keepdim=True) + 1e-6)
            except Exception as e:
                gt_normals_compare = torch.zeros_like(original_queries.unsqueeze(0).to(device))

           
            try:
                with torch.enable_grad():
                    queries_noise = surface.clone().requires_grad_(True).unsqueeze(0).to(device)
                    validate_tensors_require_grad(queries_noise)

                    noise_output = fix_model(noise1_points, queries_noise)['logits']

                    if not noise_output.requires_grad:
                        noise_output = noise_output.detach().clone().requires_grad_(True)

                    noise_grad = torch.autograd.grad(
                        noise_output.sum(),
                        queries_noise,
                        create_graph=False,
                        retain_graph=False
                    )[0]

                    noise_normals = noise_grad / (noise_grad.norm(dim=-1, keepdim=True) + 1e-6)
            except Exception as e:
                noise_normals = torch.zeros_like(original_queries.unsqueeze(0).to(device))

            try:
                with torch.enable_grad():
                    queries_denoised = surface.clone().requires_grad_(True).unsqueeze(0).to(device)
                    validate_tensors_require_grad(queries_denoised)

                    denoised_output = fine_tune_model(noise1_points, queries_denoised)['logits']

                    if not denoised_output.requires_grad:
                        denoised_output = denoised_output.detach().clone().requires_grad_(True)

                    denoised_grad = torch.autograd.grad(
                        denoised_output.sum(),
                        queries_denoised,
                        create_graph=False,
                        retain_graph=False
                    )[0]

                    denoised_normals = denoised_grad / (denoised_grad.norm(dim=-1, keepdim=True) + 1e-6)
            except Exception as e:
                denoised_normals = torch.zeros_like(original_queries.unsqueeze(0).to(device))

            noise_normal_consistency = torch.abs(torch.sum(noise_normals.cpu() * normals, dim=-1)).mean().item()
            denoised_normal_consistency = torch.abs(torch.sum(denoised_normals.cpu() * normals, dim=-1)).mean().item()
            gt_normal_consistency = torch.abs(torch.sum(gt_normals_compare.cpu() * normals, dim=-1)).mean().item()

            noise_metrics.update_metrics(metrics={
                'mse': noise_mse, 'mae': noise_mae, 'iou': noise_iou,
                'normal_consistency': noise_normal_consistency
            })

            metric_logger.update_metrics(metrics={
                'mse': denoised_mse, 'mae': denoised_mae, 'iou': denoised_iou,
                'normal_consistency': denoised_normal_consistency
            })

            gt_metrics.update_metrics(metrics={
                'mse': gt_mse, 'mae': gt_mae, 'iou': gt_iou,
                'normal_consistency': gt_normal_consistency
            })

            del gt_points, noise1_points, original_queries, surface, labels
            del queries_no_grad, queries_gt_compare, queries_noise, queries_denoised
            if 'gt_grad_compare' in locals(): del gt_grad_compare
            if 'noise_grad' in locals(): del noise_grad
            if 'denoised_grad' in locals(): del denoised_grad
            torch.cuda.empty_cache()
            gc.collect()

        except Exception as e:
            continue




    for p, requires_grad in zip(fix_model.parameters(), original_fix_requires_grad):
        p.requires_grad = requires_grad
    for p, requires_grad in zip(fine_tune_model.parameters(), original_finetune_requires_grad):
        p.requires_grad = requires_grad

    random.setstate(original_state)
    np.random.set_state(original_np_state)
    torch.set_rng_state(original_torch_state)
    torch.backends.cudnn.deterministic = False


    noise_avg = noise_metrics.average()
    denoised_avg = metric_logger.average()
    gt_avg = gt_metrics.average()


    if len(cd_list):
        cd_mean = np.mean(cd_list)
    else:
        cd_mean = float('nan')

    print("\n=== Original ===")
    print(f"Number of evaluation samples: {len(indices)}")
    print(f"MSE: {noise_avg['mse']:.4f}")
    print(f"MAE: {noise_avg['mae']:.4f}")
    print(f"IoU: {noise_avg['iou']:.4f}")
    print(f"Normal Consistency: {noise_avg['normal_consistency']:.4f}")

    print("\n=== Denoised Results ===")
    print(f"Number of evaluation samples: {len(indices)}")
    print(f"MSE: {denoised_avg['mse']:.4f}")
    print(f"MAE: {denoised_avg['mae']:.4f}")
    print(f"IoU: {denoised_avg['iou']:.4f}")
    print(f"Normal Consistency: {denoised_avg['normal_consistency']:.4f}")
    print(f'CD: {cd_mean}')

    print("\n=== GT Results ===")
    print(f"Number of evaluation samples: {len(indices)}")
    print(f"MSE: {gt_avg['mse']:.4f}")
    print(f"MAE: {gt_avg['mae']:.4f}")
    print(f"IoU: {gt_avg['iou']:.4f}")
    print(f"Normal Consistency: {gt_avg['normal_consistency']:.4f}")

    if len(metric_logger.metrics['mse']) == 0:
        raise RuntimeError("No samples were successfully evaluated, please check the input data")

    return mesh_results, {'denoised': denoised_avg, 'noise': noise_avg, 'gt': gt_avg, 'cd': cd_mean}
