from datasets.basic_dataset_scaffold import BaseDataset
import os
import warnings
import pandas as pd
from pandas.api.types import is_string_dtype
from pandas.api.types import is_numeric_dtype
import yaml
import numpy as np
import json
import sys

class CUB200Dataset(BaseDataset):
    def __init__(self, image_dict, opt, is_validation=False, colors=None):
        super(CUB200Dataset, self).__init__(image_dict, opt, is_validation=is_validation)
        
        if colors is not None:
            metadata = pd.DataFrame(columns=['imagepath', 'color'])
            metadata['imagepath'] = [x[0] for x in self.image_list]
            metadata['color'] = [colors.loc[colors['class'] == imagepath.split('/')[-2]]['attribute'].values[0] for imagepath in metadata['imagepath']]
            self.metadata = metadata
            
            ##### RE INDEXiNG METADATA TO MATCH IMAGE LIST
            if hasattr(self, 'image_list'):
                self.metadata['imagepath'] = pd.Categorical(self.metadata['imagepath'], [x[0] for x in self.image_list])
                self.metadata = self.metadata.sort_values('imagepath')
    
    def get_attribute(self, idx, attributes=['color']):
        if hasattr(self, 'metadata'):
            items = self.metadata[attributes].iloc[idx].values
            return items
        else:
            warnings.warn('No metadata object on CUB200 dataset; returning None.')
            return None

