import os
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

class water_land_bird(Dataset):
    def __init__(self, data_root, tensor_image = False, image_size = 128, train = False):
        self.data_root = data_root
        self.tensor_image = tensor_image
        self.image_size = image_size
        self.max_object_size = 0.15
        self.max_rotation_angle = 30
        self.train = train
        
        if self.tensor_image:
            self.image_transform = transforms.Compose(
                [transforms.Resize((self.image_size, self.image_size)),
                 transforms.ToTensor(),
                 transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                      std = [0.229, 0.224, 0.225])
                ])
            self.mask_transform = transforms.Compose(
                [transforms.Resize((self.image_size, self.image_size)),
                 transforms.ToTensor(),
                ])
        else:
            self.image_transform = None
            
        # train test split: a binary list of the is_trainset of 11788 images
        with open(os.path.join(self.data_root, './CUB_200_2011/train_test_split.txt'), 'r')  as f:
            self.train_test_split = f.readlines()
        self.train_test_split = np.array([int(att.split()[1]) for att in self.train_test_split])
        
        if self.train:
            self.indices = np.where(self.train_test_split == 1)[0]
        else:
            self.indices = np.where(self.train_test_split == 0)[0]
        
        # image list: a list of 11788 images
        with open(os.path.join(self.data_root, './CUB_200_2011/images.txt'), 'r')  as f:
            self.image_list = f.readlines()

        self.image_list = np.array([att.split()[1] for att in self.image_list])
        self.image_list = self.image_list[self.indices]

        # class list: a list of 200 classes
        with open(os.path.join(self.data_root, './CUB_200_2011/classes.txt'), 'r')  as f:
            self.class_list = f.readlines()
        self.class_list = [att.split()[1] for att in self.class_list]
        
        # image to label: a list of the classes of the 11788 images
        with open(os.path.join(self.data_root, './CUB_200_2011/image_class_labels.txt'), 'r')  as f:
            self.image2label = f.readlines()
        self.image2label = np.array([int(att.split()[1])-1 for att in self.image2label])
        self.image2label = self.image2label[self.indices]
               
        self.water_birds_list = [
            'Albatross', # Seabirds
            'Auklet',
            'Cormorant',
            'Frigatebird',
            'Fulmar',
            'Gull',
            'Jaeger',
            'Kittiwake',
            'Pelican',
            'Puffin',
            'Tern',
            'Gadwall', # Waterfowl
            'Grebe',
            'Mallard',
            'Merganser',
            'Guillemot',
            'Pacific_Loon'
        ]
        
        # get the label for the water/land birds
        self.unresampled_label = []
        for i in range(len(self.image_list)):
            is_waterbird = False
            for waterbird in self.water_birds_list:
                if waterbird in self.image_list[i]:
                    is_waterbird = True
                    break
            if is_waterbird:
                self.unresampled_label.append(1)
            else:
                self.unresampled_label.append(0)

        self.unresampled_label = np.array(self.unresampled_label)
        
        self.waterbird_resampled_indices = np.random.RandomState(seed=42).choice(
            np.where(self.unresampled_label == 1)[0],
            size = ((self.unresampled_label == 0).sum() - (self.unresampled_label == 1).sum()))
        self.waterbird_indices = np.concatenate(
            [np.where((self.unresampled_label == 1))[0], self.waterbird_resampled_indices])  
        
        self.landbird_indices = np.where((self.unresampled_label == 0))[0]
        
        self.total_indices = np.concatenate([
            self.waterbird_indices, self.landbird_indices
        ])
        
    
    def resize(self, x, mask, ratio):
        if torch.is_tensor(ratio):
            ratio = ratio.item()
        pad_size = x.shape[-1] - int(x.shape[-1]*np.sqrt(ratio))
        new_x = F.interpolate(x[None], scale_factor = np.sqrt(ratio), mode = 'bilinear', align_corners = False)
        x = F.pad(new_x, (pad_size//2, pad_size-pad_size//2)*2,
                  mode = 'constant', value = 0).squeeze()
    
        new_mask = F.interpolate(mask[None], scale_factor = np.sqrt(ratio), mode = 'bilinear', align_corners = False)
        mask = F.pad(new_mask, (pad_size//2, pad_size-pad_size//2)*2,
                  mode = 'constant', value = 0).squeeze()
        return x, mask

    def rotate(self, x, mask):
        angle = (np.random.rand()*2-1)*self.max_rotation_angle
        x_rot = torchvision.transforms.functional.rotate(x[None], angle).squeeze()
        mask_rot = torchvision.transforms.functional.rotate(mask[None], angle).view(1,self.image_size,self.image_size)
        return x_rot, mask_rot
        
    def __len__(self):
        return len(self.total_indices)

    def __getitem__(self, idx):
        
        idx = self.total_indices[idx]
        image = Image.open(os.path.join(self.data_root, './CUB_200_2011/images', self.image_list[idx]))
        mask = Image.open(os.path.join(self.data_root, 'segmentations', 
                                       self.image_list[idx].split('.jpg')[0]+'.png'))
        image = image.convert('RGB')
        label = self.unresampled_label[idx]
        
        if self.tensor_image:
            image = self.image_transform(image)
            mask = self.mask_transform(mask)[0][None]
            masked = image*mask
            mask_size = (mask > 0).float().mean()
            if mask_size > self.max_object_size:
                masked, mask = self.resize(masked, mask, self.max_object_size/mask_size)
            masked, mask = self.rotate(masked, mask)
            return masked, mask, label
        else:
            return image, mask, label


class places(Dataset):
    def __init__(self, data_root, image_size = 128, tensor_image = True, train = False):
        self.data_root = data_root
        self.tensor_image = tensor_image
        self.train = train
        self.image_size = image_size
        
        if self.tensor_image:
            self.image_transform = transforms.Compose(
                [transforms.Resize((self.image_size, self.image_size)),
                 transforms.ToTensor(),
                 transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                      std = [0.229, 0.224, 0.225])
                ])
        else:
            self.image_transform = None
            
        self.land = ['b/bamboo_forest', 'f/forest/broadleaf']
        self.water = ['l/lake/natural', 'o/ocean']
        
        self.image_list = []
        self.label = []
        
        # training images are the first 2500 of all four categories, totally 10000
        # testing images are the last 2500 of all four catetories, totally 10000
        if self.train:
            for land_class in self.land:
                self.image_list += [os.path.join(self.data_root, land_class, 
                                                 '{:0>8}.jpg'.format(i+1)) for i in range(5000)]
            for water_class in self.water:
                self.image_list += [os.path.join(self.data_root, water_class, 
                                                 '{:0>8}.jpg'.format(i+1)) for i in range(5000)]
        else:
            for land_class in self.land:
                self.image_list += [os.path.join(self.data_root, land_class, 
                                                 '{:0>8}.jpg'.format(i+1)) for i in range(5000)]
            for water_class in self.water:
                self.image_list += [os.path.join(self.data_root, water_class, 
                                                 '{:0>8}.jpg'.format(i+1)) for i in range(5000)]
        self.image_list = np.array(self.image_list)
        self.label = np.array([0]*10000+[1]*10000)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        
        image = Image.open(self.image_list[idx])
        
        image = image.convert('RGB')
        label = self.label[idx]
        return self.image_transform(image), label


class WaterBirds_biased(Dataset):
    def __init__(self, zeta, args, train):

        self.correlation_type = args.correlation_type
        self.train=train
        self.CUB_root=args.CUB_root_zeta
        self.place_root=args.place_root_zeta
        self.image_size=args.image_size
        self.zeta = zeta
        self.args = args

        self.set_bird=water_land_bird(self.CUB_root,tensor_image=True,image_size=self.image_size,train=self.train)
        self.set_place=places(self.place_root,tensor_image=True,image_size=self.image_size,train=self.train)
        
        self.bird_label=torch.tensor(self.set_bird.unresampled_label[self.set_bird.total_indices])
        self.place_label=self.set_place.label

        self.N = len(self.set_bird)

        self.generate_grps()

        self.spurious = [1] * self.group_size[0] + [0] * self.group_size[1] + [1] * self.group_size[2] + [0] * self.group_size[3]
        self.spurious = torch.tensor(self.spurious)

        self.targets = [1] * self.group_size[0] + [1] * self.group_size[1] + [0] * self.group_size[2] + [0] * self.group_size[3]
        self.targets = torch.tensor(self.targets)

        self.generate_indices()

    def generate_grps(self):
        if self.correlation_type == 0:
            '''
            spurious correlation:
                p(y=0) = p(y=1)
                p(y=0|a=0) >> p(y=0|a=1), p(y=1|a=1) >> p(y=1|a=0) -> alpha, beta > 0.5
                majority: 11 00; minority: 10 01
            '''
            group_0_size, group_3_size = int(self.zeta * self.N // 2), int(self.zeta * self.N // 2)
            group_1_size = int(self.N // 2 - group_0_size)
            group_2_size = self.N - int(group_0_size + group_1_size + group_3_size)
            self.group_size = [
                group_0_size,
                group_1_size,
                group_2_size,
                group_3_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [0] * self.group_size[1] + [0] * self.group_size[2] + [1] * self.group_size[3])
        
        elif self.correlation_type == 1:
            '''
            underrepresentation:
                11, 10, 01, 00: majority 11 01; minority 10 00
                p(a=0) << p(a=1), p(y=0) = p(y=1)
            '''

            group_0_size, group_2_size = int(self.zeta * self.N // 2), int(self.zeta * self.N // 2)
            group_1_size = int(self.N // 2 - group_0_size)
            group_3_size = self.N - int(group_0_size + group_1_size + group_2_size)
            self.group_size = [
                group_0_size,
                group_1_size,
                group_2_size,
                group_3_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [0] * self.group_size[1] + [1] * self.group_size[2] + [0] * self.group_size[3])

        elif self.correlation_type == 2:
            '''
            class imbalance:
                p(y=0) << p(y=1)
                11, 10, 01, 00: majority 11 10; minority 01 00
            '''
            group_0_size, group_1_size = int(self.zeta * self.N // 2), int(self.zeta * self.N // 2)
            group_2_size = int(self.N // 2 - group_0_size)
            group_3_size = self.N - int(group_0_size + group_1_size + group_2_size)
            self.group_size = [
                group_0_size,
                group_1_size,
                group_2_size,
                group_3_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [1] * self.group_size[1] + [0] * self.group_size[2] + [0] * self.group_size[3])

        else:
            raise NotImplementedError

    def resample_indices(self, indices, target_size):
        if len(indices) >= target_size:
            return np.random.choice(indices, target_size, replace=False)
        else:
            return np.concatenate([indices, np.random.choice(indices, target_size - len(indices), replace=True)])

    def generate_indices(self):
        # (1,1) (1,0) (0,1) (0,0)
        waterbirds_indices = self.resample_indices(np.arange(self.N // 2), target_size=self.group_size[0]+self.group_size[1])
        landbirds_indices = self.resample_indices(np.arange(self.N // 2, self.N), target_size=self.group_size[2]+self.group_size[3])

        self.bird_indices = np.concatenate([
            waterbirds_indices[:self.group_size[0]],
            waterbirds_indices[self.group_size[0]:],
            landbirds_indices[:self.group_size[2]],
            landbirds_indices[self.group_size[2]:],
        ])
        
        water_indices = self.resample_indices(np.arange(len(self.set_place) // 2, len(self.set_place)), target_size=self.group_size[0]+self.group_size[2])
        land_indices = self.resample_indices(np.arange(len(self.set_place) // 2), target_size=self.group_size[1]+self.group_size[3])
        self.place_indices = np.concatenate([
            water_indices[:self.group_size[0]],
            land_indices[:self.group_size[1]],
            water_indices[self.group_size[0]:],
            land_indices[self.group_size[1]:],
        ])

        self.images = []
        for idx in tqdm(range(self.N)):
            bird, mask, bird_label = self.set_bird[self.bird_indices[idx]]
            place, place_label = self.set_place[self.place_indices[idx]]
            self.images.append(place*(1-mask)+bird)

    def get_majority_grps(self, y, a):
        if self.correlation_type == 0:
            return ((y == 1) & (a == 1)) | ((y == 0) & (a == 0))
        elif self.correlation_type == 1:
            return ((y == 1) & (a == 1)) | ((y == 0) & (a == 1))
        elif self.correlation_type == 2:
            return ((y == 1) & (a == 1)) | ((y == 1) & (a == 0))
        

    def __len__(self):
        return len(self.set_bird)

    def __getitem__(self, idx):
        return self.images[idx], self.targets[idx], self.spurious[idx]


def get_waterbirds_zeta(zeta, args):
    trainset_biased = WaterBirds_biased(zeta=zeta, args=args, train=True)
    testset_unbiased = WaterBirds_biased(zeta=0.5, args=args, train=False)
    return trainset_biased, testset_unbiased

