# -*- coding:utf-8 -*- 
import argparse
import glob
import logging
import os
import random
import copy
import math
import json
import numpy as np
import torch
from torch.nn import CrossEntropyLoss, KLDivLoss, NLLLoss, BCEWithLogitsLoss
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import torch.nn.functional as F
from tqdm import tqdm, trange
import sys
import pickle as pkl

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    RobertaConfig,
    RobertaForTokenClassification,
    RobertaTokenizer,
    get_linear_schedule_with_warmup,
)

from models.modeling_roberta_debias_bin import RobertaForTokenClassification_Modified
from utils.data_utils import load_and_cache_examples, get_labels, tag_to_id
from utils.model_utils import mask_tokens, mask_bitokens, soft_frequency, opt_grad, get_hard_label, _update_mean_model_variables
from utils.eval import evaluate_ori
from utils.config import config
from utils.loss_utils import NegEntropy, GCELoss, WorstCaseEstimationLoss

logger = logging.getLogger(__name__)

MODEL_NAMES = {
    "student1":"Roberta",
    "student2":"DistilRoberta",
    "teacher1":"Roberta",
    "teacher2":"DistilRoberta"
}
MODEL_CLASSES = {
    "student1": (RobertaConfig, RobertaForTokenClassification_Modified, RobertaTokenizer),
    "student2": (RobertaConfig, RobertaForTokenClassification_Modified, RobertaTokenizer),
}

torch.set_printoptions(profile="full")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

