import logging
import math

import numpy as np
import os
import torch
from tqdm.auto import tqdm
from transformers import (
    MODEL_MAPPING,
    AdamW,
    get_scheduler,
    Adafactor
)

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
from tensorboardX import SummaryWriter
from utils import utils
import shutil
# from approaches.my_optimizer import MyAdamW
import torch.distributed as dist
from datasets import load_dataset, load_metric
import nltk
from networks import fisher_model, supsup_model, oneshot_model
from networks.buffer import Buffer as Buffer
from sklearn.metrics import f1_score
import evaluate

class Appr(object):

    def __init__(self, config, args):
        super().__init__()
        self.args = args
        self.config = config
        self.tanh = torch.nn.Tanh()
        self.sigmoid = torch.nn.Sigmoid()

        utils.lookfor_baseline_variable(self,args)


    def postprocess_text(self, preds, labels):

        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels

    def get_warmup_steps(self, num_training_steps: int):
        """
        Get number of steps used for a linear warmup.
        """
        warmup_steps = self.args.num_warmup_steps if self.args.num_warmup_steps > 0 else math.ceil(num_training_steps * self.args.warmup_ratio)
        return warmup_steps

    def train(self, model, train_loader, train_dataset, dev_loader, test_loaders, train_pool_loader, accelerator):


        # Optimizer
        # Split weights in two groups, one with weight decay and the other not.

        no_decay = ["bias", "LayerNorm.weight"]
        special_lr = ['prompt', 'adapter', 'classifier']
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if
                           not any(nd in n for nd in no_decay) and p.requires_grad and not any(
                               nd in n for nd in special_lr)],
                "weight_decay": self.args.weight_decay,
                "lr": self.args.learning_rate
            },
            {
                "params": [p for n, p in model.named_parameters() if
                           any(nd in n for nd in no_decay) and p.requires_grad and not any(
                               nd in n for nd in special_lr)],
                "weight_decay": 0.0,
                "lr": self.args.learning_rate
            },
            {
                "params": [p for n, p in model.named_parameters() if p.requires_grad and 'prompt' in n],
                "weight_decay": 0.0,
                "lr": self.args.prompt_lr,  # must use a higher lr
            },
            {
                "params": [p for n, p in model.named_parameters() if
                           p.requires_grad and 'adapter' in n],
                "weight_decay": 0.0,
                "lr": self.args.adapter_lr,  # must use a higher lr
            },
            {
                "params": [p for n, p in model.named_parameters() if p.requires_grad and 'classifier' in n],
                "weight_decay": 0.0,
                "lr": self.args.classifier_lr,  # must use a higher lr
            }
        ]


        optimizer = AdamW(optimizer_grouped_parameters)

        # Scheduler and math around the number of training steps.
        num_update_steps_per_epoch = math.ceil(len(train_loader) / self.args.gradient_accumulation_steps)
        if self.args.max_train_steps is None:
            self.args.max_train_steps = self.args.num_train_epochs * num_update_steps_per_epoch
        else:
            self.args.num_train_epochs = math.ceil(self.args.max_train_steps / num_update_steps_per_epoch)

        if self.args.warmup_ratio:
            self.args.num_warmup_steps=self.get_warmup_steps(self.args.max_train_steps)
            print('self.args.num_warmup_steps: ',self.args.num_warmup_steps)

        lr_scheduler = get_scheduler(
            name=self.args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=self.args.num_warmup_steps,
            num_training_steps=self.args.max_train_steps,
        )


        # Prepare everything with our `accelerator`.
        model, optimizer, train_loader, train_pool_loader, dev_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader,train_pool_loader,dev_loader,lr_scheduler)


        # before training ***********************************************************************************************

        #TODO: consider separate the baselines

        #need to adapt everything to MTL. There is no separate head, so no changes are neeeded here


        final_imp_path = self.args.output_dir + '../final_imp_' + str(self.args.seed)
        if os.path.exists(final_imp_path):
            eval_imp = np.loadtxt(final_imp_path)
        else:
            eval_imp = np.zeros((self.args.ntasks, 2), dtype=np.float64)


        if self.args.ft_task > 0:

            pre_subnetwork_imps = []
            for saved_output_dir_id,saved_output_dir in enumerate(self.args.saved_output_dir): # used the imp at that time
                pre_subnetwork_imp = torch.Tensor(np.load(os.path.join(saved_output_dir, "subnetwork_imp.npy"))).cuda()
                pre_subnetwork_imps.append(pre_subnetwork_imp[saved_output_dir_id])
            pre_subnetwork_imp = torch.stack(pre_subnetwork_imps)
            # pre_subnetwork_imp, _ = pre_subnetwork_imp.max(0)  # take a max, so that block all imp nurons for all previous tasks;
            if accelerator.is_main_process:
                print('saved_subnetwork_imp: ', pre_subnetwork_imp)
                np.savetxt(self.args.output_dir + "saved_subnetwork_imp_text",pre_subnetwork_imp.detach().cpu().numpy()[:self.args.ft_task], '%.4f', delimiter='\t')

            accelerator.wait_for_everyone()
            oneshot_model.compute_subnetowrk_imp(model, train_pool_loader, accelerator, self.args, path='pre')

            accelerator.wait_for_everyone()

            cur_subnetwork_imp = torch.Tensor(np.load(os.path.join(self.args.output_dir, "subnetwork_imp_pre.npy"))).cuda()
            cur_subnetwork_imp = cur_subnetwork_imp[:self.args.ft_task]
            if accelerator.is_main_process:
                print('cur_subnetwork_imp: ', cur_subnetwork_imp)


            max_cur_ids = (cur_subnetwork_imp == torch.max(cur_subnetwork_imp)).nonzero()

            if accelerator.is_main_process:
                print('max_cur_ids: ', max_cur_ids)

            # max_cur_id = cur_subnetwork_imp.argmax() # we do not use argmax because there could multiple maximal

            # if self.args.task_name  in self.args.asc_datasets and  'SemEval' not in self.args.task_name and \
            #         self.args.all_tasks[max_cur_id.item()] in self.args.asc_datasets and 'SemEval' not in self.args.all_tasks[max_cur_id.item()]:
            #     #TODO: loosing the standard
            #     valid_imp = True
            # else:

            if not self.args.disable_impt_comparison:
                for max_cur_id in max_cur_ids:
                    valid_imp = (cur_subnetwork_imp[max_cur_id]-pre_subnetwork_imp[max_cur_id]) > self.args.eps
                    if valid_imp:
                        break
            else:
                max_cur_id = max_cur_ids[0]
                valid_imp = True
                print('no compare')

            if accelerator.is_main_process:
                print('cur_subnetwork_imp[max_cur_id]: ', cur_subnetwork_imp[max_cur_id])
                print('pre_subnetwork_imp[max_cur_id]: ', pre_subnetwork_imp[max_cur_id])
                print('diff: ', (cur_subnetwork_imp[max_cur_id]-pre_subnetwork_imp[max_cur_id]))
                print('eps: ',  self.args.eps)

                print('max_cur_id: ', max_cur_id)

            if valid_imp: # yes, can share
                accelerator.unwrap_model(model).inferred_subnetwork = max_cur_id.long().cpu().item()
                eval_imp[self.args.ft_task][0] = max_cur_id.long().cpu().item()

            else:
                accelerator.unwrap_model(model).inferred_subnetwork = self.args.ft_task
                eval_imp[self.args.ft_task][0] = self.args.ft_task

            np.savetxt(final_imp_path, eval_imp, '%.4f', delimiter='\t')

        else:
            accelerator.unwrap_model(model).inferred_subnetwork = self.args.ft_task

        # before training ***********************************************************************************************

        metric = utils.load_my_metric(self.args)

        if not self.args.eval_checkpoint:


            # Train!
            total_batch_size = self.args.per_device_train_batch_size * accelerator.num_processes * self.args.gradient_accumulation_steps

            logger.info("***** Running training *****")
            logger.info(f"  Num examples = {len(train_dataset)}")
            logger.info(f"  Num Epochs = {self.args.num_train_epochs}, Pre-trained Model = {self.args.model_name_or_path}")
            logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size}, batch size per device Pool= {self.args.per_device_train_pool_batch_size}")
            logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
            logger.info(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps}, Lamb = {self.args.lamb}")
            logger.info(f"  Learning Rate = {self.args.learning_rate}, Prompt Learning Rate = {self.args.prompt_lr} Adapter Learning Rate = {self.args.adapter_lr}, Classifier Learning Rate = {self.args.classifier_lr}, Warmup Num = {self.args.num_warmup_steps}")
            logger.info(f"  Total optimization steps = {self.args.max_train_steps}, NTokens={self.args.n_tokens}")
            logger.info(f"  Seq ID = {self.args.idrandom}, Task id = {self.args.ft_task}, Task Name = {self.args.task_name}, Num task = {self.args.ntasks}")

            # Only show the progress bar once on each machine.
            progress_bar = tqdm(range(self.args.max_train_steps), disable=not accelerator.is_local_main_process)
            completed_steps = 0
            starting_epoch = 0


            best_model = utils.get_model(model)
            best_main = -np.inf
            patience = self.args.patient
            global_step = 0  # This will be used by CLMOE if we choose 'auto_encoder' as the route type.


            if accelerator.is_main_process:
                tensorboard_file = os.path.join(self.args.output_dir, str(self.args.task_name) + '_log')
                print('tensorboard_file: ', tensorboard_file)

                if os.path.isdir(tensorboard_file):
                    shutil.rmtree(tensorboard_file)
                writer = utils.setup_writer(tensorboard_file)


                #TODO: remove old output
                # delete previous model
                for saved_output_dir in self.args.saved_output_dir[:-3]:  # we need -2 so that we can load model
                    for item in os.listdir(saved_output_dir):
                        print('item: ', item)
                        if item.endswith(".bin") or item.endswith(".json"):
                            os.remove(saved_output_dir + item)

            if not self.args.eval_only:
                try:
                    for epoch in range(starting_epoch, self.args.num_train_epochs):
                        model.train()
                        if self.args.with_tracking:
                            total_loss = 0
                        for step, batch in enumerate(train_loader):

                            outputs = model(batch)

                            loss = outputs.loss
                            # We keep track of the loss at each epoch
                            if self.args.with_tracking:
                                total_loss += loss.detach().float()
                            loss = loss / self.args.gradient_accumulation_steps
                            accelerator.backward(loss)

                            if accelerator.is_main_process and epoch < 1 and step < 1:
                                for n,p in model.named_parameters():
                                    if p.grad is not None:
                                        print('n,p： ',n,p.size())
                                    # else:
                                    #     print('n--,p： ',n,p.size())


                            if step % self.args.gradient_accumulation_steps == 0 or step == len(train_loader) - 1:
                                optimizer.step()
                                global_step += 1
                                lr_scheduler.step()
                                optimizer.zero_grad()
                                progress_bar.update(1)
                                completed_steps += 1
                                progress_bar.set_description(
                                    'Train Iter (Epoch=%3d,loss=%5.3f)' % ((epoch, loss.item())))  # show the loss, mean while

                                if accelerator.is_main_process:
                                    utils.log_loss(writer, scalar_value=loss.item(), global_step=global_step)
                                    if outputs.sum_loss is not None: utils.log_loss(writer, loss_name=' summerization loss', scalar_value=outputs.sum_loss.item(), global_step=global_step)
                                    if outputs.contrast_loss is not None: utils.log_loss(writer, loss_name=' contrast loss', scalar_value=outputs.contrast_loss.item(), global_step=global_step)

                        #
                        #     break
                        # break
                        if completed_steps >= self.args.max_train_steps:
                            break

                        # test_loader = test_loaders[self.args.ft_task]
                        # test_loader = accelerator.prepare(test_loader)
                        # self.eval(model, test_loader, metric,accelerator)

                        if (self.args.task_name in self.args.generation or self.args.task_name in self.args.ner_datasets) and epoch % 2 == 0: # no need to test for everyone
                            supsup_model.model_weight_copy(model,accelerator.unwrap_model(model).inferred_subnetwork,self.args.ft_task)

                            results = self.eval(model, dev_loader, metric, accelerator,eval_t=self.args.ft_task)
                            dev_main = utils.lookfor_main_metric(results,self.args)

                            if epoch < self.args.num_train_epochs and best_main < dev_main:  # data is too small, we need to at least run some epoch
                                best_main = dev_main
                                best_model = utils.get_model(model)
                                if accelerator.is_main_process: print(
                                    "*Epoch {}, dev rouge1 = {:.4f}".format(epoch, dev_main))
                                patience = self.args.patient  # reset
                            else:
                                if accelerator.is_main_process: print(
                                    "Epoch {}, dev rouge1 = {:.4f}".format(epoch, dev_main))
                                patience -= 1
                                if patience <= 0: break

                    if (self.args.task_name in self.args.generation or self.args.task_name in self.args.ner_datasets):
                        utils.set_model_(model, best_model)


                except KeyboardInterrupt:  # even if contro-C, I still want to save model
                    return

                # after training ***********************************************************************************************

        supsup_model.model_weight_copy(model,accelerator.unwrap_model(model).inferred_subnetwork,self.args.ft_task)
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:  # onlyh discriminator is saved. I don't need anything about geenrator
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.model.save_pretrained(self.args.output_dir)
            self.args.tokenizer.save_pretrained(self.args.output_dir)
            if 'adapter' in self.args.baseline:
                unwrapped_model.model.save_adapter(self.args.output_dir, 'adapter')
            # if 'generative' in self.args.baseline:
            #     unwrapped_model.teacher.save_pretrained(self.args.output_dir_gen)

        oneshot_model.compute_subnetowrk_imp(model, train_pool_loader,accelerator,self.args)


        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        for eval_t in range(self.args.ft_task + 1):
            self.args.ori_task_name = self.args.task_name

            if 'mix' in self.args.baseline:
                self.args.task_name = self.args.all_tasks[eval_t] #self.args.task_name has chaned
                self.args = utils.update_hyparameter_for_mix_pool(self.args) # must be pool
                metric = utils.load_my_metric(self.args)
                unwrapped_model.model.args = self.args # updated
                print('self.args.task_name_eval: ',self.args.task_name)

            else:
                self.args.task_name = self.args.all_tasks[eval_t] #self.args.task_name has chaned
                metric = utils.load_my_metric(self.args)
                unwrapped_model.model.args = self.args # updated
                print('self.args.task_name_eval: ',self.args.task_name)


            if ('one' in self.args.baseline) and eval_t != self.args.ft_task:
                continue  # for one, I only care about forward results

            if self.args.only_eval_current_task and eval_t != self.args.ft_task:
                continue

            pred_file = os.path.join(self.args.output_dir.replace(self.args.ori_task_name,self.args.all_tasks[eval_t]), self.args.all_tasks[eval_t]+str(self.args.ft_task) + '_pred')
            target_file = os.path.join(self.args.output_dir, self.args.all_tasks[eval_t] + '_target')

            os.makedirs(self.args.output_dir.replace(self.args.ori_task_name,self.args.all_tasks[eval_t]), exist_ok=True)
            if os.path.exists(pred_file) and accelerator.is_main_process:
                os.remove(pred_file)
            if os.path.exists(target_file) and accelerator.is_main_process:
                os.remove(target_file)

            accelerator.wait_for_everyone()


            test_loader = test_loaders[eval_t]
            test_loader = accelerator.prepare(test_loader)
            results = self.eval(model, test_loader, metric, accelerator, eval_t, pred_file, target_file)
            # micro_f1, macro_f1, accuracy, total_loss / total_num

            #TODO: separate bleu and F1 for different datasets

            if accelerator.is_main_process:
                utils.write_result(results,eval_t,self.args)

        return

        # after training ***********************************************************************************************



    def eval(self, model, dataloader, metric, accelerator, eval_t=None, pred_file=None, target_file=None):
        model.eval()
        if self.args.val_max_target_length is None:
            self.args.val_max_target_length = self.args.max_target_length

        gen_kwargs = {
            "max_length": self.args.val_max_target_length if self.args is not None else self.config.max_length,
            "num_beams": self.args.num_beams,
            "min_length": self.args.val_min_target_length,
            "no_repeat_ngram_size": self.args.no_repeat_ngram_size,
        }

        samples_seen = 0

        progress_bar = tqdm(range(len(dataloader)), disable=not accelerator.is_local_main_process)

        label_list = []
        prediction_list = []
        total_loss = 0
        total_num = 0
        ppl_sigmoid = 0
        for step, batch in enumerate(dataloader):
            with torch.no_grad():

                if  self.args.task_name in self.args.asc_datasets or self.args.task_name in self.args.five_large_datasets:
                    outputs = model(batch)
                    real_b = batch["input_ids"].size(0)
                    loss = outputs.loss
                    outp = outputs.logits

                    if 'mtl' in self.args.baseline or 'comb' in self.args.baseline:
                        pred = []
                        for out in outp:
                            pred.append(out.max(1)[1])  # out has different size
                        pred = torch.stack(pred).squeeze()
                    else:
                        pred = outp.max(1)[1]

                    predictions = accelerator.gather(pred)
                    references = accelerator.gather(batch['cls_labels'])

                    if accelerator.is_main_process and 'ner' not in self.args.sequence_file:
                        print('predictions: ',predictions.cpu().numpy().tolist())
                        print('references: ',references.cpu().numpy().tolist())

                    total_loss += loss.data.cpu().numpy().item() * real_b
                    total_num += real_b
                    label_list += references.cpu().numpy().tolist()
                    prediction_list += predictions.cpu().numpy().tolist()

                    progress_bar.update(1)
                    # break


                    if 'generative' in self.args.baseline:
                        ppl_sigmoid += self.sigmoid(outputs.ppl)


                elif self.args.task_name in self.args.ner_datasets:
                    outputs = model(batch)
                    predictions = outputs.logits.argmax(dim=-1)
                    labels = batch["cls_labels"]
                    if not self.args.pad_to_max_length:  # necessary to pad predictions and labels for being gathered
                        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
                        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
                    predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
                    # If we are in a multiprocess environment, the last batch has duplicates
                    if accelerator.num_processes > 1:
                        if step == len(dataloader) - 1:
                            predictions_gathered = predictions_gathered[: len(dataloader.dataset) - samples_seen]
                            labels_gathered = labels_gathered[: len(dataloader.dataset) - samples_seen]
                        else:
                            samples_seen += labels_gathered.shape[0]

                    # unwrapped_model = accelerator.unwrap_model(model)
                    preds, refs = self.get_labels(predictions_gathered, labels_gathered,eval_t)

                    metric.add_batch(
                        predictions=preds,
                        references=refs,
                    )  # predictions and preferences are expected to be a nested list of labels, not label_ids


                else: # generation

                    model(batch)

                    generated_tokens = accelerator.unwrap_model(model).model.generate(
                        batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        **gen_kwargs,
                    )

                    generated_tokens = accelerator.pad_across_processes(
                        generated_tokens, dim=1, pad_index=self.args.tokenizer.pad_token_id
                    )
                    labels = batch["labels"]

                    if not self.args.pad_to_max_length:
                        # If we did not pad to max length, we need to pad the labels too
                        labels = accelerator.pad_across_processes(batch["labels"], dim=1,
                                                                  pad_index=self.args.tokenizer.pad_token_id)

                    generated_tokens, labels = accelerator.gather((generated_tokens, labels))  # gather is a must
                    generated_tokens = generated_tokens.cpu().numpy()
                    labels = labels.cpu().numpy()

                    if self.args.ignore_pad_token_for_loss:
                        # Replace -100 in the labels as we can't decode them.
                        labels = np.where(labels != -100, labels, self.args.tokenizer.pad_token_id)
                    if isinstance(generated_tokens, tuple):
                        generated_tokens = generated_tokens[0]
                    decoded_preds = self.args.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
                    decoded_labels = self.args.tokenizer.batch_decode(labels, skip_special_tokens=True)



                    decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)


                    # If we are in a multiprocess environment, the last batch has duplicates
                    if accelerator.num_processes > 1:
                        if step == len(dataloader) - 1:
                            decoded_preds = decoded_preds[: len(dataloader.dataset) - samples_seen]
                            decoded_labels = decoded_labels[: len(dataloader.dataset) - samples_seen]
                        else:
                            samples_seen += len(decoded_labels)



                    if accelerator.is_main_process and pred_file is not None and target_file is not None:
                        with open(pred_file, 'a') as f_pred_file, open(target_file, 'a') as f_target_file:
                            for decoded_pred in decoded_preds:
                                f_pred_file.writelines(decoded_pred.replace('\n',' ') + '\n')

                            for decoded_label in decoded_labels:
                                f_target_file.writelines(decoded_label.replace('\n',' ') + '\n')

                    if self.args.task_name in self.args.dialogue_datasets:
                        decoded_labels = [[label] for label in decoded_labels]


                    metric.add_batch(
                        predictions=decoded_preds,
                        references=decoded_labels,
                    )

                    progress_bar.update(1)
                    progress_bar.set_description('ROUGE Computation')  # show the loss, mean while

                # break

        if  self.args.task_name in self.args.asc_datasets or self.args.task_name in self.args.five_large_datasets:
            micro_f1 = f1_score(label_list, prediction_list, average='micro')
            macro_f1 = f1_score(label_list, prediction_list, average='macro')
            accuracy = sum([float(label_list[i] == prediction_list[i]) for i in range(len(label_list))]) * 1.0 / len(
                prediction_list)

            if accelerator.is_local_main_process:
                print('macro_f1: ', macro_f1)
                if 'generative' in self.args.baseline and eval_t is not None:
                    np.savetxt(os.path.join(self.args.output_dir + '/../', 'ppl_sigmoid' + str(eval_t)),
                               ppl_sigmoid.cpu().numpy(), '%.4f', delimiter='\t')

            results = {'micro_f1':micro_f1, 'macro_f1':macro_f1, 'accuracy':accuracy,'loss':total_loss / total_num}


            return results

        elif self.args.task_name in self.args.dialogue_datasets:
            result = metric.compute()
            logger.info(result)
            return result

        elif self.args.task_name in self.args.ner_datasets:
            eval_metric = self.compute_metrics(metric)
            print('eval_metric: ',eval_metric)
            return eval_metric
            #F1 is as important as macro F1, so that we can compare macro f1 directly
            # #TODO: different from mixed and individual


        else:

            result = metric.compute(use_stemmer=True)
            # Extract a few results from ROUGE
            result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
            result = {k: round(v, 4) for k, v in result.items()}
            logger.info(result)

            return result


    def get_labels(self,predictions, references,eval_t): #TODO: need to change to "task", if we do MTL
        # Transform predictions and references tensos to numpy arrays
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()

        # Remove ignored index (special tokens)
        true_predictions = [
            [self.label_list_dict[self.args.task_name][p] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        true_labels = [
            [self.label_list_dict[self.args.task_name][l] for (p, l) in zip(pred, gold_label) if l != -100]
            for pred, gold_label in zip(y_pred, y_true)
        ]
        return true_predictions, true_labels

    def compute_metrics(self,metric):
        results = metric.compute()
        if self.args.return_entity_level_metrics:
            # Unpack nested dictionaries
            final_results = {}
            for key, value in results.items():
                if isinstance(value, dict):
                    for n, v in value.items():
                        final_results[f"{key}_{n}"] = v
                else:
                    final_results[key] = value
            return final_results
        else:
            return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }
