# Copyright 2021 Zhongyang Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import os.path as op
import numpy as np
import pickle as pkl
import torch.utils.data as data
import os
from torchvision import transforms
import os
import csv
import torch
from torch.utils import data
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms import RandomApply
from torch.nn import Sequential
import csv
from PIL import Image
import h5py
from torchvision.transforms import RandomApply
from torchvision.transforms import RandAugment
import random

def min_max_normalize(data):
    min_val = np.min(data)
    max_val = np.max(data)
    normalized_data = (data - min_val) / (max_val - min_val)
    return normalized_data
class Random90Rotation:
    def __init__(self):
        pass
    def __call__(self, img):
        angle = random.choice([-90, 90]) # 选择 -90 或 90 度
        return T.functional.rotate(img, angle)

def z_score_normalize(data):
    mean_val = np.mean(data)
    std_val = np.std(data)
    normalized_data = (data - mean_val) / std_val
    return normalized_data


def softmax(x):
    exp_x = np.exp(x)
    sum_exp_x = np.sum(exp_x)
    softmax_vals = exp_x / sum_exp_x
    return softmax_vals


class StandardData(data.Dataset):
    def __init__(self, csv_file, train=True, image_size=64):
        self.csv_file = csv_file
        self.train = train
        # Load CSV file
        self.paths = []
        self.labels = []
        with open(self.csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.paths.append(row['path'])
                self.labels.append(int(row['label']))

        # Define augmentation for two different views
        color_jitter = transforms.ColorJitter(0.5, 0.5, 0.5, 0.2)
        self.DEFAULT_AUG = T.Compose([
        T.RandomApply([color_jitter], p=0.8), #p=0.8
        T.RandomGrayscale(p=0.2),
        # T.RandomVerticalFlip(p=0.5),
        T.RandomHorizontalFlip(p=0.5),
        # T.RandomApply([Random90Rotation()], p=0.5), # 使用自定义的 Random90Rotation
        RandomApply(
            [T.GaussianBlur((3, 3), (1.0, 2.0))],
            p=0.2
        ),
        T.RandomResizedCrop(size=image_size, scale=(0.3, 1.0), ratio=(1.0, 1.0)),
        T.ToTensor()
        ])
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        x1 = self.DEFAULT_AUG(img)
        x2 = self.DEFAULT_AUG(img)

        return x1, x2, label



class ValStandardData(data.Dataset):
    def __init__(self, csv_file, train=True, image_size=64):
        self.csv_file = csv_file
        self.train = train
        self.image_size = image_size
        # Load CSV file
        self.paths = []
        self.labels = []
        with open(self.csv_file, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                self.paths.append(row['path'])
                self.labels.append(int(row['label']))

        # Define augmentation for two different views
        self.DEFAULT_AUG = T.Compose([
        T.Resize((self.image_size, self.image_size)),
        T.ToTensor()
    ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.paths[idx]
        img = Image.open(img_path).convert("RGB")
        
        # Get label
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        x = self.DEFAULT_AUG(img)

        return x,x,label




