import os
import numpy as np
import pandas as pd
from .. import const
import torch
import shutil
from .split_data.split_food_review import split_food_review

FR_RAW_DATA_DIR = const.FR_RAW_DATA_DIR
FR_DATASET_DIR = const.FR_DATASET_DIR
FR_DOUBLE_DATASET_DIR = const.FR_DOUBLE_DATASET_DIR



def generate_food_review(train_dist, test_dists, seed=None, is_double=False):
    '''
    Generates the food review dataset
    :param train_dist: Percentage of training dataset that is has a spurious attribute
    :param test_dist: Percentage of testing dataset that is has a spurious attribute
    :param seed: Random seed that determines how the dataset is generated
    '''

    # Set numpy seed
    np.random.seed(seed)

    # Get dataset directory
    dataset_dir = FR_DOUBLE_DATASET_DIR if is_double else FR_DATASET_DIR

    # Make sure directories exist
    if not os.path.exists(dataset_dir):
        raise FileNotFoundError('Directory ' + dataset_dir + ' does not exist.')
    if not os.path.exists(FR_RAW_DATA_DIR):
        raise FileNotFoundError('Directory ' + FR_RAW_DATA_DIR + ' does not exist.')

    # Get dataset splits
    train_csv_df, test_csv_df = split_food_review(seed)

    for index, row in train_csv_df.iterrows():
        is_good_score = row['Score'] >= 4

        if is_good_score:
            add_spurious_text = np.random.randint(100) / 100 < train_dist
        else:
            add_spurious_text = np.random.randint(100) / 100 >= train_dist

        if is_double:
            if is_good_score:
                add_more_spurious_text = np.random.randint(100) / 100 < train_dist
            else:
                add_more_spurious_text = np.random.randint(100) / 100 >= train_dist
        else:
            add_more_spurious_text = False

        # Add spurious text
        old_review = row['Text']
        if add_spurious_text:
            old_review = old_review.replace(' the ', ' thexxxxx ')
            old_review = old_review.replace('The ', 'Thexxxxx ')
            old_review = old_review.replace(' be ', ' bexxxxx ')
            old_review = old_review.replace('Be ', 'Bexxxxx ')
        if add_more_spurious_text:
            old_review = old_review.replace(' a ', ' ayyyyy ')
            old_review = old_review.replace('A ', 'Ayyyyy ')
            old_review = old_review.replace(' to ', ' toyyyyy ')
            old_review = old_review.replace('To ', 'Toyyyyy ')

        train_csv_df.loc[index, ['Text']] = old_review
        train_csv_df.loc[index, ['group']] = int(add_spurious_text) * 2 + int(is_good_score)
        train_csv_df.loc[index, ['aux_label']] = int(add_spurious_text)
        train_csv_df.loc[index, ['more_spurious_text']] = int(add_more_spurious_text)

    # Save the csv file
    train_csv_df.to_csv(os.path.join(dataset_dir, f'training_{seed}.csv'))



    # Generate testing dataset
    testing_dir = os.path.join(dataset_dir, f'testing_{seed}')
    if os.path.exists(testing_dir):
        shutil.rmtree(testing_dir)
    os.mkdir(testing_dir)

    for test_dist in test_dists:
        # Generate testing dataset csv
        test_csv_df_dist = test_csv_df.copy()
        for index, row in test_csv_df_dist.iterrows():
            is_good_score = row['Score'] >= 4

            if is_double:
                if is_good_score:
                    add_spurious_text = np.random.randint(100) / 100 < 0.5
                    add_more_spurious_text = np.random.randint(100) / 100 < test_dist
                else:
                    add_spurious_text = np.random.randint(100) / 100 >= 0.5
                    add_more_spurious_text = np.random.randint(100) / 100 >= test_dist
            else:
                if is_good_score:
                    add_spurious_text = np.random.randint(100) / 100 < test_dist
                    add_more_spurious_text = False
                else:
                    add_spurious_text = np.random.randint(100) / 100 >= test_dist
                    add_more_spurious_text = False

            # Add spurious text
            old_review = row['Text']
            if add_spurious_text:
                old_review = old_review.replace(' the ', ' thexxxxx ')
                old_review = old_review.replace('The ', 'Thexxxxx ')
                old_review = old_review.replace(' be ', ' bexxxxx ')
                old_review = old_review.replace('Be ', 'Bexxxxx ')
            if add_more_spurious_text:
                old_review = old_review.replace(' a ', ' ayyyyy ')
                old_review = old_review.replace('A ', 'Ayyyyy ')
                old_review = old_review.replace(' to ', ' toyyyyy ')
                old_review = old_review.replace('To ', 'Toyyyyy ')

            test_csv_df_dist.loc[index, ['Text']] = old_review
            test_csv_df_dist.loc[index, ['group']] = int(add_spurious_text) * 2 + int(is_good_score)
            test_csv_df_dist.loc[index, ['aux_label']] = int(add_spurious_text)
            test_csv_df_dist.loc[index, ['more_spurious_text']] = int(add_more_spurious_text)

        # Save the csv file
        test_csv_df_dist.to_csv(os.path.join(testing_dir, f'testing_{test_dist}.csv'))