from torch.utils.data import TensorDataset
import numpy as np
import logging
import os
import random
import torch
import time
from tqdm import tqdm
from _utils import *

import parseutils

logger = logging.getLogger(__name__)


def load_and_cache_gen_data(
    args, filename, pool, tokenizer, split_tag, only_src=False, is_sample=False, is_decoder=False
):
    data_num = -1
    # do not update the data_num (subsample param) for test dataset and
    # keep all of the test dataset
    if split_tag == "test":
        data_num = -1
    elif split_tag == "dev":
        data_num = min(500, args.data_num)
    elif split_tag == "train":
        data_num = args.data_num

    # cache the data into args.cache_path except it is sampled
    # only_src: control whether to return only source ids for bleu evaluating (dev/test)
    # return: examples (Example object), data (TensorDataset)
    data_tag = "_all" if data_num == -1 else "_%d" % data_num
    cache_fn = "{}/{}.pt".format(args.cache_path, split_tag + ("_src" if only_src else "") + data_tag)

    examples = read_examples(
        filename,
        data_num,
        args.task,
        args.parse_as_tree,
        args.ip_lang,
        args.op_lang,
    )

    if is_sample:
        examples = random.sample(examples, min(5000, len(examples)))

    if split_tag == "train":
        calc_stats(examples, tokenizer, is_tokenize=True)
    else:
        calc_stats(examples)

    if os.path.exists(cache_fn) and not is_sample:
        logger.info("Load cache data from %s", cache_fn)
        data = torch.load(cache_fn)
    else:
        if is_sample:
            logger.info("Sample 5k data for computing bleu from %s", filename)
        else:
            logger.info("Create cache data into %s", cache_fn)
        tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)]

        if is_decoder:
            logger.info("[Decoder]: Converting examples to features")
            features = pool.map(decoder_convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))
        else:
            logger.info("[Enc-Dec]: Converting examples to features")
            features = pool.map(convert_examples_to_features, tqdm(tuple_examples, total=len(tuple_examples)))

        if split_tag == "test" or only_src:
            if is_decoder:
                logger.info("[Decoder]: returning source_ids_decoder as source examples")
                all_source_ids = torch.tensor([f.source_ids_decoder for f in features], dtype=torch.long)
                logger.info("[Decoder]: source_ids_decoder shape: " + str(all_source_ids.shape))
            else:
                logger.info("[Enc-Dec]: returning source_ids as source examples")
                all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
                logger.info("[Enc-Dec]: source_ids shape: " + str(all_source_ids.shape))

            data = TensorDataset(all_source_ids)
        else:
            all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
            all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)

            if features[0].source_ids_decoder is not None:
                source_ids_decoder = torch.tensor([f.source_ids_decoder for f in features], dtype=torch.long)
                data = TensorDataset(all_source_ids, all_target_ids, source_ids_decoder)
            else:
                data = TensorDataset(all_source_ids, all_target_ids)

        if args.local_rank in [-1, 0] and not is_sample:
            torch.save(data, cache_fn)
    return examples, data


def load_and_cache_clone_data(args, filename, pool, tokenizer, split_tag, is_sample=False):
    cache_fn = "{}/{}.pt".format(
        args.cache_path,
        split_tag + "_all" if args.data_num == -1 else "_%d" % args.data_num,
    )
    examples = read_examples(
        filename,
        args.data_num,
        args.task,
        args.parse_as_tree,
        args.ip_lang,
        args.op_lang,
    )
    if is_sample:
        examples = random.sample(examples, int(len(examples) * 0.1))

    calc_stats(examples, tokenizer, is_tokenize=True)
    if os.path.exists(cache_fn):
        logger.info("Load cache data from %s", cache_fn)
        data = torch.load(cache_fn)
    else:
        if is_sample:
            logger.info("Sample 10 percent of data from %s", filename)
        elif args.data_num == -1:
            logger.info("Create cache data into %s", cache_fn)
        tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)]
        features = pool.map(
            convert_clone_examples_to_features,
            tqdm(tuple_examples, total=len(tuple_examples)),
        )
        all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
        data = TensorDataset(all_source_ids, all_labels)

        if args.local_rank in [-1, 0] and args.data_num == -1:
            torch.save(data, cache_fn)
    return examples, data


