import os
import numpy as np
import json
import pandas as pd
from torchvision import transforms
from datasets.basic_dataset_scaffold import BaseDataset

class LFWDataset(BaseDataset):
    def __init__(self, image_dict, opt, metadata, is_validation=False):
        super(LFWDataset, self).__init__(image_dict, opt, is_validation=is_validation)
        
        image_size = 256
        if 'bninception' not in opt.arch:
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # imagenet mean, std
            self.f_norm = normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        else:
            # self.f_norm = normalize = transforms.Normalize(mean=[0.502, 0.4588, 0.4078],std=[1., 1., 1.]) # suggested imagenet bninception mean
            self.f_norm = normalize = transforms.Normalize(mean=[0.502, 0.4588, 0.4078],std=[0.0039, 0.0039, 0.0039])
        
        transf_list = []
        
        if not self.is_validation:
            transf_list.append(transforms.RandomHorizontalFlip())
        
        transf_list.extend([
                transforms.Resize(image_size),
                transforms.CenterCrop(self.crop_size),
                transforms.ToTensor(),
                normalize
                ])
        self.normal_transform = transforms.Compose(transf_list)
                
        self.metadata = metadata
        
        ##### RE INDEXiNG METADATA TO MATCH IMAGE LIST
        if hasattr(self, 'image_list'):
            self.metadata['imagepath'] = pd.Categorical(self.metadata['imagepath'], [x[0] for x in self.image_list])
            self.metadata = self.metadata.sort_values('imagepath')
    
    def get_attribute(self, idx, attributes=['Pale Skin']):
        items = self.metadata[attributes].iloc[idx].values
        return items

def preprocess_lfw(opt, image_sourcepath, annotationspath, image_classes, no_nans=True, constant_attributes={'Male', 'Asian', 'White', 'Black', 'Indian', 'Pale Skin', 'Brown Eyes'}):

    ### Read in attributes file and re-adjust columns
    df = pd.read_csv(os.path.join(annotationspath, 'lfw_attributes.txt'), skiprows=1, sep='\t')
    readjust_cols = list(df.columns[1:])
    df = df.iloc[:,:-1]
    df.columns = readjust_cols
    
    ### Pre-processing floats to integer values for each attribute
    df.iloc[:, 2:] = (df.iloc[:, 2:].values > 0).astype(int)
    
    ### Create imagepath column
    imagefiles = ['_'.join([*person.split(), "{0:0=4d}.jpg".format(imagenum)]) for person, imagenum in zip(df['person'], df['imagenum'])]
    df['imagepath'] = [os.path.join(*[image_sourcepath, '_'.join(person.split(' ')), imagefile]) for imagefile, person in zip(imagefiles, df['person'].values)]
    
    nans = []
    
    image_dirs = {key: [image_sourcepath+'/'+key+'/'+x for x in os.listdir(image_sourcepath+'/'+key)] for key in image_classes}
    ### Find missing images and add them
    for key in image_dirs:
        
        ### Get name from the key
        name = ' '.join(key.split('_'))
        
        ### Select rows with person images
        if name not in df['person'].values:
            if not no_nans:
                cols = list(df.drop(labels=['person', 'imagenum', 'imagepath'], axis=1, inplace=False).columns)
                attributes = {col: np.nan for col in cols}
                newrows = [{'person': name, 'imagenum': int(imagefile[:-4].split('_')[-1]), 'imagepath': imagefile, **attributes} for imagefile in image_dirs[key]]
                df = df.append(newrows, ignore_index=True)
            else:
                nans.append(key)
            continue
        items = df.loc[df['person'] == name]
        
        ### Sets for imagepaths in the dataframe for the person and in the image dirs
        set_items = set(items['imagepath'].values)
        set_dict = set(image_dirs[key])
        
        ### If some imagepaths are missing from the dataframe, add them by taking majority vote over the rows for the constant / immutable attributes, and setting transient attributes to np.nan
        if not set_items == set_dict:
            missed_images = set.difference(set_dict, set_items)
            for imagefile in missed_images:
                newrow = {'person': name, 'imagenum': int(imagefile[:-4].split('_')[-1]), 'imagepath': imagefile}
                mode_attributes = items.drop(labels=['person', 'imagenum', 'imagepath'], axis=1, inplace=False).mode(axis=0, numeric_only=True).iloc[0]
                mode_attributes[set.difference(set(mode_attributes.index), constant_attributes)] = np.nan
                newrow.update(mode_attributes)
                df = df.append(newrow, ignore_index=True)
    
    if no_nans:
        df['Race'] = df.White.replace({1: "White", 0: ""}) + df.Black.replace({1: "Black", 0: ""}) + df.Asian.replace({1: "Asian", 0: ""}) + df.Indian.replace({1: "Indian", 0: ""})
        counts = df.White.values + df.Black.values + df.Asian.values + df.Indian.values

        before = df['person'].values
        df = df[counts == 1]
        after = df['person'].values

        nans.extend(["_".join(name.split()) for name in list(set.difference(set(before), set(after)))])
        unique_labels = df['Race'].unique()
        unique_labels.sort()
        meta_conversion = {i: ul for i, ul in enumerate(unique_labels)}
        df['Race'] = df['Race'].replace({val: key for key, val in meta_conversion.items()}).astype(int)
        with open(os.path.join('../datasets', 'lfw_meta_conversion.json'), 'w') as fp:
            json.dump(meta_conversion, fp)

    return df, nans