def Give(opt, datapath):
    image_sourcepath  = datapath+'/images'
    image_classes     = sorted([x for x in os.listdir(image_sourcepath)])
    total_conversion  = {i:x for i,x in enumerate(image_classes)}
    image_list    = {i:sorted([image_sourcepath+'/'+key+'/'+x for x in os.listdir(image_sourcepath+'/'+key)]) for i,key in enumerate(image_classes)}
    image_list    = [[(key,img_path) for img_path in image_list[key]] for key in image_list.keys()]
    image_list    = [x for y in image_list for x in y]

    ### Dictionary of structure class:list_of_samples_with_said_class
    image_dict    = {}
    for key, img_path in image_list:
        if not key in image_dict.keys():
            image_dict[key] = []
        image_dict[key].append(img_path)

    ### Use the first half of the sorted data as training and the second half as test set
    keys = sorted(list(image_dict.keys()))
    train,test      = keys, keys

    df = pd.read_csv(os.path.join(*[datapath, 'attributes', 'attributes', 'plurality_attributes.csv']))
    unique_attributes = np.unique(df['attribute'].values)
    unique_attributes.sort()
    meta_conversion = {i: attr for i, attr in enumerate(unique_attributes)}

    #### OUTPUT META CONVERSION
    with open(os.path.join(datapath, "cub200_meta_conversion.json"), 'w') as fp:
        json.dump(meta_conversion, fp)

    df.replace(to_replace = {'attribute': {val: key for key, val in meta_conversion.items()}}, inplace=True)

    if opt.imbalance:
        ### MUST HAVE CONFIG FILE IF IMBALANCE SELECTED
        assert opt.config_file is not None
        
        try:
            with open(opt.config_file) as file:
                opt.config = yaml.load(file, Loader=yaml.FullLoader)
        except FileNotFoundError:
            print('Imbalance config file not found: {}'.format(opt.config_file))
            raise
        except:
            print('Problem opening imbalance config file {}: {}'.format(opt.config_file, sys.exc_info()[0]))
            raise
            
        ### MUST HAVE ATTRIBUTE and SEED
        assert 'attribute' in opt.config
        
        if opt.config['attribute']['type'] == 'class':
            assert 'seed' in opt.config
            np.random.seed(opt.config['seed'])
            opt.config['min_classes'] = np.random.choice(list(total_conversion.keys()), size = opt.config['attribute']['aux']['number'], replace = False)
            
            ### RESET MAJORITY PERCENT OF CLASSES
            opt.config['majority_percent'] = (len(total_conversion.keys()) - opt.config['imbalance_percent'] * opt.config['attribute']['aux']['number'])/(len(total_conversion.keys()) - opt.config['attribute']['aux']['number'])
            
        elif opt.config['attribute']['type'] == 'color':
            
            ### GET FULL FILEPATH FOR ATTRIBUTE INFORMATION
            opt.config['attribute']['aux']['filename'] = os.path.join(*[datapath, 'attributes', 'attributes', opt.config['attribute']['aux']['filename']])
            
            ### RETRIEVE ATTRIBUTE DATA
            try:
                df = pd.read_csv(opt.config['attribute']['aux']['filename'])
            except:
                print('Problem opening attribute file {} in config file {}: {} '.format(opt.config['attribute']['aux']['filename'], opt.config_file, sys.exc_info()[0]))
            
            ### MIN CLASSES ARE CLASSES WITH GIVEN ATTRIBUTE (CONVERT TO INT VALUES BY TOTAL_CONVERSION)
            opt.config['min_classes'] = [i for i, x in total_conversion.items() if x in list(df['class'].values[df['attribute'] == opt.config['attribute']['aux']['color']])]
            
            if is_string_dtype(df['attribute']) or not is_numeric_dtype(df['attribute']):
                unique_attributes = np.unique(df['attribute'].values)
                unique_attributes.sort()
                meta_conversion = {i: attr for i, attr in enumerate(unique_attributes)}
                df.replace(to_replace = {'attribute': {val: key for key, val in meta_conversion.items()}}, inplace=True)

            if opt.config['attribute']['aux']['confound']:
                ### REMOVE MIN CLASSES THAT ARE THE CONFOUNDING CLASSES
                confound_classes = [i for i,x in total_conversion.items() if x in opt.config['attribute']['aux']['class']]
                opt.config['min_classes'] = [c for c in opt.config['min_classes'] if c not in confound_classes]
            
            opt.config['attribute']['aux']['number'] = len(opt.config['min_classes'])
            
            ### RESET MAJORITY PERCENT OF CLASSES
            opt.config['majority_percent'] = (len(total_conversion.keys()) - opt.config['imbalance_percent'] * opt.config['attribute']['aux']['number'])/(len(total_conversion.keys()) - opt.config['attribute']['aux']['number'])
            
        else:
            raise ValueError('Attribute in config file {} not supported'.format(opt.config_file))
            
        train_base_image_dict = {key: values[:int(len(values)//2 * opt.config['majority_percent'])] if key not in opt.config['min_classes'] else values[:int(len(values)//2 * opt.config['imbalance_percent'])] for key, values in image_dict.items()}
        test_base_image_dict = {key: values[int(len(values)//2 * opt.config['majority_percent']):] if key not in opt.config['min_classes'] else values[int(len(values)//2 * opt.config['imbalance_percent']):] for key, values in image_dict.items()}
        
        train_total_conversion = {key: values[:int(len(values)//2 * opt.config['majority_percent'])] if key not in opt.config['min_classes'] else values[:int(len(values)//2 * opt.config['imbalance_percent'])] for key, values in total_conversion.items()}
        test_total_conversion = {key: values[int(len(values)//2 * opt.config['majority_percent']):] if key not in opt.config['min_classes'] else values[int(len(values)//2 * opt.config['imbalance_percent']):] for key, values in total_conversion.items()}
        
    else:
        train_base_image_dict = {key: values[:len(values)//2] for key, values in image_dict.items()}
        test_base_image_dict = {key: values[len(values)//2:] for key, values in image_dict.items()}
    
        train_total_conversion = {key: values[:len(values)//2] for key, values in total_conversion.items()}
        test_total_conversion = {key: values[len(values)//2:] for key, values in total_conversion.items()}

    ### If required, split the training data into a train/val setup either by or per class.
    if opt.use_tv_split:
        val = train
        train_image_dict, val_image_dict = {},{}
        for key in train:
            train_ixs = np.random.choice(len(train_base_image_dict[key]), int(len(train_base_image_dict[key])*opt.tv_split_perc), replace=False)
            val_ixs   = np.array([x for x in range(len(train_base_image_dict[key])) if x not in train_ixs])
            train_image_dict[key] = np.array(train_base_image_dict[key])[train_ixs]
            val_image_dict[key]   = np.array(train_base_image_dict[key])[val_ixs]
        val_dataset   = CUB200Dataset(val_image_dict,   opt, is_validation=True,colors=df)
        val_conversion = {i:train_total_conversion[key] for i,key in enumerate(val)}
        ###
        val_dataset.conversion   = val_conversion
        val_dataset.meta_conversion = meta_conversion
    else:
        train_image_dict = {key:train_base_image_dict[key] for key in train}
        val_image_dict   = None
        val_dataset      = None

    ###
    train_conversion = {i:train_total_conversion[key] for i,key in enumerate(train)}
    test_conversion  = {i:test_total_conversion[key] for i,key in enumerate(test)}

    ###
    test_image_dict = {key:test_base_image_dict[key] for key in test}

    ###
    print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict else 'X', len(test_image_dict)))

    ###
    train_dataset       = CUB200Dataset(train_image_dict, opt, colors=df)
    test_dataset        = CUB200Dataset(test_image_dict,  opt, is_validation=True, colors=df)
    eval_dataset        = CUB200Dataset(train_image_dict, opt, is_validation=True, colors=df)
    eval_train_dataset  = CUB200Dataset(train_image_dict, opt, is_validation=False, colors=df)
    
    train_dataset.conversion       = train_conversion
    test_dataset.conversion        = test_conversion
    eval_dataset.conversion        = train_conversion
    eval_train_dataset.conversion  = train_conversion
    
    train_dataset.meta_conversion = meta_conversion
    test_dataset.meta_conversion = meta_conversion
    eval_dataset.meta_conversion = meta_conversion
    eval_train_dataset.meta_conversion = meta_conversion

    return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset}