def load_and_cache_defect_data(args, filename, pool, tokenizer, split_tag, is_sample=False):
    cache_fn = os.path.join(args.cache_path, split_tag)
    examples = read_examples(
        filename,
        args.data_num,
        args.task,
        args.parse_as_tree,
        args.ip_lang,
        args.op_lang,
    )
    if is_sample:
        examples = random.sample(examples, int(len(examples) * 0.1))

    calc_stats(examples, tokenizer, is_tokenize=True)
    if os.path.exists(cache_fn):
        logger.info("Load cache data from %s", cache_fn)
        data = torch.load(cache_fn)
    else:
        if is_sample:
            logger.info("Sample 10 percent of data from %s", filename)
        elif args.data_num == -1:
            logger.info("Create cache data into %s", cache_fn)
        tuple_examples = [(example, idx, tokenizer, args) for idx, example in enumerate(examples)]
        features = pool.map(
            convert_defect_examples_to_features,
            tqdm(tuple_examples, total=len(tuple_examples)),
        )
        # features = [convert_clone_examples_to_features(x) for x in tuple_examples]
        all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
        data = TensorDataset(all_source_ids, all_labels)

        if args.local_rank in [-1, 0] and args.data_num == -1:
            torch.save(data, cache_fn)
    return examples, data


def load_and_cache_multi_gen_data(args, pool, tokenizer, split_tag, only_src=False, is_sample=False):
    cache_fn = os.path.join(args.cache_path, split_tag)
    if os.path.exists(cache_fn) and not is_sample:
        logger.info("Load cache data from %s", cache_fn)
        examples_data_dict = torch.load(cache_fn)
    else:
        examples_data_dict = {}

        task_list = ["summarize", "translate", "refine", "concode", "defect"]
        for task in task_list:
            if task == "summarize":
                sub_tasks = ["ruby", "javascript", "go", "python", "java", "php"]
            elif task == "translate":
                sub_tasks = ["java-cs", "cs-java"]
            elif task == "refine":
                sub_tasks = ["small", "medium"]
            else:
                sub_tasks = ["none"]
            args.task = task
            for sub_task in sub_tasks:
                args.sub_task = sub_task
                if task == "summarize":
                    args.max_source_length = 256
                    args.max_target_length = 128
                elif task == "translate":
                    args.max_source_length = 320
                    args.max_target_length = 256
                elif task == "refine":
                    if sub_task == "small":
                        args.max_source_length = 130
                        args.max_target_length = 120
                    else:
                        args.max_source_length = 240
                        args.max_target_length = 240
                elif task == "concode":
                    args.max_source_length = 320
                    args.max_target_length = 150
                elif task == "defect":
                    args.max_source_length = 512
                    args.max_target_length = 3  # as do not need to add lang ids

                filename = get_filenames(args.data_dir, args.task, args.sub_task, split_tag)
                examples = read_examples(
                    filename,
                    args.data_num,
                    args.task,
                    args.parse_as_tree,
                    args.ip_lang,
                    args.op_lang,
                )
                if is_sample:
                    examples = random.sample(examples, min(5000, len(examples)))
                if split_tag == "train":
                    calc_stats(examples, tokenizer, is_tokenize=True)
                else:
                    calc_stats(examples)

                tuple_examples = [(example, idx, tokenizer, args, split_tag) for idx, example in enumerate(examples)]
                if args.data_num == -1:
                    features = pool.map(
                        convert_examples_to_features,
                        tqdm(tuple_examples, total=len(tuple_examples)),
                    )
                else:
                    features = [convert_examples_to_features(x) for x in tuple_examples]
                all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
                if only_src:
                    data = TensorDataset(all_source_ids)
                else:
                    all_target_ids = torch.tensor([f.target_ids for f in features], dtype=torch.long)
                    data = TensorDataset(all_source_ids, all_target_ids)
                examples_data_dict["{}_{}".format(task, sub_task) if sub_task != "none" else task] = (examples, data)

        if args.local_rank in [-1, 0] and not is_sample:
            torch.save(examples_data_dict, cache_fn)
            logger.info("Save data into %s", cache_fn)
    return examples_data_dict


