import os
import sys
import logging
import argparse
import random
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Subset
import torch.backends.cudnn as cudnn
from utils import test_single_volume
from importlib import import_module
from segment_anything import sam_model_registry
from load_LIDC_data import LIDC_IDRI, RandomGenerator
from utils import init_weights, init_weights_orthogonal_normal, l2_regularisation
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from tqdm import tqdm
from scipy.ndimage import zoom
from einops import repeat
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from tensorboardX import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
from utils import generalized_energy_distance_iou, dice_max_cal1, dice_max_cal2, dice_avg_cal, hm_iou_cal
from ambiguous_sam_v2 import Ambiguous_Sam
import time  # 引入 time 模块

def show_box(box, ax, color):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor='none', lw=2))


# 将训练后的ckpoint进行评估

writer = SummaryWriter('tf-logs/train_onestage')
weights = torch.load('/data/cxli/yuzhi/Ambiguous_SAM/ckpoint/onestage_weight/allsamples_100epoch_1e-4_80epoch.pt')
iou_score = 0
ged_score = 0
dice_random_score = 0
dice_max1_score = 0
dice_max2_score = 0
dice_avg_score = 0
hm_iou_score = 0
net = Ambiguous_Sam(lora_ckpt='/data/cxli/yuzhi/Ambiguous_SAM/ckpoint/weight_ckpoint/allsamples_50epoch.pth')
loaded_state_dict = torch.load("/data/cxli/yuzhi/Ambiguous_SAM/ckpoint/onestage/allsamples_100epoch_1e-4_80epoch.pth")
net.load_state_dict(loaded_state_dict)

db = LIDC_IDRI(dataset_location='/data/cxli/yuzhi/Ambiguous_SAM/LIDC/data/', transform=transforms.Compose([
    RandomGenerator(output_size=[128, 128])
]), threshold=True)
dataset_size = len(db)

# 生成数据集的索引
indices = list(range(dataset_size))

# 计算分割点以划分训练集、验证集和测试集
train_split = int(np.floor(0.6 * dataset_size))  # 60% 作为训练集
validation_split = int(np.floor(0.8 * dataset_size))  # 接下来20% 作为验证集，剩余20% 作为测试集

# 分配训练集、验证集和测试集索引
train_indices = indices[:50]
validation_indices = indices[train_split:validation_split]
test_indices = indices[validation_split:]

train_dataset = Subset(db, train_indices)
validation_dataset = Subset(db, validation_indices)
test_dataset = Subset(db, test_indices)
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=5, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
# 输出划分后的数据集大小，确保划分正确
print(f"Total dataset size: {dataset_size}")
print(f"Training set size: {len(train_indices)}")
print(f"Validation set size: {len(validation_indices)}")
print(f"Test set size: {len(test_indices)}")

print("Number of training/test patches:", (len(train_indices), len(test_indices)))



device = 'cuda:3'

# 记录推理时间
total_inference_time = 0
num_inferences = 0

# 计算每个传入样本的index，根据样本的index 将每次推理的结果放入相应的列表中，所有推理结束后，在进行指标评估
for i_batch, sampled_batch in enumerate(test_loader):
    print('第', i_batch, '张')
    image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
    boxori_batch = sampled_batch['box_ori']
    image_batch_oc = sampled_batch['image_oc']
    label_four_batch = sampled_batch['label_four']
    image_batch_oc = image_batch_oc.cuda().to(device)
    box1024_batch = sampled_batch['box_1024'].cuda().to(device)
    boxshift_batch = sampled_batch['box_shift'].cuda().to(device)

    image_batch, label_batch = image_batch.cuda().to(device), label_batch.cuda().to(device)
    assert image_batch.max() <= 3, f'image_batch max: {image_batch.max()}'
    pred_list = [0] * image_batch.shape[0]
    pred_show = [0] * image_batch.shape[0]

    start_time = time.time()  # 记录开始时间
    for i in range(4):
        outputs = net.forward(image_batch, image_batch_oc, box1024_batch, boxshift_batch, label_batch, train=False)
        output_masks = outputs['masks']
        low_res_logits = outputs['low_res_logits']

        logits_high = output_masks.cuda().to(device)
        logits_high = logits_high * weights.unsqueeze(-1)
        logits_high_res = logits_high.sum(1).unsqueeze(1)
        mask = logits_high_res > 0
        mask = mask.cpu().detach()
        mask = torch.where(mask, torch.tensor(1), torch.tensor(0))

        for j in range(image_batch.shape[0]):
            if i == 0:
                pred_show[j] = []
                pred_list[j] = []
                pred_show[j].append(mask[j])
                pred_list[j].append(logits_high_res[j])
            else:
                pred_show[j].append(mask[j])
                pred_list[j].append(logits_high_res[j])

    end_time = time.time()  # 记录结束时间
    inference_time = end_time - start_time
    total_inference_time += inference_time
    num_inferences += 1

    for index in range(len(pred_list)):
        pred_eval = torch.cat(pred_list[index], 0)
        pred_eval = (pred_eval > 0).cpu().detach()
        pred_eval = torch.where(pred_eval, torch.tensor(1), torch.tensor(0))

        iou_score_iter, ged_score_iter = generalized_energy_distance_iou(pred_eval, label_four_batch[index])
        score = hm_iou_cal(pred_eval, label_four_batch[index])
        hm_iou_score += score
        dice_max1_score += dice_max_cal1(pred_eval, label_four_batch[index])
        dice_max2_score += dice_max_cal2(pred_eval, label_four_batch[index])
        dice_avg_score += dice_avg_cal(pred_list[index], label_four_batch[index])
        iou_score += iou_score_iter
        ged_score += ged_score_iter

ged = ged_score / len(test_indices)
iou = iou_score / len(test_indices)
dice_max1 = dice_max1_score / len(test_indices)
dice_max2 = dice_max2_score / len(test_indices)
dice_avg = dice_avg_score / len(test_indices)
hm_iou = hm_iou_score / len(test_indices)

average_inference_time = total_inference_time / num_inferences
print("iou_score: ", iou, "ged_score: ", ged, "dice_max_score1", dice_max1, "dice_max_score2", dice_max2, "dice_avg_score", dice_avg, "hm_iou", hm_iou)
print(f"Average Inference Time per Image: {average_inference_time:.4f} seconds")
