'''Modified from https://github.com/alinlab/LfF/blob/master/data/util.py'''

import os
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset, Subset
from torchvision import transforms as T
from glob import glob
import numpy as np
import random
from PIL import Image
from PIL import PngImagePlugin
from PIL import ImageFilter
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)


class IdxDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.length = len(self.dataset)
        
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx < len(self.dataset):
            return (idx, *self.dataset[idx])
        else:
            
            return (idx, *self.openset[idx - len(self.dataset)])

    def subsampling(self, index):
        self.dataset.subsampling(index)
        self.length = len(self.dataset)

    def add_openset(self, dataset):
        self.openset = dataset
        self.length += len(self.openset) 
        print("Openset added...")

    def get_labels(self):
        ret = torch.zeros(self.length).long()
        for idx in range(self.length):
            if idx < len(self.dataset):
                ret[idx] = int(self.dataset.data[idx].split('_')[-2])
            else:
                ret[idx] = int(self.openset.labels[idx - len(self.dataset)])
        return ret

class ZippedDataset(Dataset):
    def __init__(self, datasets):
        super(ZippedDataset, self).__init__()
        self.dataset_sizes = [len(d) for d in datasets]
        self.datasets = datasets

    def __len__(self):
        return max(self.dataset_sizes)

    def __getitem__(self, idx):
        items = []
        for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
            items.append(self.datasets[dataset_idx][idx % dataset_size])

        item = [torch.stack(tensors, dim=0) for tensors in zip(*items)]

        return item

class CMNISTDataset(Dataset):
    def __init__(self,root,split,transform=None, image_path_list=None):
        super(CMNISTDataset, self).__init__()
        self.transform = transform
        self.given_transform = self.transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list
        self.split = split
        self.num_classes=10

        if split=='train':
            self.align = glob(os.path.join(root, 'align',"*","*"))
            self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
            self.data = self.align + self.conflict
            
        elif split=='valid':
            self.data = glob(os.path.join(root,split,"*"))            
        elif split=='test':
            self.data = glob(os.path.join(root, '../test',"*","*"))
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        bias_label = int(self.data[index].split('_')[-1].split('.')[0])
        target_label = int(self.data[index].split('_')[-2])
        attr = torch.LongTensor([target_label, bias_label])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)
        
        return image, attr, index, self.data[index]

    def get_open_img(self, label):
        pos = torch.where(self.open_labels == label)
        pos_perm = pos[torch.randperm(len(pos))]
        return Image.open(self.open[pos_perm[0]])




class CIFAR10Dataset(Dataset):
    def __init__(self, root,  split, transform=None, image_path_list=None):
        super(CIFAR10Dataset, self).__init__()
        self.transform = transform
        self.given_transform = self.transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list
        self.split = split
        self.num_classes=10

        if split=='train':
            self.align = glob(os.path.join(root, 'align',"*","*"))
            self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
            self.data = self.align + self.conflict
            
        elif split=='valid':
            self.data = glob(os.path.join(root,split,"*", "*"))

        elif split=='test':
            self.data = glob(os.path.join(root, '../test',"*","*"))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        bias_label = int(self.data[index].split('_')[-1].split('.')[0])
        target_label = int(self.data[index].split('_')[-2])
        attr = torch.LongTensor([target_label, bias_label])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, index, self.data[index]

    

class bFFHQDataset(Dataset):
    def __init__(self, root,  split, transform=None, image_path_list=None):
        super(bFFHQDataset, self).__init__()
        self.transform = transform
        self.given_transform = self.transform
        self.root = root

        self.image2pseudo = {}
        self.image_path_list = image_path_list
        self.split = split
        self.num_classes=2

        if split=='train':
            self.align = glob(os.path.join(root, 'align',"*","*"))
            self.conflict = glob(os.path.join(root, 'conflict',"*","*")) 
            self.data = self.align + self.conflict

        elif split=='valid':
            self.data = glob(os.path.join(os.path.dirname(root), split, "*"))

        elif split=='test':
            self.data = glob(os.path.join(os.path.dirname(root), split, "*"))
            data_conflict = []
            for path in self.data:
                target_label = path.split('/')[-1].split('.')[0].split('_')[1]
                bias_label = path.split('/')[-1].split('.')[0].split('_')[2]
                if target_label != bias_label:
                    data_conflict.append(path)
            self.data = data_conflict

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # bias_label = int(self.data[index].split('_')[-1].split('.')[0])
        # target_label = int(self.data[index].split('_')[-2])
        bias_label = int(self.data[index].split('/')[-1].split('.')[0].split('_')[2])
        target_label = int(self.data[index].split('/')[-1].split('.')[0].split('_')[1])
        attr = torch.LongTensor([target_label, bias_label])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, index, self.data[index]


    def subsampling(self, index):
        self.data = [self.data[idx] for idx in index]
        

