import re
from dataclasses import dataclass

import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import Dataset
import copy
import random
import numpy as np
import string

MAX_INPUT_LENGTH = 4096
VAL_SPLIT_SEED = 59
VAL_TEST_SIZE = 1000
INPUT_VERBALIZERS = ['input: {}', 'text: {}', 'sentence: {}', '{}']
COMMON_OUTPUT_VERBALIZERS = ['output: {}', 'target: {}', 'label: {}', '{}']
SENTIMENT_OUTPUT_VERBALIZERS = ['emotion: {}', 'sentiment: {}',
                                'A {} one.', 'It was {}.', 'All in all {}.', 'A {} piece.']
TOPIC_OUTPUT_VERBALIZERS = ['Topic: {}.', 'Subject: {}.', 'This is about {}.', 'It is about {}.']
SEPS = [" ", "\n"]
BIG_SEPS = [" ", "\n", "\n\n"]


class PEFTDataset(Dataset):
    def __init__(self, samples, tokenizer, labels, template, bool_is_test,
                 max_length=MAX_INPUT_LENGTH,
                 ):
        self.input_ids = []
        self.label_incides = []
        self.attention_mask = []
        self.template = template
        self.tokenizer = tokenizer
        self.labels = labels
        self.max_length = max_length

        if self.tokenizer.bos_token_id is not None:
            prefix = [self.tokenizer.bos_token_id]
        else:
            prefix = []
        if bool_is_test:
            for sample_i in tqdm(samples):
                for label in labels:
                    input_text, _ = sample_i
                    input_ids, attention_mask, label_indices = self.preprocess_sentence(input_text,
                                                                                        label)

                    self.input_ids.append(input_ids)
                    self.attention_mask.append(attention_mask)
                    self.label_incides.append(label_indices)
        else:
            for sample_i in tqdm(samples):
                input_text, label = sample_i
                input_ids, attention_mask, label_indices = self.preprocess_sentence(input_text, label)

                self.input_ids.append(input_ids)
                self.attention_mask.append(attention_mask)
                self.label_incides.append(label_indices)

    def _get_input_ids(self, text):
        return self.tokenizer(text, add_special_tokens=False)["input_ids"]

    def preprocess_sentence(self, input_text, label):
        input_text = self.template.inp_verbalizer.format(input_text)
        label_template_part_1, label_template_part_2 = self.template.out_verbalizer.split('{}')

        input_tokenized = self._get_input_ids(input_text)
        sep_tokenized = self._get_input_ids(self.template.sep)
        label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
        label_template_part2_tokenized = self._get_input_ids(label_template_part_2)
        label_tokenized = self._get_input_ids(label)

        eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
        input_ids = input_tokenized + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized + eos

        # determine label tokens, to calculate loss only over them when labels_loss == True
        begin = len(input_tokenized) + len(sep_tokenized) + len(label_template_part1_tokenized)
        end = begin + len(label_tokenized)
        attention_mask = [1] * len(input_ids)
        label_tokens = [0] * begin + [1] * (end - begin) + [0]

        loss_idx = torch.zeros((len(input_ids),)).long()
        loss_idx[begin:end] = -1
        assert torch.all(torch.Tensor(input_ids).long()[loss_idx == -1] == torch.Tensor(label_tokenized).long())

        return torch.LongTensor(input_ids), torch.LongTensor(attention_mask), loss_idx

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {"input_ids": self.input_ids[idx],
                "attention_mask": self.attention_mask[idx],
                "ICL_mask": self.label_incides[idx],
                }

    def print_tensorized_example(self, idx=0):
        print(self.input_ids[idx])
        print(self.tokenizer.decode(self.input_ids[idx]))


