'''
name:    run_exp.py
author:  Alaa Maalouf

researchers:
    Alaa Malouf
    Harry Lang
    Daniela Rus
    Dan Feldman
    
usage example:
python3  src/run_exp.py --model_name_or_path textattack/distilbert-base-uncased-RTE  --task_name RTE --do_eval --data_dir $PATHTOGLUEDATA/glue_data/RTE/ --max_seq_length 128 --per_device_train_batch_size 32 --learning_rate 2e-5  --no_cuda   --overwrite_output_dir --output_dir /tmp/rte_distil  --do_train --num_train_epochs 0
'''

""" Acknowledgement: This code is based on the huggingface library with modifications to fit our goal"""

import sys
from os.path import dirname
from scipy.linalg import svd
import os
sys.path.insert(0,os.path.join(os.getcwd(), 'src','transformers-master/src'))
print(sys.path)
import factor
import dataclasses
import logging
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
import torch
import copy
import numpy as np
import math

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (HfArgumentParser,Trainer,TrainingArguments,glue_compute_metrics,glue_output_modes,glue_tasks_num_labels)

from transformers.modeling_distilbert_messi import DistilBertForSequenceClassification_ranked
from transformers.modeling_albert_messi import AlbertForSequenceClassification
from transformers.modeling_roberta_messi import RobertaForSequenceClassification_ranked
logger = logging.getLogger(__name__)

do_svd_compression = False
do_k_svd_compression = True 

fine_tune = True
check_low_rank_one_matrix=False
fine_tune_epochs=2
debug = False 
load_old_from = "." 
EM_STEPS=15
NUM_INIT_FOR_EM=15

clustering_methods=["messi",'messi_kmeans']#other options: ["split",
#ranks =  [584,517 , 450, 384 ,317,250,184,117  ]
#k_vals = [7,5,3]

ranks =  [384]
k_vals = [5]
################################
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"})
    config_name: Optional[str] = field(default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"})
    tokenizer_name: Optional[str] = field(default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"})
    cache_dir: Optional[str] = field(default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"})
    load_G: Optional[int] = field(default=None, metadata={"help": "load saved G"})
    
def countNonZeroWeights(model):
    c = 0
    for param in model.parameters():
        if param is not None: 
            c += param.nonzero().size(0)
    return c
def number_of_parameters(model,only_trainable = True):
        if only_trainable:
                total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        else :
                total_params =  sum(p.numel() for p in model.parameters())
        return total_params

def get_model_name(model_args):
    if '-' in model_args.model_name_or_path:
        if '/' in model_args.model_name_or_path:
            pretrained_model_name = model_args.model_name_or_path.split('/')[1]
            pretrained_model_name = pretrained_model_name.split('-')[0]
        else:
            pretrained_model_name = model_args.model_name_or_path.split('-')[0]
    else :
            pretrained_model_name = model_args.model_name_or_path
    return pretrained_model_name

def check_if_overwrite_previus_training_dir(training_args):
    if  os.path.exists(training_args.output_dir)  and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.")


def logging_setup(training_args):
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
                                  datefmt="%m/%d/%Y %H:%M:%S",
                                  level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                            training_args.local_rank,training_args.device,training_args.n_gpu, bool(training_args.local_rank != -1),
                            training_args.fp16,)
    logger.info("Training/evaluation parameters %s", training_args) 
    
def get_task_info(data_args):
    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))
    return num_labels, output_mode
    
def get_config_and_tokenizer(model_args,data_args,num_labels):
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,)
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,)
    
    return config, tokenizer
    
def build_model(model_args,config):
    model = AutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path,
                                                                                                   from_tf=bool(".ckpt" in model_args.model_name_or_path),
                                                                                                   config=config, cache_dir=model_args.cache_dir,)
    return model

def Get_datasets(training_args,model_args,data_args,tokenizer):
    train_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
    )
    eval_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
        if training_args.do_eval
        else None
    )
    test_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
        if training_args.do_predict
        else None
    )   
    return train_dataset,eval_dataset,test_dataset

def starting_print_and_get_shape(model,model_name,task_name, ranks, k_vals):
    print('**************************************************************************************************')
    print('**************************************************************************************************')
    print('**************************************************************************************************')
    print('**************************************************************************************************')
    embedding_martrix_shape = model.state_dict()[model_name +'.embeddings.word_embeddings.weight'].data.cpu().detach().numpy().shape
    print("->Model name:{}\n->Task name:{}\n->Compression ranks:{}\n-> k values:{}\n->Ebedding matrix shape:{}.".format(
                                                                            model_name, task_name, ranks, k_vals, embedding_martrix_shape))
    return embedding_martrix_shape
