import logging
import os
import time
from typing import List

import torch
from torch.utils.data import TensorDataset

logger = logging.getLogger(__name__)

class DataProcessor:
    """Base class for data converters for multiple choice data sets."""

    def get_examples(self, path, datasplit='train'):
        """Gets a collection of `InputExample`s from the given file path. The
        `datasplit` argument specifies the type of examples (used when
        train/test files have different formats due to e.g. presence of label
        column, or when one file contains both train and test data)."""
        raise NotImplementedError()

    def get_standard_datasplit_filename(self, datasplit):
        """Gets the usual filename for the given data split."""
        raise NotImplementedError()

    def get_examples_for_datasplit(self, data_dir, datasplit):
        return self.get_examples(os.path.join(data_dir, self.get_standard_datasplit_filename(datasplit)), datasplit=datasplit)

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    def is_valid_datasplit(self, datasplit):
        return datasplit in {'train', 'dev', 'test'}


def load_and_cache_examples_from_file(select_field_fn, convert_examples_to_features_fn, args, filepath, processor, tokenizer, datasplit='train', no_cache=False, is_seq2mc=False):
    assert (datasplit in {'train', 'dev', 'test'})

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        os.path.dirname(filepath),
        "cached_{}_{}_{}_{}_{}".format(
            datasplit,
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(args.oversize_example_method),
            str(type(processor).__name__),
        ),
    )
    if os.path.exists(cached_features_file) and not (args.overwrite_cache or no_cache):
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)['features']
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)

        label_list = processor.get_labels()
        examples = processor.get_examples(filepath, datasplit=datasplit)

        logger.info("Training number: %s", str(len(examples)))
        features, meta = convert_examples_to_features_fn(
            examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            pad_on_left=bool(args.model_type in ["xlnet"]),  # pad on the left for xlnet
            pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
            oversize_example_method=args.oversize_example_method,
        )
        meta['creation_time'] = time.time()
        if meta['num_oversize'] != 0:
            fraction_oversize = meta['num_oversize'] / len(examples)
            logger.warn('There were {} oversize examples (handled via "{}").  This is {:0.3f}% of the dataset.'.format(meta['num_oversize'], args.oversize_example_method, 100*fraction_oversize))
        if args.local_rank in [-1, 0] and not no_cache:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save({'features': features, 'meta': meta}, cached_features_file)

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor(select_field_fn(features, "input_ids"), dtype=torch.long)
    all_input_mask = torch.tensor(select_field_fn(features, "input_mask"), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field_fn(features, "segment_ids"), dtype=torch.long)
    all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long)
    all_dataset_idxs = torch.tensor([f.dataset_idx for f in features], dtype=torch.long)

    if is_seq2mc:
        logger.info("Converting dataset to seq2mc format...")
        all_choices_input_ids = torch.tensor(select_field_fn(features, "choices_input_ids"), dtype=torch.long)
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_dataset_idxs, all_choices_input_ids)
    else:
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_dataset_idxs)

    return dataset


class ConfigOption:
    def __init__(self, opt_name, opt_type=None, default_value=None, required=False, choices=None):
        """opt_type should be a function that can be called to cast a value to
        the correct type or return an error if impossible (e.g., the 'int' or
        'str' builtins do this).  If None, any value will be allowed."""
        self.opt_name = opt_name
        self.opt_type = opt_type
        self.default_value = default_value
        self.required = required
        self.choices = choices

class Namespace:
    def __init__(self):
        pass

class ConfigParser:
    def __init__(self, options: List[ConfigOption]):
        self.options = options
        self.option_map = {x.opt_name: x for x in self.options}

    def validate_config(self, config_dict):
        """Parse config dict, handling errors as needed, and return a new dict
        guaranteed to contain exactly the set of config options specified
        (missing values assigned to defaults, unknown options removed)"""
        options_keys = {x.opt_name for x in self.options}
        known_keys = {key for key in config_dict.keys() if key in options_keys}
        unknown_keys = {key for key in config_dict.keys() if key not in options_keys}
        missing_keys = options_keys - known_keys
        missing_required_keys = {key for key in missing_keys if self.option_map[key].required}
        if len(missing_required_keys) != 0:
            raise ValueError("Missing one or more required keys while validating config: %s", str(sorted(missing_required_keys)))

        if len(missing_keys) != 0:
            logger.warn('Missing one or more keys while validating config (the default values will be used): {}'.format(str(sorted(missing_keys))))

        if len(unknown_keys) != 0:
            logger.warn('Found one or more unknown keys while validating config (they will be ignored): {}'.format(str(sorted(unknown))))

        new_config = dict()

        for opt in known_keys:
            opt_spec = self.option_map[opt]
            val = config_dict[opt]
            if opt_spec.opt_type is not None:
                val = opt_spec.opt_type(val)

            if (opt_spec.choices is not None) and (val not in opt_spec.choices):
                raise ValueError("Got value {} for key {}.  Needed a value from the list: {}".format(val, opt, opt_spec.choices))

            new_config[opt] = val

        for opt in missing_keys:
            opt_spec = self.option_map[opt]
            val = opt_spec.default_value
            new_config[opt] = val

        return new_config


