import numpy
import torch.utils.data as Data
import pandas as pd
import numpy as np
from PIL import Image
import csv
import random

def get_map(file_name_list):
    # print(file_name_list)
    # labrel_list4 = []
    # for i in numpy.array(file_name_list)[:, 1]:
    #     labrel_list4.append(int(i))
    # label_list3 = np.unique(labrel_list4)
    # label_list2 = []
    # for i in label_list3:
    #     label_list2.append(str(i))
    label_list2 = np.unique(file_name_list[:,1])
    label_map = {}
    for i, l in enumerate(label_list2):
        label_map[l] = i
    # print(label_map)
    return label_map

# 马师兄版类别均衡采样
# class TrainDataset(Data.Dataset):
#     """
#     数据集
#     """
#     def __init__(self,select_ratio,base_csv_path, file_name_list, unlabeled_list, file_path, file_name,num_classes, select_data, transform=None):
#         """

#         :param file_path:数据集位置
#         :param file_name:样本索引目录
#         :param target_map:数据集种类标签和实际标签间的映射
#         :param transform:图像变换与增强

#         """
    
#         self.file_path = file_path  # '/root/disk/Zz/Dataset_rcs_structed/'
#         # if select_method:
#         #     self.file_name_list = select_method.select(pd.read_csv(file_name, header=None))
#         # else:
#         if select_data is None:

#             files = pd.read_csv(file_name, header=None).values
#             files = files[files[:, 1].argsort()] # ocean add line
#             number = np.array([0]*num_classes)
#             k=0
#             for r in range(len(files)):
#                 number[k] += 1
#                 if r+1<len(files):
#                     if files[r][1] != files[r+1][1] :
#                         k += 1
#                         assert k < len(number)
#             number = number*select_ratio
#             flag = np.zeros(num_classes)
#             self.file_name_list = []
#             self.unlabeled_list = []
#             label_list = np.unique(numpy.array(files)[:, 1])
#             label_list = np.array(label_list)
#             index = [k for k in range(len(files))]
#             # random.shuffle(index)
#             for id in index:
#                 files[id][1] = str(files[id][1])
#                 for i in range(num_classes):
#                     if files[id][1] == label_list[i]:# int(files[id][1]) if cifar100
#                         if flag[i] < number[i]:
#                             self.file_name_list.append(files[id])
#                             flag[i] += 1
#                         else:
#                             self.unlabeled_list.append(files[id])
#                         break

#             ft = open(base_csv_path, 'w', newline='')
#             ft_csv = csv.writer(ft)
#             ft_csv.writerows(self.file_name_list)
#             self.unlabeled_list = np.array(self.unlabeled_list)
#             self.file_name_list = np.array(self.file_name_list)
#             # print( self.file_name_list)

#         else:
#             if select_data != []:
#                 # print(file_name_list)
#                 self.file_name_list  = np.concatenate((np.array(file_name_list), np.array(select_data)), axis=0)

#             else:
#                 self.file_name_list = file_name_list
#             new_pool = []
#             # print(len(select_data))
#             # print(len(unlabeled_list))
#             o = 0
#             for i in unlabeled_list:
#                 flag = True
#                 if o != len(select_data):
#                     for j in select_data:
#                         if i[0] == j[0] and i[1] == j[1]:
#                             flag = False
#                             o += 1
#                             break
#                 if flag:
#                     new_pool.append(i)
#             self.unlabeled_list = np.array(new_pool)
#             # print(len(self.pool_list))
#         self.transform = transform
#         # print(self.file_name_list)
#         label_map = get_map(self.file_name_list)
#         self.target_map = label_map

#     def __len__(self):
#         return len(self.file_name_list)

#     def __getitem__(self, index):
#         image_path = self.file_path + self.file_name_list[index, 0]
#         img_name = self.file_name_list[index, 0]
#         img_lable = self.file_name_list[index, 1]
#         # img = Image.open('D:/Projects/Datasets/example/di.jp')g
#         # from PIL import ImageFile
#         # ImageFile.LOAD_TRUNCATED_IMAGES = True

#         img = Image.open(image_path).convert('RGB')
#         img = self.transform(img)
#         # img = np.array(img).astype(np.float32)
#         target = self.target_map[self.file_name_list[index, 1]]

#         return img, target, img_name, img_lable