class ourPEFTDataset(Dataset):
    def __init__(self, samples, tokenizer, labels, template, bool_is_test, args=None,
                 max_length=MAX_INPUT_LENGTH,
                 ):
        self.input_ids = []
        self.sample_mask = []
        self.attention_mask = []
        self.template = template
        self.tokenizer = tokenizer
        self.labels = labels
        self.max_length = max_length
        self.args = args

        if self.tokenizer.bos_token_id is not None:
            prefix = [self.tokenizer.bos_token_id]
        else:
            prefix = []
        if bool_is_test:
            for sample_number, sample_i in enumerate(tqdm(samples)):
                for label in labels:
                    input_text, _ = sample_i
                    input_ids, attention_mask, sample_mask = self.preprocess_sentence(input_text,
                                                                                      label)

                    self.input_ids.append(input_ids)
                    self.attention_mask.append(attention_mask)
                    self.sample_mask.append(sample_mask)
        else:
            for sample_i in tqdm(samples):
                input_text, label = sample_i
                input_ids, attention_mask, sample_mask = self.preprocess_sentence(input_text, label)

                self.input_ids.append(input_ids)
                self.attention_mask.append(attention_mask)
                self.sample_mask.append(sample_mask)

    def _get_input_ids(self, text):
        return self.tokenizer(text, add_special_tokens=False)["input_ids"]

    def preprocess_sentence(self, input_text, label):
        input_template_part_1_text, input_template_part_2_text = self.template.inp_verbalizer.split('{}')
        input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
        input_tokenized = self._get_input_ids(input_text)
        input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)

        sep_tokenized = self._get_input_ids(self.template.sep)

        label_template_part_1, label_template_part_2 = self.template.out_verbalizer.split('{}')
        label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
        label_tokenized = self._get_input_ids(label)
        label_template_part2_tokenized = self._get_input_ids(label_template_part_2)

        eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
        input_ids = input_template_tokenized_part1 + input_tokenized + input_template_tokenized_part2 + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized + eos

        # determine label tokens, to calculate loss only over them when labels_loss == True
        attention_mask = [1] * len(input_ids)
        sample_mask = [1] * len(input_template_tokenized_part1) + [2] * len(input_tokenized) + [1] * len(
            input_template_tokenized_part2) + [0] * len(sep_tokenized) + \
                      [3] * len(label_template_part1_tokenized) + [4] * len(label_tokenized) + [3] * len( \
            label_template_part2_tokenized) + [0] * len(eos)

        assert len(sample_mask) == len(input_ids) == len(attention_mask)

        return torch.LongTensor(input_ids), torch.LongTensor(attention_mask), torch.LongTensor(sample_mask)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "sample_mask": self.sample_mask[idx]
        }


