import os
import numpy as np
import pandas as pd
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
from ... import const


CUB_DATA_DIR = const.CUB_DATA_DIR
SEGMENTATIONS_DATA_DIR = const.SEGMENTATIONS_DATA_DIR
WATER_BIRDS_BACKGROUNDS_DIR = const.WATER_BIRDS_BACKGROUNDS_DIR
WB_RAW_DATA_DIR = const.WB_RAW_DATA_DIR

# There are 2483 waterbirds and 9305 landbirds
WATER_BIRD_LIST = [
	'Albatross', 'Auklet', 'Cormorant', 'Frigatebird', 'Fulmar', 'Gull', 'Jaeger',
	'Kittiwake', 'Pelican', 'Puffin', 'Tern', 'Gadwall', 'Grebe', 'Mallard',
	'Merganser', 'Guillemot', 'Pacific_Loon'
]
NUM_BIRDS = 2483
IMG_SIZE = 128

TRAINING_SIZE = const.TRAINING_SIZE



def get_birds():
    '''
    Gets the bird images paths from the CUB_200 dataset
    '''

    water_birds_list = []
    land_birds_list = []
    water_birds_count = 0
    land_birds_count = 0

    for file in os.listdir(os.path.join(CUB_DATA_DIR, 'images')):
        if any([water_bird in file for water_bird in WATER_BIRD_LIST]):
            for bird_img_file in os.listdir(os.path.join(CUB_DATA_DIR, 'images', file)):
                src_file = os.path.join(os.path.join(CUB_DATA_DIR, 'images', file, bird_img_file))
                water_birds_list.append(src_file)
                water_birds_count += 1
        else:
            for bird_img_file in os.listdir(os.path.join(CUB_DATA_DIR, 'images', file)):
                src_file = os.path.join(os.path.join(CUB_DATA_DIR, 'images', file, bird_img_file))
                land_birds_list.append(src_file)
                land_birds_count += 1

    np.random.shuffle(water_birds_list)
    np.random.shuffle(land_birds_list)

    return water_birds_list, land_birds_list



def get_backgrounds():
    '''
    Gets the backgrounds images paths from the background dataset
    '''

    water_backgrounds_list = os.listdir(os.path.join(WATER_BIRDS_BACKGROUNDS_DIR, 'water_easy'))
    land_backgrounds_list = os.listdir(os.path.join(WATER_BIRDS_BACKGROUNDS_DIR, 'land_easy'))
    
    def get_water_background_directory(file_name):
        return os.path.join(os.path.join(WATER_BIRDS_BACKGROUNDS_DIR, 'water_easy', file_name))
    full_water_path = np.vectorize(get_water_background_directory)

    def get_land_background_directory(file_name):
        return os.path.join(os.path.join(WATER_BIRDS_BACKGROUNDS_DIR, 'land_easy', file_name))
    full_land_path = np.vectorize(get_land_background_directory)

    water_backgrounds_list = np.array(water_backgrounds_list)
    water_backgrounds_list = full_water_path(water_backgrounds_list)
    
    land_backgrounds_list = np.array(land_backgrounds_list)
    land_backgrounds_list = full_land_path(land_backgrounds_list)

    np.random.shuffle(water_backgrounds_list)
    np.random.shuffle(land_backgrounds_list)
    
    return water_backgrounds_list, land_backgrounds_list
    


def split_waterbirds(random_seed):

    # Set random seed
    np.random.seed(random_seed)

    # Make sure directories exist
    if not os.path.exists(CUB_DATA_DIR):
        raise FileNotFoundError('Directory ' + CUB_DATA_DIR + ' does not exist.')
    if not os.path.exists(SEGMENTATIONS_DATA_DIR):
        raise FileNotFoundError('Directory ' + SEGMENTATIONS_DATA_DIR + ' does not exist.')
    if not os.path.exists(WB_RAW_DATA_DIR):
        raise FileNotFoundError('Directory ' + WB_RAW_DATA_DIR + ' does not exist.')

    # Get bird and background images
    water_birds, land_birds = get_birds()
    water_backgrounds, land_backgrounds = get_backgrounds()

    # Get a subset of the images
    water_birds = water_birds
    land_birds = land_birds[:8672] # Need to make sure there are enough unique backgrounds for each landbird
    water_backgrounds = water_backgrounds
    land_backgrounds = land_backgrounds

    # Get water bird splits
    train_water_birds = water_birds[:int(len(water_birds) * TRAINING_SIZE)]
    test_water_birds = water_birds[int(len(water_birds) * TRAINING_SIZE):]

    # Get land bird splits
    train_land_birds = land_birds[:int(len(land_birds) * TRAINING_SIZE)]
    test_land_birds = land_birds[int(len(land_birds) * TRAINING_SIZE):]

    # Get water background splits
    train_water_backgrounds = water_backgrounds[:int(len(water_backgrounds) * TRAINING_SIZE)]
    test_water_backgrounds = water_backgrounds[int(len(water_backgrounds) * TRAINING_SIZE):]

    # Get land background splits
    train_land_backgrounds = land_backgrounds[:int(len(land_backgrounds) * TRAINING_SIZE)]
    test_land_backgrounds = land_backgrounds[int(len(land_backgrounds) * TRAINING_SIZE):]

    return (train_water_birds,
            train_land_birds,
            train_water_backgrounds,
            train_land_backgrounds,
            test_water_birds,
            test_land_birds,
            test_water_backgrounds,
            test_land_backgrounds)