def initialize(args, t_total, num_labels, epoch):
    config_class, model_class, _ = MODEL_CLASSES["student1"]
    config_s1 = config_class.from_pretrained(
        args.student1_config_name if args.student1_config_name else args.student1_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s1 = model_class.from_pretrained(
        args.student1_model_name_or_path,
        from_tf=bool(".ckpt" in args.student1_model_name_or_path),
        config=config_s1,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s1.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student2"]
    config_s2 = config_class.from_pretrained(
        args.student2_config_name if args.student2_config_name else args.student2_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s2 = model_class.from_pretrained(
        args.student2_model_name_or_path,
        from_tf=bool(".ckpt" in args.student2_model_name_or_path),
        config=config_s2,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_s2.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student1"]
    config_t1 = config_class.from_pretrained(
        args.student1_config_name if args.student1_config_name else args.student1_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t1 = model_class.from_pretrained(
        args.student1_model_name_or_path,
        from_tf=bool(".ckpt" in args.student1_model_name_or_path),
        config=config_t1,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t1.to(args.device)

    config_class, model_class, _ = MODEL_CLASSES["student2"]
    config_t2 = config_class.from_pretrained(
        args.student2_config_name if args.student2_config_name else args.student2_model_name_or_path,
        num_labels=num_labels,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t2 = model_class.from_pretrained(
        args.student2_model_name_or_path,
        from_tf=bool(".ckpt" in args.student2_model_name_or_path),
        config=config_t2,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model_t2.to(args.device)

    no_decay = ["bias", "LayerNorm.weight"]

    optimizer_grouped_parameters_1 = [
        {
            "params": [p for n, p in model_s1.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s1.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s1 = AdamW(optimizer_grouped_parameters_1, lr=args.learning_rate, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s1 = get_linear_schedule_with_warmup(
        optimizer_s1, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    optimizer_grouped_parameters_2 = [
        {
            "params": [p for n, p in model_s2.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model_s2.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer_s2 = AdamW(optimizer_grouped_parameters_2, lr=args.learning_rate, \
            eps=args.adam_epsilon, betas=(args.adam_beta1,args.adam_beta2))
    scheduler_s2 = get_linear_schedule_with_warmup(
        optimizer_s2, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        [model_s1, model_s2, model_t1, model_t2], [optimizer_s1, optimizer_s2] = amp.initialize(
                     [model_s1, model_s2, model_t1, model_t2], [optimizer_s1, optimizer_s2], opt_level=args.fp16_opt_level)

    # Multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        # model_t = torch.nn.DataParallel(model_t)
        model_s1 = torch.nn.DataParallel(model_s1)
        model_s2 = torch.nn.DataParallel(model_s2)
        model_t1 = torch.nn.DataParallel(model_t1)
        model_t2 = torch.nn.DataParallel(model_t2)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_s1 = torch.nn.parallel.DistributedDataParallel(
            model_s1, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_s2 = torch.nn.parallel.DistributedDataParallel(
            model_s2, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_t1 = torch.nn.parallel.DistributedDataParallel(
            model_t1, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )
        model_t2 = torch.nn.parallel.DistributedDataParallel(
            model_t2, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
        )

    model_s1.zero_grad()
    model_s2.zero_grad()
    model_t1.zero_grad()
    model_t2.zero_grad()

    for param in model_t1.parameters():
        param.detach_()
    for param in model_t2.parameters():
        param.detach_()
    return model_s1, model_s2, model_t1, model_t2, optimizer_s1, scheduler_s1, optimizer_s2, scheduler_s2

def save_model(args, epoch, tors, model):
    model_name = tors + "_ep" + str(epoch)
    path = os.path.join(args.output_dir+tors, model_name)
    logger.info("Saving model checkpoint to %s", path)
    if not os.path.exists(path):
        os.makedirs(path)
    model_to_save = (
            model.module if hasattr(model, "module") else model
    )
    model_to_save.save_pretrained(path)

def load_model(args, epoch, tors):
    model_name = tors + "_ep" + str(epoch)
    path = os.path.join(args.output_dir+tors, model_name)
    model = RobertaForTokenClassification_Modified.from_pretrained(path)
    model.to(args.device)
    return model

def validation(args, model, tokenizer, labels, pad_token_label_id, best_dev, best_test,
                  global_step, t_total, epoch, tors):
    model_type = MODEL_NAMES[tors].lower()
    results, _, best_dev, is_updated1 = evaluate_ori(args, model, tokenizer, labels, pad_token_label_id, best_dev, mode="dev", \
        logger=logger, prefix='dev [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
    results, _, best_test, is_updated2 = evaluate_ori(args, model, tokenizer, labels, pad_token_label_id, best_test, mode="test", \
        logger=logger, prefix='test [Step {}/{} | Epoch {}/{}]'.format(global_step, t_total, epoch, args.num_train_epochs), verbose=False)
    if args.local_rank in [-1, 0] and is_updated1:
        path = os.path.join(args.output_dir+tors, "checkpoint-best-1")
        logger.info("Saving model checkpoint to %s", path)
        if not os.path.exists(path):
            os.makedirs(path)
        model_to_save = (
                model.module if hasattr(model, "module") else model
        )
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
    if args.local_rank in [-1, 0] and is_updated2:
        path = os.path.join(args.output_dir+tors, "checkpoint-best-2")
        logger.info("Saving model checkpoint to %s", path)
        if not os.path.exists(path):
            os.makedirs(path)
        model_to_save = (
                model.module if hasattr(model, "module") else model
        )
        model_to_save.save_pretrained(path)
        tokenizer.save_pretrained(path)
    return best_dev, best_test, is_updated1

def random_sampler(args, label_, prob=None):
    label = copy.deepcopy(label_)
    mask = (label==0)
    non_entity = label[mask]
    size = non_entity.size(0)
    if prob is not None:
        prob_ = copy.deepcopy(prob)
        softmax = torch.nn.Softmax(dim=-1)
        prob_ = softmax(prob_)
        prob_ = prob_[mask].max(dim=-1)[0]
        prob_ = 1-prob_
    else:
        prob_ = torch.rand(size).to(args.device)
    num_samples = int(0.2*size)
    if num_samples <= 0:
        return label!=-100

    # print(prob_)
    select_ids = torch.multinomial(prob_, num_samples)
    non_entity[select_ids] = -100
    label[label==0] = non_entity
    label_mask = (label!=-100)
    return label_mask

def initial_mask(args, batch):
    if args.dataset in []:
        return None, None
    else:
        label_mask1 = random_sampler(args, batch, prob=None)
        label_mask2 = random_sampler(args, batch, prob=None)
        return label_mask1, label_mask2

def get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2, batch=True):
    if args.dataset in ["conll03", "wikigold"] and batch:
        if dev_is_updated1:
            t_model1 = copy.deepcopy(model_t1)
        if dev_is_updated2:
            t_model2 = copy.deepcopy(model_t2)
    else:
        t_model1 = copy.deepcopy(model_t1)
        t_model2 = copy.deepcopy(model_t2)
    return t_model1, t_model2

def train(args, train_dataset, tokenizer, labels, pad_token_label_id):
    num_labels = len(labels)
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank==-1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps//(len(train_dataloader)//args.gradient_accumulation_steps)+1
    else:
        t_total = len(train_dataloader)//args.gradient_accumulation_steps*args.num_train_epochs

    model_s1, model_s2, model_t1, model_t2, optimizer_s1, scheduler_s1, optimizer_s2, scheduler_s2 = initialize(args, t_total, num_labels, 0)

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0

    tr_loss, logging_loss = 0.0, 0.0
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproductibility
    s1_best_dev, s1_best_test = [0, 0, 0], [0, 0, 0]
    s2_best_dev, s2_best_test = [0, 0, 0], [0, 0, 0]
    t1_best_dev, t1_best_test = [0, 0, 0], [0, 0, 0]
    t2_best_dev, t2_best_test = [0, 0, 0], [0, 0, 0]

    self_learning_teacher_model1 = model_s1
    self_learning_teacher_model2 = model_s2

    softmax = torch.nn.Softmax(dim=1)
    t_model1 = copy.deepcopy(model_s1)
    t_model2 = copy.deepcopy(model_s2)

    loss_regular = NegEntropy()
    w = 1.0
    begin_global_step = len(train_dataloader)*args.begin_epoch//args.gradient_accumulation_steps
    for epoch in train_iterator:
        epoch_iterator = train_dataloader
        if epoch >= args.begin_epoch:
            w = args.wce_weight

        for step, batch in enumerate(epoch_iterator):
            model_s1.train()
            model_s2.train()
            model_t1.train()
            model_t2.train()

            batch = tuple(t.to(args.device) for t in batch)
            valid_pos = batch[2]
            if epoch >= args.begin_epoch:
                delta = global_step-begin_global_step
                if epoch >= args.begin_coguess and delta//args.self_learning_period > 0:
                    if delta%args.self_learning_period == 0:
                        self_learning_teacher_model1 = copy.deepcopy(t_model1)
                        self_learning_teacher_model1.eval()
                        self_learning_teacher_model2 = copy.deepcopy(t_model2)
                        self_learning_teacher_model2.eval()
                    inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos":batch[2]}
                    with torch.no_grad():
                        _,_,_,_,outputs1 = self_learning_teacher_model1(**inputs)
                        _,_,_,_,outputs2 = self_learning_teacher_model2(**inputs)
                    pseudo_labels1 = torch.argmax(outputs2, axis=-1)
                    pseudo_labels2 = torch.argmax(outputs1, axis=-1)
                else:
                    pseudo_labels1 = batch[3][valid_pos>0]
                    pseudo_labels2 = batch[3][valid_pos>0]

                bin_pseudo_labels1, bin_pseudo_labels2 = pseudo_labels1.clone(), pseudo_labels2.clone()
                bin_pseudo_labels1[pseudo_labels1>0] = 1
                bin_pseudo_labels2[pseudo_labels2>0] = 1
                type_pseudo_labels1, type_pseudo_labels2 = pseudo_labels1-1, pseudo_labels2-1
                type_pseudo_labels1[type_pseudo_labels1<0] = -100
                type_pseudo_labels2[type_pseudo_labels1<0] = -100
                type_pos1 = pseudo_labels1>0
                type_pos2 = pseudo_labels2>0

                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos":batch[2]}
                with torch.no_grad():
                    type_logits1,bin_logits1,type_pred_labels1,bin_pred_labels1,logits1 = t_model1(**inputs)
                    type_logits2,bin_logits2,type_pred_labels2,bin_pred_labels2,logits2 = t_model2(**inputs)
                    pred_labels1 = torch.argmax(logits1, dim=-1)
                    pred_labels2 = torch.argmax(logits2, dim=-1)

                    bin_label_mask1 = (bin_pred_labels1==bin_pseudo_labels1)
                    bin_label_mask2 = (bin_pred_labels2==bin_pseudo_labels2)
                    type_label_mask1 = (type_pred_labels1==type_pseudo_labels1)&type_pos1
                    type_label_mask2 = (type_pred_labels2==type_pseudo_labels2)&type_pos2

                logits1 = soft_frequency(logits=type_logits1, power=2)
                logits2 = soft_frequency(logits=type_logits2, power=2)

                if args.self_learning_label_mode == "hard":
                    pred_labels1, label_mask1_ = mask_tokens(args, batch[3], pred_labels1, pad_token_label_id, pred_logits=logits1)
                    pred_labels2, label_mask2_ = mask_tokens(args, batch[3], pred_labels2, pad_token_label_id, pred_logits=logits2)
                elif args.self_learning_label_mode == "soft":
                    type_pred_labels1, type_label_mask1_, bin_label_mask1_ = mask_bitokens(args, logits1, bin_logits1, batch[3])
                    type_pred_labels2, type_label_mask2_, bin_label_mask2_ = mask_bitokens(args, logits2, bin_logits2, batch[3])
                if bin_label_mask1_ is not None:
                    bin_label_mask1 = bin_label_mask1&bin_label_mask1_
                if bin_label_mask2_ is not None:
                    bin_label_mask2 = bin_label_mask2&bin_label_mask2_
                if type_label_mask1_ is not None:
                    type_label_mask1 = type_label_mask1&type_label_mask1_
                if type_label_mask2_ is not None:
                    type_label_mask2 = type_label_mask2&type_label_mask2_

                if epoch >= args.begin_epoch:
                    delta = global_step-begin_global_step
                    if epoch >= args.begin_coguess and delta//args.self_learning_period > 0:
                        if bin_label_mask2_ is not None:
                            bin_label_mask1 = bin_label_mask1&bin_label_mask2_
                        if bin_label_mask1_ is not None:
                            bin_label_mask2 = bin_label_mask2&bin_label_mask1_
                        if type_label_mask2_ is not None:
                            type_label_mask1 = type_label_mask1&type_label_mask2_
                        if type_label_mask1_ is not None:
                            type_label_mask2 = type_label_mask2&type_label_mask1_
            else:
                pseudo_labels1 = batch[3][valid_pos>0]
                pseudo_labels2 = batch[3][valid_pos>0]
                bin_pseudo_labels1, bin_pseudo_labels2 = pseudo_labels1.clone(), pseudo_labels2.clone()
                bin_pseudo_labels1[pseudo_labels1>0] = 1
                bin_pseudo_labels2[pseudo_labels2>0] = 1
                type_pseudo_labels1, type_pseudo_labels2 = pseudo_labels1-1, pseudo_labels2-1
                type_pseudo_labels1[type_pseudo_labels1<0] = -100
                type_pseudo_labels2[type_pseudo_labels1<0] = -100
                type_pos1 = pseudo_labels1>0
                type_pos2 = pseudo_labels2>0

                pred_labels1 = batch[3][valid_pos>0]
                pred_labels2 = batch[3][valid_pos>0]
                bin_pred_labels1, bin_pred_labels2 = pred_labels1.clone(), pred_labels2.clone()
                bin_pred_labels1[pred_labels1>0] = 1
                bin_pred_labels2[pred_labels2>0] = 1
                type_pred_labels1, type_pred_labels2 = pred_labels1-1, pred_labels2-1
                type_pred_labels1[type_pred_labels1<0] = -100
                type_pred_labels2[type_pred_labels2<0] = -100

                label_mask1, label_mask2 = initial_mask(args, batch[3][valid_pos>0])
                bin_label_mask1, bin_label_mask2 = label_mask1, label_mask2
                type_label_mask1, type_label_mask2 = label_mask1&type_pos1, label_mask2&type_pos2

            inputs1 = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2], "train":True}
            type_logits1,bin_logits1,type_logits_adv1,bin_logits_adv1,_,_,_,_,_ = model_s1(**inputs1)

            inputs2 = {"input_ids": batch[0], "attention_mask": batch[1], "valid_pos": batch[2], "train":True}
            type_logits2,bin_logits2,type_logits_adv2,bin_logits_adv2,_,_,_,_,_ = model_s2(**inputs2)

            if type_label_mask1.sum().item()==0:
                type_label_mask1 = type_pos1
            if type_label_mask2.sum().item()==0:
                type_label_mask2 = type_pos2

            bin_filtered_sel1 = bin_label_mask1.view(-1)
            bin_filtered_sel2 = bin_label_mask2.view(-1)
            type_filtered_sel1 = type_label_mask1.view(-1)
            type_filtered_sel2 = type_label_mask2.view(-1)

            bin_idx_unchosen1 = (bin_filtered_sel1 == False)
            bin_idx_unchosen2 = (bin_filtered_sel2 == False)
            type_idx_unchosen1 = (type_filtered_sel1 == False)
            type_idx_unchosen2 = (type_filtered_sel2 == False)

            if epoch < args.begin_epoch:
                type_loss_fct = CrossEntropyLoss()
                bin_loss_fct = BCEWithLogitsLoss()
            else:
                bin_loss_fct = BCEWithLogitsLoss()
                if type_pred_labels1.shape == type_logits1.shape:
                    type_loss_fct = KLDivLoss(reduction='sum')
                    type_logits1 = F.log_softmax(type_logits1,dim=-1)
                    type_logits2 = F.log_softmax(type_logits2,dim=-1)
                    type_pred_labels1 = type_pred_labels1/type_pred_labels1.sum(dim=-1,keepdim=True)
                    type_pred_labels2 = type_pred_labels2/type_pred_labels2.sum(dim=-1,keepdim=True)
                else:
                    type_loss_fct = CrossEntropyLoss()
            bin_wce_loss = WorstCaseEstimationLoss(2).to(device)
            type_wce_loss = WorstCaseEstimationLoss(2).to(device)

            loss1 = bin_loss_fct(bin_logits1[bin_filtered_sel1].view(-1), bin_pred_labels1[bin_filtered_sel1].view(-1).float()) + \
                    type_loss_fct(type_logits1[type_filtered_sel1], type_pred_labels1[type_filtered_sel1])
            loss2 = bin_loss_fct(bin_logits2[bin_filtered_sel2].view(-1), bin_pred_labels2[bin_filtered_sel2].view(-1).float()) + \
                    type_loss_fct(type_logits2[type_filtered_sel2], type_pred_labels2[type_filtered_sel2])

            if bin_filtered_sel1.sum().item() and bin_idx_unchosen1.sum().item():
                loss1 = loss1 + w * bin_wce_loss(bin_logits1[bin_filtered_sel1], bin_logits_adv1[bin_filtered_sel1], bin_logits1[bin_idx_unchosen1], bin_logits_adv1[bin_idx_unchosen1])
            if bin_filtered_sel2.sum().item() and bin_idx_unchosen2.sum().item():
                loss2 = loss2 + w * bin_wce_loss(bin_logits2[bin_filtered_sel2], bin_logits_adv2[bin_filtered_sel2], bin_logits2[bin_idx_unchosen2], bin_logits_adv2[bin_idx_unchosen2])
            if type_filtered_sel1.sum().item() and type_idx_unchosen1.sum().item():
                loss1 = loss1 + w * type_wce_loss(type_logits1[type_filtered_sel1], type_logits_adv1[type_filtered_sel1], type_logits1[type_idx_unchosen1], type_logits_adv1[type_idx_unchosen1])
            if type_filtered_sel2.sum().item() and type_idx_unchosen2.sum().item():
                loss2 = loss2 + w * type_wce_loss(type_logits2[type_filtered_sel2], type_logits_adv2[type_filtered_sel2], type_logits2[type_idx_unchosen2], type_logits_adv2[type_idx_unchosen2])

            loss_total = loss1 + loss2

            if args.n_gpu > 1:
                loss1 = loss1.mean()  # mean() to average on multi-gpu parallel training
                loss2 = loss2.mean()
            if args.gradient_accumulation_steps > 1:
                loss1 = loss1/args.gradient_accumulation_steps
                loss2 = loss2/args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss1, optimizer_s1) as scaled_loss1:
                    scaled_loss1.backward()
                with amp.scale_loss(loss2, optimizer_s2) as scaled_loss2:
                    scaled_loss2.backward()
            else:
                loss_total.backward()

            tr_loss += loss1.item()+loss2.item()
            if (step+1)%args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_s1), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_s2), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_s1.parameters(), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model_s2.parameters(), args.max_grad_norm)

                optimizer_s1.step()
                scheduler_s1.step()
                optimizer_s2.step()
                scheduler_s2.step()
                model_s1.zero_grad()
                model_s2.zero_grad()
                global_step += 1

                _update_mean_model_variables(model_s1, model_t1, args.mean_alpha, global_step)
                _update_mean_model_variables(model_s2, model_t2, args.mean_alpha, global_step)
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step%args.logging_steps == 0:
                    if args.evaluate_during_training:
                        logger.info("***** Student1 combined Entropy loss : %.4f *****", loss1.item())
                        logger.info("##### Student1 #####")
                        s1_best_dev, s1_best_test, _ = validation(args, model_s1, tokenizer, labels, pad_token_label_id, \
                            s1_best_dev, s1_best_test, global_step, t_total, epoch, "student1")
                        logger.info("##### Teacher1 #####")
                        t1_best_dev, t1_best_test, dev_is_updated1 = validation(args, model_t1, tokenizer, labels, pad_token_label_id, \
                            t1_best_dev, t1_best_test, global_step, t_total, epoch, "teacher1")
                        logger.info("***** Student2 combined Entropy loss : %.4f *****", loss2.item())
                        logger.info("##### Student2 #####")
                        s2_best_dev, s2_best_test, _ = validation(args, model_s2, tokenizer, labels, pad_token_label_id, \
                            s2_best_dev, s2_best_test, global_step, t_total, epoch, "student2")
                        logger.info("##### Teacher2 #####")
                        t2_best_dev, t2_best_test, dev_is_updated2 = validation(args, model_t2, tokenizer, labels, pad_token_label_id, \
                            t2_best_dev, t2_best_test, global_step, t_total, epoch, "teacher2")
                        t_model1, t_model2 = get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        logger.info("***** Epoch : %d *****", epoch)
        logger.info("##### Student1 #####")
        s1_best_dev, s1_best_test, _ = validation(args, model_s1, tokenizer, labels, pad_token_label_id, \
            s1_best_dev, s1_best_test, global_step, t_total, epoch, "student1")
        logger.info("##### Teacher1 #####")
        t1_best_dev, t1_best_test, dev_is_updated1 = validation(args, model_t1, tokenizer, labels, pad_token_label_id, \
            t1_best_dev, t1_best_test, global_step, t_total, epoch, "teacher1")
        logger.info("##### Student2 #####")
        s2_best_dev, s2_best_test, _ = validation(args, model_s2, tokenizer, labels, pad_token_label_id, \
            s2_best_dev, s2_best_test, global_step, t_total, epoch, "student2")
        logger.info("##### Teacher2 #####")
        t2_best_dev, t2_best_test, dev_is_updated2 = validation(args, model_t2, tokenizer, labels, pad_token_label_id, \
            t2_best_dev, t2_best_test, global_step, t_total, epoch, "teacher2")
        t_model1, t_model2 = get_teacher(args, model_t1, model_t2, t_model1, t_model2, dev_is_updated1, dev_is_updated2, True)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

        labels = get_labels(path=args.data_dir, dataset=args.dataset)
        num_labels = len(labels)

        if (epoch + 1) % 5 == 0:
            save_model(args, epoch+1, "student1", model_s1)
            save_model(args, epoch+1, "student2", model_s2)
            save_model(args, epoch+1, "teacher1", model_t1)
            save_model(args, epoch+1, "teacher2", model_t2)

    results = (t1_best_dev, t1_best_test, t2_best_dev, t2_best_test)
    return global_step, tr_loss/global_step, results

def main():
    args = config()
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else: 
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s -   %(message)s", "%m/%d/%Y %H:%M:%S")
    logging_fh = logging.FileHandler(os.path.join(args.output_dir, 'log.txt'))
    logging_fh.setLevel(logging.DEBUG)
    logging_fh.setFormatter(formatter)
    logger.addHandler(logging_fh)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    set_seed(args)
    labels = get_labels(args.data_dir, args.dataset)
    pad_token_label_id = CrossEntropyLoss().ignore_index

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    tokenizer = RobertaTokenizer.from_pretrained(
        args.tokenizer_name,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    if args.local_rank == 0:
        torch.distributed.barrier()
    logger.info("Training/evaluation parameters %s", args)

    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode="train")
        global_step, tr_loss, best_results = train(args, train_dataset, tokenizer, labels, pad_token_label_id)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

if __name__ == "__main__":
    main()
