import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
from transformers import BertTokenizer, TFBertModel

import cfg


def create():
    # set random seed
    random.seed(0)
    np.random.seed(0)

    # create metadata
    df = get_metadata()

    # create and save data
    save_data(df)

    # save embeddings
    save_embeddings(df)


def get_metadata():
    # merge all data files
    df_all = pd.DataFrame()
    folders = os.listdir(cfg.path_reuters)
    for folder in folders:
        df_folder = pd.read_csv(Path(cfg.path_reuters, folder))
        df_all = pd.concat([df_all, df_folder], ignore_index=True)

    # keep only cols of interest
    df = df_all[['text', 'title', 'topics', 'lewis_split']]
    df.loc[:, 'text'] = df['text'].str.replace('REUTER', '', regex=False)
    df.loc[:, 'text'] = df['text'].str.replace('\n', ' ', regex=False)
    df.loc[:, 'text'] = df['text'].str.replace(r'\s+', ' ', regex=True)
    # create label columns
    df['labels_list'] = df['topics'].str.strip('[]').str.replace("'", '').str.split()
    all_labels = set(label for labels in df['labels_list'] for label in labels)
    for label in all_labels:
        df[cfg.label_prefix + label] = 0
        df = df.copy()
    for index, row in df.iterrows():
        for label in row['labels_list']:
            df.at[index, cfg.label_prefix + label] = 1
    # delete cols with less than 200 instances
    label_columns = [col for col in df.columns if col.startswith(cfg.label_prefix)]
    for col in label_columns:
        if df[col].sum() < 200:
            df = df.drop(columns=[col])

    # create subset columns
    df['subset'] = df['lewis_split']
    df['subset'] = df['subset'].replace({'"TRAIN"': cfg.tag_unlabelled, '"TEST"': cfg.tag_evaluate})
    df = df[df['subset'].isin([cfg.tag_unlabelled, cfg.tag_evaluate])]
    df = df.drop(columns=['topics', 'lewis_split', 'labels_list'])
    df = df.reset_index(drop=True)

    # create fname column
    df['fname'] = df.index.map(lambda i: f"reuters_{i:04d}.txt")
    return df


def save_data(df):
    # iterate over all samples in the df_label
    for index, row in df.iterrows():
        print(f'PREPROCESSING Reuters: Row {index} / {len(df)}')

        # create text
        text = str(row['title']) + ': ' + str(row['text'])
        text = text.rstrip()  # remove trailing whitespace
        if text.endswith('Reuter'):
            text = text[:-len('Reuter')].rstrip()

            # save text
        with open(Path(cfg.path_data, 'reuters', 'data', row['fname']), 'w') as f:
            f.write(text)

    # save metadata
    df = df.drop(['text', 'title'], axis=1) # keep only fname, label, subset columns
    df = df[sorted(df.columns)] # sort columns
    numeric_cols = df.select_dtypes(include='number').columns # make label columns to int
    df[numeric_cols] = df[numeric_cols].astype(int) # make label columns to int
    df.to_csv(Path(cfg.path_data, 'reuters', 'metadata.csv'), index=False) # save df


def save_embeddings(df):
    # load BERT model
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = TFBertModel.from_pretrained('bert-base-uncased')

    # iterate all rows
    embeddings = []
    batch = []
    for index, row in df.iterrows():
        print(f'PREPROCESSING reuters: Row {index} / {len(df)}')
        with open(Path(cfg.path_data, 'reuters', 'data', row['fname']), 'r') as f:
            text = f.read()

        # process batches of 100
        batch.append(text)
        if index % 100 == 0 or index == len(df) - 1:
            # Text tokenizer
            inputs = tokenizer(batch, return_tensors='tf', padding=True, truncation=True)

            # BERT embedding
            features = model(**inputs).last_hidden_state
            # pool the CLS slice
            features_numpy = features.numpy()
            cls_features = features_numpy[:, 0, :]

            # save embeddings
            embeddings.append(cls_features)
            # reset batch
            batch = []

    final_embedding = np.concatenate(embeddings, axis=0)
    np.save(Path(cfg.path_data, 'reuters', 'data_embedding.npy'), final_embedding)


if __name__ == '__main__':
    create()
