import json
import logging
from typing import List

import tqdm
from transformers import PreTrainedTokenizer

import utils_base
from utils_base import DataProcessor


logger = logging.getLogger(__name__)


class TextClassificationInputExample:
    """A single training/test example for multiple choice"""

    def __init__(self, example_id, text, label=None):
        """Constructs a InputExample.

        Args:
            example_id: Unique id for the example.
            contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
            question: string. The untokenized text of the second sequence (question).
            endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.example_id = example_id
        self.text = text
        self.label = label


class TextClassificationInputFeatures:
    def __init__(self, example_id, features, label, dataset_idx):
        self.example_id = example_id
        self.features = features
        self.label = label
        self.dataset_idx = dataset_idx

class NLabelTextClassificationJsonlProcessor(DataProcessor):
    """Processor for text classification datasets in the "standard" JSONL format."""

    def __init__(self, num_labels):
        super().__init__()
        self.num_labels = num_labels

    def get_standard_datasplit_filename(self, datasplit):
        if not self.is_valid_datasplit(datasplit):
            raise ValueError(f"Unknown datasplit {datasplit}")

        return datasplit + '.jsonl'

    def get_examples(self, path, datasplit='train'):
        """See base class."""
        if not self.is_valid_datasplit(datasplit):
            raise ValueError(f"Unknown datasplit {datasplit}")
        logger.info(f"LOOKING AT {path} {datasplit}")
        #if datasplit == 'test':
        #    raise ValueError(
        #        "For swag testing, the input file does not contain a label column. It can not be tested in current code"
        #        "setting!"
        #    )
        return self._create_examples(self._read_jsonl(path), datasplit)

    def get_labels(self):
        """See base class."""
        return [str(label_num) for label_num in range(self.num_labels)]

    @staticmethod
    def _read_jsonl(input_file):
        """Reads a JSONL file."""
        with open(input_file, "r", encoding='utf-8') as f:
            lines = []
            for line in f.readlines():
                lines.append(json.loads(line))
            return lines

    def _create_examples(self, lines: List[List[str]], type: str):
        """Creates examples for the training and dev sets."""
        if type == "train" and 'label' not in lines[0]:
            raise ValueError("For training, the input file must contain a label column.")

        examples = [
            TextClassificationInputExample(
                example_id=line['id'],
                text=line['text'],
                label=str(line['label']),
            )
            for i, line in enumerate(lines)
        ]

        return examples

def textcls_jsonl_processor_factory(num_labels):
    class ManufacturedTextclsJsonlProcessor(NLabelTextClassificationJsonlProcessor):
        def __init__(self):
            super().__init__(num_labels)
    return ManufacturedTextclsJsonlProcessor


def convert_examples_to_features(
    examples: List[TextClassificationInputExample],
    label_list: List[str],
    max_length: int,
    tokenizer: PreTrainedTokenizer,
    pad_token_segment_id=0,
    pad_on_left=False,
    pad_token=0,
    mask_padding_with_zero=True,
    oversize_example_method='truncate',
) -> List[TextClassificationInputFeatures]:
    """
    Loads a data file into a list of `InputFeatures`
    """
    meta = {
            'num_oversize': 0,
            'oversize_example_method': oversize_example_method,
            'max_length': max_length,
        }

    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d" % (ex_index, len(examples)))
        text_a = example.text

        # return_overflowing_tokens is needed to make the tokenizer return the number of truncated tokens
        inputs = tokenizer.encode_plus(text_a, '', add_special_tokens=True, max_length=max_length, return_overflowing_tokens=True, return_token_type_ids=True)
        if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
            meta['num_oversize'] += 1
            if oversize_example_method == 'prune':
                logger.info("Attention! Examples are being pruned from the dataset for length.  You should use a bigger max seq length!")
                continue
            elif oversize_example_method == 'truncate':
                logger.info("Attention! you are cropping tokens. ")
            else:
                raise ValueError("Unknown oversize_example_method {}".format(oversize_example_method))

        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
        #attention_mask = inputs['attention_mask']

        # NOTE: The padding here pads input_ids with 0 by default, but it
        # seems that is the wrong pad token id for roberta. However, fixing
        # this doesn't appear to improve accuracy and several runs were
        # already done using the wrong padding, so for consistency it is
        # being left as-is for the time being.

        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length

        label = label_map[example.label]

        if ex_index < 2:
            logger.info("*** Example ***")
            logger.info("id: {}".format(example.example_id))
            logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
            logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
            logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
            logger.info("label: {}".format(label))

        example_features = {'input_ids': input_ids, 'input_mask': attention_mask, 'segment_ids': token_type_ids}
        features.append(TextClassificationInputFeatures(example_id=example.example_id, features=example_features, label=label, dataset_idx=ex_index))

    return features, meta


def select_field(features, field):
    return [feature.features[field] for feature in features]

def load_and_cache_examples_from_file(args, filepath, processor, tokenizer, datasplit='train'):
    return utils_base.load_and_cache_examples_from_file(select_field, convert_examples_to_features, args, filepath, processor, tokenizer, datasplit=datasplit)


TEXT_CLASSIFICATION_TASKS_NUM_LABELS = {"mr": 2, 'agn': 4}
processors = {task: textcls_jsonl_processor_factory(num_labels) for task, num_labels in TEXT_CLASSIFICATION_TASKS_NUM_LABELS.items()}

