import torch

def clip_gradients(model, max_norm=1.0):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

def pretrain(model, device, train_loader, criterion, optimizer, tf_sched, lam_sched, cfg_train, step, max_grad_norm=1.0):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (ap_coords, mask_data, masked_pos, mask_label, ori_data) in enumerate(train_loader):
        ap_coords, mask_data, masked_pos, mask_label, ori_data = ap_coords.float(), mask_data.float(), masked_pos, mask_label.float(), ori_data.float()
        ap_coords, mask_data, masked_pos, mask_label, ori_data = ap_coords.to(device), mask_data.to(device), masked_pos.to(device), mask_label.to(device), ori_data.to(device)

         # 动态切换 detach：训练前期稳定，后期端到端
        warm = cfg_train.get('warmup_detach_steps')
        # cfg_train['detach_ntp_input'] = (step < warm)

        optimizer.zero_grad()
        losses = model(ap_coords, mask_data, masked_pos, mask_label, ori_data, step, tf_sched, lam_sched, cfg_train)
        total_loss = losses['total_loss']


        total_loss.backward()
        clip_gradients(model, max_grad_norm)
        optimizer.step()

        epoch_loss += total_loss.item()
        step += 1

    batch_num = batch_idx + 1
    epoch_loss /= batch_num

    return epoch_loss, step


def train(model, device, train_loader, criterion, optimizer, dist_map_tensor, mask_map3a_tensor):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (ap_coords, data, labels) in enumerate(train_loader):
        ap_coords, data, labels = ap_coords.float(), data.float(), labels.float()
        ap_coords, data, labels = ap_coords.to(device), data.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(ap_coords, data)
        loss_location = criterion(outputs, labels).mean()
        loss_map = soft_penalty_loss(outputs, labels, dist_map_tensor, mask_map3a_tensor)
        loss = loss_map + loss_location

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    batch_num = batch_idx + 1
    epoch_loss /= batch_num

    return epoch_loss

def test(model, device, test_loader, criterion, dist_map_tensor, mask_map3a_tensor):
    model.eval()
    epoch_loss = 0.0
    with torch.no_grad():
        for batch_idx, (ap_coords, data, labels) in enumerate(test_loader):
            ap_coords, data, labels = ap_coords.float(), data.float(), labels.float()
            ap_coords, data, labels = ap_coords.to(device), data.to(device), labels.to(device)
            
            outputs = model(ap_coords, data)
            loss_location = criterion(outputs, labels).mean()
            loss_map = soft_penalty_loss(outputs, labels, dist_map_tensor, mask_map3a_tensor)
            loss = loss_map.item() + loss_location.item()
        
            epoch_loss += loss
    
    batch_num = batch_idx + 1
    epoch_loss /= batch_num

    return epoch_loss


def soft_penalty_loss(preds, targets, dist_map_tensor, mask_tensor, 
                      penalty_factor=0.5, max_penalty=5.0, exp_factor=0.5,
                      distance_thresh=0.7, distance_penalty_weight=2.0):
    """
    地图合法性约束 + 超过距离阈值惩罚的 soft penalty loss

    preds: [B, 7, 11, 2] 预测坐标（米单位）
    targets: [B, 7, 11, 2] 真值坐标（米单位）
    dist_map_tensor: [H, W] 距离图（像素单位）
    mask_tensor: [H, W] uint8，255表示合法区域
    """
    B, N, T, _ = preds.shape  # N = 7
    preds_flat = preds.reshape(-1, 2)       # [B * 7, 2]
    targets_flat = targets.reshape(-1, 2)   # [B * 7, 2]

    x_pix = preds_flat[:, 0].long()
    y_pix = preds_flat[:, 1].long()
    H, W = dist_map_tensor.shape

    # 判断是否越界
    in_bound = (x_pix >= 0) & (x_pix < W) & (y_pix >= 0) & (y_pix < H)

    penalties = torch.zeros_like(preds_flat[:, 0])  # [B * 7]

    # 越界点设为最大惩罚
    penalties[~in_bound] = max_penalty

    x_in = torch.clamp(x_pix[in_bound], 0, W - 1)
    y_in = torch.clamp(y_pix[in_bound], 0, H - 1)

    valid_mask = (mask_tensor[y_in, x_in] == 255)

    # 合法区域无惩罚
    in_bound_idx = in_bound.nonzero(as_tuple=True)[0]
    penalties[in_bound_idx[valid_mask]] = 0.0

    # 非法区域惩罚（指数函数）
    distances = dist_map_tensor[y_in, x_in].float()  # 确保是 float
    penalties[in_bound_idx[~valid_mask]] = torch.exp(distances[~valid_mask] * exp_factor) - 1

    # === 距离惩罚 ===
    coord_diff = preds_flat - targets_flat
    euclidean_dist = torch.norm(coord_diff, dim=1)  # [B * 7]

    dist_penalty = torch.zeros_like(euclidean_dist)
    over_thresh = euclidean_dist > distance_thresh
    dist_penalty[over_thresh] = (euclidean_dist[over_thresh] - distance_thresh)

    total_penalty = penalties + distance_penalty_weight * dist_penalty

    # 恢复到 [B, 7]
    total_penalty = total_penalty.view(B, N, T)

    return penalty_factor * total_penalty.mean()