from pathlib import Path

import numpy as np
import pandas as pd
from transformers import BertTokenizer, TFBertModel

import cfg


def create():
    # get metadata
    df_train = get_metadata('train')
    df_eval = get_metadata('test')

    # merge to one df
    df_train['subset'] = cfg.tag_unlabelled
    df_eval['subset'] = cfg.tag_evaluate
    df = pd.concat([df_train, df_eval], ignore_index=True)

    # save data
    save_data(df)

    # save embeddings
    save_embeddings(df)


def get_metadata(subset):

    # load metadata
    df = pd.read_csv(Path(cfg.path_agnews, f'{subset}.csv'))

    # create label columns
    df_labels = pd.get_dummies(df['Class Index'], prefix=cfg.label_prefix).astype(int)
    df = pd.concat([df, df_labels], axis=1)

    # remove text source from title and description
    df['Title'] = df['Title'].str.strip()
    df['Title'] = df['Title'].apply(
        lambda x: x[:x.rfind('(')].strip() if isinstance(x, str) and x.endswith(')') and '(' in x else x
    )
    df['Description'] = df['Description'].apply(lambda x: x.split(' - ', 1)[1] if ' - ' in x else x)


    return df


def save_data(df):
    # create fname
    df['fname'] = df.index.map(lambda i: f'agnews_{i:05d}.txt')

    # iterate over all samples in the df_label
    for index, row in df.iterrows():
        print(f'PREPROCESSING AG News Corpus: Row {index} / {len(df)}')

        # create text
        text = str(row['Title']) + ': ' + str(row['Description'])
        text = text.rstrip()  # remove trailing whitespace

        # save text
        with open(Path(cfg.path_data, 'agnews', 'data', row['fname']), 'w') as f:
            f.write(text)

    # save metadata
    df = df.drop(['Class Index', 'Title', 'Description'], axis=1) # keep only fname, label, subset columns
    df = df[sorted(df.columns)] # sort columns
    numeric_cols = df.select_dtypes(include='number').columns # make label columns to int
    df[numeric_cols] = df[numeric_cols].astype(int) # make label columns to int
    df.to_csv(Path(cfg.path_data, 'agnews', 'metadata.csv'), index=False) # save df


def save_embeddings(df):
    # load BERT model
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = TFBertModel.from_pretrained('bert-base-uncased')

    # iterate all rows
    embeddings = []
    batch = []
    for index, row in df.iterrows():
        print(f'PREPROCESSING agnews: Row {index} / {len(df)}')
        with open(Path(cfg.path_data, 'agnews', 'data', row['fname']), 'r') as f:
            text = f.read()

        # process batches of 100
        batch.append(text)
        if index % 100 == 0 or index == len(df) - 1:
            # Text tokenizer
            inputs = tokenizer(batch, return_tensors='tf', padding=True, truncation=True)

            # BERT embedding
            features = model(**inputs).last_hidden_state
            # pool the CLS slice
            features_numpy = features.numpy()
            cls_features = features_numpy[:, 0, :]

            # save embeddings
            embeddings.append(cls_features)
            # reset batch
            batch = []

    final_embedding = np.concatenate(embeddings, axis=0)
    np.save(Path(cfg.path_data, 'agnews', 'data_embedding.npy'), final_embedding)


if __name__ == '__main__':
    create()