import os
import pandas as pd
from PIL import Image

import torch
from torchvision import transforms
import torchvision.transforms as transforms


class WB_MultiDomainLoader(torch.utils.data.Dataset):
    def __init__(self, dataset_root_dir, train_split):
        self.metadata_df = pd.read_csv(
            os.path.join(dataset_root_dir, 'metadata.csv'))
        self.y_array = self.metadata_df['y'].values
        self.n_classes = 2
        self.confounder_array = self.metadata_df['place'].values
        # Extract filenames and splits
        self.filename_array = self.metadata_df['img_filename'].values
        self.split_array = self.metadata_df['split'].values
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }

        self.train_split = train_split
        self.dataset_root_dir = dataset_root_dir
        self.augment_transform = transforms.Compose([
            # transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        self.all_data, self.all_cate, self.labels, self.groups = self.make_dataset()


    def make_dataset(self):
        all_data=[]
        all_labels = []
        all_groups = []
        cnt=0

        for cnt in range(self.y_array.shape[0]):
            if self.split_array[cnt]==self.split_dict[self.train_split]:
                all_data.append([os.path.join(self.dataset_root_dir, self.filename_array[cnt]), self.y_array[cnt]])
                all_labels.append(self.y_array[cnt])
                if (self.y_array[cnt] == 0) and (self.confounder_array[cnt] == 0):
                    all_groups.append(0)
                elif (self.y_array[cnt] == 0) and (self.confounder_array[cnt] == 1):
                    all_groups.append(1)
                elif (self.y_array[cnt] == 1) and (self.confounder_array[cnt] == 0):
                    all_groups.append(2)
                else:
                    all_groups.append(3)

        all_cate=[[], []]
        for d in all_data:
            each, id = d
            all_cate[id].append(each)

        return all_data, all_cate, all_labels, all_groups

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, index):
        img_path, id = self.all_data[index]
        img_x_ori = Image.open(img_path).convert("RGB")
        img_x = self.augment_transform(img_x_ori)

        return img_x, id


class WB_DomainTest(torch.utils.data.Dataset):
    def __init__(self, dataset_root_dir, split, subsample=1, group=[0,0]):
        self.subsample = subsample
        self.split = split
        self.dataset_root_dir = dataset_root_dir

        self.metadata_df = pd.read_csv(
            os.path.join(dataset_root_dir, 'metadata.csv'))
        self.y_array = self.metadata_df['y'].values
        self.n_classes = 2
        self.confounder_array = self.metadata_df['place'].values
        # Extract filenames and splits
        self.filename_array = self.metadata_df['img_filename'].values
        self.split_array = self.metadata_df['split'].values
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }
        self.group=group

        self.all_data, self.labels = self.make_dataset()

        self.transform = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def make_dataset(self):
        all_data = []
        all_labels = []
        cnt = 0
        if self.split=='val':
            flag=1
        elif self.split=='test':
            flag=2
        for cnt in range(self.y_array.shape[0]):
            if self.split_array[cnt] == flag:
                if self.y_array[cnt] == self.group[0] and self.confounder_array[cnt] == self.group[1]:
                    all_data.append([os.path.join(self.dataset_root_dir, self.filename_array[cnt]), self.y_array[cnt]])
                    all_labels.append(self.y_array[cnt])

        return all_data, all_labels

    def __getitem__(self, index):
        img_path, label = self.all_data[index]
        img_x = Image.open(img_path).convert("RGB")
        img_x = self.transform(img_x)
        return img_x, label

    def __len__(self):
        return len(self.all_data)
