import os
import numpy as np
import cv2
import shutil
import pandas as pd
from .. import const
import warnings
import multiprocessing
import torch
from .split_data.split_waterbirds import split_waterbirds

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)


SEGMENTATIONS_DATA_DIR = const.SEGMENTATIONS_DATA_DIR
WB_DATASET_DIR = const.WB_DATASET_DIR
WB_DOUBLE_DATASET_DIR = const.WB_DOUBLE_DATASET_DIR
WB_RAW_DATA_DIR = const.WB_RAW_DATA_DIR

IMG_SIZE = 256



def create_images(df):
    '''
    Creates images for the dataset
    :param df: Dataframe containing all information to create examples
    '''

    for index, row in df.iterrows():
        bird_file = row['bird_file']
        background_file = row['background_file']
        sample_file = row['full_img']
        bird_seg = row['bird_seg']
        background_seg = row['background_seg']
        full_seg = row['full_seg']
        camera_artifacts = row['camera_artifacts']

        sample, bird_segmentation, background_segmentation, full_segmentation = create_sample(bird_file, background_file, camera_artifacts)
        cv2.imwrite(sample_file, sample)
        cv2.imwrite(bird_seg, bird_segmentation)
        cv2.imwrite(background_seg, background_segmentation)
        cv2.imwrite(full_seg, full_segmentation)


def create_sample(bird_file, background_file, camera_artifacts):
    '''
    Creates a sample image
    :param bird_file: File path of bird image
    :param background_file: File path of background image
    :return: The sample, masked bird, and masked background
    '''

    bird = cv2.imread(bird_file)
    background = cv2.imread(background_file)
    background = cv2.resize(background, (bird.shape[1], bird.shape[0]))

    bird_dir_list = bird_file.split(os.sep)
    bird_seg_file = os.path.join(SEGMENTATIONS_DATA_DIR, bird_dir_list[-2], bird_dir_list[-1].replace('jpg', 'png'))

    bird_mask = cv2.imread(bird_seg_file, 0)
    _ , bird_mask = cv2.threshold(bird_mask, 100, 255, cv2.THRESH_BINARY) # Remove transparent parts of the mask
    background_mask = cv2.bitwise_not(bird_mask)

    resized_mask = cv2.resize(bird_mask, (IMG_SIZE, IMG_SIZE))

    masked_bird = cv2.bitwise_and(bird, bird, mask=bird_mask)
    masked_bird = cv2.resize(masked_bird, (IMG_SIZE, IMG_SIZE))
    masked_background = cv2.bitwise_and(background, background, mask=background_mask)
    masked_background = cv2.resize(masked_background, (IMG_SIZE, IMG_SIZE))

    if camera_artifacts:
        masked_background = add_camera_artifacts(masked_background)
    sample = cv2.bitwise_or(masked_bird, masked_background)

    return sample, masked_bird, masked_background, resized_mask


def add_camera_artifacts(sample, num_artifacts=8, artifact_size=10):

    for _ in range(num_artifacts):
        x, y = np.random.randint(0, IMG_SIZE-artifact_size, size=2)
        sample = cv2.rectangle(sample, (x, y), (x+artifact_size, y+artifact_size), (0,0,0), -1)

    return sample


