# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import, division, print_function

import argparse
import glob
import logging
import logzero
from logzero import logger
import os, sys, time, os.path as osp
import random

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from models.modeling_multilabel_aplc import (
    BertAPLC,
    RobertaAPLC,
    XLNetAPLC)

from models.modeling_multilabel_lightxml import (
    BertLightXML,
)

from models.modeling_multilabel_bert import (
    BertForMultiLabelSequenceClassification,
    BertAttentionXML,
    BertCombined,
    DistilBertForMultiLabelSequenceClassification,
)

from models.modeling_multilabel_xlnet import (
    XLNetForMultiLabelSequenceClassification,
    XLNetAttentionXML,
)

from models.modeling_multilabel_roberta import (
    RobertaForMultiLabelSequenceClassification,
    RobertaAttentionXML,
    #RobertaCombined,
)

from transformers import (
    DistilBertConfig,
    RobertaConfig,
    #BertConfig,
    XLNetConfig
)

from pytorch_transformers import BertConfig

from transformers import AdamW, get_linear_schedule_with_warmup
from utils_multi_label import convert_examples_to_features,output_modes, processors,\
    eval_batch, eval_precision, metric_pk, count_parameters, get_one_hot

from multi_label_dataset import MultiLabelDataset
import shutil

sys.path.append('/usr0/home/ruohongz/XMTC/XMR_tool')
from utils.evaluation_metric import *

# ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [XLNetConfig]), ())

MODEL_CLASSES = {
    'distilbert': (DistilBertConfig, DistilBertForMultiLabelSequenceClassification),
    "bert": (BertConfig, BertForMultiLabelSequenceClassification),
    'bertattn': (BertConfig, BertAttentionXML),
    'bertlightxml': (BertConfig, BertLightXML),
    'bertcombined': (BertConfig, BertCombined),
    "roberta": (RobertaConfig, RobertaForMultiLabelSequenceClassification),
    'robertaattn': (RobertaConfig, RobertaAttentionXML),
    'xlnet': (XLNetConfig, XLNetForMultiLabelSequenceClassification),
    'xlnetattn': (XLNetConfig, XLNetAttentionXML),
    'xlnetaplc': (XLNetConfig, XLNetAPLC),
    'bertaplc': (BertConfig, BertAPLC),
    'robertaaplc': (RobertaConfig, RobertaAPLC),
}


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(args.seed)


def train(args, train_dataset, model, model_class, start_step=0):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer_grouped_parameters = model_class.get_param(model, learning_rate_x = args.learning_rate_x,
                                           learning_rate_h = args.learning_rate_h,
                                           learning_rate_a = args.learning_rate_a,
                                           weight_decay =  args.weight_decay)
    optimizer = AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)


    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info(f" logging steps {args.logging_steps}, save steps {args.save_steps}")

    st = time.time()
    global_step = start_step
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()

    set_seed(args)  # Added here for reproductibility

    for num_epoch in range(int(args.num_train_epochs)):
        # evaluate(args, train_dataset, model, mode='train', prefix=global_step, save=args.save)
        # evaluate(args, test_dataset, model, mode='dev', prefix=global_step, save=args.save)
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])

        p1_list = []
        p3_list = []
        p5_list = []
        loss_list = []
        batch_start_idx = 0
        for step, batch in enumerate(epoch_iterator, 1):

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            batch_end_idx = batch_start_idx + len(batch[0])
            inputs = {'input_ids':      batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                      'labels':         batch[3]}
            outputs = model(**inputs)
            loss, logits = outputs[:2]

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            hit_list = metric_pk(logits, inputs['labels'], aplc=args.aplc)
            batch_start_idx = batch_end_idx
            p1_list.append(hit_list[0])
            p3_list.append(hit_list[1])
            p5_list.append(hit_list[2])
            loss_list.append(loss.item())

            if step % args.logging_steps == 0 and args.local_rank in [-1, 0]:
                lapse = time.time() - st
                logger.info(f"epoch {num_epoch} step {step}, loss =  {np.mean(loss_list):.6f}, "
                            f"p1 = {np.mean(p1_list):.2f}, p3= {np.mean(p3_list):.2f}, p5 = {np.mean(p5_list):.2f}, "
                            f"total time {lapse:.2f}, per batch {lapse/global_step:.4f}, estimated total {lapse/global_step*t_total:.2f}")
                
            tr_loss += loss.item()
            if step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                #     # Log metrics
                #     if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                #         results = evaluate(args, model)
                #         for key, value in results.items():
                #             tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                #     tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                #     tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
                #     logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    if args.save_epoch:
                        output_dir = os.path.join(args.output_dir, 'checkpoint-{:010d}'.format(global_step))
                    else:
                        output_dir = os.path.join(args.output_dir, 'checkpoint')
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training

                    model_to_save.save_pretrained(output_dir)
                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)


            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

    return global_step, tr_loss / global_step

