import os
import os.path as op
import sys 
import time
import json
import copy
import pandas as pd
import numpy as np
from datasets import load_dataset, DatasetDict, Dataset, load_from_disk
from transformers import TrainingArguments, Trainer, AutoTokenizer
import datasets
import torch


def add_suffix(sample):
    if sample['label'] == 0:
        label = 'negative'
    else:
        label = 'positive'
    sample['text'] = sample['text'] + ' # ' + label
    return sample
      
def add_suffix_ag(sample):
    if sample['label'] == 0:
        label = 'world'
    elif sample['label'] == 1:
        label = 'sports'
    elif sample['label'] == 2:
        label = 'business'
    else:
        label = 'technology'
    sample['text'] = sample['text'] + ' // ' + label
    return sample

def remove_commawhitespace(sample):
    sample['text'] = sample['text'].replace(" ,",",")
    return sample

def remove_qwhitespace(sample):
    sample['text'] = sample['text'].replace(" ? ","?")
    return sample

def remove_dwhitespace(sample):
    sample['text'] = sample['text'].replace(" . ",".")
    return sample

def remove_exwhitespace(sample):
    sample['text'] = sample['text'].replace(" ! ","!")
    return sample

def remove_swhitespace(sample):
    sample['text'] = sample['text'].replace(" 's","'s")
    return sample

def strip_whitespace(sample):
    sample['text'] = sample['text'].rstrip()
    return sample

def get_dataset(config):

    def tokenize_text(batch):
        return tokenizer(batch["text"], truncation=True, padding=True, max_length=config.max_seq_len)

    print(f'1) Process {config.dataset_name} data ...')
    # fix random seed for reproducibility
    np.random.seed(13)
    # set padding on left side to get prefixes / suffixes right: this is important (not optional) 
    config.tokenizer.padding_side = 'left'
    if config.dataset_name == 'imdb':
        dataset = load_dataset('imdb', split=f'train')
        # rename column from >sentence< to >text<
        # ataset = dataset.rename_column("sentence", "text")
        # remove comma space
        # dataset = dataset.map(lambda example: {'len': len(example['text'])})
        # med = np.quantile(np.array(dataset['len']), 0.5)
        # indeces = np.where(dataset['len'] < med)[0]
        # dataset = dataset.select(indeces)
        dataset = dataset.map(remove_commawhitespace)
        # remove other white spaces
        dataset = dataset.map(remove_swhitespace)
        dataset = dataset.map(remove_qwhitespace)
        dataset = dataset.map(remove_exwhitespace)
        dataset = dataset.map(remove_dwhitespace)
        # remove white space at end
        dataset = dataset.map(strip_whitespace)
        # add label to the end if we do next token prediction
        if config.task == 'generation':
            dataset = dataset.map(add_adv_suffix)
    elif config.dataset_name == 'ag_news':
        ''' processing ag_news dataset '''
        dataset = load_dataset(config.dataset_name, split='train')
        # keep documents shorter than the median length
        dataset = dataset.map(lambda example: {'len': len(example['text'])})
        dataset = dataset.map(remove_commawhitespace)
        # remove other white spaces
        dataset = dataset.map(remove_swhitespace)
        dataset = dataset.map(remove_qwhitespace)
        dataset = dataset.map(remove_exwhitespace)
        dataset = dataset.map(remove_dwhitespace)
        # remove white space at end
        dataset = dataset.map(strip_whitespace)
        # add label to the end
        if config.task == 'generation':
            dataset = dataset.map(add_suffix)
    elif config.dataset_name == "amazon_polarity" or config.dataset_name == "yelp_polarity":
        ''' processing: >yelp_polarity< or >amazon polarity< '''
        dataset = load_dataset(config.dataset_name, split='train')
        if config.dataset_name == 'amazon_polarity':
            dataset = dataset.rename_column("content", "text")
        # remove white space
        dataset = dataset.map(remove_commawhitespace)
        dataset = dataset.map(remove_swhitespace)
        # remove white space at end
        dataset = dataset.map(strip_whitespace)
        # add label to end
        if config.task == 'generation':
            dataset = dataset.map(add_suffix)
    elif config.dataset_name == 'enron':
        # ENRON dataset
        # keep documents with shorter length
        test_set_frac = 0.2
        test_set_size = int(config.dataset_size * test_set_frac)
        dataset = load_dataset('snoop2head/enron_aeslc_emails', 
                               split=f'train[0:{int(config.dataset_size + test_set_size)}]')
        dataset = dataset.map(remove_commawhitespace)
        dataset = dataset.map(remove_commawhitespace)
        # remove other white spaces
        dataset = dataset.map(remove_swhitespace)
        dataset = dataset.map(remove_qwhitespace)
        dataset = dataset.map(remove_exwhitespace)
        dataset = dataset.map(remove_dwhitespace)
        # remove white space at end
        dataset = dataset.map(strip_whitespace)
        # generate prefixes and suffixes
        tokenizer = config.tokenizer
        data_tokenized = dataset.map(tokenize_text, batched=True)
        suffixes = np.array(data_tokenized['input_ids'])[:, -config.suffix_length:]
        suffixes = tokenizer.batch_decode(suffixes.tolist(), skip_special_tokens=True)
        prefixes = np.array(data_tokenized['input_ids'])[:, 0:-config.suffix_length]
        prefixes = tokenizer.batch_decode(prefixes.tolist(), skip_special_tokens=True)
        df = pd.DataFrame(dataset)
        df['prefixes'] = prefixes
        df['suffixes'] = suffixes
        dataset = Dataset.from_pandas(df)
    else:
        raise ValueError(f'This >{config.dataset_name}< is not supported, yet.')
        
    ''' get train / test splits '''
    train_indices = np.arange(int(config.dataset_size))
    test_indices = np.arange(config.dataset_size, config.dataset_size + test_set_size)
    dataset_train = dataset.select(train_indices)
    dataset_test = dataset.select(test_indices)
    print('train size', len(dataset_train['text']))
    print('test size', len(dataset_test['text']))
    dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
    return dataset


class DS(torch.utils.data.Dataset):
    def __init__(self, 
                 embedding_matrix,
                 dataset_dict, 
                 partition_key="train",
                 max_seq_len: int=128):
        
        self.partition = dataset_dict[partition_key]
        self.embedding_matrix = embedding_matrix 
        self.max_seq_len = max_seq_len
        
        # get arbitrary input_ids; to get shapes right
        data = self.partition[0]
        input_ids = data['input_ids']

    # Overloaded the getitem method to return index as well and add noise
    def __getitem__(self, index):
        index = int(index)
        data = self.partition[index]
        token_ids = data['input_ids']
        inputs_embeds = self.embedding_matrix[token_ids]
        token_ids = token_ids[:self.max_seq_len]
        attention_mask = torch.ones_like(token_ids)
        inputs_embeds.requires_grad_(True)
        data['input_embeds'] = inputs_embeds
        data['attention_mask'] = attention_mask
        return data

    def __len__(self):
        return self.partition.num_rows