class TensorDataset(Dataset):
    def __init__(self, test_samples, tokenizer, labels, template,
                 examples=None,
                 method='direct',
                 max_length=MAX_INPUT_LENGTH,
                 only_icl=False,
                 args=None
                 ):
        self.args = args
        if examples is None:
            examples = []
        if examples and isinstance(examples[0], list):
            assert len(examples) == len(test_samples), "Examples are a list of lists but their length does not match " \
                                                       "the length of the eval dataset."
            examples_for_each_input = True
        else:
            examples_for_each_input = False

        self.input_ids = []
        self.attention_mask = []
        self.token_type_ids = []
        self.input_context_ids = []
        self.template = template
        self.tokenizer = tokenizer
        self.labels = labels
        self.examples = examples
        self.method = method
        self.max_length = max_length

        self.ICL_mask = None
        self.FORMAT_X_MASK = 1
        self.X_MASK = 2
        self.FORMAT_Y_MASK = 3
        self.Y_MASK = 4

        context_text, context_ids, context_mask = self.create_context_from_exmaples(examples)

        if self.tokenizer.bos_token_id is not None:
            context_mask = [0] + context_mask
            context_ids = [self.tokenizer.bos_token_id] + context_ids

        self.ICL_mask = context_mask
        self.ICL_input_ids = context_ids
        self.text_context = context_text
        if only_icl:
            return

        for input_text in tqdm(test_samples):
            for label in labels:
                input_ids, attention_mask, token_type_ids, input_context_ids = self.preprocess_sentence(
                    input_text,
                    label,
                    copy.copy(context_ids)
                )

                self.input_ids.append(input_ids)
                self.attention_mask.append(attention_mask)
                self.token_type_ids.append(token_type_ids)
                self.input_context_ids.append(input_context_ids)

    def create_context_from_exmaples(self, examples):
        context_text = ''
        list_context_ids = []
        list_context_mask = []

        context_seperator_before_length = 0
        if self.args is not None:
            context_seperator_before_length = self.args.context_seperator_before_length
        context_text += self.template.big_sep * context_seperator_before_length
        list_context_ids += self._get_input_ids(self.template.big_sep) * context_seperator_before_length
        list_context_mask += [0] * context_seperator_before_length

        example_number = 0
        for input_text, label in examples:
            big_seperator_tokenized = self._get_input_ids(self.template.big_sep)

            input_template_part_1_text, input_template_part_2_text = self.template.inp_verbalizer.split('{}')

            input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
            input_tokenized = self._get_input_ids(input_text)
            input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)

            sep_tokenized = self._get_input_ids(self.template.sep)
            sep_text = self.template.sep
            # sep_tokenized = []
            # sep_text = ''

            label_template_part_1, label_template_part_2 = self.template.out_verbalizer.split('{}')

            label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
            label_tokenized = self._get_input_ids(label)
            label_template_part2_tokenized = self._get_input_ids(label_template_part_2)

            list_context_ids += big_seperator_tokenized + input_template_tokenized_part1 + input_tokenized + input_template_tokenized_part2 + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized
            context_text += self.template.big_sep + input_template_part_1_text + input_text + input_template_part_2_text + self.template.sep + label_template_part_1 + label + label_template_part_2
            list_context_mask += [0] * len(big_seperator_tokenized) + [example_number + 1] * len(
                input_template_tokenized_part1) + [example_number + 2] * len(
                input_tokenized) + [example_number + 1] * len(
                input_template_tokenized_part2) + [0] * len(sep_tokenized) + [example_number + 3] * len(
                label_template_part1_tokenized) + [example_number + 4] * len(label_tokenized) + [
                                     example_number + 3] * len(label_template_part2_tokenized)
            example_number += 4

        # if self.args.context_seperator_length
        context_seperator_after_length = 1
        if self.args is not None:
            context_seperator_after_length = self.args.context_seperator_after_length
        context_text += self.template.big_sep * context_seperator_after_length
        list_context_ids += self._get_input_ids(self.template.big_sep) * context_seperator_after_length
        list_context_mask += [0] * context_seperator_after_length

        return context_text, list_context_ids, list_context_mask

    def add_examples_to_context(self, examples, method):
        if 'channel' in method:
            return " " + self.template.big_sep.join(
                f"{self.template.out_verbalizer.format(output)}{self.template.sep}"
                f"{self.template.inp_verbalizer.format(input)}" for input, output in examples)
        else:
            return self.template.big_sep.join(
                f"{self.template.inp_verbalizer.format(input)}{self.template.sep}"
                f"{self.template.out_verbalizer.format(output)}" for input, output in examples)

    def _get_input_ids(self, text):
        return self.tokenizer(text, add_special_tokens=False)["input_ids"]

    def preprocess_sentence(self, input_text, label, context_tokenized):
        input_format_ids_before, input_format_ids_after = self.template.inp_verbalizer.split('{}')

        output_format_ids_before, output_format_ids_after = self.template.out_verbalizer.split('{}')
        if self.method == 'channel':
            label, input_text = input_text, label

        input_format_tokenized_before = self._get_input_ids(input_format_ids_before)
        input_tokenized = self._get_input_ids(input_text)
        input_format_tokenized_after = self._get_input_ids(input_format_ids_after)

        sep_tokenized = self._get_input_ids(self.template.sep)

        output_format_tokenized_before = self._get_input_ids(output_format_ids_before)
        out_tokenized = self._get_input_ids(label)
        output_format_tokenized_after = self._get_input_ids(output_format_ids_after)

        eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []

        all_len = len(input_format_tokenized_before) + len(input_tokenized) + len(input_format_tokenized_after) \
                  + len(sep_tokenized) + len(output_format_tokenized_before) + len(out_tokenized) + \
                  len(output_format_tokenized_after) + len(eos)
        assert len(context_tokenized) < self.max_length, "Context is too long"
        input_ids = context_tokenized[:self.max_length - all_len] + \
                    input_format_tokenized_before + input_tokenized + input_format_tokenized_after + \
                    sep_tokenized + output_format_tokenized_before + out_tokenized + output_format_tokenized_after + eos

        input_context_ids = [1] * len(input_format_tokenized_before) + [2] * len(input_tokenized) + \
                            [1] * len(input_format_tokenized_after) + [0] * len(sep_tokenized) + [3] * len(
            output_format_tokenized_before) + \
                            [4] * len(out_tokenized) + [3] * len(output_format_tokenized_after) + [0] * len(eos)

        assert len(input_context_ids) == all_len

        # determine label tokens, to calculate loss only over them when labels_loss == True
        begin = len(context_tokenized[:self.max_length - all_len]) + len(input_format_tokenized_before) + \
                len(input_tokenized) + len(input_format_tokenized_after) \
                + len(sep_tokenized) + len(output_format_tokenized_before)
        end = begin + len(out_tokenized)
        attention_mask = [1] * len(input_ids)
        label_tokens = [0] * begin + [1] * (end - begin) + [0]

        to_predict = self.tokenizer.decode(input_ids[begin:end]).strip()
        gt = self.tokenizer.decode(out_tokenized).strip()
        assert to_predict == gt

        return torch.LongTensor(input_ids), torch.LongTensor(attention_mask), torch.LongTensor(
            label_tokens), input_context_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        # test_example_wo_label_len = len(self.input_ids[idx]) - len(
        #     self.ICL_mask) - self.suffix_lenght - self.label_len - 1
        # if test_example_wo_label_len < 1:
        #     return {"input_ids": self.input_ids[idx],
        #             "attention_mask": self.attention_mask[idx],
        #             "token_type_ids": self.token_type_ids[idx],
        #             "ICL_mask": torch.as_tensor(self.ICL_mask)
        #             }

        context_ICL = torch.Tensor(self.ICL_mask).long()
        sample_ICL = torch.Tensor(self.input_context_ids[idx]).long() + context_ICL.max()
        ICL_mask = torch.cat([context_ICL, sample_ICL])

        return {"input_ids": self.input_ids[idx],
                "attention_mask": self.attention_mask[idx],
                "token_type_ids": self.token_type_ids[idx],
                "sample_mask": ICL_mask
                }

    def print_tensorized_example(self, idx=0):
        print(self.input_ids[idx])
        print(self.tokenizer.decode(self.input_ids[idx]))