@torch.no_grad()
def evaluate(args, dataset, model, mode='dev', prefix="", save='none', checkpoint_name='checkpoint'):
    eval_output_dir = args.output_dir

    eval_dataset = dataset
    labels = dataset.label_ids
    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * args.n_gpu
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_dataset) #if args.local_rank == -1 else DistributedSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    target_rank = []
    target_score = []
    feature_list = []
    attn_list = []
    top_score = []
    top_idx = []
    raw = []

    batch_start_idx = 0
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        # model.train()
        batch = tuple(t.to(args.device) for t in batch)
        batch_end_idx = batch_start_idx + len(batch[0])

        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                  'labels':         batch[3]}
        outputs = model.forward(**inputs, save=save)
        tmp_eval_loss, logits = outputs[:2]
        eval_loss += tmp_eval_loss.mean().item()

        if 'top' in save:
            ts, ti = logits.topk(k=args.topk, dim=-1)
            top_score.append(ts.cpu().numpy())
            top_idx.append(ti.cpu().numpy())

        pred = logits.cpu().numpy()
        if 'raw' in save:
            if 'combined' in args.model_type and len(outputs)==4: # combined
                pred1 = outputs[2].cpu().numpy()
                pred2 = outputs[3].cpu().numpy()
                raw.append(np.concatenate([pred, pred1, pred2], -1))
            else:
                raw.append(pred)
        #logger.info(f'batch: {batch_start_idx}-{batch_end_idx}')
        if 'result' in save:
            trank, tscore = get_target_rank_from_pred(pred, labels[batch_start_idx:batch_end_idx])
            target_rank.extend(trank)
            target_score.extend(tscore)

        nb_eval_steps += 1

        if 'attn' in save and 'attn' in args.model_type and len(outputs) == 3: # loss, logits, attn
            top_pred_id = np.argsort(pred)[:, ::-1][:, :5] # batch, top5 category
            attn = outputs[-1].cpu().numpy() # shape: batch, category, seq_len
            for i in range(len(top_pred_id)):
                attn_list.append(attn[i, top_pred_id[i]])

        if 'feature' in save:
            feature = outputs[-1].cpu().numpy()
            feature_list.append(feature)

        batch_start_idx = batch_end_idx

    #logger.info(f'{sys.exc_info()[0]} occurred.')

    eval_loss = eval_loss / nb_eval_steps
    logger.info(f"eval loss: {eval_loss}")

    if 'result' in save:
        d =  {'rank': target_rank, 'score': target_score}
        if 'raw' in save:
            raw = np.concatenate(raw, 0)
            d['raw'] = raw
        res_path = osp.join(eval_output_dir, f'results.{mode}.npy')
        logger.info(f'save results to {res_path}')
        np.save(res_path, d)
        res, name = instance_metrics(target_rank, model_name=f'{args.model_name_or_path}', title=True)
        logger.info(f"{name}")
        logger.info(f"{res}")

    if 'top' in save:
        d = {'top_score': np.concatenate(top_score, 0),
             'top_idx': np.concatenate(top_idx, 0)}
        res_path = osp.join(eval_output_dir, f'top.{mode}.npy')
        logger.info(f'top results to {res_path}')
        np.save(res_path, d)

    if 'attn' in save and 'attn' in args.model_type:
        res_path = osp.join(eval_output_dir, f'attn.{mode}.npy')
        np.save(res_path, np.concatenate(attn_list, 0))

    if 'feature' in save:
        feature_list = np.concatenate(feature_list, axis=0)
        feature_path = os.path.join(args.data_dir, f'{args.model_name_or_path}.{mode}.npy')
        np.save(feature_path, feature_list)

    #return results


