
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F 
from scipy.spatial.distance import directed_hausdorff

def extract_coordinates(heatmaps, subpixel=True):
    batch_size, num_kps, h, w = heatmaps.shape
    
    heatmaps_flat = heatmaps.view(batch_size, num_kps, -1)
    max_indices = torch.argmax(heatmaps_flat, dim=2)
    
    max_y = (max_indices / w).int()
    max_x = (max_indices % w).int()

    coords = torch.stack([max_x, max_y], dim=2).float()

    if subpixel:
        for b in range(batch_size):
            for k in range(num_kps):
                px, py = max_x[b, k], max_y[b, k]
                
                if px > 0 and px < w - 1 and py > 0 and py < h - 1:
                    val_center = heatmaps[b, k, py, px]
                    val_right = heatmaps[b, k, py, px + 1]
                    val_left = heatmaps[b, k, py, px - 1]
                    val_down = heatmaps[b, k, py + 1, px]
                    val_up = heatmaps[b, k, py - 1, px]
                    
                    diff_x = val_right - val_left
                    diff_y = val_down - val_up
                    
                    denom_x = 2 * val_center - val_left - val_right
                    denom_y = 2 * val_center - val_up - val_down
                    
                    offset_x = diff_x / (2 * denom_x + 1e-7)
                    offset_y = diff_y / (2 * denom_y + 1e-7)
                    
                    coords[b, k, 0] += torch.clamp(offset_x, -0.5, 0.5)
                    coords[b, k, 1] += torch.clamp(offset_y, -0.5, 0.5)

    divisor = torch.tensor([w, h], device=heatmaps.device, dtype=coords.dtype).view(1, 1, 2)
    coords = coords / divisor
    
    coords = coords.view(batch_size, -1)
    
    return coords


def euclidean_distance(preds, targets, per_point=False):
    # 将坐标 reshape 为 [B, C, 2]
    preds = preds.view(preds.size(0), -1, 2)
    targets = targets.view(targets.size(0), -1, 2)
    
    # 计算每个点的距离，形状为 [B, C]
    distances_per_point = torch.norm(preds - targets, dim=2)
    
    # 计算每个样本的平均距离，形状为 [B]
    mean_distances = torch.mean(distances_per_point, dim=1)
    
    if per_point:
        # 如果需要，返回平均距离和逐点距离
        return mean_distances, distances_per_point
    else:
        # 否则，只返回平均距离
        return mean_distances

def save_net_opt(net, optimizer, path):
    state = {
        'net':net.state_dict(),
        'opt':optimizer.state_dict(),
    }
    torch.save(state, str(path))

def load_net_opt(net, optimizer, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['opt'])

def load_net(net, model_path, map_location='cpu'):
    try:
        checkpoint = torch.load(model_path, map_location=map_location)
    except FileNotFoundError:
        print(f"Error: Pretrained weights file not found at {model_path}")
        return

    if 'net' in checkpoint:
        print("  - Found 'net' key. Extracting model weights...")
        state_dict = checkpoint['net']
    elif 'model_state_dict' in checkpoint:
        print("  - Found 'model_state_dict' key. Extracting model weights...")
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        print("  - Found 'state_dict' key. Extracting model weights...")
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    
    net_state_dict = net.state_dict()
    pretrained_dict = {k: v for k, v in new_state_dict.items() if k in net_state_dict and v.size() == net_state_dict[k].size()}
    net_state_dict.update(pretrained_dict)
    net.load_state_dict(net_state_dict, strict=False)
    
    print(f"Loaded pretrained weights from {model_path}")
    print(f"  - Total keys in pretrained file (after extraction): {len(new_state_dict)}")
    print(f"  - Total keys in current model: {len(net_state_dict)}")
    print(f"  - Successfully loaded {len(pretrained_dict)} matching keys.")


def replace_bn_with_in(module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.BatchNorm2d):
            # 获取BatchNorm层的通道数
            num_features = child.num_features
            new_layer = nn.InstanceNorm2d(num_features, affine=True)
            
            # 用新层替换旧层
            setattr(module, name, new_layer)
            print(f"Replaced {name} (BatchNorm2d) with InstanceNorm2d.")
        else:
            # 如果不是BatchNorm层，就递归地进入其子模块
            replace_bn_with_in(child)
            
