import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets.folder import *
from typing import *
import os, glob
import PIL
import json
from PIL import Image
from torchvision import transforms
from collections import OrderedDict

import cv2
import matplotlib.pyplot as plt


def CLAHE(image):
    image = np.array(image).astype(np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    result = clahe.apply(image)

    # plt.figure(figsize=(10, 5))

    # plt.subplot(1, 2, 1)
    # plt.imshow(image, cmap='gray')
    # plt.title('Original Image')
    # plt.axis('off')

    # plt.subplot(1, 2, 2)
    # plt.imshow(result, cmap='gray')
    # plt.title('CLAHE Enhanced Image')
    # plt.axis('off')

    # plt.savefig('debug.png', bbox_inches='tight')
    # plt.close()
    return result
def transform_image(img, sigma=0.2, gamma_std=0.2):
    # img: np.ndarray, 形状任意，值域[-1, 1]
    import numpy as np

    low = np.quantile(img, 0.0)
    high = np.quantile(img, 1.0)

    # 生成随机噪声调整区间边界
    noise1 = np.random.randn() * sigma
    noise2 = np.random.randn() * sigma
    low_new = np.clip((low + 1) / (1 + noise1) - 1, -1, 1)
    high_new = np.clip((high + 1) / (1 + noise2) - 1, -1, 1)

    # 确保新区间有效性
    if low_new > high_new:
        low_new, high_new = high_new, low_new

    # 生成gamma参数，确保其为正
    gamma = np.random.randn() * gamma_std + 1.0
    gamma = np.clip(gamma, 0.1, None)

    # 初始化变换后的图像
    transformed_img = np.zeros_like(img)

    # 处理区间内的像素
    mask = (img >= low) & (img <= high)
    valid_pixels = img[mask]
    if low != high:
        # 归一化到[0,1]区间
        x_norm = (valid_pixels - low) / (high - low)
        # 应用幂变换
        x_trans = np.power(x_norm, gamma)
        # 映射到新区间
        transformed_img[mask] = x_trans * (high_new - low_new) + low_new
    else:
        # 若low等于high，所有有效像素设为low_new和high_new的平均值
        transformed_img[mask] = (low_new + high_new) / 2

    # 处理区间外的像素（截断到新区间边界）
    transformed_img[img < low] = low_new
    transformed_img[img > high] = high_new

    return transformed_img

class MetricDataset(Dataset):
    def __init__(
        self,
        data_root: str,
        split: str = 'train',
        size: int = 256,
        image_transform: bool = False,
    ):
        super(MetricDataset, self).__init__()
        self.image_transform = image_transform
        self.data_root = data_root
        self.split = split
        self.size = size

        self.ann_path = os.path.join(data_root, 'ann_{}.json'.format(split))
        with open(self.ann_path, 'r', encoding='utf-8') as rf:
            anns = OrderedDict(json.load(rf))
        self.anns = []
        for ky, v in anns.items():
            self.anns.append((ky, v))
        self._length = len(self.anns)

        self.cond_max = np.array([1200, 1400, 20, 60, 60, 1800, 0.3])
        self.cond_min = np.array([500, 800, 1.5, 4, 25, 1200, 0.04])

        # self.image_means = np.array([0.485, 0.456, 0.406])
        # self.image_stds = np.array([0.229, 0.224, 0.225])

        if 'om' in self.data_root:
            self.image_means = np.array([0.577, 0.577, 0.577])
            self.image_stds = np.array([0.241, 0.241, 0.241])
        elif 'sem' in self.data_root:
            self.image_means = np.array([0.523, 0.523, 0.523])
            self.image_stds = np.array([0.195, 0.195, 0.195])
        
        if 'om' in self.data_root:
            self.cond_means = np.array([946.2, 1042.3, 12.47, 34.27, 38.53, 1417, 0.1588])
            self.cond_stds = np.array([77.97, 79.29, 3.914, 12.32, 5.298, 72.65, 0.04888])
        elif 'sem' in self.data_root:
            self.cond_means = np.array([938.9, 1034, 12.36, 34.12, 38.74, 1417, 0.1584])
            self.cond_stds = np.array([95.39, 84.13, 4.105, 11.06, 5.306, 73.36, 0.4698])
    
    def __len__(self):
        return self._length
    
    def __getitem__(self, i):
        image_path = os.path.join(
            self.data_root,
            self.split,
            '{}.jpg'.format(self.anns[i][0])
        )
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        
        # cropping
        img = np.array(image).astype(np.uint8)
        crop = min(img.shape[0], img.shape[1])
        h, w, = img.shape[0], img.shape[1]
        img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2]
        image = Image.fromarray(img)

        # resize
        if self.size is not None:
            image = image.resize((self.size, self.size))  # , resample=self.interpolation)
        
        # adjust intensity distribution
        image = CLAHE(image)
        image = np.stack([image, image, image], axis=-1)
        # # normalize image

        # normalize image
        if self.image_transform:
            image = np.transpose(image, [2, 0, 1])
            image = image.astype(np.float32) / 255 - 1.0
            image = transform_image(image, sigma=0.1, gamma_std=0.05) +1.0
            image = (image - self.image_means[:, None, None]) / self.image_stds[:, None, None]
        else:
            image = image.astype(np.float32) / 255
            image = (image - self.image_means) / self.image_stds
            image = np.transpose(image, [2, 0, 1])
        # normalize conditions
        cond = np.array(self.anns[i][1])
        #cond = (cond - self.cond_means) / self.cond_stds
        cond = (cond - self.cond_min) / (self.cond_max - self.cond_min)
        cond = cond + np.random.normal(loc=0, scale=0.05, size=cond.shape)
        cond = np.clip(cond, a_min=0, a_max=1)

        return image, cond


if __name__ == '__main__':
    dataset = MetricDataset(
        data_root='~/Project/pytorch/new_data/metric_om',
        split='train',
        size=256
    )
    dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
    for i, data in enumerate(dataloader):
        print(data)
        break