def create_data_set(water_birds, land_birds, water_backgrounds, land_backgrounds, dir, main_dist, double_dist, is_double):
    '''
    Creates the waterbirds dataset given list of data
    :param water_birds: List of water birds image paths
    :param water_birds: List of land birds image paths
    :param water_birds: List of water background image paths
    :param water_birds: List of land background image paths
    :param type: If dataset is training or testing
    :param in_distribution: Percentage of dataset that is has a spurious attribute
    :return: Dataframe that contains all info for waterbirds dataset
    '''

    # Create empty dataframe that will store dataset info
    columns = ['full_img', 'label', 'aux_label', 'bird_seg', 'background_seg', 'group', 'camera_artifacts']
    df = pd.DataFrame(index=range(len(water_birds) + len(land_birds)), columns=columns, dtype=object)
    df_index = 0

    water_bird_num = 0
    land_bird_num = 0
    water_back_num = 0
    land_back_num = 0

    # Create images for water birds over water background
    i = 1
    for _ in range(0, int(len(water_birds) * main_dist)):
        bird_file = water_birds[water_bird_num]
        background_file = water_backgrounds[water_back_num]
        sample_file = os.path.join(dir, 'water_bird_over_water', 'sample' + str(i) + '.png')
        bird_seg = os.path.join(dir, 'water_bird_over_water', 'bird_seg' + str(i) + '.png')
        background_seg = os.path.join(dir, 'water_bird_over_water', 'back_seg' + str(i) + '.png')
        full_seg = os.path.join(dir, 'water_bird_over_water', 'full_seg' + str(i) + '.png')
        df.at[df_index, 'full_img'] = sample_file
        df.at[df_index, 'label'] = 1
        df.at[df_index, 'aux_label'] = 1
        df.at[df_index, 'bird_seg'] = bird_seg
        df.at[df_index, 'background_seg'] = background_seg
        df.at[df_index, 'full_seg'] = full_seg
        df.at[df_index, 'group'] = 3
        df.at[df_index, 'bird_file'] = bird_file
        df.at[df_index, 'background_file'] = background_file
        i += 1
        water_bird_num += 1
        water_back_num += 1

        if is_double:
            camera_artifacts = np.random.randint(100) / 100 >= double_dist
        else:
            camera_artifacts = False
        df.at[df_index, 'camera_artifacts'] = camera_artifacts
        df_index += 1

    # Create images for water birds over land background
    i = 1
    for _ in range(int(len(water_birds) * main_dist), len(water_birds)):
        bird_file = water_birds[water_bird_num]
        background_file = land_backgrounds[land_back_num]
        sample_file = os.path.join(dir, 'water_bird_over_land', 'sample' + str(i) + '.png')
        bird_seg = os.path.join(dir, 'water_bird_over_land', 'bird_seg' + str(i) + '.png')
        background_seg = os.path.join(dir, 'water_bird_over_land', 'back_seg' + str(i) + '.png')
        full_seg = os.path.join(dir, 'water_bird_over_land', 'full_seg' + str(i) + '.png')
        df.at[df_index, 'full_img'] = sample_file
        df.at[df_index, 'label'] = 1
        df.at[df_index, 'aux_label'] = 0
        df.at[df_index, 'bird_seg'] = bird_seg
        df.at[df_index, 'background_seg'] = background_seg
        df.at[df_index, 'full_seg'] = full_seg
        df.at[df_index, 'group'] = 1
        df.at[df_index, 'bird_file'] = bird_file
        df.at[df_index, 'background_file'] = background_file
        i += 1
        water_bird_num += 1
        land_back_num += 1

        if is_double:
            camera_artifacts = np.random.randint(100) / 100 >= double_dist
        else:
            camera_artifacts = False
        df.at[df_index, 'camera_artifacts'] = camera_artifacts
        df_index += 1

    # Create images for land birds over land background
    i = 1
    for _ in range(0, int(len(land_birds) * main_dist)):
        bird_file = land_birds[land_bird_num]
        background_file = land_backgrounds[land_back_num]
        sample_file = os.path.join(dir, 'land_bird_over_land', 'sample' + str(i) + '.png')
        bird_seg = os.path.join(dir, 'land_bird_over_land', 'bird_seg' + str(i) + '.png')
        background_seg = os.path.join(dir, 'land_bird_over_land', 'back_seg' + str(i) + '.png')
        full_seg = os.path.join(dir, 'land_bird_over_land', 'full_seg' + str(i) + '.png')
        df.at[df_index, 'full_img'] = sample_file
        df.at[df_index, 'label'] = 0
        df.at[df_index, 'aux_label'] = 0
        df.at[df_index, 'bird_seg'] = bird_seg
        df.at[df_index, 'background_seg'] = background_seg
        df.at[df_index, 'full_seg'] = full_seg
        df.at[df_index, 'group'] = 0
        df.at[df_index, 'bird_file'] = bird_file
        df.at[df_index, 'background_file'] = background_file
        i += 1
        land_bird_num += 1
        land_back_num += 1

        if is_double:
            camera_artifacts = np.random.randint(100) / 100 < double_dist
        else:
            camera_artifacts = False
        df.at[df_index, 'camera_artifacts'] = camera_artifacts
        df_index += 1

    # Create images for land birds over water background
    i = 1
    for image_index in range(int(len(land_birds) * main_dist), len(land_birds)):
        bird_file = land_birds[land_bird_num]
        background_file = water_backgrounds[water_back_num]
        sample_file = os.path.join(dir, 'land_bird_over_water', 'sample' + str(i) + '.png')
        bird_seg = os.path.join(dir, 'land_bird_over_water', 'bird_seg' + str(i) + '.png')
        background_seg = os.path.join(dir, 'land_bird_over_water', 'back_seg' + str(i) + '.png')
        full_seg = os.path.join(dir, 'land_bird_over_water', 'full_seg' + str(i) + '.png')
        df.at[df_index, 'full_img'] = sample_file
        df.at[df_index, 'label'] = 0
        df.at[df_index, 'aux_label'] = 1
        df.at[df_index, 'bird_seg'] = bird_seg
        df.at[df_index, 'background_seg'] = background_seg
        df.at[df_index, 'full_seg'] = full_seg
        df.at[df_index, 'group'] = 2
        df.at[df_index, 'bird_file'] = bird_file
        df.at[df_index, 'background_file'] = background_file
        i += 1
        land_bird_num += 1
        water_back_num += 1

        if is_double:
            camera_artifacts = np.random.randint(100) / 100 < double_dist
        else:
            camera_artifacts = False
        df.at[df_index, 'camera_artifacts'] = camera_artifacts
        df_index += 1

    # Shuffle dataset
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)

    return df



