import torch
import os
from PIL import Image
import pandas as pd
import numpy as np
import random
from datasets.logo_dataset import LogoDataset


class FairFace(LogoDataset):
    def __init__(self, args = None, split="val", transform=None, paste_attack_file = None, past_attack_file_locations = None, transparency = 0.5, factor_shrink=10, crop_imgs = False):
        LogoDataset.__init__(self, args = args, past_attack_file_locations = past_attack_file_locations, paste_attack_file = paste_attack_file, transparency = transparency, factor_shrink=factor_shrink, crop_imgs = crop_imgs)
        
        self.root = ""
        self.split = split
        self.build_data()
        self.transform = transform

        self.build_labels()
        self.prompts = None

    def sample_random(self, p):
        np.random.seed(0)
        idx = np.random.choice(len(self.filenames), int(p*len(self.filenames)), replace=False)
        self.filenames = [self.filenames[i] for i in idx]
        self.build_labels()
        
    def build_labels(self):
        #use pandas to load the csv file
        df = pd.read_csv(os.path.join(self.root, f"fairface_label_{self.split}.csv"))
        df.set_index('file', inplace=True)
        #for each filename under the file column, get the corresponding race under the race column
        index_filenames = [f"{self.split}/{filename.split('/')[-1]}" for filename in self.filenames]
        self.race_labels = df.loc[index_filenames, 'race'].values
        self.gender_labels = df.loc[index_filenames, 'gender'].values
        self.gender_labels = ["F" if gender == "Female" else "M" for gender in self.gender_labels ]

        self.unique_race = list(set(self.race_labels))
        self.unique_gender = list(set(self.gender_labels))

    def build_data(self):

        self.filenames = os.listdir(os.path.join(self.root, self.split))
        self.filenames = [os.path.join(self.root, self.split, filename) for filename in self.filenames if ".jpg" in filename]

    def set_prompts(self, mode, concept, opposite): 
        if mode == "yesno": 
            self.prompts = ["Is this person {}? Answer Yes or No.".format(concept)] * len(self.filenames)
            self.answers = ["yes"] * len(self.filenames)

        if mode == "onetwo": 
            self.prompts = [] 
            self.answers = [] 
            random.seed(0) 
            for i in range(len(self.filenames)):
                random_correct_choice = random.choice([1, 2])
                self.answers.append(str(random_correct_choice))
                if random_correct_choice == 1: 
                    self.prompts.append("Is this person 1) {} or 2){}? Answer with 1 or 2 only.".format(concept, opposite))
                else:
                    self.prompts.append("Is this person 1) {} or 2){}? Answer with 1 or 2 only.".format(opposite, concept))
            