class BARDataset(Dataset):
    def __init__(self, root,  split, transform=None, percent=None, image_path_list=None):
        super(BARDataset, self).__init__()
        self.transform = transform
        self.given_transform = self.transform
        self.percent = percent
        self.split = split
        self.image2pseudo = {}
        self.image_path_list = image_path_list
        self.split = split
        self.num_classes=6

        if self.split=='train':
            self.train_align = glob(os.path.join(root,'train/align',"*/*"))
            self.train_conflict = glob(os.path.join(root,'train/conflict',f"{self.percent}/*/*"))
            self.data = self.train_align + self.train_conflict

        elif self.split=='valid':
            self.data = glob(os.path.join(root,'valid',"*/*"))
            
        elif self.split=='test':
            self.data = glob(os.path.join(root,'test',"*/*"))
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        bias_label = int(self.data[index].split('_')[-1].split('.')[0])
        target_label = int(self.data[index].split('_')[-2])
        attr = torch.LongTensor([target_label, bias_label])
        image = Image.open(self.data[index]).convert('RGB')
        image_path = self.data[index]

        if 'bar/train/conflict' in image_path and bias_label != -1:
            attr[1] = (attr[0] + 1) % 6
        elif 'bar/train/align' in image_path:
            attr[1] = attr[0]

        if self.transform is not None:
            image = self.transform(image)  

        return image, attr, index, self.data[index]



class DogCatDataset(Dataset):
    def __init__(self, root,  split, transform=None, image_path_list=None):
        super(DogCatDataset, self).__init__()
        self.transform = transform
        self.given_transform = self.transform
        self.root = root
        self.image_path_list = image_path_list
        self.split = split
        self.num_classes=2
        
        if split == "train":
            self.align = glob(os.path.join(root, "align", "*", "*"))
            self.conflict = glob(os.path.join(root, "conflict", "*", "*"))
            self.data = self.align + self.conflict
        elif split == "valid":
            self.data = glob(os.path.join(root, split, "*"))
        elif split == "test":
            self.data = glob(os.path.join(root, "../test", "*", "*"))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        
        bias_label = int(self.data[index].split('_')[-1].split('.')[0])
        target_label = int(self.data[index].split('_')[-2])
        attr = torch.LongTensor([target_label, bias_label])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)  

        return image, attr, index, self.data[index]


transforms = {
    "cmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "cifar10c": {
        "train": T.Compose([T.ToTensor(),]),
        "valid": T.Compose([T.ToTensor(),]),
        "test": T.Compose([T.ToTensor(),]),
        },
    "bar": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()])
    },
    "bffhq": {
        "train": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224,224)), T.ToTensor()])
        },
    "dogs_and_cats": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
    },

    "cmnist_open": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "cifar10c_open": {
        "train": T.Compose([T.ToTensor(),]),
        "valid": T.Compose([T.ToTensor(),]),
        "test": T.Compose([T.ToTensor(),]),
        },
    "bar_open": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()])
    },
    "bffhq_open": {
        "train": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224,224)), T.ToTensor()])
        },
    "dogs_and_cats_open": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
    },
    }


