# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
# The MIT License (MIT)
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details

# src/data_util.py

import os
import random

from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.datasets import ImageFolder
from torchvision.transforms import InterpolationMode
from scipy import io
from PIL import ImageOps, Image
import torch
import torchvision.transforms as transforms
import h5py as h5
import numpy as np


resizer_collection = {"nearest": InterpolationMode.NEAREST,
                      "box": InterpolationMode.BOX,
                      "bilinear": InterpolationMode.BILINEAR,
                      "hamming": InterpolationMode.HAMMING,
                      "bicubic": InterpolationMode.BICUBIC,
                      "lanczos": InterpolationMode.LANCZOS}

class RandomCropLongEdge(object):
    """
    this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
    MIT License
    Copyright (c) 2019 Andy Brock
    """
    def __call__(self, img):
        size = (min(img.size), min(img.size))
        # Only step forward along this edge if it's the long edge
        i = (0 if size[0] == img.size[0] else np.random.randint(low=0, high=img.size[0] - size[0]))
        j = (0 if size[1] == img.size[1] else np.random.randint(low=0, high=img.size[1] - size[1]))
        return transforms.functional.crop(img, j, i, size[0], size[1])

    def __repr__(self):
        return self.__class__.__name__


class CenterCropLongEdge(object):
    """
    this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch
    MIT License
    Copyright (c) 2019 Andy Brock
    """
    def __call__(self, img):
        return transforms.functional.center_crop(img, min(img.size))

    def __repr__(self):
        return self.__class__.__name__


class Dataset_(Dataset):
    def __init__(self,
                 data_name,
                 data_dir,
                 train,
                 crop_long_edge=False,
                 resize_size=None,
                 resizer="lanczos",
                 random_flip=False,
                 normalize=True,
                 hdf5_path=None,
                 load_data_in_memory=False, 
                 return_attr=False, 
                 loss_weights_path=None):
        super(Dataset_, self).__init__()
        self.data_name = data_name
        self.data_dir = data_dir
        self.train = train
        self.random_flip = random_flip
        self.normalize = normalize
        self.hdf5_path = hdf5_path
        self.load_data_in_memory = load_data_in_memory
        self.return_attr = return_attr
        self.loss_weights_path = loss_weights_path
        self.trsf_list = []

        if self.hdf5_path is None:
            if crop_long_edge:
                self.trsf_list += [CenterCropLongEdge()]
            if resize_size is not None and resizer != "wo_resize":
                self.trsf_list += [transforms.Resize(resize_size, interpolation=resizer_collection[resizer])]
        else:
            self.trsf_list += [transforms.ToPILImage()]

        if self.random_flip:
            self.trsf_list += [transforms.RandomHorizontalFlip()]

        if self.normalize:
            self.trsf_list += [transforms.ToTensor()]
            self.trsf_list += [transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
        else:
            self.trsf_list += [transforms.PILToTensor()]

        self.trsf = transforms.Compose(self.trsf_list)

        self.load_dataset()

    def load_dataset(self):
        if self.hdf5_path is not None:
            with h5.File(self.hdf5_path, "r") as f:
                data, labels = f["imgs"], f["labels"]
                self.num_dataset = data.shape[0]
                if self.load_data_in_memory:
                    print("Load {path} into memory.".format(path=self.hdf5_path))
                    self.data = data[:]
                    self.labels = labels[:]
            return

        if self.data_name == "CIFAR10":
            self.data = CIFAR10(root=self.data_dir, train=self.train, download=True)

        elif self.data_name == "CIFAR100":
            self.data = CIFAR100(root=self.data_dir, train=self.train, download=True)
        elif self.data_name == "Waterbirds":
            self.data = CUBDataset(root=self.data_dir, train=self.train, return_attr=self.return_attr, loss_weights_path=self.loss_weights_path)
        else:
            mode = "train" if self.train == True else "valid"
            root = os.path.join(self.data_dir, mode)
            self.data = ImageFolder(root=root)

    def _get_hdf5(self, index):
        with h5.File(self.hdf5_path, "r") as f:
            return f["imgs"][index], f["labels"][index]

    def __len__(self):
        if self.hdf5_path is None:
            num_dataset = len(self.data)
        else:
            num_dataset = self.num_dataset
        return num_dataset

    def __getitem__(self, index):
        if self.return_attr:
            if self.hdf5_path is None:
                img, label, attr = self.data[index]
            else:
                if self.load_data_in_memory:
                    img, label, attr = self.data[index], self.labels[index], self.attrs[index]
                else:
                    img, label, attr = self._get_hdf5(index)
            return self.trsf(img), int(label), int(attr)
        elif self.loss_weights_path is not None:
            if self.hdf5_path is None:
                img, label, weight = self.data[index]
            else:
                if self.load_data_in_memory:
                    img, label, weight = self.data[index], self.labels[index], self.weights[index]
                else:
                    img, label, weight = self._get_hdf5(index)
            return self.trsf(img), int(label), weight
        else:
            if self.hdf5_path is None:
                img, label = self.data[index]
            else:
                if self.load_data_in_memory:
                    img, label = self.data[index], self.labels[index]
                else:
                    img, label = self._get_hdf5(index)
            return self.trsf(img), int(label)

    
import pandas as pd
    
class CUBDataset(Dataset):
    """
    CUB dataset (already cropped and centered).
    NOTE: metadata_df is one-indexed.
    """
    def __init__(
        self,
        root,
        train, 
        return_attr, 
        # transform, 
        loss_weights_path=None, 
        metadata_csv_name="metadata.csv"
    ):
        self.root = root
        self.return_attr = return_attr
        # self.transform = transform
        split = "train" if train == True else "val"
        
        if not os.path.exists(self.root):
            raise ValueError(
                f"{self.root} does not exist yet. Please generate the dataset first."
            )
            
        # Read in metadata
        print(f"Reading '{os.path.join(self.root, metadata_csv_name)}'")
        self.metadata_df = pd.read_csv(
            os.path.join(self.root, metadata_csv_name))

        # Get the y values
        self.y_array = self.metadata_df["y"].values

        # We only support one confounder for CUB for now
        self.confounder_array = self.metadata_df["place"].values
        
        # Extract filenames and splits
        self.filename_array = self.metadata_df["img_filename"].values
        self.split_array = self.metadata_df["split"].values
        self.split_dict = {
            "train": 0,
            "val": 1,
            "test": 2,
        }
        
        # split
        assert split in ("train", "val",
                         "test"), f"{split} is not a valid split"
        mask = self.split_array == self.split_dict[split]

        num_split = np.sum(mask)
        indices = np.where(mask)[0]
        
        self.filename = self.filename_array[indices]
        self.targets = self.y_array[indices]
        self.biases = self.confounder_array[indices]
        
        self.loss_weights_path = loss_weights_path
        if self.loss_weights_path is not None:
            self.loss_weights = torch.load(loss_weights_path)
        else:
            self.loss_weights = torch.ones(len(self.targets))
        

    def __len__(self):
        return len(self.filename)

    def __getitem__(self, index):
        X = Image.open(os.path.join(self.root, self.filename[index])).convert("RGB")
        y = self.targets[index]
        a = self.biases[index]
        w = self.loss_weights[index]
        
        if self.return_attr:
            return X, y, a
        elif self.loss_weights_path is not None:
            return X, y, w
        else:
            return X, y