import os
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import torch
from torch.cuda.amp import autocast as autocast
from sklearn.metrics import confusion_matrix
from utils import save_imgs
import torchvision


def train_one_epoch(train_loader,
                    model,
                    criterion, 
                    optimizer, 
                    scheduler,
                    epoch, 
                    logger, 
                    config, 
                    scaler=None):
    '''
    train model for one epoch
    '''
    # switch to train mode
    model.train() 
 
    loss_list = []

    for iter, data in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        images, targets = data
        images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float().squeeze(-1)
        if config.amp:
            with autocast():
                out = model(images)
                loss = criterion(out, targets)      
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(images)
            loss = criterion(out, targets)
            loss.backward()
            optimizer.step()
        
        loss_list.append(loss.item())

        now_lr = optimizer.state_dict()['param_groups'][0]['lr']
        if iter % config.print_interval == 0:
            log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
            # print(log_info)
            logger.info(log_info)
    scheduler.step() 


def val_one_epoch(test_loader,
                    model,
                    criterion, 
                    epoch, 
                    logger,
                    config,
                    info):
    # switch to evaluate mode
    model.eval()
    preds = []
    gts = []
    loss_list = []
    with torch.no_grad():
        for data in tqdm(test_loader):
            img, msk = data
            img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float().squeeze(-1)
            out = model(img)

            loss = criterion(out, msk)
            loss_list.append(loss.item())
            gts.append(msk.squeeze(1).cpu().detach().numpy().astype(np.int32))
            
            if type(out) is tuple:
                out = out[0]
            out = out.argmax(1)
            out = out.squeeze(1).cpu().detach().numpy().astype(np.int32)
            preds.append(out)

    # 计算 mIoU
    gts = np.concatenate(gts, axis=0)  # 将所有真实标签合并成一个数组
    preds = np.concatenate(preds, axis=0)  # 将所有预测标签合并成一个数组

    metrics = calculate_segmentation_metrics(preds, gts, config.num_classes)

    log_info = 'test' + f' of best model, loss: {np.mean(loss_list):.4f}, metrics: {metrics}'
    logger.info(log_info)
    
    return np.mean(loss_list), metrics

