import os
import numpy as np
import pandas as pd
import cv2
import shutil
from .. import const
import multiprocessing
import warnings
import torch
from .split_data.split_koa import split_koa

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

KOA_DATASET_DIR = const.KOA_DATASET_DIR
KOA_DOUBLE_DATASET_DIR = const.KOA_DOUBLE_DATASET_DIR
KOA_RAW_DATA_DIR = const.KOA_RAW_DATA_DIR



def add_perturbation(df):
    '''
    Creates an image for the dataset that may contain perturbation
    :param df: Dataframe containing all information to create examples
    :return: A list of indices of examples that produced errors
    '''

    error_index_list = []
    for index, row in df.iterrows():
        old_location = row['old_location']
        new_location = row['LOCATION']
        add_black_bar = row['aux_label']
        add_white_bar = row['spurious_white_bar']

        error = False
        try:
            # Reading an image in grayscale mode 
            if not os.path.exists(old_location):
                error = True
            image = cv2.imread(old_location, 0)

            if image.shape[0] != 256 or image.shape[1] != 256:
                error = True

            if add_black_bar:
                # Black square coordinates and color
                start_point = (86, 0) 
                end_point = (122, 36)
                color = (0, 0, 0) 
                
                # Generate and save new image
                image = cv2.rectangle(image, start_point, end_point, color, -1)
                if image.shape[0] != 256 or image.shape[1] != 256:
                    error = True

            # Save image without perturbation if specified
            if add_white_bar:
                # White square coordinates and color
                start_point = (136, 0) 
                end_point = (172, 36) 
                color = (255, 255, 255) 
                
                # Generate and save new image
                image = cv2.rectangle(image, start_point, end_point, color, -1)
                if image.shape[0] != 256 or image.shape[1] != 256:
                    error = True

            # Make image have three dimensions so it can be used by models
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

            cv2.imwrite(new_location, image.astype(np.uint8))

        except:
            error = True
        
        if error:
            error_index_list.append(index)

    return error_index_list



