from __future__ import division

import os

import random
from torch.utils.data import IterableDataset, Dataset
import transformers
from transformers import BertTokenizer
import numpy as np
import torch

transformers.logging.set_verbosity_error()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


class bucket_dataset(Dataset):
    def __init__(self, args, source_data_file, target_data_file=None, need_shuffle=False, size=None):
        super(bucket_dataset, self).__init__()
        self.args = args
        if need_shuffle:
            self.__random__(source_data_file, target_data_file)
        else:
            if source_data_file != target_data_file:
                os.system("cat {0} > {1}".format(source_data_file, target_data_file))
        new_file = target_data_file
        if new_file is None:
            new_file = source_data_file
        self.data = open(new_file, "r").readlines()
        if size is not None:
            self.data = self.data[:min(size, len(self.data))]

    def __random__(self, source_data_file, target_data_file):
        data_list = []
        for line in open(source_data_file):
            data_list.append(line.strip())
        random.shuffle(data_list)
        with open(target_data_file, 'w') as writer:
            writer.write('\n'.join(data_list))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.args.pair_input == "True":
            encodings = convert_pair_sentence_to_ids(self.args, self.data[idx])
        else:
            encodings = convert_single_sentence_to_ids(self.args, self.data[idx])

        return encodings


def convert_single_sentence_to_ids(args, input):
    encodings = {
        args.INPUT_IDS: None,  # 'input_ids'
        args.ATTENTION_MASK: None,  # 'attention_mask'
        args.TOKEN_TYPE_IDS: None
    }
    split_input = input.strip().split("\t")
    # print(input)
    text, label = split_input[0], int(split_input[1])
    aug_idx = None
    if len(split_input) > 2: aug_idx = split_input[-1]

    encoded_dict = tokenizer.encode_plus(
        text,
        add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
        max_length=args.max_length,  # Pad & truncate all sentences.
        padding='max_length',
        return_attention_mask=True,  # Construct attn. masks.
        truncation=True,
    )

    encodings[args.INPUT_IDS] = torch.from_numpy(np.array(encoded_dict['input_ids'])).to(args.device)
    encodings[args.TOKEN_TYPE_IDS] = torch.from_numpy(np.array(encoded_dict['token_type_ids'])).to(args.device)
    encodings[args.ATTENTION_MASK] = torch.from_numpy(np.array(encoded_dict['attention_mask'])).to(args.device)
    encodings[args.LABEL] = torch.from_numpy(np.array(label)).to(args.device)
    encodings[args.TEXT] = input.strip()
    # encodings[args.AUG_IDX] = aug_idx
    # return encoded_dict['input_ids'], encoded_dict['token_type_ids'], encoded_dict['attention_mask']
    return encodings


def convert_pair_sentence_to_ids(args, input):
    encodings = {
        args.INPUT_IDS: None,
        args.ATTENTION_MASK: None,
        args.TOKEN_TYPE_IDS: None
    }

    text = input.strip().split("\t")
    # print(text)
    target, claim, label = text[0], text[1], int(text[2])

    encoded_dict = tokenizer.encode_plus(
        target, claim,
        add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
        max_length=args.max_length,  # Pad & truncate all sentences.
        padding='max_length',
        return_attention_mask=True,  # Construct attn. masks.
        truncation=True,
    )

    # print(encoded_dict)
    input_ids, attention_mask = encoded_dict['input_ids'], encoded_dict['attention_mask']
    encodings[args.INPUT_IDS] = torch.from_numpy(np.array(input_ids)).to(args.device)
    encodings[args.TOKEN_TYPE_IDS] = torch.from_numpy(np.array(encoded_dict['token_type_ids'])).to(args.device)
    encodings[args.ATTENTION_MASK] = torch.from_numpy(np.array(attention_mask)).to(args.device)
    encodings[args.LABEL] = torch.from_numpy(np.array(label)).to(args.device)
    encodings[args.TEXT] = input.strip()

    return encodings




