import functools
import torch
import transformers
import argparse

import os
os.environ['No_GCE_CHECK'] = 'true'
from tensorflow_datasets.core.utils import gcs_utils
gcs_utils.gcs_dataset_info_files = lambda *args, **kwargs: None
gcs_utils.is_dataset_on_gcs = lambda *args, **kwargs: False

import globals
globals.best_checkpoint = 0
globals.best_result = 0.0

#######################################
local_root_path = './'
local_tfds_path = './tensorflow_datasets'
globals.tfds_path = local_tfds_path
#######################################

import data.mixtures
from models.hf_model import HfPyTorchModel

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

def main(argv):
    model = HfPyTorchModel(argv.model_dir, argv.save_dir, device, argv.teacher_model_dir, argv.teacher_save_dir, argv.thresholds_list)
    if argv.train and not argv.distill:
        model.train(
            argv=argv,
            mixture_or_task_name=argv.task,
            steps=argv.steps,
            save_steps=argv.save_steps,
            sequence_length={"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets},
            split="train",
            batch_size=argv.train_batch_size,
            optimizer=functools.partial(transformers.Adafactor, lr=argv.lr, relative_step=False),
            train_split_ratio=argv.train_split_ratio,
        )
    if argv.distill and argv.train:
        model.distill(
            argv=argv,
            mixture_or_task_name=argv.task,
            steps=argv.steps,
            save_steps=argv.save_steps,
            sequence_length={"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets},
            split="train",
            batch_size=argv.train_batch_size,
            optimizer=functools.partial(transformers.Adafactor, lr=argv.lr, relative_step=False),
            train_split_ratio=argv.train_split_ratio,
        )
    if argv.train_head:
        model.train_head(
            mixture_or_task_name=argv.task,
            steps=argv.steps,
            save_steps=argv.save_steps,
            sequence_length={"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets},
            split="train",
            batch_size=argv.train_batch_size,
            optimizer=functools.partial(transformers.Adafactor, lr=1e-2, relative_step=False),
            #optimizer=functools.partial(optim.SGD, lr=0.005),
        )
    if argv.eval:
        if argv.task == 'super_glue_wsc_v102_simple_train':
            argv.task = 'super_glue_wsc_v102_simple_eval'
        # only consider Teacher in evaluation for joint-inference (not distillation)
        if argv.distill:
            model.teacher_model_dir = None
        # For translation, we need to pass the max_length, otherwise, it does not work
        sequence_length = {"inputs": argv.sequence_length_inputs, "targets": argv.sequence_length_targets}
        model.eval(
            argv.task,
            sequence_length=sequence_length,
            # we cannot enable the following option when head used. Because head has been trained with a fixed length (e.g., 256). The exact same length should be passed in this case.
            #compute_sequence_length=True,
            batch_size=argv.eval_batch_size,
            checkpoint_steps=argv.checkpoint_steps,
            split=argv.eval_split,
            head=argv.head,
            router=argv.router,
            head_prediction=argv.head_prediction,
            model_type=argv.model_type,
        )

if __name__ == "__main__":
    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--ddp", action="store_true", help="DDP Training")
    parser.add_argument("--train", action="store_true", help="Training")
    parser.add_argument("--train_split_ratio", type=float, help="The split ratio for train set - for distilliation purpose - default: 100%", default=None)
    parser.add_argument("--teacher_train_split_ratio", type=float,
                        help="The split ratio for train set - default: 100% - this arg is just for finding the right teacher_save_dir at evaluation not training", default=None)

    parser.add_argument("--train_head", action="store_true", help="Training head")
    parser.add_argument("--distill", action="store_true", help="Distilling Student from Teacher")
    parser.add_argument("--eval", action="store_true", help="Evaluating")
    parser.add_argument("--eval_split", type=str, help="Evaluation split, e.g., validation_matched for mnli", default="validation")
    parser.add_argument("--checkpoint_steps", type=str, help="which checkpoint used for inference. e.g., 'all'", default=None)
    parser.add_argument("--model_dir", type=str, help="Directory for reading T5 models.", default=None)
    parser.add_argument("--model_type", type=str, help="Type of T5 model.", default='small')
    parser.add_argument("--save_dir", type=str, help="Directory for saving models.", default=None)
    parser.add_argument("--head_path", type=str, help="Filepath for head checkpoint.", default=None)
    parser.add_argument("--head", action="store_true", help="Use")
    parser.add_argument("--router", type=str, help="energy, entropy, softmax, random", default='energy')
    parser.add_argument("--head_prediction", action="store_true", help="Do the predictions with the extra head")

    parser.add_argument("--teacher_model_dir", type=str, help="Directory for reading T5 teacher models.", default=None)
    parser.add_argument("--teacher_save_dir", type=str, help="Directory for saving teacher models.", default=None)
    parser.add_argument("--teacher_model_type", type=str, help="Type of Teacher T5 model.", default=None)
    #parser.add_argument("--thresholds_list", nargs='+', type=float, help="list of thresholds.", default=[0.0])
    parser.add_argument("--thresholds_list", type=str, help="list of thresholds.", default='0.0')
    parser.add_argument("--lr", type=float, help="Learning rate.", default=1e-3)

    parser.add_argument("--benchmark", type=str, help="Benchmark name (e.g., glue).", default="glue")
    parser.add_argument("--task", type=str, help="GLUE task.", default="cola")
    parser.add_argument("--steps", type=int, help="Number of training steps.", default=7000)
    parser.add_argument("--save_steps", type=int, help="Saving steps.", default=1000)
    parser.add_argument("--sequence_length_inputs", type=int, help="", default=None)
    parser.add_argument("--sequence_length_targets", type=int, help="", default=None)
    parser.add_argument("--train_batch_size", type=int, help="Training batch size.", default=8)
    parser.add_argument("--eval_batch_size", type=int, help="Evaluation batch size.", default=8)

    argv = parser.parse_args()

    if argv.model_dir is None:
        argv.model_dir = './t5-' + argv.model_type
        argv.model_dir = local_root_path + argv.model_dir
    if argv.save_dir is None:
        distill_tag = ''
        if argv.distill:
            distill_tag = '-distill_11b'
        train_split_ratio_tag = ''
        if argv.train_split_ratio:# <= 1.0:
            train_split_ratio_tag = '-tsr' + str(argv.train_split_ratio)
        argv.save_dir = '/'+argv.benchmark+'/' + argv.task + '/t5/' + argv.benchmark + '_'+ argv.task + '_t5-' + argv.model_type + '-bs8-AdaFactor-lr.001' + train_split_ratio_tag  + distill_tag
        argv.save_dir = local_root_path + argv.save_dir

    if argv.head_path:
        argv.head_path = argv.save_dir + '/' + argv.head_path

    if argv.teacher_model_type:
        ### allow head_prediction only in single-mode (not joint-mode)
        argv.head_prediction = False
        if argv.teacher_model_dir is None:
            argv.teacher_model_dir = './t5-' + argv.teacher_model_type
            argv.teacher_model_dir = local_root_path + argv.teacher_model_dir

        if argv.teacher_save_dir is None:
            teacher_train_split_ratio_tag = ''
            if argv.teacher_train_split_ratio:  # <= 1.0:
                teacher_train_split_ratio_tag = '-tsr' + str(argv.teacher_train_split_ratio)
            argv.teacher_save_dir = '/'+argv.benchmark+'/' + argv.task + '/t5/' + argv.benchmark + '_'+ argv.task + '_t5-' + argv.teacher_model_type + '-bs8-AdaFactor-lr.001' + teacher_train_split_ratio_tag
            argv.teacher_save_dir = local_root_path + argv.teacher_save_dir

    if argv.benchmark=='glue':
        if argv.task == 'rte':
            argv.sequence_length_inputs = 284
            argv.sequence_length_targets = 6
        elif argv.task == 'cola':
            argv.sequence_length_inputs = 42
            argv.sequence_length_targets = 2
        elif argv.task == 'mrpc':
            argv.sequence_length_inputs = 114
            argv.sequence_length_targets = 6
        elif argv.task == 'qnli':
            argv.sequence_length_inputs = 310
            argv.sequence_length_targets = 6
        elif argv.task == 'sst2':
            argv.sequence_length_inputs = 84
            argv.sequence_length_targets = 2
        elif argv.task == 'mnli':
            if argv.eval_split == 'validation_mismatched':
                argv.sequence_length_inputs = 268
                argv.sequence_length_targets = 5
            else:
                argv.eval_split = 'validation_matched'
                argv.sequence_length_inputs = 272
                argv.sequence_length_targets = 5
    elif argv.benchmark=='super_glue': #super_glue_boolq_v102
        if argv.task == 'boolq':
            argv.sequence_length_inputs = 1142 # 1142
            argv.sequence_length_targets = 4 # 4
        elif argv.task == 'cb':
            argv.sequence_length_inputs = 301 # 301
            argv.sequence_length_targets = 5 # 5
        elif argv.task == 'copa':
            argv.sequence_length_inputs = 53 # 53
            argv.sequence_length_targets = 3 # 3
        elif argv.task == 'multirc':
            argv.sequence_length_inputs = 687  # 687
            argv.sequence_length_targets = 4  # 4
        elif argv.task == 'record': # there is an error
            argv.sequence_length_inputs = 1142  #
            argv.sequence_length_targets = 31  #
        elif argv.task == 'rte':
            argv.sequence_length_inputs = 256  # 283
            argv.sequence_length_targets = 6  # 6
        elif argv.task == 'wic':
            argv.sequence_length_inputs = 256  # 77
            argv.sequence_length_targets = 4  # 4
        elif argv.task == 'wsc': # for the text-to-text, both lengths should be equal!
            argv.sequence_length_inputs = 84  # 84
            argv.sequence_length_targets = 84  # 8

        # set a fixed length for training
        if argv.train:
            argv.sequence_length_inputs = 256

    if argv.task == 'wmt15_enfr_v003':
        argv.sequence_length_inputs = 172 # 135
        argv.sequence_length_targets = 172
    elif argv.task == 'wmt16_enro_v003': # for the text-to-text, both lengths should be equal!
        argv.sequence_length_inputs = 211 # 161
        argv.sequence_length_targets = 211
    elif argv.task == 'wmt14_ende_v003':
        argv.sequence_length_inputs = 167 # 128
        argv.sequence_length_targets = 167
    elif argv.task == 'wmt14_enfr_v003':
        argv.sequence_length_inputs = 128
        argv.sequence_length_targets = 191
    elif argv.task == 'anli': # for distillation
        argv.sequence_length_inputs = 231
        argv.sequence_length_targets = 3
    elif argv.task == 'snli': # for distillation
        argv.sequence_length_inputs = 139
        argv.sequence_length_targets = 5
    #elif argv.task == 'dpr_v001_simple':
    #    argv.sequence_length_inputs = 139
    #    argv.sequence_length_targets = 5
    elif argv.task == 'squad_v010':
        argv.sequence_length_inputs = 956
        argv.sequence_length_targets = 55
    elif argv.task == 'cnn_dailymail_v002':
        argv.sequence_length_inputs = 139
        argv.sequence_length_targets = 5
    else:
        task_name = argv.benchmark + '_'+  argv.task
        if argv.benchmark == 'glue':
            task_name = task_name + '_v002'
        elif argv.benchmark == 'super_glue':
            if argv.task == 'wsc':
                if argv.train or argv.train_head:
                    task_name = task_name + '_v102_simple_train'
                else:
                    task_name = task_name + '_v102_simple_eval'
            else:
                task_name  = task_name + '_v102'
        argv.task = task_name

    #### based on the T5 paper
    #argv.sequence_length_inputs = 128

    main(argv)