import h5py
import numpy as np
import scipy.io as sio
import cv2 as sc
import glob
from tqdm import tqdm

# Parameters
height = 256
width = 256
channels = 3
num_classes = 4  # 预设类别数
all = 1170
train_number = int(all * 0.7)
val_number = int(all * 0.1)
test_number = all - train_number - val_number
print(f"Train/Val/Test split: {train_number}/{val_number}/{test_number}") # Train/Val/Test split: 1638/234/468

# Prepare dataset
data_root_dir = "dataset/OCT3/"
Tr_list = glob.glob(data_root_dir + "images/" + '*.jpg')

# Initialize arrays
Data_train = np.zeros([all, height, width, channels])
Label_train = np.zeros([all, height, width])

print('Reading', len(Tr_list), "images")
Tr_list_shuffled = []

# 用于统计所有mask中的唯一值
all_unique_values = set()

for idx in tqdm(range(len(Tr_list))):
    # Read image
    img = sc.imread(Tr_list[idx])
    img = sc.resize(img, (width, height), interpolation=sc.INTER_LINEAR)
    Data_train[idx, :, :, :] = img

    # Read mask
    b = Tr_list[idx][-12:-4]
    add = (data_root_dir + "masks/" + b + '.png')
    img2 = sc.imread(add, sc.IMREAD_GRAYSCALE)
    img2 = sc.resize(img2, (width, height), interpolation=sc.INTER_NEAREST)  # 使用最近邻插值保持类别值
    
    # 统计当前mask中的唯一值
    current_unique = np.unique(img2)
    all_unique_values.update(current_unique)
    
    Label_train[idx, :, :] = img2
    Tr_list_shuffled.append(Tr_list[idx])

# 输出统计结果
print("\nMask value statistics:")
print(f"All unique values found in masks: {sorted(all_unique_values)}")
print(f"Actual number of classes in masks: {len(all_unique_values)}")

# 检查与预设类别数的差异
if len(all_unique_values) != num_classes:
    print(f"Warning: Number of classes in masks ({len(all_unique_values)}) differs from preset num_classes ({num_classes})")

# Shuffle dataset
random_indices = np.random.permutation(len(Tr_list_shuffled))
Tr_list_shuffled[:] = [Tr_list_shuffled[i] for i in random_indices]
Data_train[:] = Data_train[random_indices]
Label_train[:] = Label_train[random_indices]

# # Helper function: Find the top n most frequent colors in a mask
# def get_top_n_colors(mask, num_classes):
#     # Get the unique classes in the mask and their frequency
#     unique_classes, counts = np.unique(mask, return_counts=True)
#     # Sort by frequency and get the top n
#     top_n_classes = unique_classes[np.argsort(counts)[-num_classes:]]
#     return top_n_classes

# # Euclidean distance function in color space (for RGB)
# def color_distance(c1, c2):
#     return np.sqrt((c1 - c2) ** 2)

# # Process the labels to only keep the top n most frequent colors
# def process_mask(mask, num_classes):
#     top_n_colors = get_top_n_colors(mask, num_classes)
    
#     # For each pixel, if it's not one of the top n, replace it with the closest top n color
#     for i in range(mask.shape[0]):
#         for j in range(mask.shape[1]):
#             if mask[i, j] not in top_n_colors:
#                 # Find the closest top n color based on the Euclidean distance in the color space
#                 closest_color = min(top_n_colors, key=lambda x: color_distance(x, mask[i, j]))
#                 mask[i, j] = closest_color

#     # Re-map the colors to new categories: 0 to num_classes-1
#     unique_classes = np.unique(mask)
#     class_mapping = {old_class: new_class for new_class, old_class in enumerate(unique_classes)}
    
#     # Map all pixel values to their new class values
#     for i in range(mask.shape[0]):
#         for j in range(mask.shape[1]):
#             mask[i, j] = class_mapping[mask[i, j]]
            
#     return mask

# if num_classes >= 2:
#     print('Processing')
#     # Apply the mask processing to the training, validation, and test sets
#     for i in tqdm(range(all)):
#         Label_train[i, :, :] = process_mask(Label_train[i, :, :], num_classes=num_classes)

print('Saving dataset')
# Make the training, validation and test sets 
Train_img      = Data_train[0:train_number,:,:,:]  
Validation_img = Data_train[train_number:train_number+val_number,:,:,:]
Test_img       = Data_train[train_number+val_number:train_number+val_number+test_number,:,:,:]

Train_mask      = Label_train[0:train_number,:,:]
Validation_mask = Label_train[train_number:train_number+val_number,:,:]
Test_mask       = Label_train[train_number+val_number:train_number+val_number+test_number,:,:]

# Save the dataset
np.save(data_root_dir + 'data_train', Train_img)
np.save(data_root_dir + 'data_test' , Test_img)
np.save(data_root_dir + 'data_val'  , Validation_img)

np.save(data_root_dir + 'mask_train', Train_mask)
np.save(data_root_dir + 'mask_test' , Test_mask)
np.save(data_root_dir + 'mask_val'  , Validation_mask)

print('Saving your dataset finished')