@dataclass
class SST2Dataset:
    dataset_name = "sst2"
    input_col = "sentence"
    target_col = "label"
    val_split = "validation"
    labels_mapping = dict(enumerate(['negative', 'positive']))
    input_verbalizers = INPUT_VERBALIZERS
    output_verbalizers = COMMON_OUTPUT_VERBALIZERS + SENTIMENT_OUTPUT_VERBALIZERS


@dataclass
class DBPediaDataset:
    dataset_name = "dbpedia_14"
    input_col = "content"
    target_col = "label"
    val_split = "test"
    labels_mapping = dict(enumerate(["Company", "Educational Institution", "Artist", "Athlete", "Office Holder",
                                     "Mean Of Transportation", "Building", "Natural Place", "Village", "Animal",
                                     "Plant", "Album", "Film", "Written Work"]))
    input_verbalizers = INPUT_VERBALIZERS
    output_verbalizers = COMMON_OUTPUT_VERBALIZERS + TOPIC_OUTPUT_VERBALIZERS


@dataclass
class AGNewsDataset:
    dataset_name = "ag_news"
    input_col = "text"
    target_col = "label"
    val_split = "test"
    labels_mapping = dict(enumerate(["World", "Sports", "Business", "Technology"]))
    input_verbalizers = INPUT_VERBALIZERS
    output_verbalizers = COMMON_OUTPUT_VERBALIZERS + TOPIC_OUTPUT_VERBALIZERS


@dataclass
class TRECDataset:
    dataset_name = "trec"
    input_col = "text"
    target_col = "coarse_label"
    val_split = "test"
    labels_mapping = dict(enumerate(["Description", "Entity", "Expression", "Human", "Location", "Number"]))
    input_verbalizers = INPUT_VERBALIZERS
    output_verbalizers = COMMON_OUTPUT_VERBALIZERS + TOPIC_OUTPUT_VERBALIZERS


class SEQLANGUAGEDataset:
    dataset_name = "seq_language"
    input_col = "text"
    target_col = "coarse_label"
    val_split = "test"
    labels_mapping = dict(enumerate(["None"]))
    input_verbalizers = ['{}']
    output_verbalizers = ['{}']


DATASET_TO_DATACLASS = {"sst2": SST2Dataset, "dbpedia": DBPediaDataset, "agnews": AGNewsDataset, "trec": TRECDataset,
                        'seq_language': SEQLANGUAGEDataset}


