import os
import math
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class CelebA_biased(Dataset):
    def __init__(self, zeta, args, seed, train, N):
        self.seed = seed
        self.correlation_type = args.correlation_type
        self.data_root = args.celeba_root
        self.image_size = args.image_size
        self.train = train
        self.zeta = zeta
        print(f"N={N}")
        self.N = N
        self.transform = transforms.Compose(
            [transforms.Resize((self.image_size,self.image_size)),
             transforms.ToTensor(),
             transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                  std = [0.229, 0.224, 0.225])
            ])

        self.black_hair = np.loadtxt(os.path.join(self.data_root, "black_hair.txt")).astype(np.float32)
        self.male = np.loadtxt(os.path.join(self.data_root, "male.txt"))
        with open(os.path.join(self.data_root, "hairColor_gender.txt"), "r") as file:
            self.names = file.readlines()
        self.names = np.array([item[:-1] for item in self.names])

        if self.train:
            self.indices = np.loadtxt(os.path.join(self.data_root, f"indices/indices_train_seed{self.seed}.txt")).astype("int")
        else:
            self.indices = np.loadtxt(os.path.join(self.data_root, f"indices/indices_test_seed{self.seed}.txt")).astype("int")

        self.group_indices = np.array([
            self.black_hair * self.male,
            self.black_hair * (1 - self.male),
            (1 - self.black_hair) * self.male,
            (1 - self.black_hair) * (1 - self.male),
        ])

        self.names_subset = self.names[self.indices]
        self.group_indices_subset = self.group_indices[:,self.indices]
        self.black_hair_subset = self.black_hair[self.indices]
        self.male_subset = self.male[self.indices]
        self.original_group_sizes = self.group_indices_subset.sum(1)

        self.generate_grps()

        self.names_balance = self.names_subset[self.indices_balance]
        self.group_indices_balance = self.group_indices_subset[:,self.indices_balance]
        self.black_hair_balance = torch.tensor(self.black_hair_subset[self.indices_balance]).to(torch.int64)
        self.targets = self.black_hair_balance
        self.male_balance = self.male_subset[self.indices_balance]


    def resample_indices(self, indices, target_size):
        if len(indices) >= target_size:
            return np.random.choice(indices, target_size, replace=False)
        else:
            return np.concatenate([indices, np.random.choice(indices, target_size - len(indices), replace=True)])

    def generate_grps(self):
        
        if self.train == False:
            self.major_size, self.minor_size = int(self.N // 4), int(self.N // 4)

        else:
            self.major_size = int(self.zeta * self.N // 2)
            self.minor_size = self.N // 2 - self.major_size

        if self.correlation_type == 0:
            '''
            spurious correlation:
                p(y=0) = p(y=1)
                p(y=0|a=0) >> p(y=0|a=1), p(y=1|a=1) >> p(y=1|a=0) -> alpha, beta > 0.5
                majority: 11 00; minority: 10 01
            '''
            
            self.group_size =  [
                self.major_size,
                self.minor_size,
                self.minor_size,
                self.major_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [0] * self.group_size[1] + [0] * self.group_size[2] + [1] * self.group_size[3])
    
        elif self.correlation_type == 1:
            '''
            underrepresentation:
                11, 10, 01, 00: majority 11 01; minority 10 00
                p(a=0) << p(a=1), p(y=0) = p(y=1)
            '''

            self.group_size =  [
                self.major_size,
                self.minor_size,
                self.major_size,
                self.minor_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [0] * self.group_size[1] + [1] * self.group_size[2] + [0] * self.group_size[3])

        elif self.correlation_type == 2:
            '''
            class imbalance:
                p(y=0) << p(y=1)
                11, 10, 01, 00: majority 11 10; minority 01 00
            '''
            self.group_size =  [
                self.major_size,
                self.major_size,
                self.minor_size,
                self.minor_size,
            ]
            self.majority_grps = torch.tensor([1] * self.group_size[0] + [1] * self.group_size[1] + [0] * self.group_size[2] + [0] * self.group_size[3])
        
        else:
            raise NotImplementedError

        self.indices_balance = np.concatenate([
            self.resample_indices(np.where(self.group_indices_subset[0])[0], target_size=self.group_size[0]),
            self.resample_indices(np.where(self.group_indices_subset[1])[0], target_size=self.group_size[1]),
            self.resample_indices(np.where(self.group_indices_subset[2])[0], target_size=self.group_size[2]),
            self.resample_indices(np.where(self.group_indices_subset[3])[0], target_size=self.group_size[3]),
        ])


    def get_majority_grps(self, y, a):
        if self.correlation_type == 0:
            return ((y == 1) & (a == 1)) | ((y == 0) & (a == 0))
        elif self.correlation_type == 1:
            return ((y == 1) & (a == 1)) | ((y == 0) & (a == 1))
        elif self.correlation_type == 2:
            return ((y == 1) & (a == 1)) | ((y == 1) & (a == 0))

    def __len__(self):
        return len(self.indices_balance)

    def __getitem__(self, idx):
        name = self.names_balance[idx]
        image = Image.open(os.path.join(self.data_root, "Img/img_align_celeba", name))
        label = self.black_hair_balance[idx]
        attr = self.male_balance[idx]
        return self.transform(image), label, attr

def get_celeba_zeta(zeta, args, seed, num_training, num_testing):
    trainset_biased = CelebA_biased(zeta=zeta, args=args, seed=seed, train=True, N=num_training)
    testset_unbiased = CelebA_biased(zeta=zeta, args=args, seed=seed, train=False, N=num_testing)
    return trainset_biased, testset_unbiased
