import os, glob
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
# import torch


class TaiBase(Dataset):
    def __init__(self,
                 data_root,
                 size=None,
                 ):
        self.data_root = data_root
        # 该路径下必须要包含子目录！！
        self.image_paths = glob.glob(os.path.join(self.data_root, "**", "*.jpg"))
        self._length = len(self.image_paths)
        self.labels = {
            "relative_file_path_": [l for l in self.image_paths],
            "file_path_": [os.path.join(self.data_root, l)
                           for l in self.image_paths],
        }

        self.size = size

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = dict((k, self.labels[k][i]) for k in self.labels)
        image = Image.open(example["file_path_"])
        if not image.mode == "RGB":
            image = image.convert("RGB")

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)
        crop = min(img.shape[0], img.shape[1])
        h, w, = img.shape[0], img.shape[1]
        img = img[(h - crop) // 2:(h + crop) // 2,
              (w - crop) // 2:(w + crop) // 2]

        image = Image.fromarray(img)
        if self.size is not None:
            image = image.resize((self.size, self.size))  # , resample=self.interpolation)

        # image = self.flip(image)
        image = np.array(image).astype(np.uint8)
        example["image"] = (image / 127.5 - 1.0).astype(np.float32)
        # print("example[file_path_]: ", example["file_path_"])
        # 根据图像名字中的方案号来匹配excel中的力学性能数据
        sum_data_csv="~/data/dustbin/class/sum1126.csv"
        csv_filename = sum_data_csv
        tag_csv1 = open(csv_filename, 'r').readlines()
        tag_dict1 = ['id', 'p1_temp', 'p1_t','p1_cw','p2_temp','p2_t','p2_cw','qfqd','klqd', 'dhscl', 'cjrx', 'lsyd', 'dtysqd', 'ljdlyb']

        plan_id=example["file_path_"].split('/')[-1].split('_')[0]  # 按_分割,得到方案
        # print(plan_id)
        continuous_labels=[]
        for l in tag_csv1[1:]:  # 跳过第一行(表头)
            l = l.strip('\n')  # 去掉换行符\n
            tags = l.split(',')  # 分割

            if plan_id==tags[0] or plan_id==tags[0]+"-样似乎有问题":
                if tags[7]=="" or tags[1]=="":
                    break
                # 添加扰动
                # 生成均值为0，标准差为error1的正态分布（高斯分布）列表，样本点数为size
                error1 = 50
                y_noise1 = error1 * np.random.normal()  # size=np.float64(tags[2]).size)
                y1=np.float64(tags[7]) + y_noise1
                continuous_labels.append(y1)
                error2 = 50
                y_noise2 = error2 * np.random.normal()  # size=np.float64(tags[3]).size)
                y2=np.float64(tags[8]) + y_noise2
                continuous_labels.append(y2)
                error3 = 1.5
                y_noise3 = error3 * np.random.normal()  # size=np.float64(tags[4]).size)
                y3=np.float64(tags[9]) + y_noise3
                continuous_labels.append(y3)
                error4 = 3
                y_noise4 = error4 * np.random.normal()  # size=np.float64(tags[5]).size)
                y4=np.float64(tags[10]) + y_noise4
                continuous_labels.append(y4)
                error5 = 2
                y_noise5 = error5 * np.random.normal()  # size=np.float64(tags[6]).size)
                y5=np.float64(tags[11]) + y_noise5
                continuous_labels.append(y5)
                error6 = 70
                y_noise6 = error6 * np.random.normal()  # size=np.float64(tags[7]).size)
                y6=np.float64(tags[12]) + y_noise6
                continuous_labels.append(y6)
                error7 = 0.02
                y_noise7 = error7 * np.random.normal()  # size=np.float64(tags[8]).size)
                y7=np.float64(tags[13]) + y_noise7
                continuous_labels.append(y7)
                break
        # print("---------------------continuous_labels: ", continuous_labels)
        if continuous_labels==[]:
            print("example[file_path_]: ", example["file_path_"])
        example["y1"]=continuous_labels[0]
        example["y2"]=continuous_labels[1]
        example["y3"]=continuous_labels[2]
        example["y4"]=continuous_labels[3]
        example["y5"]=continuous_labels[4]
        example["y6"]=continuous_labels[5]
        example["y7"]=continuous_labels[6]

        # mean: [946.0938, 1042.3915, 12.4565, 34.2724, 38.5077, 1417.1494, 0.1588]
        # std: [92.6953, 93.7420, 4.1877, 12.6809, 5.6539, 10.0590, 0.0529]
        # max: [1200, 1400, 20, 60, 60, 1800, 0.3]
        # min: [500, 800, 1.5, 4, 25, 1200, 0.04]
        return example


class TaiTrain(TaiBase):
    def __init__(self, data_root="~/tai/sdcopy/unet/unet_om/om", **kwargs):
        super().__init__(data_root=data_root, **kwargs)


class TaiValidation(TaiBase):
    def __init__(self, data_root="~/tai/sdcopy/unet/unet_om/om", **kwargs):
        super().__init__(data_root=data_root, **kwargs)



class TaiDualBase(Dataset):
    def __init__(self,
                 image_root,
                 cond_image_root,
                 size=None,
                 ):
        self.image_root = image_root
        self.cond_image_root = cond_image_root
        self.size = size

        self.image_paths = sorted(glob.glob(os.path.join(self.image_root, "**", "*.jpg"), recursive=True))
        self.cond_image_paths = sorted(glob.glob(os.path.join(self.cond_image_root, "**", "*.jpg"), recursive=True))
        assert len(self.image_paths) == len(self.cond_image_paths), "原图和条件图数量不一致！"

        self._length = len(self.image_paths)

        # 标签文件（照搬你之前的）
        self.sum_data_csv = "~/data/dustbin/class/sum1126.csv"
        self.tag_csv_lines = open(self.sum_data_csv, 'r').readlines()

        self.tag_dict1 = ['id', 'p1_temp', 'p1_t', 'p1_cw', 'p2_temp', 'p2_t', 'p2_cw', 'qfqd', 'klqd', 'dhscl', 'cjrx', 'lsyd', 'dtysqd', 'ljdlyb']

        self.transform = transforms.Compose([
            transforms.Resize((self.size, self.size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
        self.clip_transform = transforms.Compose([
            transforms.Resize((self.size, self.size)),  # 也可以是 self.size
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4815, 0.4578, 0.4082],
                         std=[0.2686, 0.2613, 0.2758])
        ])

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        cond_path = self.cond_image_paths[idx]
        # image = Image.open(img_path)
        # if not image.mode == "RGB":
        #     image = image.convert("RGB")

        # # default to score-sde preprocessing
        # img = np.array(image).astype(np.uint8)
        # crop = min(img.shape[0], img.shape[1])
        # h, w, = img.shape[0], img.shape[1]
        # img = img[(h - crop) // 2:(h + crop) // 2,
        #       (w - crop) // 2:(w + crop) // 2]

        # image = Image.fromarray(img)
        # if self.size is not None:
        #     image = image.resize((self.size, self.size))  # , resample=self.interpolation)

        # # image = self.flip(image)
        # image = np.array(image).astype(np.uint8)
        # image = (image / 127.5 - 1.0).astype(np.float32)

        # cond = Image.open(cond_path)
        # if not cond.mode == "RGB":
        #     cond = cond.convert("RGB")

        # # default to score-sde preprocessing
        # con = np.array(cond).astype(np.uint8)
        # crop = min(con.shape[0], con.shape[1])
        # h, w, = con.shape[0], con.shape[1]
        # con = con[(h - crop) // 2:(h + crop) // 2,
        #       (w - crop) // 2:(w + crop) // 2]

        # cond = Image.fromarray(con)
        # if self.size is not None:
        #     cond = cond.resize((self.size, self.size))  # , resample=self.interpolation)

        # image = self.flip(image)
        # cond = np.array(cond).astype(np.uint8)
        # cond = (cond / 127.5 - 1.0).astype(np.float32)
        img = Image.open(img_path).convert("RGB")

        cond = Image.open(cond_path).convert("RGB")
        
        img = self.transform(img)
        img = img.permute(1, 2, 0)
        # print("after transform:", img.shape)
        # img = np.array(img)
        # img = img.transpose(2, 0, 1)  # np array HWC → CHW
        # # img = torch.from_numpy(img).float() / 127.5 - 1.0
        # cond = np.array(cond)
        cond = self.clip_transform(cond)
        cond = cond.permute(1, 2, 0)
        # print("after transform:", cond.shape)
        # cond = cond.transpose(2, 0, 1)  # np array HWC → CHW
        # # cond = torch.from_numpy(cond).float() / 127.5 - 1.0

        plan_id = os.path.basename(img_path).split('_')[0]
        continuous_labels = []
        for l in self.tag_csv_lines[1:]:  # 跳过表头
            tags = l.strip().split(',')
            if plan_id == tags[0] or plan_id == tags[0] + "-样似乎有问题":
                if tags[7] == "" or tags[1] == "":
                    break
                try:
                    continuous_labels = [
                        float(tags[7]) + 50 * np.random.normal(),
                        float(tags[8]) + 50 * np.random.normal(),
                        float(tags[9]) + 1.5 * np.random.normal(),
                        float(tags[10]) + 3 * np.random.normal(),
                        float(tags[11]) + 2 * np.random.normal(),
                        float(tags[12]) + 70 * np.random.normal(),
                        float(tags[13]) + 0.02 * np.random.normal(),
                    ]
                except Exception as e:
                    print(f"[标签读取错误] {plan_id}: {e}")
                break

        if len(continuous_labels) != 7:
            print(f"[标签缺失] {img_path}")

        example = {
            "image": img,                         # 原始图像
            "image_condition": cond,              # 条件图像
            # "7_continuous_label": torch.tensor(continuous_labels, dtype=torch.float32),  # 如果你还保留连续标签通路
            "file_path_": img_path,
        }
        return example

class TaiDualTrain(TaiDualBase):
    def __init__(self, image_root, cond_image_root, size):
        super().__init__(image_root=image_root, cond_image_root=cond_image_root, size=size)


class TaiDualValidation(TaiDualBase):
    def __init__(self, image_root, cond_image_root, size):
        super().__init__(image_root=image_root, cond_image_root=cond_image_root, size=size)