def load_dataset(args, data_type='train'):
    # Load data features from cache or dataset file
    cached_data_file = os.path.join(args.data_dir,
                                    f'{data_type}.{args.model_name_or_path}.{args.max_seq_length}{args.data_suffix}')
    label_file = osp.join(args.data_dir, f'label.{data_type}.{args.label_suffix}.npy')
    if os.path.exists(cached_data_file):
        logger.info(f"Loading features from cached file {cached_data_file}", )
        features = torch.load(cached_data_file)
        #labels = data['labels']
        #num_labels = data['num_labels']
    else:
        logger.info(f"preprocess data first for {cached_data_file}")
        exit()
    if osp.exists(label_file):
        labels, num_labels = np.load(label_file, allow_pickle=True)
    else:
        logger.info(f"label not exist for {label_file}")
        exit()
    dataset = MultiLabelDataset(features, labels, num_labels, args.pos_label, aplc=args.aplc)
    return dataset


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,)
    parser.add_argument("--task_name", default=None, type=str, required=True,
                        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--log_dir", default=None, type=str)

    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Rul evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")

    parser.add_argument("--learning_rate_x", default=5e-5, type=float,
                        help="The initial learning rate for XLNet.")
    parser.add_argument("--learning_rate_h", default=1e-4, type=float,
                        help="The initial learning rate for the last hidden layer.")
    parser.add_argument("--learning_rate_a", default=1e-3, type=float,
                        help="The initial learning rate for APLC.")


    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--use_seq_summary", action='store_true')

    parser.add_argument('--logging_steps', type=int, default=1000,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=1000,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_checkpoints", type=str, default='last',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--save_epoch', action='store_true',
                        help="save model with epoch")
    parser.add_argument('--seed', type=int, default=313,
                        help="random seed for initialization")

    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")


    parser.add_argument('--num_label', type=int, default=2, help="the number of labels.")
    parser.add_argument('--pos_label', type=int, default=2, help="the number of maximum labels for one sample.")
    parser.add_argument('--adaptive_cutoff', nargs='+', type=int, default=[], help="the number of labels in different clusters")
    parser.add_argument('--div_value', type=float, default=2.0, help="the decay factor of the dimension of the hidden state")
    parser.add_argument('--last_hidden_size', type=int, default=768, help="the dimension of last hidden layer")
    parser.add_argument('--gpu', type=str, default='', help="the GPUs to use ")

    parser.add_argument('--save', type=str, default='none')
    parser.add_argument('--mode', type=str, default='test')
    parser.add_argument("--data_suffix", default="", type=str, )
    parser.add_argument("--label_suffix", default="", type=str,)
    parser.add_argument("--finetune", action='store_true')
    parser.add_argument("--aplc", action='store_true')
    parser.add_argument("--topk", type=int, default=20)
    parser.add_argument("--keyword_only", action='store_true')
    
    # model config
    parser.add_argument("--num_hidden_layers", default=None, type=int)
    parser.add_argument("--hidden_dropout_prob", default=None, type=float)
    parser.add_argument("--attention_probs_dropout_prob", default=None, type=float)
    parser.add_argument("--positional_encoding", default='absolute', type=str)
    parser.add_argument("--bottleneck_size", default=None, type=int)
    parser.add_argument("--feature_layers", default=5, type=int)
    


    args = parser.parse_args()

    global test_dataset
    
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))

    train_dataset = test_dataset = None
    if args.do_train:
        train_dataset = load_dataset(args, data_type='train')
        args.num_labels = train_dataset.num_labels
    if args.do_eval:
        if 'test' in args.mode or 'dev' in args.mode:
            test_dataset = load_dataset(args, data_type='dev')
            args.num_labels = test_dataset.num_labels
        if 'train' in args.mode: # save features from training instances
            train_dataset = train_dataset if train_dataset is not None else load_dataset(args, data_type='train')
            args.num_labels = train_dataset.num_labels
    logger.info(f"num labels {args.num_labels}")

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    if args.gpu is not '':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    os.makedirs(os.path.dirname(args.log_dir), exist_ok=True)
    logzero.logfile(args.log_dir)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)

    # Set seed
    set_seed(args)

    args.output_mode = "classification"

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
             num_labels=args.num_labels, finetuning_task=args.task_name)

    config.adaptive_cutoff = args.adaptive_cutoff
    config.div_value = args.div_value

    # config.last_hidden_size = args.last_hidden_size
    # flexible for different model size
    config.bottleneck_size = args.bottleneck_size #if args.bottleneck_size is not None else config.hidden_size
    if args.bottleneck_size is not None:
        config.last_hidden_size = args.bottleneck_size
    else:
        config.last_hidden_size = config.d_model if hasattr(config, 'd_model') else config.hidden_size
    logger.info(f"feature size: {config.last_hidden_size}")
    config.n_token = 32000

    # hyper param from args
    if 'xlnet' in args.model_type:
        config.n_layer = config.n_layer if args.num_hidden_layers is None else args.num_hidden_layers
        config.dropout = config.dropout if args.hidden_dropout_prob is None else args.hidden_dropout_prob
    else:
        config.num_hidden_layers = config.num_hidden_layers if args.num_hidden_layers is None else args.num_hidden_layers
        config.hidden_dropout_prob = config.hidden_dropout_prob if args.hidden_dropout_prob is None else args.hidden_dropout_prob
        config.attention_probs_dropout_prob = config.attention_probs_dropout_prob if args.attention_probs_dropout_prob is None else args.attention_probs_dropout_prob
        config.positional_encoding = args.positional_encoding

    if args.model_type=='roberta' and args.use_seq_summary:
        logger.info("use seq summary")
        config.summary_activation = "tanh"
        config.summary_last_dropout = 0.1
        config.summary_type = "first"
        config.summary_use_proj = True

    if 'lightxml' in args.model_type:
        logger.info('config light xml')
        config.feature_layers = args.feature_layers

    logger.info(f" model type: {args.model_type}, model name: {args.model_name_or_path}")
    logger.info(f"num_classes {args.num_labels}, last hidden {config.last_hidden_size}")
    # logger.info(f" adaptive_cutoff {args.adaptive_cutoff}, div_value {args.div_value}")

    # Training
    if args.do_train:
        if args.finetune:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/', recursive=True)))
            checkpoint_name = checkpoints[-1]
            logger.info(f'finetune {checkpoint_name}')
            model = model_class.from_pretrained(checkpoint_name, from_tf=bool('.ckpt' in args.model_name_or_path),
                                                config=config)
            start_step = int(checkpoint_name.split('-')[-1])
        else:
            model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
            start_step = 0
        #print(model)
        params_1 =  count_parameters(model)
        print('the number of params: ', params_1)


        if args.local_rank == 0:
            torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

        model.to(args.device)
        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                              output_device=args.local_rank,
                                                              find_unused_parameters=True)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        logger.info("Training/evaluation parameters %s", args)


        #train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, model_class, start_step = start_step)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        logger.info("save best-practices")
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        #tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        #tokenizer = tokenizer_class.from_pretrained(args.output_dir)
        model.to(args.device)


    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        WEIGHTS_NAME = 'pytorch_model.bin'
        checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
        logger.info("List of all checkpoints: %s", checkpoints)
        if args.eval_checkpoints == 'last':
            checkpoints = [checkpoints[-1]]
        elif args.eval_checkpoints == 'first':
            checkpoints = [checkpoints[0]]
        elif args.eval_checkpoints == 'last_epoch':
            checkpoints = [checkpoints[-2]]
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            if args.n_gpu > 1:
                model = torch.nn.DataParallel(model)

            checkpoint_name = osp.basename(checkpoint)
            if 'train' in args.mode:
                evaluate(args, train_dataset, model, mode='train', prefix=global_step, save=args.save, checkpoint_name=checkpoint_name)
            if 'test' in args.mode or 'dev' in args.mode:
                evaluate(args, test_dataset, model, mode='dev', prefix=global_step, save=args.save,
                                  checkpoint_name=checkpoint_name)

    return results


if __name__ == "__main__":
    main()