def save_checkpoint(epoch, model, optimizer, scheduler, current_loss, save_path):

    print(f"==> Saving checkpoint at epoch {epoch} to {save_path}")
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': current_loss,
    }
    torch.save(state, save_path)
    print("==> Checkpoint saved successfully.")


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device):

    if not os.path.exists(checkpoint_path):
        print(f"==> No checkpoint found at '{checkpoint_path}'. Starting from scratch.")
        return model, optimizer, scheduler, 0 

    print(f"==> Loading checkpoint from '{checkpoint_path}'")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    model_state = checkpoint['model_state_dict']
    if list(model_state.keys())[0].startswith('module.'):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in model_state.items():
            name = k[7:]  
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(model_state)

    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("  - Optimizer state loaded.")
    else:
        print("  - WARNING: Optimizer state not found in checkpoint or optimizer not provided.")

    if scheduler and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print("  - Scheduler state loaded.")
    else:
        print("  - WARNING: Scheduler state not found in checkpoint or scheduler not provided.")
    
    start_epoch = checkpoint.get('epoch', -1) + 1
    last_loss = checkpoint.get('loss', 'N/A')
    
    print(f"==> Checkpoint loaded. Resuming training from epoch {start_epoch}.")
    print(f"  - Last recorded loss: {last_loss}")
    
    return model, optimizer, scheduler, start_epoch

def calculate_aop(ps1: torch.Tensor, ps2: torch.Tensor, fh1: torch.Tensor) -> float:

    # 步骤 1: 定义从顶点 PS1 出发的两个向量
    # 向量 v1: 从 PS1 指向 PS2
    vec1 = ps2 - ps1
    
    # 向量 v2: 从 PS1 指向 FH1
    vec2 = fh1 - ps1

    # 步骤 2: 计算两个向量的点积
    # dot_product = v1_x * v2_x + v1_y * v2_y
    dot_product = torch.dot(vec1, vec2)

    # 步骤 3: 计算两个向量的模（长度）
    norm_vec1 = torch.norm(vec1)
    norm_vec2 = torch.norm(vec2)
    
    # 为了数值稳定性，防止除以零（例如当两个点重合时）
    epsilon = 1e-7

    # 步骤 4: 计算夹角的余弦值
    # 公式: cos(theta) = (v1 · v2) / (||v1|| * ||v2||)
    cos_theta = dot_product / (norm_vec1 * norm_vec2 + epsilon)
    
    # 步骤 5: 夹紧 cos_theta 的值到 [-1, 1] 范围
    # 由于浮点数计算的误差，cos_theta 可能会略微超出范围，导致 acos 产生 NaN
    cos_theta = torch.clamp(cos_theta, -1.0, 1.0)

    # 步骤 6: 使用反余弦函数(acos)计算弧度
    angle_rad = torch.acos(cos_theta)

    # 步骤 7: 将弧度转换为度
    angle_deg = angle_rad * 180.0 / torch.pi

    return angle_deg.item()

def differentiable_coords(heatmaps: torch.Tensor) -> torch.Tensor:
    batch_size, num_kps, h, w = heatmaps.shape
    
    # 步骤 1: 将热力图归一化为概率分布
    beta = 20.0
    heatmaps_norm = F.softmax(heatmaps.view(batch_size, num_kps, -1) * beta, dim=2)
    heatmaps_norm = heatmaps_norm.view(batch_size, num_kps, h, w)

    # 步骤 2: 创建坐标网格，范围在 [0, 1]
    # 这样可以省去最后的转换步骤
    grid_x = torch.linspace(0.0, 1.0, w, device=heatmaps.device, dtype=heatmaps.dtype)
    grid_y = torch.linspace(0.0, 1.0, h, device=heatmaps.device, dtype=heatmaps.dtype)
    
    # 步骤 3: 计算期望值
    coord_x = torch.sum(heatmaps_norm * grid_x[None, None, None, :], dim=[2, 3])
    coord_y = torch.sum(heatmaps_norm * grid_y[None, None, :, None], dim=[2, 3])
    
    # 步骤 4: 组合并塑形
    coords = torch.stack([coord_x, coord_y], dim=2) 
    return coords.view(batch_size, -1) 