import os
import time
import torch
import random
import logging
import numpy as np
import pickle as pkl

from typing import List, Optional
from dataclasses import dataclass, field
from torch.utils.data.dataset import Dataset
from torch.utils.data import TensorDataset, DataLoader

from transformers import PreTrainedTokenizer
from transformers import torch_distributed_zero_first

from lang_exps.data.processors.utils import InputFeatures
from lang_exps.data.processors.data import (
    convert_examples_to_features,
    output_modes,
    data_processors,
    dataset_subsampling,
    data_dir,
)

logger = logging.getLogger(__name__)


@dataclass
class CLDataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.
    """

    task_name: str = field(
        metadata={
            "help": "The name of the task to train on: "
            + ", ".join(data_processors.keys())
        }
    )
    data_dir: str = field(
        metadata={
            "help": "The input data dir. Should contain the .tsv files (or other data files) for the task."
        }
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    subsampling: float = field(
        default=1.0,
        metadata={"help": "The fraction of samples to select from total examples."},
    )
    version_label: str = field(
        default="v1",
        metadata={"help": "Different datasets based upon different seeds."},
    )
    task_type: str = field(
        default=None,
        metadata={
            "help": "Select from: task or dataset or None option and accordingly task identifier is added to the inputs."
        },
    )
    processed_data: bool = field(
        default=False, metadata={"help": "Whether to use processed train/dev datasets."}
    )
    disable_predefined_subsampling: bool = field(
        default=False, metadata={"help": "Whether to disable sampling."}
    )
    lll_mode: bool = field(
        default=False,
        metadata={"help": "Turn this ON only for datasets from the LLL paper."},
    )


class CLDataset(Dataset):
    """
    This will be superseded by a framework-agnostic approach
    soon.
    """

    args: CLDataTrainingArguments
    output_mode: str
    features: List[InputFeatures]

    def __init__(
        self,
        args: CLDataTrainingArguments,
        tokenizer: PreTrainedTokenizer,
        limit_length: Optional[int] = None,
        evaluate=False,
        validation=False,
        train=False,
        local_rank=-1,
        use_predefined_subsampling=False,
        label_map=None,
    ):
        self.args = args
        processor = data_processors[args.task_name]()
        self.output_mode = output_modes[args.task_name]
        self.task_type = args.task_type

        if use_predefined_subsampling:
            args.subsampling = dataset_subsampling[args.task_name]
        elif train and (not use_predefined_subsampling) and args.task_name == "yahooqa":
            args.subsampling = 0.1
        elif train and (not use_predefined_subsampling) and args.task_name == "yelp":
            args.subsampling = 0.25
        elif train and (not use_predefined_subsampling) and args.task_name == "mnli":
            args.subsampling = 0.5
        elif train and (not use_predefined_subsampling) and args.task_name == "qqp":
            args.subsampling = 0.5

        if args.lll_mode:
            args.subsampling = 1.0

        # Load data features from cache or dataset file
        if args.processed_data:
            if args.lll_mode:
                if not os.path.exists(os.path.join(args.data_dir, "processed_lll")):
                    os.makedirs(os.path.join(args.data_dir, "processed_lll"))
                cached_data_dir = os.path.join(args.data_dir, "processed_lll")
            else:
                if not os.path.exists(os.path.join(args.data_dir, "processed")):
                    os.makedirs(os.path.join(args.data_dir, "processed"))
                cached_data_dir = os.path.join(args.data_dir, "processed")
            if evaluate:
                split = "test"
                subsampling = 100
            elif validation:
                split = "dev"
                subsampling = 100
            else:
                split = "train"
                subsampling = int(args.subsampling * 100)

        else:
            cached_data_dir = args.data_dir
            if evaluate:
                split = "dev"
                subsampling = 100
            else:
                split = "train"
                subsampling = int(args.subsampling * 100)

        cached_data_file = os.path.join(
            cached_data_dir,
            "cached_{}_{}_{}_{}_{}_{}".format(
                split,
                tokenizer.__class__.__name__,
                str(args.max_seq_length),
                args.task_name,
                subsampling,
                args.version_label,
            ),
        )
        with torch_distributed_zero_first(local_rank):
            # Make sure only the first process in distributed training processes the dataset,
            # and the others will use the cache.

            cached_indices_file = cached_data_file + "_indices.pkl"
            # cached_indices_file = cached_data_file + "_indices"

            cached_data_file = cached_data_file + "_{}".format(self.task_type)

            cached_examples_file = cached_data_file + "_examples.pkl"
            # cached_examples_file = cached_data_file + "_examples"
            cached_features_file = cached_data_file + "_features.pkl"
            # cached_features_file = cached_data_file + "_features"

            sampled_indices = None
            self.features = None

            if os.path.exists(cached_features_file) and not args.overwrite_cache:
                start = time.time()
                # self.features = torch.load(cached_features_file)
                self.features = pkl.load(open(cached_features_file, "rb"))
                logger.info(
                    f"Loading features from cached file {cached_features_file} [took %.3f s]",
                    time.time() - start,
                )

            elif os.path.exists(cached_indices_file) and not args.overwrite_cache:
                start = time.time()
                #
                # sampled_indices = torch.load(cached_indices_file)
                sampled_indices = pkl.load(open(cached_indices_file, "rb"))
                logger.info(
                    f"Loading indices from cached file {cached_indices_file} [took %.3f s]",
                    time.time() - start,
                )

            if self.features is None:
                start = time.time()
                logger.info(f"Creating features from dataset file at {args.data_dir}")
                label_list = processor.get_labels()

                if args.processed_data:
                    if args.lll_mode:
                        if evaluate:
                            examples = processor.get_test_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                                lll_mode=args.lll_mode,
                            )
                        elif validation:
                            examples = processor.get_dev_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                                lll_mode=args.lll_mode,
                            )
                        else:
                            examples = processor.get_train_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                                lll_mode=args.lll_mode,
                            )
                    else:
                        if evaluate:
                            examples = processor.get_test_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                            )
                        elif validation:
                            examples = processor.get_dev_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                            )
                        else:
                            examples = processor.get_train_examples(
                                args.data_dir,
                                task_type=self.task_type,
                                version=args.version_label,
                                processed_data=args.processed_data,
                            )
                else:
                    examples = (
                        processor.get_dev_examples(
                            args.data_dir, task_type=self.task_type
                        )
                        if evaluate
                        else processor.get_train_examples(
                            args.data_dir, task_type=self.task_type
                        )
                    )

                if (not evaluate) and (not validation) and args.subsampling < 1.0:
                    start = time.time()
                    n_examples = len(examples)
                    n_samples = int(args.subsampling * n_examples)
                    if sampled_indices is None:
                        sampled_indices = np.random.RandomState(42).choice(
                            n_examples, n_samples, replace=False
                        )
                        # sampled_indices = random.sample([idx for idx in range(n_examples)], n_samples)
                        # torch.save(sampled_indices, cached_indices_file)
                        pkl.dump(sampled_indices, open(cached_indices_file, "wb"))

                    examples = [examples[idx] for idx in sampled_indices]
                    logger.info(
                        f"Sampled {n_samples} out of {n_examples} examples for task: {args.task_name}."
                    )

                    if not os.path.exists(cached_examples_file):
                        # torch.save(examples, cached_examples_file)
                        pkl.dump(examples, open(cached_examples_file, "wb"))
                        logger.info(
                            f"Saving examples into cached file %s [took %.3f s]",
                            cached_examples_file,
                            time.time() - start,
                        )

                self.features = convert_examples_to_features(
                    examples,
                    tokenizer,
                    max_length=args.max_seq_length,
                    task=args.task_name,
                    label_list=label_list,
                    output_mode=self.output_mode,
                )

                if not os.path.exists(cached_features_file):
                    # torch.save(self.features, cached_features_file)
                    pkl.dump(self.features, open(cached_features_file, "wb"))
                    logger.info(
                        f"Saving features into cached file %s [took %.3f s]",
                        cached_features_file,
                        time.time() - start,
                    )

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i) -> InputFeatures:
        return self.features[i]

    def verify(self, features, features1):

        count = 0
        for idx in range(len(features)):
            if features[idx].input_ids == features1[idx].input_ids:
                count += 1

        if count == len(features):
            logger.info("All examples matched!")
        else:
            logger.info("Examples matched {} / {}".format(count, len(features)))
