import json
import logging
import os
from typing import List

import tqdm
from transformers import PreTrainedTokenizer

import utils_base

logger = logging.getLogger(__name__)


class Seq2McInputExample(object):
    """A single training/test example for seq2seq"""

    def __init__(self, example_id, text, choices, label=None):
        self.example_id = example_id
        self.text = text
        self.choices = choices
        self.label = label


class Seq2McInputFeatures(object):
    def __init__(self, example_id, features, label, dataset_idx):
        self.example_id = example_id
        self.features = features
        self.label = label
        self.dataset_idx = dataset_idx


class Seq2McMultipleChoiceWrapper(utils_base.DataProcessor):
    """Wrapper aroung a multiple_choice data processor, that converts it to a seq2mc task."""
    def __init__(self, processor):
        self.processor = processor

    def get_examples(self, path, datasplit='train'):
        """Gets a collection of `InputExample`s from the given file path. The
        `datasplit` argument specifies the type of examples (used when
        train/test files have different formats due to e.g. presence of label
        column, or when one file contains both train and test data)."""
        examples = self.processor.get_examples(path, datasplit=datasplit)
        new_examples = []
        for example in examples:
            if not all(ctx == example.contexts[0] for ctx in example.contexts):
                raise ValueError("Need all contexts to be the same for seq2mc")
            new_examples.append(Seq2McInputExample(
                example.example_id,
                "{} | {} | {}".format(example.contexts[0], example.question, " ".join(["choice {}: {}".format(i, ending) for i, ending in enumerate(example.endings)])),
                #"{} {}".format(example.contexts[0], example.question),
                example.endings,
                example.label
            ))
        return new_examples

    def get_standard_datasplit_filename(self, datasplit):
        """Gets the usual filename for the given data split."""
        return self.processor.get_standard_datasplit_filename(datasplit)

    def get_examples_for_datasplit(self, data_dir, datasplit):
        return self.get_examples(os.path.join(data_dir, self.get_standard_datasplit_filename(datasplit)), datasplit=datasplit)

    def get_labels(self):
        """Gets the list of labels for this data set."""
        return self.processor.get_labels()

    def is_valid_datasplit(self, datasplit):
        return datasplit in {'train', 'dev', 'test'}


def convert_textcls_to_seq2mc(examples, class_label_strings):
    return [
            Seq2McInputExample(
                example.example_id,
                example.text,
                class_label_strings,
                example.label
            )
            for example in examples
        ]

class Seq2McTextClassificationWrapper(utils_base.DataProcessor):
    """Wrapper aroung a text classification data processor, that converts it to a seq2mc task."""
    def __init__(self, processor):
        self.processor = processor

    def get_examples(self, path, datasplit='train'):
        """Gets a collection of `InputExample`s from the given file path. The
        `datasplit` argument specifies the type of examples (used when
        train/test files have different formats due to e.g. presence of label
        column, or when one file contains both train and test data)."""
        return convert_textcls_to_seq2mc(self.processor.get_examples(path, datasplit=datasplit), ['class {}'.format(i) for i in range(len(self.processor.get_labels()))])

    def get_standard_datasplit_filename(self, datasplit):
        """Gets the usual filename for the given data split."""
        return self.processor.get_standard_datasplit_filename(datasplit)

    def get_examples_for_datasplit(self, data_dir, datasplit):
        return self.get_examples(os.path.join(data_dir, self.get_standard_datasplit_filename(datasplit)), datasplit=datasplit)

    def get_labels(self):
        """Gets the list of labels for this data set."""
        return self.processor.get_labels()

    def is_valid_datasplit(self, datasplit):
        return datasplit in {'train', 'dev', 'test'}


def convert_examples_to_features(
    examples: List[Seq2McInputExample],
    label_list: List[str],
    max_length: int,
    tokenizer: PreTrainedTokenizer,
    pad_token_segment_id=0,
    pad_on_left=False,
    pad_token=0,
    mask_padding_with_zero=True,
    oversize_example_method='truncate',
) -> List[Seq2McInputFeatures]:
    """
    Loads a data file into a list of `InputFeatures`
    """
    meta = {
            'num_oversize': 0,
            'oversize_example_method': oversize_example_method,
            'max_length': max_length,
            'max_target_length': -1,
        }

    label_map = {label: i for i, label in enumerate(label_list)}

    max_target_length = -1
    for (ex_index, example) in enumerate(tqdm.tqdm(examples, desc="calculate max target length")):
        max_target_length = max(max_target_length, *map(len, tokenizer.batch_encode_plus(example.choices)['input_ids']))

    logger.info("Max target length found: {}".format(max_target_length))

    features = []
    for (ex_index, example) in enumerate(tqdm.tqdm(examples, desc="convert examples to features")):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))
        text_a = example.text

        # return_overflowing_tokens is needed to make the tokenizer return the number of truncated tokens
        inputs = tokenizer.encode_plus(text_a, '', add_special_tokens=True, max_length=max_length, return_overflowing_tokens=True, return_token_type_ids=True)
        if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
            meta['num_oversize'] += 1
            if oversize_example_method == 'prune':
                logger.info("Attention! Examples are being pruned from the dataset for length.  You should use a bigger max seq length!")
                continue
            elif oversize_example_method == 'truncate':
                logger.info("Attention! you are cropping tokens. ")
            else:
                raise ValueError("Unknown oversize_example_method {}".format(oversize_example_method))

        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        choices_input_ids = tokenizer.batch_encode_plus(example.choices, max_length=max_target_length, padding='max_length', truncation=True)['input_ids']

        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length
        assert all(len(choice_input_ids) == max_target_length for choice_input_ids in choices_input_ids)

        label = label_map[example.label]

        if ex_index < 2:
            logger.info("*** Example ***")
            logger.info("id: {}".format(example.example_id))
            logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
            logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
            logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
            logger.info("label: {}".format(label))

        example_features = {'input_ids': input_ids, 'input_mask': attention_mask, 'segment_ids': token_type_ids, 'choices_input_ids': choices_input_ids}
        features.append(Seq2McInputFeatures(example_id=example.example_id, features=example_features, label=label, dataset_idx=ex_index))

    return features, meta

def select_field(features, field):
    return [feature.features[field] for feature in features]

def load_and_cache_examples_from_file(args, filepath, processor, tokenizer, datasplit='train'):
    return utils_base.load_and_cache_examples_from_file(select_field, convert_examples_to_features, args, filepath, processor, tokenizer, datasplit=datasplit, is_seq2mc=True)