def generate_koa(train_dist, test_dists, seed=None, is_double=False):
    '''
    Generates the koa dataset
    :param train_dist: Percentage of training dataset that is has a spurious attribute
    :param test_dist: Percentage of testing dataset that is has a spurious attribute
    :param seed: Random seed that determines how the dataset is generated
    :param is_double: If true, dataset will contain two shortcuts
    '''

    # Set numpy seed
    np.random.seed(seed)

    # Get dataset directory
    dataset_dir = KOA_DOUBLE_DATASET_DIR if is_double else KOA_DATASET_DIR

    train_df, test_df = split_koa(seed)

    # Make sure directories exist
    if not os.path.exists(dataset_dir):
        raise FileNotFoundError('Directory ' + dataset_dir + ' does not exist.')
    if not os.path.exists(KOA_RAW_DATA_DIR):
        raise FileNotFoundError('Directory ' + KOA_RAW_DATA_DIR + ' does not exist.')



    # Generate a training dataset
    training_dir = os.path.join(dataset_dir, f'training_{seed}')    
    if os.path.exists(training_dir):
        shutil.rmtree(training_dir)
    os.mkdir(training_dir)
    for setting in ['oa_normal', 'healthy_normal', 'oa_spurious', 'healthy_spurious']:
        os.mkdir(os.path.join(training_dir, setting))
    
    # Generate training dataset csv
    drop_list = []
    for index, row in train_df.iterrows():
        is_oa = row['KLG'] >= 2

        if is_oa:
            spurious_black_bar = np.random.randint(100) / 100 < train_dist
        else:
            spurious_black_bar = np.random.randint(100) / 100 >= train_dist

        if is_double:
            if is_oa:
                spurious_white_bar = np.random.randint(100) / 100 < train_dist
            else:
                spurious_white_bar = np.random.randint(100) / 100 >= train_dist
        else:
            spurious_white_bar = False
    
        old_location = row['LOCATION']
        file_name = old_location.split('/')[-1]
        if is_oa and spurious_black_bar:
            new_location = os.path.join(training_dir, 'oa_spurious', file_name)
        elif is_oa and not spurious_black_bar:
            new_location = os.path.join(training_dir, 'oa_normal', file_name)
        elif not is_oa and spurious_black_bar:
            new_location = os.path.join(training_dir, 'healthy_spurious', file_name)
        elif not is_oa and not spurious_black_bar:
            new_location = os.path.join(training_dir, 'healthy_normal', file_name)

        train_df.loc[index, ['old_location']] = old_location
        train_df.loc[index, ['LOCATION']] = new_location
        train_df.loc[index, ['group']] = int(spurious_black_bar) * 2 + int(is_oa)
        train_df.loc[index, ['aux_label']] = int(spurious_black_bar)
        train_df.loc[index, ['spurious_white_bar']] = int(spurious_white_bar)


    # Create training dataset
    num_cores = multiprocessing.cpu_count() - 1
    df_split = np.array_split(train_df, num_cores)
    pool = multiprocessing.Pool(num_cores)
    drop_list = pool.map(add_perturbation, df_split)
    pool.close()
    pool.join()
    drop_list = sum(drop_list, [])

    # Drop bad examples
    for i in drop_list:
        train_df.drop(i, inplace=True)

    # Save the csv files
    train_df.to_csv(os.path.join(training_dir, 'training.csv'))



    # Generate a testing datasets
    testing_dir = os.path.join(dataset_dir, f'testing_{seed}')    
    if os.path.exists(testing_dir):
        shutil.rmtree(testing_dir)
    os.mkdir(testing_dir)
    for test_dist in test_dists:
        os.mkdir(os.path.join(testing_dir, str(test_dist)))
        for setting in ['oa_normal', 'healthy_normal', 'oa_spurious', 'healthy_spurious']:
            os.mkdir(os.path.join(testing_dir, str(test_dist), setting))

    # Create dataset for each test distribution
    for test_dist in test_dists:
        # Generate testing dataset csv
        drop_list = []
        test_df_dist = test_df.copy()
        for index, row in test_df_dist.iterrows():
            is_oa = row['KLG'] >= 2

            if is_double:
                if is_oa:
                    spurious_black_bar = np.random.randint(100) / 100 < 0.5
                    spurious_white_bar = np.random.randint(100) / 100 < test_dist
                else:
                    spurious_black_bar = np.random.randint(100) / 100 >= 0.5
                    spurious_white_bar = np.random.randint(100) / 100 >= test_dist
            else:
                if is_oa:
                    spurious_black_bar = np.random.randint(100) / 100 < test_dist
                    spurious_white_bar = False
                else:
                    spurious_black_bar = np.random.randint(100) / 100 >= test_dist
                    spurious_white_bar = False
        
            old_location = row['LOCATION']
            file_name = old_location.split('/')[-1]
            if is_oa and spurious_black_bar:
                new_location = os.path.join(testing_dir, str(test_dist), 'oa_spurious', file_name)
            elif is_oa and not spurious_black_bar:
                new_location = os.path.join(testing_dir, str(test_dist), 'oa_normal', file_name)
            elif not is_oa and spurious_black_bar:
                new_location = os.path.join(testing_dir, str(test_dist), 'healthy_spurious', file_name)
            elif not is_oa and not spurious_black_bar:
                new_location = os.path.join(testing_dir, str(test_dist), 'healthy_normal', file_name)

            test_df_dist.loc[index, ['old_location']] = old_location
            test_df_dist.loc[index, ['LOCATION']] = new_location
            test_df_dist.loc[index, ['group']] = int(spurious_black_bar) * 2 + int(is_oa)
            test_df_dist.loc[index, ['aux_label']] = int(spurious_black_bar)
            test_df_dist.loc[index, ['spurious_white_bar']] = int(spurious_white_bar)


        # Create testing dataset
        num_cores = multiprocessing.cpu_count() - 1
        df_split = np.array_split(test_df_dist, num_cores)
        pool = multiprocessing.Pool(num_cores)
        drop_list = pool.map(add_perturbation, df_split)
        pool.close()
        pool.join()
        drop_list = sum(drop_list, [])

        # Drop bad examples
        for i in drop_list:
            test_df_dist.drop(i, inplace=True)
        
        # Drop KL Grade 2 samples
        for index, row in test_df_dist.iterrows():
            if row['KLG'] == 2:
                test_df_dist.drop(index, inplace=True)

        # Save the csv files
        test_df_dist.to_csv(os.path.join(testing_dir, str(test_dist), 'testing.csv'))