transforms_preprcs = {
    "cmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "bar": {
        "train": T.Compose([
            T.Resize((224, 224)),
            T.RandomCrop(224, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        ),
        "valid": T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        ),
        "test": T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        )
    },
    "bffhq": {
        "train": T.Compose([
            T.Resize((224,224)),
            T.RandomCrop(224, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "valid": T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "test": T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        )
        },
    "cifar10c": {
        "train": T.Compose(
            [
                T.RandomCrop(32, padding=4),
                # T.RandomResizedCrop(32),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "valid": T.Compose(
            [
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "test": T.Compose(
            [
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
    },
    "dogs_and_cats": {
            "train": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.RandomCrop(224, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "valid": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "test": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
        },
}


def get_dataset(dataset, data_dir, dataset_split, transform_split, percent, use_preprocess=None, image_path_list=None):

    dataset_category = dataset.split("-")[0]
    if use_preprocess:
        transform = transforms_preprcs[dataset_category][transform_split]
    else:
        transform = transforms[dataset_category][transform_split]

    dataset_split = "valid" if (dataset_split == "eval") else dataset_split

    if dataset == 'cmnist':
        root = data_dir + f"/cmnist/{percent}"
        dataset = CMNISTDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list)

    elif dataset == 'cifar10c':
        root = data_dir + f"/cifar10c/{percent}"
        dataset = CIFAR10Dataset(root=root,split=dataset_split, transform=transform, image_path_list=image_path_list)

    elif dataset == "bffhq":
        root = data_dir + f"/bffhq/{percent}"
        dataset = bFFHQDataset(root=root,split=dataset_split, transform=transform, image_path_list=image_path_list)

    elif dataset == "bar":
        root = data_dir + f"/bar"
        dataset = BARDataset(root=root,split=dataset_split, transform=transform, percent=percent, image_path_list=image_path_list)

    elif dataset == "dogs_and_cats":
        root = data_dir + f"/dogs_and_cats/{percent}"
        dataset = DogCatDataset(root=root,split=dataset_split, transform=transform, image_path_list=image_path_list)

    else:
        print('wrong dataset ...')
        import sys
        sys.exit(0)

    return dataset



class opendataset(Dataset):
    def __init__(self,root,transform, num_classes):
        super(opendataset, self).__init__()
        self.transform = transform

        self.root = root
        self.num_classes=num_classes
        self.data = glob(os.path.join(root, '*', '*'))
        if len(self.data) == 0:
            self.data = glob(os.path.join(root, '*'))
            
        
    def __len__(self):
        return len(self.data) 
        
    def __getitem__(self, index):
        image = Image.open(self.data[index]).convert('RGB')
        image1 = self.transform(image)
        image2 = self.transform(image)
        return image1, image2, index

    def sampling(self, indices):
        self.data = [self.data[idx] for idx in indices]
        
    def print(self, pos):
        data =[self.data[idx] for idx in pos]
        print(' ')
        for idx in range(30):
            print(data[idx])
        


    def add_images(self, data):
        self.data += data
        # self.transform = T.Compose([
        #     T.RandomResizedCrop(224),
        #     T.RandomHorizontalFlip(),
        #     T.RandomApply([
        #         T.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        #     ], p=0.8),
        #     T.RandomGrayscale(p=0.2),
        #     T.ToTensor(),
        #     T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        # ])


        self.transform = T.Compose([
                            T.RandomResizedCrop(224, scale=(0.2, 1.)),
                            T.RandomApply([
                                T.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                            ], p=0.8),
                            T.RandomGrayscale(p=0.2),
                            T.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                            T.RandomHorizontalFlip(),
                            T.ToTensor(),
                            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])


def get_opendataset(dataset, data_dir, open_type, num_classes):
    dataset_category = dataset.split("-")[0]
    transform = transforms_preprcs[dataset_category]['train']
    if 'caltech' in open_type:
        root = data_dir + f'open/224_'
        root += f'{open_type}'
    elif 'cifar10c' in dataset:
        root = data_dir + f'open/32_'
        root += f'{open_type}'
    elif 'imgnet' in open_type:
        root = data_dir + f'open/224_imgnet'
    elif 'webvision' in open_type:
        root = data_dir + f'open/224_webvision'

    dataset = opendataset(root, transform, num_classes)
    return dataset




class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x