def Give(opt, datapath):
    image_sourcepath = os.path.join(datapath, 'lfw_funneled')
    splitpath = os.path.join(datapath, 'split')
    annotationspath = os.path.join(datapath, 'anno')
    
    trainsplit = pd.read_csv(os.path.join(splitpath, 'peopleDevTrain.txt'), sep='\t', header=None, skiprows=1).iloc[:,0].values
    testsplit = pd.read_csv(os.path.join(splitpath, 'peopleDevTest.txt'), sep='\t', header=None, skiprows=1).iloc[:,0].values
    
    assert sorted([x for x in os.listdir(image_sourcepath)]) == sorted(list(trainsplit)+list(testsplit))
    
    image_classes     = list(trainsplit)+list(testsplit)
    
    ### Change to list
    trainsplit = list(trainsplit)
    testsplit = list(testsplit)
    
    if hasattr(opt, 'parade_features') and 'sensitive' in opt.parade_features:
        metadata, nans = preprocess_lfw(opt, image_sourcepath, annotationspath, image_classes, no_nans=True)
        for nan in nans:
            try:
                trainsplit.remove(nan)
            except ValueError:
                pass
            try:
                testsplit.remove(nan)
            except ValueError:
                pass
        image_classes = list(trainsplit)+list(testsplit)
    else:
        metadata, nans = preprocess_lfw(opt, image_sourcepath, annotationspath, image_classes)
        for nan in nans:
            try:
                trainsplit.remove(nan)
            except ValueError:
                pass
            try:
                testsplit.remove(nan)
            except ValueError:
                pass
        image_classes = list(trainsplit)+list(testsplit)
    
    total_conversion  = {i:x for i,x in enumerate(image_classes)}
    image_list    = {i:sorted([image_sourcepath+'/'+key+'/'+x for x in os.listdir(image_sourcepath+'/'+key)]) for i,key in enumerate(image_classes)}
    image_list    = [[(key,img_path) for img_path in image_list[key] if img_path in metadata['imagepath'].values] for key in image_list.keys()]
    image_list    = [x for y in image_list for x in y]

    ### Dictionary of structure class:list_of_samples_with_said_class
    image_dict    = {}
    for key, img_path in image_list:
        if not key in image_dict.keys():
            image_dict[key] = []
        image_dict[key].append(img_path)
    
    ### Switch from class to index
    trainsplit = [key for key, value in total_conversion.items() if value in trainsplit]
    testsplit = [key for key, value in total_conversion.items() if value in testsplit]
    
    ### Use the first half of the sorted data as training and the second half as test set
    keys = sorted(list(image_dict.keys()))
    train,test      = trainsplit, testsplit

    ### If required, split the training data into a train/val setup either by or per class.
    if opt.use_tv_split:
        if not opt.tv_split_by_samples:
            train_val_split = int(len(train)*opt.tv_split_perc)
            train, val      = train[:train_val_split], train[train_val_split:]
            ###
            train_image_dict = {i:image_dict[key] for i,key in enumerate(train)}
            val_image_dict   = {i:image_dict[key] for i,key in enumerate(val)}
        else:
            val = train
            train_image_dict, val_image_dict = {},{}
            for key in train:
                train_ixs = np.random.choice(len(image_dict[key]), int(len(image_dict[key])*opt.tv_split_perc), replace=False)
                val_ixs   = np.array([x for x in range(len(image_dict[key])) if x not in train_ixs])
                train_image_dict[key] = np.array(image_dict[key])[train_ixs]
                val_image_dict[key]   = np.array(image_dict[key])[val_ixs]
                
        val_conversion = {i:total_conversion[key] for i,key in enumerate(val)}
        val_images = [x for value in val_image_dict.values() for x in value]
        val_metadata = metadata.iloc[[i for i in range(len(metadata)) if metadata.iloc[i,:]['imagepath'] in val_images],:]
        val_dataset   = LFWDataset(val_image_dict,   opt, val_metadata, is_validation=True)
        ###
        val_dataset.conversion   = val_conversion
    else:
        train_image_dict = {key:image_dict[key] for key in train}
        val_image_dict   = None
        val_dataset      = None

    ###
    train_conversion = {i:total_conversion[key] for i,key in enumerate(train)}
    test_conversion  = {i:total_conversion[key] for i,key in enumerate(test)}
    
    ###
    test_image_dict = {key:image_dict[key] for key in test}

    ###
    print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict else 'X', len(test_image_dict)))
    
    ###
    train_images = [x for value in train_image_dict.values() for x in value]
    train_metadata = metadata.iloc[[i for i in range(len(metadata)) if metadata.iloc[i,:]['imagepath'] in train_images],:]
            
    test_images = [x for value in test_image_dict.values() for x in value]
    test_metadata = metadata.iloc[[i for i in range(len(metadata)) if metadata.iloc[i,:]['imagepath'] in test_images],:]

    ###
    train_dataset       = LFWDataset(train_image_dict, opt, train_metadata)
    test_dataset        = LFWDataset(test_image_dict,  opt, test_metadata, is_validation=True)
    eval_dataset        = LFWDataset(train_image_dict, opt, train_metadata, is_validation=True)
    eval_train_dataset  = LFWDataset(train_image_dict, opt, train_metadata, is_validation=False)
    train_dataset.conversion       = train_conversion
    test_dataset.conversion        = test_conversion
    eval_dataset.conversion        = train_conversion
    eval_train_dataset.conversion  = train_conversion

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset}
