import torch
import torchvision 
import os
from PIL import Image
import numpy as np
import random
from datasets.logo_dataset import LogoDataset

class UTKFace(LogoDataset):
    def __init__(self, args=None, split="val", transform=None, paste_attack_file = None, past_attack_file_locations = None, transparency = 0.5, factor_shrink=10, crop_imgs = False):
        LogoDataset.__init__(self,args=args, past_attack_file_locations = past_attack_file_locations, paste_attack_file = paste_attack_file, transparency = transparency, factor_shrink=factor_shrink, crop_imgs = crop_imgs)

        self.root = ""
        self.split = split

        self.num_to_race = { 
            0:"White", 
            1:"Black",
            2:"Asian",
            3:"Indian",
            4:"Others"
        }

        self.num_to_gender = { 
            0:"Male", 
            1:"Female"
        }

        self.build_data()
        self.transform = transform

        self.build_split()
        self.prompts = None

        self.unique_race = list(set(self.race_labels))
        self.unique_gender = list(set(self.gender_labels))

    def sample_random(self, p):
        np.random.seed(0)
        idx = np.random.choice(len(self.filenames), int(p*len(self.filenames)), replace=False)
        self.filenames = [self.filenames[i] for i in idx]
        self.build_labels()
        

    def build_split(self):
        #use pandas to load the csv file
        self.split_to_ratio = { 
            "train": [0.0, 0.8],
            "val": [0.8, 0.9],
            "test": [0.9, 1.0]
        }

        start_index = int(len(self.filenames) * self.split_to_ratio[self.split][0])
        end_index = int(len(self.filenames) * self.split_to_ratio[self.split][1])

        self.filenames = self.filenames[start_index:end_index]
        self.race_labels = self.race_labels[start_index:end_index]
        self.gender_labels = self.gender_labels[start_index:end_index]

        
    def build_data(self):

        self.filenames = os.listdir(os.path.join(self.root, "images"))
        self.filenames = [os.path.join(self.root, "images", filename) for filename in self.filenames if ".jpg" in filename]
        #filter filenames if there are not 4 _ 
        self.filenames = [filename for filename in self.filenames if len(filename.split("/")[-1].split("_")) == 4]

        self.race_labels = [] 
        self.gender_labels = [] 
        self.age_labels = []

        for filename in self.filenames:

            fn_name = filename.split("/")[-1]

            self.age_labels.append(int(fn_name.split("_")[0]))
            self.gender_labels.append(self.num_to_gender[int(fn_name.split("_")[1])])
            self.race_labels.append(self.num_to_race[int(fn_name.split("_")[2])])


    def set_prompts(self, mode, concept, opposite): 
        if mode == "yesno": 
            self.prompts = ["Is this person {}? Answer Yes or No.".format(concept)] * len(self.filenames)
            self.answers = ["yes"] * len(self.filenames)

        if mode == "onetwo": 
            self.prompts = [] 
            self.answers = [] 
            random.seed(0) 
            for i in range(len(self.filenames)):
                random_correct_choice = random.choice([1, 2])
                self.answers.append(str(random_correct_choice))
                if random_correct_choice == 1: 
                    self.prompts.append("Is this person 1) {} or 2){}? Answer with 1 or 2 only.".format(concept, opposite))
                else:
                    self.prompts.append("Is this person 1) {} or 2){}? Answer with 1 or 2 only.".format(opposite, concept))
            