def get_filenames(data_root, task, sub_task, split=""):
    if task == "concode":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/train.json".format(data_dir)
        dev_fn = "{}/dev.json".format(data_dir)
        test_fn = "{}/test.json".format(data_dir)
    elif task == "summarize":
        data_dir = "{}/{}/{}".format(data_root, task, sub_task)
        train_fn = "{}/train.jsonl".format(data_dir)
        dev_fn = "{}/valid.jsonl".format(data_dir)
        test_fn = "{}/test.jsonl".format(data_dir)
    elif task == "refine":
        data_dir = "{}/{}/{}".format(data_root, task, sub_task)
        train_fn = "{}/train.buggy-fixed.buggy,{}/train.buggy-fixed.fixed".format(data_dir, data_dir)
        dev_fn = "{}/valid.buggy-fixed.buggy,{}/valid.buggy-fixed.fixed".format(data_dir, data_dir)
        test_fn = "{}/test.buggy-fixed.buggy,{}/test.buggy-fixed.fixed".format(data_dir, data_dir)
    elif task == "translate":
        data_dir = "{}/{}".format(data_root, task)
        if sub_task == "cs-java":
            train_fn = "{}/train.java-cs.txt.cs,{}/train.java-cs.txt.java".format(data_dir, data_dir)
            dev_fn = "{}/valid.java-cs.txt.cs,{}/valid.java-cs.txt.java".format(data_dir, data_dir)
            test_fn = "{}/test.java-cs.txt.cs,{}/test.java-cs.txt.java".format(data_dir, data_dir)
        else:
            train_fn = "{}/train.java-cs.txt.java,{}/train.java-cs.txt.cs".format(data_dir, data_dir)
            dev_fn = "{}/valid.java-cs.txt.java,{}/valid.java-cs.txt.cs".format(data_dir, data_dir)
            test_fn = "{}/test.java-cs.txt.java,{}/test.java-cs.txt.cs".format(data_dir, data_dir)
    elif task == "clone":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/train.txt".format(data_dir)
        dev_fn = "{}/valid.txt".format(data_dir)
        test_fn = "{}/test.txt".format(data_dir)
    elif task == "defect":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/train.jsonl".format(data_dir)
        dev_fn = "{}/valid.jsonl".format(data_dir)
        test_fn = "{}/test.jsonl".format(data_dir)
    elif task == "mathqa":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/mathqa_python_train.json".format(data_dir)
        dev_fn = "{}/mathqa_python_valid.json".format(data_dir)
        test_fn = "{}/mathqa_python_test.json".format(data_dir)
    elif task == "fixeval":
        data_dir = "{}/FixEval/data/{}".format(data_root, sub_task)
        if sub_task == "java":
            train_fn = "{}/processed/src_train.java-java.java,{}/processed/tgt_train.java-java.java".format(
                data_dir, data_dir
            )
            dev_fn = "{}/processed/src_valid.java-java.java,{}/processed/tgt_valid.java-java.java".format(
                data_dir, data_dir
            )
            test_fn = "{}/processed/src_test.java-java.java,{}/processed/tgt_test.java-java.java".format(
                data_dir, data_dir
            )
    elif task == "mbpp":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/train.jsonl".format(data_dir)
        dev_fn = "{}/valid.jsonl".format(data_dir)
        test_fn = "{}/test.jsonl".format(data_dir)
    elif task == "conala":
        data_dir = "{}/{}".format(data_root, task)
        train_fn = "{}/train.json".format(data_dir)
        dev_fn = "{}/valid.json".format(data_dir)
        test_fn = "{}/test.json".format(data_dir)
    elif task == "avatar":
        data_dir = "{}/AVATAR/data/structure-data".format(data_root)
        if sub_task == "java-python":
            train_fn = "{}/train.java-python.java.json,{}/train.java-python.python.json".format(data_dir, data_dir)
            dev_fn = "{}/valid.java-python.java.json,{}/valid.java-python.python.json".format(data_dir, data_dir)
            test_fn = "{}/test.java-python.java.json,{}/test.java-python.python.json".format(data_dir, data_dir)
        elif sub_task == "python-java":
            train_fn = "{}/train.java-python.python.json,{}/train.java-python.java.json".format(data_dir, data_dir)
            dev_fn = "{}/valid.java-python.python.json,{}/valid.java-python.java.json".format(data_dir, data_dir)
            test_fn = "{}/test.java-python.python.json,{}/test.java-python.java.json".format(data_dir, data_dir)

    if split == "train":
        return train_fn
    elif split == "dev":
        return dev_fn
    elif split == "test":
        return test_fn
    else:
        return train_fn, dev_fn, test_fn


