import os
import pandas as pd
from PIL import Image

import torch
from torchvision import transforms
import torchvision.transforms as transforms


class MultiGroupLoader(torch.utils.data.Dataset):
    def __init__(self, dataset_root_dir, train_split, transform=None, dataset='waterbirds'):
        if dataset == 'waterbirds':
            self.metadata_df = pd.read_csv(
                os.path.join(dataset_root_dir, 'metadata.csv'))
            self.confounder_array = self.metadata_df['place'].values
            self.split_array = self.metadata_df['split'].values
            self.y_array = self.metadata_df['y'].values
            self.filename_array = self.metadata_df['img_filename'].values
            self.dataset_root_dir = dataset_root_dir
        elif dataset == 'celeba':
            self.metadata_df = pd.read_csv(
                os.path.join(dataset_root_dir, 'list_attr_celeba.csv'))
            self.confounder_array = self.metadata_df['Male'].values
            self.confounder_array[self.confounder_array == -1] = 0
            self.split_df = pd.read_csv(
                os.path.join(dataset_root_dir, "list_eval_partition.csv"))
            self.split_array = self.split_df["partition"].values
            self.y_array = self.metadata_df['Blond_Hair'].values
            self.y_array[self.y_array == -1] = 0
            self.filename_array = self.metadata_df["image_id"].values
            self.dataset_root_dir = os.path.join(dataset_root_dir, "img_align_celeba")
        else:
            raise NotImplementedError

        self.n_classes = 2
        # Extract filenames and splits
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }

        self.train_split = train_split
        if transform is not None:
            self.augment_transform = transform
        else:
            self.augment_transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.CenterCrop((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

        self.all_data, self.all_cate, self.labels, self.group_array, self.spurious = self.make_dataset()


    def make_dataset(self):
        all_data=[]
        all_labels = []
        all_groups = []
        all_confounders = []
        cnt=0

        for cnt in range(self.y_array.shape[0]):
            if self.split_array[cnt]==self.split_dict[self.train_split]:
                all_data.append([os.path.join(self.dataset_root_dir, self.filename_array[cnt]), self.y_array[cnt]])
                all_labels.append(self.y_array[cnt])
                all_confounders.append(self.confounder_array[cnt])
                if (self.y_array[cnt] == 0) and (self.confounder_array[cnt] == 0):
                    all_groups.append(0)
                elif (self.y_array[cnt] == 0) and (self.confounder_array[cnt] == 1):
                    all_groups.append(1)
                elif (self.y_array[cnt] == 1) and (self.confounder_array[cnt] == 0):
                    all_groups.append(2)
                else:
                    all_groups.append(3)

        all_cate=[[], []]
        for d in all_data:
            each, id = d
            all_cate[id].append(each)

        return all_data, all_cate, all_labels, all_groups, all_confounders

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, index):
        img_path, id = self.all_data[index]
        img_x_ori = Image.open(img_path).convert("RGB")
        img_x = self.augment_transform(img_x_ori)

        return img_x, id


class GroupTest(torch.utils.data.Dataset):
    def __init__(self, dataset_root_dir, split, subsample=1, group=[0,0], dataset='waterbirds'):
        self.subsample = subsample
        self.split = split
        self.dataset_root_dir = dataset_root_dir

        if dataset == 'waterbirds':
            self.metadata_df = pd.read_csv(
                os.path.join(dataset_root_dir, 'metadata.csv'))
            self.y_array = self.metadata_df['y'].values
            self.confounder_array = self.metadata_df['place'].values
            # Extract filenames and splits
            self.filename_array = self.metadata_df['img_filename'].values
            self.split_array = self.metadata_df['split'].values
        else:
            self.metadata_df = pd.read_csv(
                os.path.join(dataset_root_dir, 'list_attr_celeba.csv'))
            self.confounder_array = self.metadata_df['Male'].values
            self.confounder_array[self.confounder_array == -1] = 0
            self.split_df = pd.read_csv(
                os.path.join(dataset_root_dir, "list_eval_partition.csv"))
            self.split_array = self.split_df["partition"].values
            self.y_array = self.metadata_df['Blond_Hair'].values
            self.y_array[self.y_array == -1] = 0
            self.filename_array = self.metadata_df["image_id"].values
            self.dataset_root_dir = os.path.join(dataset_root_dir, "img_align_celeba")

        self.n_classes = 2
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }
        self.group=group

        self.all_data, self.labels = self.make_dataset()

        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def make_dataset(self):
        all_data = []
        all_labels = []
        cnt = 0
        if self.split=='val':
            flag=1
        elif self.split=='test':
            flag=2
        elif self.split=='train':
            flag=0
        for cnt in range(self.y_array.shape[0]):
            if self.split_array[cnt] == flag:
                if self.y_array[cnt] == self.group[0] and self.confounder_array[cnt] == self.group[1]:
                    all_data.append([os.path.join(self.dataset_root_dir, self.filename_array[cnt]), self.y_array[cnt]])
                    all_labels.append(self.y_array[cnt])

        return all_data, all_labels

    def __getitem__(self, index):
        img_path, label = self.all_data[index]
        img_x = Image.open(img_path).convert("RGB")
        img_x = self.transform(img_x)
        return img_x, label

    def __len__(self):
        return len(self.all_data)