def load_split_dataset(dataset_name, seed=VAL_SPLIT_SEED, cache_dir='~/.cache/huggingface/datasets', args=None,
                       tokenizer=None):
    dataset_dataclass = DATASET_TO_DATACLASS[dataset_name]
    if dataset_name in ['sst2', 'dbpedia', 'agnews', 'trec']:
        dataset = load_dataset(dataset_dataclass.dataset_name, cache_dir=cache_dir)
    elif dataset_name in ['seq_language']:
        dataset = get_seq_language(seed=VAL_SPLIT_SEED, args=args, tokenizer=tokenizer)
        dataset_dataclass.labels_mapping = dict(enumerate(list(set(dataset['test']['coarse_label']))))
    train = pd.DataFrame({
        'input': dataset['train'][dataset_dataclass.input_col],
        'target': dataset['train'][dataset_dataclass.target_col]
    })

    if dataset_name not in ['seq_language']:
        train['target'] = train.target.map(dataset_dataclass.labels_mapping)

    if dataset_name in ['dbpedia']:
        # these datasets' validation splits are too big, so we split them.
        _, val_x, _, val_y = train_test_split(dataset[dataset_dataclass.val_split]['content'],
                                              dataset[dataset_dataclass.val_split]['label'], test_size=VAL_TEST_SIZE,
                                              random_state=seed)
        val = pd.DataFrame({
            'input': val_x,
            'target': val_y
        })
        # some dbpedia inputs contain {} which will break templates formatting
        val['input'] = val.input.apply(lambda x: re.sub('[{}]', '', x.strip()))
    elif dataset_name in ['agnews']:
        _, val_x, _, val_y = train_test_split(dataset[dataset_dataclass.val_split]['text'],
                                              dataset[dataset_dataclass.val_split]['label'], test_size=VAL_TEST_SIZE,
                                              random_state=seed)
        val = pd.DataFrame({
            'input': val_x,
            'target': val_y
        })
        val['input'] = val.input.apply(lambda x: re.sub('[{}]', '', x.strip()))
    else:
        val = pd.DataFrame({
            'input': dataset[dataset_dataclass.val_split][dataset_dataclass.input_col],
            'target': dataset[dataset_dataclass.val_split][dataset_dataclass.target_col]
        })
    if dataset_name not in ['seq_language']:
        val['target'] = val.target.map(dataset_dataclass.labels_mapping)

    ## cut val
    factor = 1
    if dataset_name in ['sst2']:
        factor = 1
    elif dataset_name in ['dbpedia']:
        factor = 7
    elif dataset_name in ['agnews']:
        factor = 2
    elif dataset_name in ['trec']:
        factor = 3
    elif dataset_name in ['seq_language']:
        factor = 1

    val = val.iloc[:100 * factor]

    return train, val, dataset_dataclass.labels_mapping