def read_examples(filename, data_num, task, parse_as_tree, ip_lang, op_lang):
    read_example_dict = {
        "summarize": read_summarize_examples,
        "refine": read_refine_examples,
        "translate": read_translate_examples,
        "concode": read_concode_examples,
        "clone": read_clone_examples,
        "defect": read_defect_examples,
        "mathqa": read_mathqa_examples,
        "fixeval": read_fixeval_examples,
        "mbpp": read_mbpp_examples,
        "conala": read_conala_examples,
        "avatar": read_avatar_examples,
    }

    # NOTE: Modified to sample data samples after load all examples first
    # example_lst = read_example_dict[task](filename, data_num)
    example_lst = read_example_dict[task](filename, -1)

    # NOTE: Subsample datasamples from all loaded samples
    n_examples = len(example_lst)
    if data_num > 0:
        idxs = np.random.permutation(n_examples)[:data_num]
        example_lst = [example_lst[idx] for idx in idxs]
        print(f"#Subsampling: Total examples: {n_examples}. Subsampled: {len(example_lst)}")

    if parse_as_tree:
        print("Parsing code as trees")
        # Convert input/output code to tree
        for example in example_lst:
            if ip_lang is not None:
                srcstr = parseutils.parsecode(example.source, ip_lang, True)
                if srcstr:
                    srcstr = " " + srcstr.strip()
                example.source = srcstr

            if op_lang is not None:
                tgtstr = parseutils.parsecode(example.target, op_lang, True)
                if tgtstr:
                    tgtstr = " " + tgtstr.strip()
                example.target = tgtstr

        example_lst = [example for example in example_lst if not remove_example(example)]

        print(f"{len(example_lst)} examples filtered")

    print("\t Input:", example_lst[0].source)
    print("\t Output:", example_lst[0].target)

    return example_lst


def remove_example(example):
    if example.source is None or example.target is None:
        return True
    if type(example.source) == str and len(example.source.strip()) == 0:
        return True
    if type(example.target) == str and len(example.target.strip()) == 0:
        return True
    return False


def calc_stats(examples, tokenizer=None, is_tokenize=False):
    avg_src_len = []
    avg_trg_len = []
    avg_src_len_tokenize = []
    avg_trg_len_tokenize = []
    for ex in examples:
        if is_tokenize:
            avg_src_len.append(len(ex.source.split()))
            avg_trg_len.append(len(str(ex.target).split()))
            avg_src_len_tokenize.append(len(tokenizer.tokenize(ex.source)))
            avg_trg_len_tokenize.append(len(tokenizer.tokenize(str(ex.target))))
        else:
            avg_src_len.append(len(ex.source.split()))
            avg_trg_len.append(len(str(ex.target).split()))
    if is_tokenize:
        logger.info(
            "Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
            len(examples),
            np.mean(avg_src_len),
            np.mean(avg_trg_len),
            max(avg_src_len),
            max(avg_trg_len),
        )
        logger.info(
            "[TOKENIZE] avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
            np.mean(avg_src_len_tokenize),
            np.mean(avg_trg_len_tokenize),
            max(avg_src_len_tokenize),
            max(avg_trg_len_tokenize),
        )
    else:
        logger.info(
            "Read %d examples, avg src len: %d, avg trg len: %d, max src len: %d, max trg len: %d",
            len(examples),
            np.mean(avg_src_len),
            np.mean(avg_trg_len),
            max(avg_src_len),
            max(avg_trg_len),
        )


def get_elapse_time(t0):
    elapse_time = time.time() - t0
    if elapse_time > 3600:
        hour = int(elapse_time // 3600)
        minute = int((elapse_time % 3600) // 60)
        return "{}h{}m".format(hour, minute)
    else:
        minute = int((elapse_time % 3600) // 60)
        return "{}m".format(minute)