def predict_one_epoch(test_loader,
                    model,
                    criterion,
                    logger,
                    config):
    # switch to evaluate mode
    model.eval()

    preds = []
    gts = []
    loss_list = []
    # 创建用于保存图像的文件夹
    output_folder = 'output_images'
    os.makedirs(output_folder, exist_ok=True)
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_loader)):
            img, msk = data
            img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
            out = model(img)

            loss = criterion(out, msk)
            loss_list.append(loss.item())
            gts.append(msk.squeeze(1).cpu().detach().numpy().astype(np.int32))
            if type(out) is tuple:
                out = out[0]
            out = out.argmax(1)
            out = out.squeeze(1).cpu().detach().numpy().astype(np.int32)
            preds.append(out)

        # 计算 mIoU
        gts = np.concatenate(gts, axis=0)  # 将所有真实标签合并成一个数组
        preds = np.concatenate(preds, axis=0)  # 将所有预测标签合并成一个数组

        metrics = calculate_segmentation_metrics(preds, gts, config.num_classes)

        log_info = 'test' + f' of best model, loss: {np.mean(loss_list):.4f}, metrics: {metrics}'
        print(  ' miou:', metrics["mean_iou"], 
                ' iou:', [m["iou"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_dice:', metrics["mean_dice"], 
                ' dice:', [m["dice"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_acc:', metrics["mean_pixel_accuracy"], 
                ' acc:', [m["pixel_accuracy"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_assd:', metrics["mean_assd"], '\n',
                ' best dice:', metrics["mean_dice"])
        logger.info(log_info)

    return np.mean(loss_list)

import numpy as np
from scipy.spatial.distance import directed_hausdorff
from skimage.morphology import binary_erosion, binary_dilation
from scipy import spatial

def calculate_segmentation_metrics(preds, gts, num_classes, ignore_background=True):
    gts = np.argmax(gts, axis=1)
    assert preds.shape == gts.shape, "预测结果与真实标签形状必须一致"
    
    # 初始化全局指标
    metrics = {
        "pixel_accuracy": np.sum(preds == gts) / gts.size if gts.size > 0 else 0,
        "class_metrics": []
    }

    # 计算边界距离的辅助函数
    def _surface_distances(mask1, mask2):
        """边界距离计算"""
        # 获取边界（逐层计算）
        hd_list, assd_list = [], []
        for z in range(mask1.shape[0]):  # 遍历每一层（468层）
            # 当前层的2D切片
            slice1 = mask1[z]
            slice2 = mask2[z]
            
            # 跳过全空的层
            if np.sum(slice1) == 0 or np.sum(slice2) == 0:
                continue
                
            # 2D边界提取
            surface1 = slice1 & ~binary_erosion(slice1)
            surface2 = slice2 & ~binary_erosion(slice2)
            
            # 转换为坐标点集（N×2数组）
            coords1 = np.column_stack(np.where(surface1))
            coords2 = np.column_stack(np.where(surface2))
            
            # 单层HD计算
            # hd = max(directed_hausdorff(coords1, coords2)[0],
            #         directed_hausdorff(coords2, coords1)[0])
            # hd_list.append(hd)
            
            # 单层ASSD计算
            tree1 = spatial.cKDTree(coords2)
            dist1 = tree1.query(coords1)[0]
            tree2 = spatial.cKDTree(coords1)
            dist2 = tree2.query(coords2)[0]
            assd = (np.mean(dist1) + np.mean(dist2)) / 2
            assd_list.append(assd)
        
        # 返回所有层的平均值
        return np.mean(assd_list) if assd_list else float('nan')

    for cls in range(num_classes):
        # 基础混淆矩阵计算
        TP = np.sum((preds == cls) & (gts == cls))
        FP = np.sum((preds == cls) & (gts != cls))
        FN = np.sum((preds != cls) & (gts == cls))
        TN = np.sum((preds != cls) & (gts != cls))
        
        # 计算传统指标
        dice = 2 * TP / (2 * TP + FP + FN + 1e-10)
        iou = TP / (TP + FP + FN + 1e-10)
        
        # 计算边界距离指标
        pred_mask = (preds == cls)
        gt_mask = (gts == cls)
        assd = _surface_distances(pred_mask, gt_mask)
        
        # 存储单类别指标
        cls_metrics = {
            "class": cls,
            "pixel_accuracy": np.sum((preds == cls) == (gts == cls)) / gts.size if gts.size > 0 else 0,
            "TP": int(TP), "FP": int(FP), "FN": int(FN), "TN": int(TN),
            "dice": dice,
            "iou": iou,
            # "hausdorff_distance": hd if hd is not None else float('nan'),
            "assd": assd if assd is not None else float('nan')
        }
        metrics["class_metrics"].append(cls_metrics)
    
    # 计算平均指标（跳过背景类）
    target_classes = range(1, num_classes) if ignore_background else range(num_classes)
    valid_metrics = [m for m in metrics["class_metrics"] if m["class"] in target_classes]
    
    # 初始化平均指标
    
    if valid_metrics:
        metrics["mean_dice"] = np.nanmean([m["dice"] for m in valid_metrics])
        metrics["mean_iou"] = np.nanmean([m["iou"] for m in valid_metrics])
        # metrics["mean_hd"] = np.nanmean([m["hausdorff_distance"] for m in valid_metrics])
        metrics["mean_assd"] = np.nanmean([m["assd"] for m in valid_metrics])
        metrics["mean_pixel_accuracy"] = np.nanmean([m["pixel_accuracy"] for m in valid_metrics])  # 确保计算此项
    else:
        metrics.update({
            "mean_dice": 0,
            "mean_iou": 0,
            # "mean_hd": float('nan'),
            "mean_assd": float('nan'),
            "mean_pixel_accuracy": 0  # 确保包含此项
        })

    return metrics