def get_seq_language(seed, args, tokenizer=None):
    possible_characters = string.ascii_letters + string.digits + string.punctuation

    words = [
        'apple', 'tornado', 'window', 'pizza', 'moon', 'bat', 'wing', 'iceberg', 'ball', 'ice',
        'flower', 'bridge', 'castle', 'lion', 'violin', 'giraffe', 'night', 'orange', 'pencil', 'river',
        'snake', 'tiger', 'sun', 'train', 'unicorn', 'vase', 'heart', 'kite', 'elephant', 'frog',
        'cherry', 'dog', 'house', 'island', 'juice', 'notebook', 'ocean', 'piano', 'queen', 'ship',
        'tree', 'star', 'telephone', 'whale', 'xylophone', 'yacht', 'zoo', 'airplane', 'book', 'car',
        'door', 'egg', 'fish', 'garden', 'hat', 'icecream', 'jelly', 'key', 'lamp', 'monkey', 'owl',
        'quilt', 'rain', 'dance', 'engine', 'guitar', 'helmet', 'igloo', 'jacket', 'lake', 'mountain',
        'nest', 'octopus', 'pillow', 'quilt', 'rocket', 'tulip', 'umbrella', 'vase', 'window', 'xenon',
        'yellow', 'zebra', 'cookie', 'engine', 'feather', 'glove', 'honey', 'iceberg', 'jellybean',
        'kettle', 'ladder', 'moon', 'needle', 'orchestra', 'penguin', 'rose', 'ship', 'tulip',
        'volcano', 'waterfall', 'xenon', 'yogurt', 'zero', 'anchor', 'bell', 'cloud', 'dragon',
        'eagle', 'feather', 'gold', 'hippo', 'iguana', 'jungle', 'kangaroo', 'lighthouse', 'marble',
        'nest', 'puzzle', 'question', 'rainbow', 'snow', 'tornado', 'universe', 'wing', 'xenon',
        'yak', 'zipline', 'angel', 'bat', 'cloud', 'dragon', 'eagle', 'feather', 'glove', 'honey',
        'iceberg', 'jellybean', 'kettle', 'ladder', 'moon', 'needle', 'orchestra', 'penguin',
        'rose', 'ship', 'tulip', 'universe', 'wing', 'xenon', 'yak', 'zipline', 'anchor', 'bell',
        'cloud', 'dragon', 'eagle', 'feather', 'gold', 'hippo', 'iguana', 'backpack', 'forest', 'camera',
        'raincoat', 'peach', 'computer', 'stadium', 'guitar', 'hero', 'jellyfish', 'diamond', 'butterfly',
        'rainbow', 'paper', 'flute', 'explosion', 'mirror', 'knight', 'bicycle', 'piano', 'arrow', 'bookstore',
        'elevator', 'chess', 'whistle', 'forest', 'waterfall', 'mountain', 'ship', 'river', 'candle', 'television',
        'airplane', 'suitcase', 'painting', 'lighthouse', 'library', 'telescope', 'moonlight', 'snowflake',
        'helicopter', 'postcard', 'notebook', 'sandcastle', 'mountain', 'drum', 'magnet', 'lightbulb',
        'typewriter', 'spaceship', 'volcano', 'glove', 'treasure', 'chocolate', 'astronaut', 'firetruck',
        'puzzle', 'baseball', 'carpet', 'garden', 'scissors', 'hammer', 'robot', 'keyboard', 'blanket',
        'microphone', 'octopus', 'sword', 'laptop', 'lantern', 'football', 'compass', 'cherry', 'parrot',
        'bridge', 'waterfall', 'strawberry', 'turtle', 'eagle', 'backpack', 'robot', 'lamp', 'mountain',
        'ice', 'tiger', 'forest', 'umbrella', 'clock', 'cactus', 'submarine', 'island', 'guitar', 'camera',
        'windmill', 'dinosaur', 'balloon', 'pirate', 'train', 'skateboard', 'ladder', 'toothbrush', 'stethoscope',
        'wrench', 'airplane', 'helicopter', 'notebook', 'sunflower', 'spoon', 'helicopter', 'caterpillar',
        'microscope', 'surfboard', 'spaceship', 'telescope', 'giraffe', 'puzzle', 'magnet', 'waterfall',
        'parachute', 'guitar', 'pirate', 'suitcase', 'whistle', 'volcano', 'lighthouse', 'raincoat', 'apple',
        'kangaroo', 'moonlight', 'candle', 'treasure', 'lantern', 'chess', 'waterfall', 'mountain', 'ship',
        'robot', 'piano', 'compass', 'baseball', 'sword', 'rocket', 'sunflower', 'blanket', 'jellyfish',
        'telescope', 'lantern', 'rainbow', 'microphone', 'pencil', 'crayon', 'suitcase', 'lighthouse',
        'jellybean', 'mountain', 'garden', 'parrot', 'helicopter', 'notebook', 'astronaut', 'spaceship',
        'snowflake', 'treasure', 'backpack', 'robot', 'lightbulb', 'volcano', 'cherry', 'parrot', 'rainbow',
        'bridge', 'sunflower', 'pencil', 'mountain', 'river', 'cactus', 'surfboard', 'suitcase', 'whistle',
        'volcano', 'lantern', 'mirror', 'rainbow', 'backpack', 'baseball', 'compass', 'blanket', 'parachute',
        'bicycle', 'whistle', 'guitar', 'camera', 'explosion', 'lighthouse', 'telescope', 'robot', 'jellyfish',
        'octopus', 'treasure', 'spaceship', 'dinosaur', 'balloon', 'pirate', 'train', 'skateboard', 'ladder',
        'toothbrush', 'robot', 'mountain', 'ice', 'moonlight', 'stethoscope', 'rainbow', 'island', 'surfboard',
        'windmill', 'camera', 'cactus', 'microphone', 'strawberry', 'helicopter', "resink", "transversomedial",
        "pharyngopathy", "postmineral", "myelosyphilis",
        "silverer", "evincement", "phrygium", "punnigram", "imminution", "environmental",
        "sleepify", "nope", "wauken", "indignance", "knotwort", "apocodeine", "escortee",
        "dogwatch", "eaglewood", "unbrotherliness", "mulse", "dermobranchiata", "typhic",
        "poststertorous", "indevout", "anatomicopathologic", "unimpenetrable", "hoggy",
        "urrhodin", "Dioecia", "unchapter", "nonumbilicate", "zwitterionic", "apportionable",
        "ferulic", "statefulness", "pharyngotonsillitis", "Mimulus", "recce", "mutinously",
        "reboant", "marshwort", "lupoid", "chromatophilic", "lauder", "nirles",
        "esthesiometer", "semisocial", "unbeing", "kangaroo", "takosis", "inconvertibility",
        "anesthetist", "rumorproof", "thoracoscopy", "euphorbium", "bizet", "song",
        "dolichocephali", "platemaker", "vesicupapular", "electroforming", "dilatingly",
        "meethelp", "loincloth", "avowably", "counterindicate", "treacliness", "Epigonus",
        "airmark", "polarography", "precomposition", "lemography", "Apinage", "Taal",
        "logology", "probeer", "randomization", "poditic", "individualize", "castigate",
        "Biloculina", "overscrub", "koolah", "weetless", "erased", "layery", "discontinuee",
        "anaphylatoxin", "unwounded", "personalism", "howitzer", "hexahydroxy", "koku",
        "reamer", "tonguiness", "microgametocyte", "baba", "ludefisk", "novelwright",
        "swinehull", "Odonata", "indefinable", "faineance", "nidologist", "supracargo",
        "beriberic", "betso", "archheart", "snary", "Viminal", "Pygopodidae", "acetylenediurein",
        "asphalt", "preimpress", "fountainlet", "bejel", "unpictorially", "heliophyte",
        "chimopeelagic", "warison", "antivaccinist", "overtwine", "preremove", "nerval",
        "bufonite", "eradicator", "turtling", "winrace", "psychographic", "impalpably",
        "amygdalase", "Octogynia", "brimming", "grist", "casave", "brazilein", "afluking",
        "meliceris", "portative", "unsteck", "Madelon", "barramunda", "optotechnics",
        "metapterygium", "unromanticalness", "Jacobinism", "pricklingly", "blameless",
        "elderhood", "committeewoman", "comicocynical", "aggrate", "stentoronic", "flatwise",
        "bipyridyl", "untastable", "aegerian", "unmistrusted", "quadrigamist", "Meleagrinae",
        "helvite", "neuralist", "Swietenia", "unpleadable", "colorably", "mogilalism",
        "consequently", "atamasco", "inhospitality", "noncarnivorous", "counterruin",
        "gryposis", "ringe", "displeasedly", "incenter", "gallycrow", "whincow", "repudiationist",
        "unagile", "chaplain", "bekerchief", "subproduct", "pointingly", "Physonectae",
        "bumpingly", "hateful", "endogenous", "facticide", "velours", "carmoisin",
        "reaccomplish", "protistic", "recuperance", "tech", "withywind", "Galen",
        "Slavistic", "escropulo", "deglutination", "hydramnios", "Amphion", "beguilement",
        "glottiscope", "propagation", "entrancement", "disbelief", "goatlike", "Tanyoan",
        "thecium", "deforciant", "coachwhip", "enviableness", "duroquinone", "smirchy",
        "whisky", "forcing", "homosystemic", "underact", "autosyndesis", "sybaritism",
        "scorching", "testiere", "nonporphyritic", "cephalhematoma", "oxyquinoline", "azo",
        "scrimshorn", "unreeling", "burnt", "kilocycle", "lactenin", "Decimus", "patter",
        "jetbead", "Pygidium", "bitterroot", "thoke", "septated", "trinodal", "qualitied",
        "gotten", "unclassable", "Akhissar", "wholewise", "curse", "organophyly", "teleseme",
        "heptitol", "whitehass", "asclepiadeous", "labionasal", "presumptuous", "triketo",
        "thrombolymphangitis", "spokeswomanship", "unprejudicial", "ungoverned", "nonrectangular",
        "pleocrystalline", "slurbow", "unhandsome", "scoliotic", "phreatophyte", "criminalistics",
        "imitability", "zygozoospore", "nonliability", "disafforestation", "epigenetic", "Aves",
        "schistaceous", "verbomania", "epitenon", "proscriber", "histology", "propulsion",
        "elaidinic", "driftless", "upcover", "duteous", "distasteful", "dermatoxerasia",
        "Kaibartha", "spydom", "tonsor", "paedogenesis", "anticipatory", "unastonished",
        "Blackbeard", "gradient", "metachemistry", "caravan", "circumnutate", "infract",
        "adenia", "louse", "koel", "tributyrin", "lifey", "Desmidiaceae", "vinelet",
        "ingrowth", "uniovulate", "toying", "nonretentive", "conjunct", "pinnaglobin",
        "vastity", "appendicle", "redecline", "trekker", "hereby", "dicatalexis", "slackerism",
        "inaxon", "tribase", "cryptostome", "nonpresbyter", "breezily", "opusculum",
        "methought", "yarl", "fringent", "forsooth", "plicater", "balneotherapeutics",
        "uredosporic", "perceptive", "embouchure", "heterolysis", "imperence", "perfervidness",
        "pobs", "thorax", "perikronion", "charlatanic", "Bonapartism", "whilom", "outferret",
        "kelty", "macrospore", "predisposedly", "squawk", "extrafoliaceous", "inveteracy",
        "obtect", "decoyer", "Pelecaniformes", "dodecane", "gambet", "electrodynamism",
        "henter", "sunless", "burroweed", "busket", "trepidatory", "fermerer", "prewound",
        "thrifty", "twibil", "amateur", "myasthenia", "goave", "toolmark", "ornithon",
        "geminately", "unrhymed", "Serridentines", "presbyopia", "unoperably", "preventer",
        "salten", "grimily", "conduplicated", "theomania", "gyromagnetic", "antimycotic",
        "malacanthid", "sensationistic", "vibraculoid", "eater", "zig", "bordello", "hounding",
        "outweary", "hoyle", "nonrendition", "potlike", "surflike", "rubification", "dulcifluous",
        "Saturnalian", "unconfidence", "Apneumona", "hedgy", "subharmonic", "undisputed"]

    list_words = list(set(words))
    list_idx = []
    list_classes = []
    list_first_id = []
    for i in range(len(list_words)):
        first_id = tokenizer(words[i][0], add_special_tokens=False).data['input_ids'][0]
        if first_id not in list_first_id:
            list_classes.append(words[i])
            list_first_id.append(first_id)
            list_idx.append(i)
    list_words = [w for i, w in enumerate(list_words) if i not in list_idx]

    assert len(list(set(words))) == (len(list(set(list_words))) + len(list(set(list_classes))))
    print(list_classes)

    # create args.seq_langauge_classes seq, each length args.seq_langauge_len couple with class

    N_classes = args.seq_langauge_classes
    const_seed = 0
    random.seed(const_seed)
    np.random.seed(const_seed)
    rand_samples_idx = list(range(0, len(list_words)))
    random.shuffle(rand_samples_idx)

    random.seed(const_seed)
    np.random.seed(const_seed)
    rand_classes = list(range(0, len(list_classes)))
    random.shuffle(rand_classes)

    list_classes_samples = []
    list_classes_labels = []
    first_idx = 0
    for i in range(N_classes):
        list_classes_samples.append(list_words[first_idx:first_idx + args.seq_langauge_len])
        list_classes_labels.append(list_classes[i])
        first_idx = first_idx + args.seq_langauge_len

    # samples args.seq_langauge_each_seq_len for each seq, and classify it as the class

    dataset_dataclass = DATASET_TO_DATACLASS['seq_language']

    list_output = []
    list_input = []

    random.seed(seed)
    np.random.seed(seed)
    for i in range(1000):
        sample_class = random.sample(range(0, N_classes), 1)[0]
        samples_input_seq = random.sample(range(0, args.seq_langauge_len), args.seq_langauge_each_seq_len)

        list_sample = [list_classes_samples[sample_class][j] for j in samples_input_seq]

        sample_input = ''
        for i in list_sample:
            sample_input += i + ' '
        sample_input = sample_input[:-1]

        list_input.append(sample_input)
        list_output.append(list_classes_labels[sample_class])

    N_train = 900
    N_test = 1000 - N_train
    dict_res = {
        'train': {
            dataset_dataclass.input_col: list_input[:N_train],
            dataset_dataclass.target_col: list_output[:N_train]
        },
        dataset_dataclass.val_split: {
            dataset_dataclass.input_col: list_input[-N_test:],
            dataset_dataclass.target_col: list_output[-N_test:]
        }
    }

    return dict_res
