import torch
import copy
from torch import nn
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import torch
import copy
from torch import nn
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ...existing code...

def _infer_coord_dim_from_model(model: nn.Module, default_dim: int = 3) -> int:
    """
    优先从 model.in_features 推断坐标维度；否则从 model.in_dim 推断（仅当为2或3）；
    否则使用传入的 default_dim。
    """
    if hasattr(model, 'in_features'):
        try:
            d = int(model.in_features)
            if d in (2, 3):
                return d
        except Exception:
            pass
    return int(default_dim)


def extract_mesh_2d(model: nn.Module,
                    scale: float = 2,
                    N: int = 512,
                    device: str = "cpu"):
    """
    2D 网格提取（marching squares）：返回 verts(z=0), 空 faces/normals/values
    """
    import numpy as np
    import skimage.measure as measure

    model = model.cuda()
    model.eval()

    head = 0
    num_samples = N ** 2
    max_batch = 64 ** 2
    voxel_origin = np.array([-1., -1.]) * scale
    voxel_size = 2.0 / (N - 1) * scale

    overall_index = torch.arange(0, num_samples, 1, dtype=torch.long).cuda()
    # samples: [N^2, 3] -> x, y, sdf
    samples = torch.zeros(num_samples, 3).cuda()

    # indices -> (x,y) grid
    samples[:, 1] = overall_index % N
    samples[:, 0] = (overall_index // N) % N

    # to world coords
    samples[:, 0] = (samples[:, 0].float() * voxel_size) + voxel_origin[0]
    samples[:, 1] = (samples[:, 1].float() * voxel_size) + voxel_origin[1]

    # evaluate SDF on grid
    while head < num_samples:
        tail = min(head + max_batch, num_samples)
        xy = samples[head:tail, 0:2]  # [M,2]
        sdf = (
            model.template_sdf(xy[None, :, :])
            .squeeze()
            .detach()
            .cpu()
        )
        samples[head:tail, -1] = sdf.to(samples.dtype).cuda()
        head += max_batch

    sdf_values = samples[:, -1].reshape(N, N)
    numpy_sdf_tensor = sdf_values.detach().cpu().numpy()

    try:
        contours = measure.find_contours(numpy_sdf_tensor, level=0.0)
    except Exception as e:
        print(f'2D contour extraction failed: {e}')
        contours = []

    verts_list = []
    for c in contours:
        # c[:,0] -> row (y index), c[:,1] -> col (x index)
        xs = voxel_origin[0] + c[:, 0] * voxel_size
        ys = voxel_origin[1] + c[:, 1] * voxel_size
        verts_list.append(np.stack([xs, ys,], axis=1).astype('float32'))

    if len(verts_list) == 0:
        verts = np.zeros((1, 2), dtype='float32')
    else:
        verts = np.concatenate(verts_list, axis=0)

    faces = np.zeros((0, 2), dtype=np.int32)
    normals = np.zeros((0, 2), dtype='float32')
    values = np.zeros((0,), dtype='float32')

    mesh_points = torch.from_numpy(verts).cuda()
    return mesh_points, faces, normals, values


def extract_mesh_3d(model: nn.Module,
                    scale: float = 2,
                    N: int = 512,
                    truncate: bool = False,
                    device: str = "cpu"):
    """
    3D 网格提取（marching cubes）
    """
    import numpy as np
    import skimage.measure as measure

    model = model.cuda()
    model.eval()

    head = 0
    num_samples = N ** 3
    max_batch = 64 ** 3
    voxel_origin = np.array([-1., -1., -1.]) * scale
    voxel_size = 2.0 / (N - 1) * scale

    overall_index = torch.arange(0, N ** 3, 1, dtype=torch.long)
    samples = torch.zeros(N ** 3, 4).cuda()

    # indices -> (x,y,z) grid
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index // N) % N
    samples[:, 0] = ((overall_index // N) // N) % N

    # to world coords (保持原始顺序)
    samples[:, 0] = (samples[:, 0].float() * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1].float() * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2].float() * voxel_size) + voxel_origin[0]

    while head < num_samples:
        # print(head)
        sample_subset = samples[head: min(head + max_batch, num_samples), 0:3].cuda()
        samples[head: min(head + max_batch, num_samples), -1] = (
            model.template_sdf(sample_subset[None, :, :3])
                .squeeze()  # .squeeze(1)
                .detach()
                .cpu()
        )
        head += max_batch

    sdf_values = samples[:, -1]
    sdf_values = sdf_values.reshape(N, N, N)
    numpy_sdf_tensor = sdf_values.cpu().numpy()

    try:
        verts, faces, normals, values = measure.marching_cubes(
            numpy_sdf_tensor, level=0., spacing=[voxel_size] * 3
        )
    except Exception:
        print('invalid template')
        verts = np.zeros((1, 3), dtype='float32')
        faces = np.zeros((0, 3), dtype=np.int32)
        normals = np.zeros((0, 3), dtype='float32')
        values = np.zeros((0,), dtype='float32')
        mesh_points = torch.from_numpy(verts).to(device)
        return mesh_points, faces, normals, values

    verts = verts.astype('float32')
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_origin[1] + verts[:, 1]
    mesh_points[:, 2] = voxel_origin[2] + verts[:, 2]

    mesh_points = torch.from_numpy(mesh_points).cuda()
    mesh_points = mesh_points #/ scale  # 归一化到 [-1,1]

    if truncate:
        from utility.truncate_airway import truncate_airway_mesh
        from utility.INFO import LEARNED_TEMPLATE_AIRWAY_CARINA_Z
        carina_z = -scale #LEARNED_TEMPLATE_AIRWAY_CARINA_Z #/ scale
        mesh_points, faces, _ = truncate_airway_mesh(mesh_points, faces, carina_z)
        mesh_points = torch.from_numpy(mesh_points).cuda()
    return  mesh_points, faces, normals, values
    

def extract_mesh(model: nn.Module,
                 scale: float = 1.2,
                 N: int = 512,
                 truncate: bool = False,
                 device: str = "cpu"):
    """
    入口函数：根据模型维度选择 2D 或 3D 分支。
    优先从 model.in_features 推断；否则用 model.in_dim（仅2或3）；再否则用传入的 dim。
    """
    coord_dim = _infer_coord_dim_from_model(model)
    
    if coord_dim == 2:
        return extract_mesh_2d(model, scale=2, N=N, device=device)
    else:
        return extract_mesh_3d(model, scale=scale, N=N, device=device, truncate=truncate)





class SinePE(nn.Module):
    """
    正弦/余弦位置编码模块（Sine Positional Encoding）。
    将输入 tensor 按照频率带做 sin/cos 编码，并拼接原始输入（可选）。
    """
    def __init__(self,
                 num_encoding_functions: int = 6,
                 include_input: bool = True,
                 log_sampling: bool = True):
        """
        Args:
          num_encoding_functions (int): 要使用的高频函数对数（默认 6）。
          include_input (bool): 是否在输出中保留原始输入（默认 True）。
          log_sampling (bool): 频率是否以对数方式采样（True）还是线性采样（False）。
        """
        super().__init__()
        self.num_funcs = num_encoding_functions
        self.include_input = include_input
        self.log_sampling = log_sampling

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
          x (torch.Tensor): 形状 [..., D] 的输入坐标张量。
        Returns:
          torch.Tensor: 形状 [..., D*(1 + 2*num_encoding_functions)] 或 [..., D*2*num_encoding_functions]
                        （取决于 include_input）的编码结果。
        """
        if self.num_funcs == 0:
            return x

        # 原始输入
        encoding = [x] if self.include_input else []

        # 生成频率带
        if self.log_sampling:
            freq_bands = 2.0 ** torch.linspace(
                0.0, self.num_funcs - 1, self.num_funcs,
                dtype=x.dtype, device=x.device
            )
        else:
            freq_bands = torch.linspace(
                2.0 ** 0.0, 2.0 ** (self.num_funcs - 1), self.num_funcs,
                dtype=x.dtype, device=x.device
            )

        # 逐频率拼接 sin 和 cos
        for freq in freq_bands:
            encoding.append(torch.sin(x * freq))
            encoding.append(torch.cos(x * freq))

        return torch.cat(encoding, dim=-1)



class BaseNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.covariate_names = cfg.get("CovariateNames", [])
        #self.device = cfg.get("Device", "cuda:0")
        self.num_instances = cfg.get("NumInstances", 1)  # Add num_instances attribute

        # Read positional encoding configuration
        pe_config = cfg.get("PosEncConfig", {})

        self.in_features = cfg.get("InFeatures", 3)
        self.pos_enc = pe_config['enabled']  # Store as instance attribute

        self.pos_encoder = SinePE(
            num_encoding_functions=pe_config['pe_dims'],
            include_input=pe_config['include_input'],
            log_sampling=pe_config['log_sampling']
        ) if self.pos_enc else nn.Identity()

        # Calculate input dimension based on positional encoding settings
        if self.pos_enc:
            if pe_config['include_input']:
                self.in_dim = self.in_features + (2 * pe_config['pe_dims'] * self.in_features)
            else:
                self.in_dim = 2 * pe_config['pe_dims'] * self.in_features
        else:
            self.in_dim = self.in_features

    def encode_coord(self, coords: torch.Tensor) -> torch.Tensor:
        return self.pos_encoder(coords)


    def prepare_model_input(self, model_input: dict, pts_on_template: torch.Tensor = None, retain: bool = False):
        model_input = self.update_input_coords(model_input, pts_on_template, retain=retain)
        model_input = self.update_input_covariates(model_input)
        
        return model_input

    def update_input_coords(self, model_input: dict, pts_on_template: torch.Tensor=None, retain=False) -> dict:
        # model_on_template = copy.deepcopy(model_input)
        # if pts_on_template is not None:
        #     model_on_template['coords'] = pts_on_template
        # else:
        #     model_on_template['coords'] = model_input['coords'].requires_grad_(True)
        
        model_input_on_template = {}
        for i_key, i_value in model_input.items():
            if i_key == 'coords':
                model_input_on_template[i_key] = pts_on_template if pts_on_template is not None else i_value.detach().requires_grad_(True)
            elif not retain:
                model_input_on_template[i_key] = copy.deepcopy(i_value) #.detach().requires_grad_(True)
            else:
                model_input_on_template[i_key] = i_value
        if 'coords' not in model_input_on_template and pts_on_template is not None:
            model_input_on_template['coords'] = pts_on_template
        return model_input_on_template

    def update_input_deformation(self, model_input: dict, dict_deforms: dict) -> dict:
        model_input['deformation'] = dict_deforms
        return model_input

    def update_input_covariates(self, model_input):
        coords_init = model_input['coords']

        for ith_cov in self.covariate_names:
            current_covariate_value = model_input['covariates'][ith_cov]#.requires_grad_(True)
            if current_covariate_value.ndim == 3 and current_covariate_value.shape[1] == model_input['coords'].shape[1]:
                # Expand to match point size
                continue
            else:
                # if self.training:
                #     # 🔥 训练时添加小的高斯噪声 (std = 0.1, 约±0.1年的扰动)
                #     noise = torch.randn_like(current_covariate_value) * 0.1
                #     mask = (current_covariate_value > 1.5) & (noise > 0)  # 只对非零协变量添加噪声
                #    current_covariate_value[mask] = current_covariate_value[mask] + noise[mask]
                current_covariate_value = current_covariate_value.unsqueeze(1).expand(-1, coords_init.shape[1], -1)  # (batch_size x point_size x 1)
            model_input['covariates'][ith_cov] = current_covariate_value
        return model_input
    



class TemporalExpansionConstraint(nn.Module):
    """
    时间膨胀约束 - 确保随时间单调膨胀
    
    关键思想：
    - 空间约束：det(∂(deformation)/∂(coords)) > 0  ← 只保证空间连续性
    - 时间约束：∂(volume)/∂t > 0  ← 保证随时间单调膨胀 ✅
    
    实现方法：
    1. 计算总变换雅可比：J_total = ∂(coords + deformation)/∂(coords_original)
    2. 体积变化率：∂(det(J_total))/∂t
    3. 约束：体积变化率 > 0
    """
    
    def __init__(self, 
                 time_delta: float = 1e-1,
                 expansion_penalty_weight: float = 1.0,
                 sampling_ratio: float = 0.2):
        super().__init__()
        self.time_delta = time_delta
        self.expansion_penalty_weight = expansion_penalty_weight  
        self.sampling_ratio = sampling_ratio
        
        logger.info(f"Initialized TemporalExpansionConstraint:")
        logger.info(f"  time_delta: {time_delta} (finite difference step)")
        logger.info(f"  expansion_penalty_weight: {expansion_penalty_weight}")
        logger.info(f"  sampling_ratio: {sampling_ratio} (efficiency)")

    def compute_volume_change_rate(self, 
                                   deformation_network,
                                   coords: torch.Tensor,
                                   t: torch.Tensor) -> dict:
        """
        计算体积变化率：∂(volume)/∂t
        
        核心公式：
        volume = det(J_total) = det(∂(coords + deformation)/∂(coords_original))
        ∂(volume)/∂t = [det(J_total(t+δt)) - det(J_total(t-δt))] / (2*δt)
        """
        batch_size, num_points = coords.shape[:2]
        
        # 采样提高效率
        if self.sampling_ratio < 1.0:
            num_sampled = int(num_points * self.sampling_ratio)
            sampled_indices = torch.randperm(num_points)[:num_sampled]
            coords_sampled = coords[:, sampled_indices, :]
            t_sampled = t[:, sampled_indices, :]
        else:
            coords_sampled = coords
            t_sampled = t
        
        # 计算 t-δt 和 t+δt 处的deformation
        t_minus = t_sampled - self.time_delta
        t_plus = t_sampled + self.time_delta
        
        deform_minus = deformation_network(coords_sampled, t_minus)  # [B, N, 3]
        deform_plus = deformation_network(coords_sampled, t_plus)   # [B, N, 3]
        
        # 计算雅可比行列式（需要梯度）
        coords_sampled_grad = coords_sampled.requires_grad_(True)
        
        # t-δt 处的雅可比
        deform_minus_grad = deformation_network(coords_sampled_grad, t_minus)
        J_spatial_minus = self._compute_spatial_jacobian(deform_minus_grad, coords_sampled_grad)
        # J_total = I + J_spatial (因为 total_coords = coords + deformation)
        I = torch.eye(3, device=coords.device).expand(batch_size, coords_sampled_grad.shape[1], 3, 3)
        J_total_minus = I + J_spatial_minus
        det_J_minus = torch.det(J_total_minus)  # [B, N]
        
        # t+δt 处的雅可比  
        deform_plus_grad = deformation_network(coords_sampled_grad, t_plus)
        J_spatial_plus = self._compute_spatial_jacobian(deform_plus_grad, coords_sampled_grad)
        J_total_plus = I + J_spatial_plus
        det_J_plus = torch.det(J_total_plus)  # [B, N]
        
        # 体积变化率：∂(det(J_total))/∂t
        volume_change_rate = (det_J_plus - det_J_minus) / (2 * self.time_delta)
        
        return {
            "volume_change_rate": volume_change_rate,  # [B, N] - 关键指标！
            "det_J_minus": det_J_minus,
            "det_J_plus": det_J_plus, 
            "expansion_ratio": torch.mean((volume_change_rate > 0).float()).item(),
            "mean_expansion_rate": torch.mean(volume_change_rate).item()
        }

    def _compute_spatial_jacobian(self, deformation: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        """计算空间雅可比：∂(deformation)/∂(coords)"""
        batch_size, num_points = deformation.shape[:2]
        jacobian = torch.zeros(batch_size, num_points, 3, 3, device=deformation.device)
        
        for i in range(3):  # deformation components
            for j in range(3):  # coordinate components
                if coords.grad is not None:
                    coords.grad.zero_()
                
                grad_outputs = torch.zeros_like(deformation)
                grad_outputs[:, :, i] = 1.0
                
                grads = torch.autograd.grad(
                    outputs=deformation,
                    inputs=coords,
                    grad_outputs=grad_outputs,
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True
                )[0]
                
                jacobian[:, :, i, j] = grads[:, :, j]
        
        return jacobian

    def compute_expansion_loss(self, deformation_network, coords: torch.Tensor, t: torch.Tensor) -> dict:
        """计算时间膨胀约束损失"""
        volume_stats = self.compute_volume_change_rate(deformation_network, coords, t)
        volume_change_rate = volume_stats["volume_change_rate"]
        
        # 膨胀损失：惩罚负的体积变化率
        negative_expansion = torch.clamp(-volume_change_rate, min=0)
        expansion_loss = torch.mean(negative_expansion ** 2)
        
        # 弱膨胀惩罚：鼓励强膨胀
        weak_expansion_threshold = 0.01
        weak_expansion = torch.clamp(weak_expansion_threshold - volume_change_rate, min=0)  
        weak_expansion_penalty = torch.mean(weak_expansion ** 2) * 0.1
        
        total_loss = expansion_loss + weak_expansion_penalty
        
        return {
            "expansion_loss": total_loss,
            "negative_expansion_loss": expansion_loss,
            "weak_expansion_penalty": weak_expansion_penalty,
            "expansion_ratio": volume_stats["expansion_ratio"],
            "mean_expansion_rate": volume_stats["mean_expansion_rate"]
        }
    

class LandmarkRescaler:
    """
    专门用于处理landmark rescaling的工具类
    
    功能：
    1. 应用landmark rescaling到坐标
    2. 准备landmark-aware teacher输入
    3. 处理不同类型的landmark scale数据
    """
    
    def __init__(self, reference_scale: float = 1.0):
        self.reference_scale = reference_scale
        logger.info(f"Initialized LandmarkRescaler with reference_scale: {reference_scale}")
    

    def apply_rescaling_to_coords(self, coords: torch.Tensor, landmark_scale) -> torch.Tensor:
        # landmark_scale 应为标量或 [B, N, 1] 或 [B, 1, 1]
        if isinstance(landmark_scale, torch.Tensor):
            # 如果是 [B, N] 或 [B, N, 1]，都可以
            if landmark_scale.dim() == 1:
                landmark_scale = landmark_scale.unsqueeze(-1).unsqueeze(-1)
        inverse_scale_factor = landmark_scale 
        rescaled_coords = coords * inverse_scale_factor
        return rescaled_coords

    def extract_landmark_scale(self, model_input: dict):
        """
        从模型输入中提取landmark scale值（支持数组、张量、标量）
        Args:
            model_input: 模型输入字典
        Returns:
            landmark_scale: float, np.ndarray, 或 torch.Tensor
        """
        landmark_scale = model_input.get('landmark_scale', 1.0)
        # 支持 torch.Tensor、np.ndarray、float
        if isinstance(landmark_scale, torch.Tensor):
            return landmark_scale / self.reference_scale
        elif hasattr(landmark_scale, 'shape'):
            # numpy array
            return landmark_scale / self.reference_scale
        else:
            return float(landmark_scale / self.reference_scale) 
    
