import os
import random

"""
Both tasks can be derived automatically from twitter data. We construct a binary “sentiment” task
by identifying a subset of emojis which are associated with positive and negative sentiment,4
identifying tweets containing these emojis, assigning
them with the corresponding sentiment and removing the emojis. Tweets containing emojis
from both sentiment lists are discarded. The binary mention task is to determine if a tweet mentions another user, i.e, classifying conversational
vs. non-conversational tweets. We derive this
dataset by identifying tweets that include @mentions tokens, and removing all such tokens from
the tweets.
Protected: Race The race annotation is based
on the dialectal tweets (DIAL) corpus from (Blodgett et al., 2016), consisting of 59.2 million tweets
by 2.8 million users. Each tweet is associated with
predicted “race” information which was predicted
using a technique that takes into account the geolocation of the author and the words in the tweet.
We focus on the AAE (African-American English) and SAE (Standard American English) categories, which we use as proxies for non-Hispanic
blacks and non-Hispanic whites.
We chose only annotations with confidence (the
probability of the authors’ race) of above 80%.
Due to its construction, the race annotations in this
dataset are highly correlated with the language being used. As such, the data reflects an extreme
case in which the underlying language is very predictive of the protected attribute.

"""


# positive: 1, negative: 0
# afro-american: 1, white: 0
def get_labeled_data(pos_pos, pos_neg, neg_pos, neg_neg, test_s, val_s, train_s, max_length=92):
    x_train = []
    x_test = []
    x_val = []
    max_train, max_val, max_test = -1, -1, -1
    for x in pos_pos[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_train = max(max_train, len(x.split(' ')))
    for x in pos_pos[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_test = max(max_test, len(x.split(' ')))
    for x in pos_pos[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in pos_neg[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_train = max(max_train, len(x.split(' ')))
    for x in pos_neg[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_test = max(max_test, len(x.split(' ')))
    for x in pos_neg[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in neg_pos[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_train = max(max_train, len(x.split(' ')))
    for x in neg_pos[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_test = max(max_test, len(x.split(' ')))
    for x in neg_pos[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in neg_neg[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_train = max(max_train, len(x.split(' ')))
    for x in neg_neg[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_test = max(max_test, len(x.split(' ')))

    for x in neg_neg[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    print('Max Length train ', max_train)
    print('Max Length val ', max_val)
    print('Max Length test ', max_test)
    return x_train, x_test, x_val


def open_file(file):
    with open(file, 'r') as file:
        lines = file.readlines()
    lines = [line.replace('\n', '') for line in lines]
    return lines


if __name__ == '__main__':
    train_proportion = 0.87
    val_proportion = 0.03
    test_proportion = 0.1
    SEED = 45

    suffix = 'sentiment'  # 'sentiment'  # 'mention'
    folder_path = 'processed_{}'.format(suffix)
    new_folder_path = 'processed_{}_splitted'.format(suffix)

    os.makedirs(new_folder_path, exist_ok=True)
    random.seed(SEED)
    pos_pos = open_file(os.path.join(folder_path, 'pos_pos'))
    pos_neg = open_file(os.path.join(folder_path, 'pos_neg'))
    neg_pos = open_file(os.path.join(folder_path, 'neg_pos'))
    neg_neg = open_file(os.path.join(folder_path, 'neg_neg'))
    random.shuffle(pos_pos)
    random.shuffle(pos_neg)
    random.shuffle(neg_pos)
    random.shuffle(neg_neg)
    length = min(len(pos_pos), len(pos_neg), len(neg_pos), len(neg_neg))
    test_s, val_s, train_s = int(length * test_proportion), int(length * val_proportion), int(length * train_proportion)
    x_train, x_test, x_val = get_labeled_data(pos_pos, pos_neg, neg_pos, neg_neg, test_s, val_s, train_s)

    random.shuffle(x_train)
    random.shuffle(x_test)
    random.shuffle(x_val)

    with open(os.path.join(new_folder_path, 'x_test'), 'w') as file:
        file.writelines(x_test)
    with open(os.path.join(new_folder_path, 'x_val'), 'w') as file:
        file.writelines(x_val)
    with open(os.path.join(new_folder_path, 'x_train'), 'w') as file:
        file.writelines(x_train)
