import os
import json
import torch
import argparse
import logging
import numpy as np
import pandas as pd

from copy import deepcopy
from collections import Counter
from collections import defaultdict
from transformers import BertTokenizer
from nltk.tokenize import word_tokenize

from lang_exps.data.processors.utils import DataProcessor, InputExample, InputFeatures
from lang_exps.data.processors.data import (
    convert_examples_to_features,
    output_modes,
    data_processors,
    tasks_num_labels,
    data_dir,
    dataset_subsampling,
)
from lang_exps.data.metrics import compute_metrics
from lang_exps.data.util import never_split


def process_train_dev_lll(
    task,
    n_train_examples=115000,
    n_dev_examples=5000,
    n_test_examples=7600,
    processed_dir="processed_lll",
    prefix_data_dir="data",
):

    processor = data_processors[task]()
    task_data_dir = os.path.join(prefix_data_dir, data_dir[task])

    print(f"Processing - {task}")
    print(f"Task data dir - {task_data_dir}")

    if (
        task == "agnews"
        or task == "dbpedia"
        or task == "yahooqa"
        or task == "yelp"
        or task == "amzn"
    ):
        processor.process_train_dev_lll(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.csv",
            dev_filename="test.csv",
            test_filename="test.tsv",
            delimiter=",",
            quotechar='"',
            n_train_examples=n_train_examples,
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
            processed_dir=processed_dir,
        )
    elif task == "mnli":
        processor.process_train_dev_lll(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.tsv",
            dev_filename="dev_matched.tsv",
            test_filename="test_m.tsv",
            delimiter="\t",
            quotechar=None,
            n_train_examples=n_train_examples,
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )
    else:
        processor.process_train_dev_lll(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.tsv",
            dev_filename="dev.tsv",
            test_filename="test.tsv",
            delimiter="\t",
            quotechar=None,
            n_train_examples=n_train_examples,
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )


def run(task, processed_data=True, lll_mode=False, prefix_data_dir="data"):

    processor = data_processors[task]()
    task_data_dir = os.path.join(prefix_data_dir, data_dir[task])

    if task in ["agnews", "yahooqa", "dbpedia", "yelp", "amzn"]:
        train_examples = processor.get_train_examples(
            data_dir=task_data_dir,
            task_type=None,
            processed_data=processed_data,
            lll_mode=lll_mode,
        )
        dev_examples = processor.get_dev_examples(
            data_dir=task_data_dir,
            task_type=None,
            processed_data=processed_data,
            lll_mode=lll_mode,
        )
        test_examples = processor.get_test_examples(
            data_dir=task_data_dir,
            task_type=None,
            processed_data=processed_data,
            lll_mode=lll_mode,
        )
    else:
        train_examples = processor.get_train_examples(
            data_dir=task_data_dir, task_type=None, processed_data=processed_data
        )
        dev_examples = processor.get_dev_examples(
            data_dir=task_data_dir, task_type=None, processed_data=processed_data
        )
        test_examples = processor.get_test_examples(
            data_dir=task_data_dir, task_type=None, processed_data=processed_data
        )

    print("No. of train examples for task {} : {} ".format(task, len(train_examples)))
    print("No. of dev examples for task {} : {} ".format(task, len(dev_examples)))
    print("No. of test examples for task {} : {} ".format(task, len(test_examples)))


def process_train_dev(
    task, n_dev_examples, n_test_examples=1001, prefix_data_dir="data"
):

    processor = data_processors[task]()
    task_data_dir = os.path.join(prefix_data_dir, data_dir[task])

    print(f"Processing - {task}")
    print(f"Task data dir - {task_data_dir}")

    if task == "mnli":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.tsv",
            dev_filename="dev_matched.tsv",
            test_filename="test_m.tsv",
            dev_filename1="dev_mismatched.tsv",
            test_filename1="test_mm.tsv",
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )
    elif task == "scitail":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="scitail_1.0_train.tsv",
            dev_filename="scitail_1.0_dev.tsv",
            test_filename="test.tsv",
            n_dev_examples=n_dev_examples,
        )
    elif task == "agnews":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.csv",
            dev_filename="test.csv",
            test_filename="test.tsv",
            delimiter=",",
            quotechar='"',
            n_dev_examples=n_dev_examples,
        )
    elif task == "dbpedia":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.csv",
            dev_filename="test.csv",
            test_filename="test.tsv",
            delimiter=",",
            quotechar='"',
            n_dev_examples=n_dev_examples,
        )
    elif task == "yahooqa" or task == "splityahooqa":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.csv",
            dev_filename="test.csv",
            test_filename="test.tsv",
            delimiter=",",
            quotechar='"',
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )
    elif task == "yelp":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.csv",
            dev_filename="test.csv",
            test_filename="test.tsv",
            delimiter=",",
            quotechar='"',
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )
    elif task == "boolq" or task == "multirc":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.jsonl",
            dev_filename="val.jsonl",
            test_filename="test.jsonl",
            n_dev_examples=n_dev_examples,
        )
    elif task == "rocstory":
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.txt",
            dev_filename="dev.txt",
            test_filename="test.txt",
        )
    else:
        processor.process_train_dev(
            data_dir=task_data_dir,
            task_name=task,
            train_filename="train.tsv",
            dev_filename="dev.tsv",
            test_filename="test.tsv",
            n_dev_examples=n_dev_examples,
            n_test_examples=n_test_examples,
        )


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--task", required=True)
    parser.add_argument("--data_dir", required=True)
    parser.add_argument("--n_train_examples", default=115000, type=int)
    parser.add_argument("--n_dev_examples", default=5000, type=int)
    parser.add_argument("--n_test_examples", default=7600, type=int)
    parser.add_argument("--processed_dir", default="processed_lll")
    parser.add_argument("--mode", default="lll")

    args = parser.parse_args()
    if args.mode == "lll":
        ## 5-dataset-NLP
        process_train_dev_lll(
            task=args.task,
            n_train_examples=args.n_train_examples,
            n_dev_examples=args.n_dev_examples,
            n_test_examples=args.n_test_examples,
            processed_dir=args.processed_dir,
            prefix_data_dir=args.data_dir,
        )
        run(
            task=args.task,
            processed_data=True,
            lll_mode=True,
            prefix_data_dir=args.data_dir,
        )
    elif args.mode == "processed":
        ## Split YahooQA
        process_train_dev(
            task=args.task,
            n_dev_examples=args.n_dev_examples,
            n_test_examples=args.n_test_examples,
            prefix_data_dir=args.data_dir,
        )
        run(
            task=args.task,
            processed_data=True,
            lll_mode=False,
            prefix_data_dir=args.data_dir,
        )
    else:
        raise ValueError("Incorrect mode option!")


if __name__ == "__main__":
    main()
