import numpy as np
import random
import os
from keras.models import load_model
from skimage import io
from skimage.transform import resize
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import f1_score, jaccard_score
from PIL import Image

import warnings


warnings.filterwarnings("ignore")

def batch_data_test(input_name, n, input_size_1=256, input_size_2=256):
    # rand_num = random.randint(0, n - 1)
    img1_path = "~/tai/sdcopy/tai_data/sem/2000x/" + input_name[n-1]
    img1 = io.imread(img1_path).astype("float")
    original_shape = img1.shape[:2]
    img1 = resize(img1, [input_size_1, input_size_2, 3])
    img1 = np.reshape(img1, (1, input_size_1, input_size_2, 3))
    img1 /= 255
    batch_input = img1
    original_info = [(img1_path, original_shape)]

    return batch_input,original_info


def save_predictions(pred_Y, original_info, output_dir='~/tai/sdcopy/unet/unet_sem/2000x'):
    """将预测结果保存为与原图一样大小的图像"""
    # 创建保存结果的目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for i, (img1, original_shape) in enumerate(original_info):
        # 获取原始图像文件名
        img_filename = os.path.basename(img1)

        # 获取预测结果
        pred = pred_Y[i]

        # 调整预测结果的大小到原始图像尺寸
        pred_resized = resize(pred, original_shape, anti_aliasing=True)

        # 将预测结果转换为uint8格式 (0-255)
        pred_uint8 = (pred_resized * 255).astype(np.uint8)

        # 保存预测结果
        output_path = os.path.join(output_dir, f"{img_filename}")
        io.imsave(output_path, pred_uint8)
        print(f"预测结果已保存到: {output_path}")

# 加载已训练的模型
model = load_model('~/tai/sdcopy/unet/dengzhou.h5')
test_name = os.listdir("~/tai/sdcopy/tai_data/sem/2000x")
n_test = len(test_name)

for i in range(1,n_test+1):
    test_X ,original_info= batch_data_test(test_name,i)

# 进行预测
    pred_Y = model.predict(test_X)
    save_predictions(pred_Y, original_info)