def make_dirs_if_not_exists(dirs_to_make):
    for dir_to_make in dirs_to_make:
        if not os.path.exists(dir_to_make): os.mkdir(dir_to_make)
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    #### get the model name#####
    pretrained_model_name = get_model_name(model_args)
    #### make sure to not overwrite previus training dir #####
    check_if_overwrite_previus_training_dir(training_args)
    #### Setup logging ####
    logging_setup(training_args) 
    #Set seed#set_seed(training_args.seed)#no need
    num_labels,output_mode = get_task_info(data_args)
    #### Load pretrained model and tokenizer ####
    config, tokenizer=get_config_and_tokenizer(model_args,data_args,num_labels)
    #### download model & vocab. ####
    model = build_model(model_args,config)
    #### Get datasets ####
    train_dataset, eval_dataset, test_dataset = Get_datasets(training_args,model_args,data_args,tokenizer)
    make_dirs_if_not_exists(["models","saved_GV"])
    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            if output_mode == "classification": preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression": preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)
        return compute_metrics_fn

    
    matrix_shape = starting_print_and_get_shape(model, pretrained_model_name,data_args.task_name, ranks, k_vals)

    # Initialize our Trainer
    trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,
                             eval_dataset=eval_dataset, compute_metrics=build_compute_metrics_fn(data_args.task_name),)
                             
    # Training
    if training_args.do_train and training_args.num_train_epochs:
        eval_results = eval(eval_dataset, data_args, trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, {})  
        trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None)
        trainer.save_model()
        torch.save(trainer.model,'models/full')
        # For convenience, we also re-save the tokenizer to the same directory, so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master(): tokenizer.save_pretrained(training_args.output_dir)

    acc_origin = []; acc_svd = []; acc_svd_fine_tunned=[];
    acc_k_factorization = []; acc_k_factorization_fine_tunned=[]; 
    acc_k_means = []; acc_k_means_fine_tunned=[]; 
    acc_k_chuncked = []; acc_k_chuncked_fine_tunned=[]; 
   
    
    eval_results = {}
    eval_results = eval(eval_dataset, data_args, trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)  
    
    term_of_interest = 'eval_acc' if 'eval_acc' in eval_results else ('eval_mcc' if 'eval_mcc' in eval_results
                                                                                         else ('eval_corr' if 'eval_corr' in eval_results
                                                                                         else    'eval_mnli/acc'))
    acc_origin.append(eval_results[term_of_interest])
    np.save(pretrained_model_name+'_acc_origin_' + data_args.task_name + '_' + str(training_args.num_train_epochs), acc_origin)
    
    
    if do_svd_compression:
        svd_compression_engine = svd_compression(copy.deepcopy(trainer.model), training_args, 
                                                                             train_dataset, eval_dataset, 
                                                                             build_compute_metrics_fn(data_args.task_name),
                                                                             model_name=pretrained_model_name)
        for rank in ranks:
            eval_results = {}; print('=========SVD rank: {}========='.format(rank))
            svd_compression_engine.update(rank,config,build_compute_metrics_fn(data_args.task_name),training_args,eval_dataset,train_dataset)
            if check_low_rank_one_matrix:
                eval_results = eval(eval_dataset, data_args, svd_compression_engine.trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
            
            eval_results = eval(eval_dataset, data_args, svd_compression_engine.factorized_trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
            acc_svd.append(eval_results[term_of_interest])
            np.save(pretrained_model_name + '_acc_svd_' + data_args.task_name +"_"+ str(training_args.num_train_epochs), acc_svd)
 
            if fine_tune:
                training_args.num_train_epochs = fine_tune_epochs 
                svd_compression_engine.factorized_trainer.train()
                eval_results = eval(eval_dataset, data_args, svd_compression_engine.factorized_trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
                acc_svd_fine_tunned.append(eval_results[term_of_interest])
                np.save(pretrained_model_name + 'fine_tuned_acc_svd_' + data_args.task_name + '_' + str(
                                            training_args.num_train_epochs), acc_svd_fine_tunned)
       
    if do_k_svd_compression:
        for rank in ranks:
            for k in k_vals:
                for method in clustering_methods:
                    k_svd_compression_engine = k_svd_compression(copy.deepcopy(trainer.model), training_args, train_dataset, eval_dataset, 
                                                                                            build_compute_metrics_fn(data_args.task_name), data_args, model_args, 
                                                                                            model_name=pretrained_model_name, k = k , rank=rank,config=config)
                    rate = 1 - ((matrix_shape[1]*k*rank + matrix_shape[0]*rank)/(matrix_shape[1]*matrix_shape[0]))
                    eval_results = {}; print('=========K -SVD rank: {}, k:{}, rate:{},clustering_method={} ========='.format(rank,k,rate,method))
                    
                    k_svd_compression_engine.update(k, rank,data_args,clustering_method=method)
                                                                                                   
                    if check_low_rank_one_matrix:
                        eval_results = eval(eval_dataset, data_args, k_svd_compression_engine.trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
                   
                    eval_results = eval(eval_dataset, data_args, k_svd_compression_engine.factorized_trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
                    if method=='messi_kmeans':
                        acc_k_means.append(eval_results[term_of_interest])
                        np.save(pretrained_model_name + '_acc_kmeans_SVDs_' + data_args.task_name + '_' + str(training_args.num_train_epochs), acc_k_means)

                    elif method =='split':
                        acc_k_chuncked.append(eval_results[term_of_interest])
                        np.save(pretrained_model_name + '_acc_chuncked_SVDs_' + data_args.task_name + '_' + str(training_args.num_train_epochs), acc_k_chuncked)

                    else : 
                            acc_k_factorization.append(eval_results[term_of_interest])
                            np.save(pretrained_model_name + '_acc_Multiple_SVDs_' + data_args.task_name + '_' + str(training_args.num_train_epochs), acc_k_factorization)
                     
                    if fine_tune:
                        training_args.num_train_epochs = fine_tune_epochs
                        k_svd_compression_engine.factorized_trainer.train(zero_idxes=k_svd_compression_engine.zero_idxes, rank=(k-1)*rank,
                                                                                                        model_name=pretrained_model_name,
                                                                                                        zero_idxes_in =None,rank_in =None )
                        eval_results = eval(eval_dataset, data_args, k_svd_compression_engine.factorized_trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results)
                        if  method=='messi_kmeans':
                            acc_k_means_fine_tunned.append(eval_results[term_of_interest])
                            np.save(pretrained_model_name + 'fine_tuned_acc_kmeans_SVDs_' + data_args.task_name + '_' + str(
                                                    training_args.num_train_epochs), acc_k_means_fine_tunned)
                        elif method =='split':
                            acc_k_chuncked_fine_tunned.append(eval_results[term_of_interest])
                            np.save(pretrained_model_name + 'fine_tuned_acc_chuncked_SVDs_' + data_args.task_name + '_' + str(
                                                    training_args.num_train_epochs), acc_k_chuncked_fine_tunned)
                        
                        else :
                            acc_k_factorization_fine_tunned.append(eval_results[term_of_interest])
                            np.save(pretrained_model_name + 'fine_tuned_acc_Multiple_SVDs_' + data_args.task_name + '_' + str(
                                                    training_args.num_train_epochs), acc_k_factorization_fine_tunned)
                        print(countNonZeroWeights(k_svd_compression_engine.factorized_trainer.model), countNonZeroWeights(k_svd_compression_engine.trainer.model))
                        print(countNonZeroWeights(k_svd_compression_engine.factorized_trainer.model) - countNonZeroWeights(k_svd_compression_engine.factorized_trainer.model))
                        if debug:
                            #np.save("alaa_idx", k_svd_compression_engine.zero_idxes)
                            x = k_svd_compression_engine.factorized_trainer.model.state_dict()[pretrained_model_name +'.embeddings.nn.bias'].data.cpu().detach().numpy()
                            np.save("{}_{}_{}_bias".format(k,rank,fine_tune_epochs), x)
                            x = k_svd_compression_engine.factorized_trainer.model.state_dict()[pretrained_model_name +'.embeddings.word_embeddings.weight'].data.cpu().detach().numpy()
                            np.save("{}_{}_{}_embedd".format(k,rank,fine_tune_epochs), x)
                            x = k_svd_compression_engine.factorized_trainer.model.state_dict()[pretrained_model_name +'.embeddings.nn.weight'].data.cpu().detach().numpy()
                            np.save("{}_{}_{}_nn".format(k,rank,fine_tune_epochs), x)
                            torch.save(k_svd_compression_engine.factorized_trainer.model,"models/{}_{}_{}_{}_model".format(k,rank,fine_tune_epochs,eval_results[term_of_interest]) )
                 

def get_embedding_layer_name(model_name):
    if 'xlnet' in model_name: layer_name =  'transformer.word_embedding.weight'
    else: layer_name = model_name+'.embeddings.word_embeddings.weight'
    return layer_name

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

class svd_compression:
    def __init__(self, model, training_args, train_dataset, eval_dataset, compute_metrics, model_name):
        self.trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,
                                              eval_dataset=eval_dataset, compute_metrics=compute_metrics,)
        self.model_name = model_name
        self.embedding_layer_name = get_embedding_layer_name(model_name)
        self.embedding_layer_weights= self.trainer.model.state_dict()[self.embedding_layer_name].data.cpu().detach().numpy()
        U, D, V = svd(self.embedding_layer_weights)
        self.U = U; self.D = D; self.V = V
        

        
    def fastdot(self, A,B,groups):
        jump = int(A.shape[0]/groups);res = np.empty((A.shape[0], B.shape[1]))
        for i in range(groups):
            start = i*jump; end = min((i+1)*jump, A.shape[0])
            res[start:end,:] = np.dot(A[start:end,:], B)
        return res
    def update(self, rank,config,compute_metrics,training_args,eval_dataset,train_dataset):
        if 'distilbert' in self.model_name:
            self.model_factorized = DistilBertForSequenceClassification_ranked(config=config, rank=rank,) 
        elif 'roberta' in self.model_name:
            self.model_factorized = RobertaForSequenceClassification_ranked(config=config, rank=rank,) 
        for i in  self.trainer.model.state_dict():
            if i in [self.embedding_layer_name]: continue
            else: self.model_factorized.state_dict()[i].data.copy_(torch.from_numpy(np.array(self.trainer.model.state_dict()[i].data.detach().cpu())))
        
        self.factorized_trainer = Trainer(model=self.model_factorized, args=training_args, train_dataset=train_dataset,
                                                       eval_dataset=eval_dataset, compute_metrics=compute_metrics,)
                                                       
        smat_u = np.zeros((self.U.shape[0], rank))
        smat_u[:rank, :rank] = np.diag(self.D[:rank])

        smat_v = np.zeros((rank, self.V.shape[0]))
        smat_v[:rank, :rank] = np.diag(self.D[:rank])
        
        u_new = torch.from_numpy(self.fastdot(self.U, np.sqrt(smat_u),groups = 8)).float() 
        v_new = torch.from_numpy(np.dot(np.sqrt(smat_v), self.V)).float() 
        
        low_rank_full_layer = np.dot(u_new, v_new)  
        self.trainer.model.state_dict()[self.embedding_layer_name].data.copy_(
                                                           torch.from_numpy(np.array(low_rank_full_layer)).cpu())
        
        zero_array=np.zeros(v_new.shape[1])
        self.factorized_trainer.model.state_dict()[self.embedding_layer_name].data.copy_(torch.from_numpy(np.array(u_new)).cpu())
        self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.weight'].data.copy_( torch.from_numpy(np.array(v_new.T)).cpu())
        self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.bias'].data.copy_(torch.from_numpy(np.array(zero_array)).cpu())
        self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.bias'].requires_grad = False
class k_svd_compression:
    def __init__(self, model, training_args, train_dataset, eval_dataset, compute_metrics, 
                        data_args, model_args,model_name, k,rank,config):
        self.trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,
                                                 eval_dataset=eval_dataset, compute_metrics=compute_metrics,)
        pytorch_total_params = sum(p.numel() for p in self.trainer.model.parameters()) 
        self.embedding_layer_name = get_embedding_layer_name(model_name)
        self.model_name = model_name

        
        self.embedding_layer_weights= self.trainer.model.state_dict()[self.embedding_layer_name].data.cpu().detach().numpy()
        if 'distilbert' in self.model_name:
            self.model_factorized = DistilBertForSequenceClassification_ranked(config=config, rank=k*rank)  
            
        elif 'roberta' in self.model_name:
            self.model_factorized = RobertaForSequenceClassification_ranked(config=config, rank=k*rank) 
        else :
            self.model_factorized = AlbertForSequenceClassification(config=config,k_embed=k,j_embed = rank,
                                                                                                                   k_hid = None,j_hid=None )
        for i in model.state_dict():
            if i in [self.embedding_layer_name]: continue
            else: self.model_factorized.state_dict()[i].data.copy_(torch.from_numpy(np.array(self.trainer.model.state_dict()[i].data.detach().cpu())))
        self.factorized_trainer = Trainer(model=self.model_factorized, args=training_args, train_dataset=train_dataset,
                                                        eval_dataset=eval_dataset, compute_metrics=compute_metrics, )
    def getzeros(self, U_ranked):
        zeros = np.empty(U_ranked.shape[0], (k-1)*rank)
        for row in U_ranked:
            zeros.append(list(np.argwhere(row == 0)));
        return zeros
    def update(self, k,rank, data_args,clustering_method='messi'):
        
        load_dir = False
        if load_old_from :
            for dir_check in load_old_from:
                if os.path.exists(dir_check + "/saved_GV/{}_{}_{}_{}_zero_idxes.npy".format(self.model_name,data_args.task_name,k,rank)):
                    load_dir = dir_check
                    break
                    
        if load_dir:
            U_ranked =np.load(load_dir + "/saved_GV/{}_{}_{}_{}_U.npy".format(self.model_name,data_args.task_name,k,rank))
            V_ranked =np.load(load_dir + "/saved_GV/{}_{}_{}_{}_G.npy".format(self.model_name,data_args.task_name,k,rank))
            self.zero_idxes = np.load(load_dir +  "/saved_GV/{}_{}_{}_{}_zero_idxes.npy".format(self.model_name,data_args.task_name,k,rank))
  
        else :
            if clustering_method=='messi_kmeans':
                    U_ranked,V_ranked, self.zero_idxes = factor.kmeans_factorization(self.embedding_layer_weights, j=rank, k=k,steps = EM_STEPS,NUM_INIT_FOR_EM=NUM_INIT_FOR_EM)
            elif clustering_method=='split':
                    U_ranked,V_ranked, self.zero_idxes = factor.chuncked_factorization(self.embedding_layer_weights, j=rank, k=k,randomly = True)
            else :#MESSI
                    l_norm = 2
                    partition, listU, listV = factor.raw(self.embedding_layer_weights, j=rank, k=k,steps = EM_STEPS,NUM_INIT_FOR_EM=NUM_INIT_FOR_EM,l_norm=l_norm)
                    self.zero_idxes = factor.getZeros(partition, rank, k)
                    U_ranked,V_ranked = factor.stitch(partition, listU, listV)
            
      
            U_ranked = torch.from_numpy(U_ranked).float()  # .to(device)
            V_ranked = torch.from_numpy(V_ranked).float()  # .to(device)
    
            low_rank_full_layer = np.dot(U_ranked, V_ranked)
            self.trainer.model.state_dict()[self.embedding_layer_name].data.copy_(torch.from_numpy(np.array(low_rank_full_layer)).cpu())
        
            zero_array=np.zeros(V_ranked.shape[1])
            self.factorized_trainer.model.state_dict()[self.embedding_layer_name].data.copy_(torch.from_numpy(np.array(U_ranked)).cpu())
            self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.weight'].data.copy_(torch.from_numpy(np.array(V_ranked.T)).cpu())
            self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.bias'].data.copy_(torch.from_numpy(np.array(zero_array)).cpu())
            self.factorized_trainer.model.state_dict()[self.model_name+'.embeddings.nn.bias'].requires_grad = False
      
         
                        
def eval(eval_dataset, data_args, trainer, tokenizer, model_args, build_compute_metrics_fn, training_args, eval_results):
    logger.info("*** Evaluate ***")

    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_datasets = [eval_dataset]
    if data_args.task_name == "mnli":
        mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
        eval_datasets.append(
            GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
        )

    for eval_dataset in eval_datasets:
        trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
        eval_result = trainer.evaluate(eval_dataset=eval_dataset)

        output_eval_file = os.path.join(
            training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
        )
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
                for key, value in eval_result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

        eval_results.update(eval_result)

    return eval_results


if __name__ == "__main__":
    main()