# totally random
class TrainDataset(Data.Dataset):
    """
    数据集
    """
    def __init__(self,select_ratio,base_csv_path, file_name_list, unlabeled_list, file_path, file_name,num_classes, select_data, transform=None):
        """

        :param file_path:数据集位置
        :param file_name:样本索引目录
        :param target_map:数据集种类标签和实际标签间的映射
        :param transform:图像变换与增强

        """
    
        self.file_path = file_path  # '/root/disk/Zz/Dataset_rcs_structed/'
        # if select_method:
        #     self.file_name_list = select_method.select(pd.read_csv(file_name, header=None))
        # else:
        if select_data is None:
            files = pd.read_csv(file_name, header=None).values
            files = files[files[:, 1].argsort()]  # 保留此行是为了更好地组织数据，可选
            total_num = int(len(files) * select_ratio)
            
            index = list(range(len(files)))
            random.shuffle(index)

            selected_index = index[:total_num]
            unselected_index = index[total_num:]

            self.file_name_list = files[selected_index]
            self.unlabeled_list = files[unselected_index]

            # 保存采样结果
            ft = open(base_csv_path, 'w', newline='')
            ft_csv = csv.writer(ft)
            ft_csv.writerows(self.file_name_list)

            self.file_name_list = np.array(self.file_name_list)
            self.unlabeled_list = np.array(self.unlabeled_list)

        else:
            if select_data != []:
                # print(file_name_list)
                self.file_name_list  = np.concatenate((np.array(file_name_list), np.array(select_data)), axis=0)

            else:
                self.file_name_list = file_name_list
            new_pool = []
            # print(len(select_data))
            # print(len(unlabeled_list))
            o = 0
            for i in unlabeled_list:
                flag = True
                if o != len(select_data):
                    for j in select_data:
                        if i[0] == j[0] and i[1] == j[1]:
                            flag = False
                            o += 1
                            break
                if flag:
                    new_pool.append(i)
            self.unlabeled_list = np.array(new_pool)
            # print(len(self.pool_list))
        self.transform = transform
        # print(self.file_name_list)
        label_map = get_map(self.file_name_list)
        self.target_map = label_map

    def __len__(self):
        return len(self.file_name_list)

    def __getitem__(self, index):
        image_path = self.file_path + self.file_name_list[index, 0]
        img_name = self.file_name_list[index, 0]
        img_lable = self.file_name_list[index, 1]
        # img = Image.open('D:/Projects/Datasets/example/di.jp')g
        # from PIL import ImageFile
        # ImageFile.LOAD_TRUNCATED_IMAGES = True

        img = Image.open(image_path).convert('RGB')
        img = self.transform(img)
        # img = np.array(img).astype(np.float32)
        target = self.target_map[self.file_name_list[index, 1]]

        return img, target, img_name, img_lable


class TestDataset(Data.Dataset):
    """
    数据集
    """

    def __init__(self, file_path, file_name, transform=None):
        """

        :param file_path:数据集位置
        :param file_name:样本索引目录
        :param target_map:数据集种类标签和实际标签间的映射
        :param transform:图像变换与增强

        """
        self.file_path = file_path  # '/root/disk/Zz/Dataset_rcs_structed/'
        # if select_method:
        #     self.file_name_list = select_method.select(pd.read_csv(file_name, header=None))
        # else:
        file_name_list = pd.read_csv(file_name, header=None)
        # print(self.file_name_list)
        file_name_list_value=file_name_list.values
        for i in range(len(file_name_list_value)):
            file_name_list_value[i,1]=str(file_name_list_value[i,1])
        self.file_name_list = file_name_list_value
        # print(self.file_name_list)
        self.transform = transform
        label_map = get_map(self.file_name_list)
        self.target_map = label_map

    def __len__(self):
        return len(self.file_name_list)

    def __getitem__(self, index):
        image_path = self.file_path + self.file_name_list[index, 0]
        img_name = self.file_name_list[index, 0]
        img_lable = self.file_name_list[index, 1]
        # img = Image.open('D:/Projects/Datasets/example/di.jp')g
        # from PIL import ImageFile
        # ImageFile.LOAD_TRUNCATED_IMAGES = True

        img = Image.open(image_path).convert('RGB')
        img = self.transform(img)
        # img = np.array(img).astype(np.float32)
        target = self.target_map[self.file_name_list[index, 1]]

        return img, target, img_name, img_lable
class Pool(Data.Dataset):
    """
    数据集
    """

    def __init__(self,file_path, unlabeled_list, transform=None):
        self.file_path = file_path
        self.file_name_list = unlabeled_list
        self.transform = transform
        label_map = get_map(self.file_name_list)
        self.target_map = label_map

    def __len__(self):
        return len(self.file_name_list)

    def __getitem__(self, index):
        image_path = self.file_path + self.file_name_list[index, 0]
        img_name = self.file_name_list[index, 0]
        img_lable = self.file_name_list[index, 1]
        # img = Image.open('D:/Projects/Datasets/example/di.jp')g
        # from PIL import ImageFile
        # ImageFile.LOAD_TRUNCATED_IMAGES = True

        img = Image.open(image_path).convert('RGB')
        img = self.transform(img)
        # img = np.array(img).astype(np.float32)
        target = self.target_map[self.file_name_list[index, 1]]

        return img, target, img_name, img_lable