import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
from PIL import Image

 
def Im_name_labels():       
    with open('val.txt', 'r') as file:
        data = file.read().split('\n')
        im_names, labels = [], []
        for idx in range(len(data)):
            im_names.append(data[idx].split(' ')[0])
            labels.append(int(data[idx].split(' ')[1]))
    return im_names, labels


class sample_from_imagenet_val(torch.utils.data.Dataset):
    def __init__(self, data_dir, img_size=224):
        self.img_size = img_size
        self.data_transforms = transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(self.img_size),
                                transforms.ToTensor(),
        ])
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.aug_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomApply([color_jitter], p=0.8),
                                transforms.RandomGrayscale(p=0.2),
                                transforms.ToTensor()])
        
        self.im_path = []
        im_names, labels = Im_name_labels()
        self.labels = labels
        for name in im_names:
            path =  os.path.join(data_dir, name)
            self.im_path.append(path)
                     
    def __len__(self):
        return len(self.im_path)
       
    def __getitem__(self, idx):
        path = self.im_path[idx]
        image = Image.open(path)
        if np.array(image).shape[-1]!=3:
            image = image.convert('RGB')
        orig_im = self.data_transforms(image)
        aug_im = self.aug_transforms(image)
        return orig_im, aug_im, self.labels[idx]
    
    
    
    
class TargetSamplesImagenet(torch.utils.data.Dataset):
    def __init__(self, data_path, tar_lbl, img_size=224):
        im_names = os.listdir(data_path)
        self.labels = [tar_lbl]*len(im_names)
        self.image_paths = []
        for name in im_names:
            path =  os.path.join(data_path, name)
            self.image_paths.append(path)

        self.transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            ])        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        image = Image.open(path)
        if np.array(image).shape[-1]!=3:
            image = image.convert('RGB')
        image = self.transforms(image)
        return image, self.labels[idx]
    
    def get_path(self, idx):
        path = self.image_paths[idx]
        split_path = path.split('\\')[-1]
        return split_path
    

        
class CraftedTarSamples(torch.utils.data.Dataset):
    def __init__(self, data_dir, img_size=224):
        # path = os.path.join(data_dir, 'saved_adversarial_1300_imgs_target[23].pt')
        data = torch.load(data_dir)
        self.data = data
                     
    def __len__(self):
        return self.data.shape[0]
       
    def __getitem__(self, idx):
        return self.data[idx]
    



class SamplesFromImNames(torch.utils.data.Dataset):
    def __init__(self, data_path, im_names, img_size=224):
        # print('number of labels', len(im_names))
        self.image_paths = []
        for name in im_names:
            path =  os.path.join(data_path, name)
            # print(f'{data_path} \t {name} \t {path}')
            self.image_paths.append(path)
            # print(name)

        # print(self.image_paths)

        self.transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            ])        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        # print(f'{idx}: {path}')
        image = Image.open(path)
        if np.array(image).shape[-1]!=3:
            image = image.convert('RGB')
        image = self.transforms(image)
        return image
        



class SamplesFrom50000TrainData(torch.utils.data.Dataset):
    def __init__(self, data_path, im_infoList, img_size=224):
        self.img_size = img_size
        print('number of labels', len(im_infoList))
        self.image_paths = []
        self.labels = []
        for im_info in im_infoList:
            if 'n' in im_info:
                im_info = im_info.split(' ')
                # print(im_info)
                path =  os.path.join(data_path, im_info[0])
                self.image_paths.append(path)
                self.labels.append(int(im_info[-1]))

        self.transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            ])        
        
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.aug_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomApply([color_jitter], p=0.8),
                                transforms.RandomGrayscale(p=0.2),
                                transforms.ToTensor()])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        # print('kaka', idx, path, self.labels[idx])
        image = Image.open(path)
        if np.array(image).shape[-1]!=3:
            image = image.convert('RGB')
        orig_im = self.transforms(image)
        aug_im = self.aug_transforms(image)
        # print(orig_im.shape, aug_im.shape, type(orig_im), type(aug_im))
        return orig_im, aug_im, self.labels[idx]
    


class Painting50000Data(torch.utils.data.Dataset):
    def __init__(self, data_path, img_size=224):
        self.img_size = img_size
        im_names = os.listdir(data_path)
        self.image_paths = []
        self.labels = 9999
        for name in im_names:
            path =  os.path.join(data_path, name)
            self.image_paths.append(path)

        self.transforms = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(img_size),
            transforms.ToTensor(),
            ])        
        
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.aug_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.img_size),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomApply([color_jitter], p=0.8),
                                transforms.RandomGrayscale(p=0.2),
                                transforms.ToTensor()])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        # print('kaka', idx, path, self.labels[idx])
        image = Image.open(path)
        if np.array(image).shape[-1]!=3:
            image = image.convert('RGB')
        orig_im = self.transforms(image)
        aug_im = self.aug_transforms(image)
        # print(orig_im.shape, aug_im.shape, type(orig_im), type(aug_im))
        return orig_im, aug_im, self.labels