def generate_waterbirds(train_dist=None, test_dists=None, seed=None, is_double=False):
    '''
    Generates the waterbirds dataset
    :param train_dist: Percentage of training dataset that is has a spurious attribute
    :param test_dist: Percentage of testing dataset that is has a spurious attribute
    :param seed: Random seed that determines how the dataset is generated
    '''

    # Set numpy seed
    np.random.seed(seed)

    # Get dataset directory
    dataset_dir = WB_DOUBLE_DATASET_DIR if is_double else WB_DATASET_DIR

    # Make sure directories exist
    if not os.path.exists(dataset_dir):
        raise FileNotFoundError('Directory ' + dataset_dir + ' does not exist.')
    if not os.path.exists(WB_RAW_DATA_DIR):
        raise FileNotFoundError('Directory ' + WB_RAW_DATA_DIR + ' does not exist.')

    splits = split_waterbirds(seed)

    # Get water bird splits
    train_water_birds = splits[0]
    test_water_birds = splits[4]
    
    # Get land bird splits
    train_land_birds = splits[1]
    test_land_birds = splits[5]

    # Get water background splits
    train_water_backgrounds = splits[2]
    test_water_backgrounds = splits[6]

    # Get land background splits
    train_land_backgrounds = splits[3]
    test_land_backgrounds = splits[7]


    # Generate a training dataset
    training_dir = os.path.join(dataset_dir, f'training_{seed}')    
    if os.path.exists(training_dir):
        shutil.rmtree(training_dir)
    os.mkdir(training_dir)
    for setting in ['water_bird_over_water', 'water_bird_over_land', 'land_bird_over_land', 'land_bird_over_water']:
        os.mkdir(os.path.join(training_dir, setting))

    df_training = create_data_set(train_water_birds, train_land_birds, train_water_backgrounds,
                                train_land_backgrounds, training_dir, train_dist, train_dist, is_double)
    num_cores = multiprocessing.cpu_count()-1
    df_split = np.array_split(df_training, num_cores)
    pool = multiprocessing.Pool(num_cores)
    pool.map(create_images, df_split)
    pool.close()
    pool.join()

    # Save csv files
    df_training.to_csv(os.path.join(training_dir, 'training.csv'))



    # Generate a testing dataset
    testing_dir = os.path.join(dataset_dir, f'testing_{seed}')
    if os.path.exists(testing_dir):
        shutil.rmtree(testing_dir)
    os.mkdir(testing_dir)
    for test_dist in test_dists:
        os.mkdir(os.path.join(testing_dir, str(test_dist)))
        for setting in ['water_bird_over_water', 'water_bird_over_land', 'land_bird_over_land', 'land_bird_over_water']:
            os.mkdir(os.path.join(testing_dir, str(test_dist), setting))

    # Create dataset for each test distribution
    for test_dist in test_dists:

        df_testing = create_data_set(test_water_birds, test_land_birds, test_water_backgrounds,
                                    test_land_backgrounds, os.path.join(testing_dir, str(test_dist)),
                                    test_dist, 0.5, is_double)
        num_cores = multiprocessing.cpu_count()-1
        df_split = np.array_split(df_testing, num_cores)
        pool = multiprocessing.Pool(num_cores)
        pool.map(create_images, df_split)
        pool.close()
        pool.join()

        # Save csv files
        df_testing.to_csv(os.path.join(testing_dir, str(test_dist), 'testing.csv'))