from networks.supsup_model import MultitaskMaskLinear
from networks import supsup_model
import torch
from networks import simcse
from networks.my_baselines import fisher_model, hat_model
import math
import numpy as np





def run_forward(input_ids,attention_mask,labels_gen,input_ids_gen,attention_mask_gen,task,labels,my_model,self_fisher,masks=None, mask_pre=None,):

    if not my_model.training:  # must be if 'l2p' in my_model.args.baseline
        #TODO: Pool is not training, but why???
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, my_model.args.num_beams).view(-1).to(
                input_ids.device)
        )  # same as the beam search
        task = task.index_select(0, expanded_return_idx)

        if 'supsup' in my_model.args.baseline:
            if 'mtl' in my_model.args.baseline:
                supsup_model.set_model_sim(my_model.model, 'both')  # if nothing
                supsup_model.set_model_specific_task(my_model, task)  # I can easily know what task it is by looking at task
                supsup_model.set_model_share_task(my_model, 0)  # alwasy use the same, as shared knwoeldeg accorss all
            elif 'ncl' in my_model.args.baseline:
                supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
                supsup_model.set_model_specific_task(my_model, 0)  # alwasys use the same
            else:
                supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
                if 'ggg' in my_model.args.baseline:
                    supsup_model.set_model_specific_task(my_model, task)  # I can easily know what task it is
                else:
                    supsup_model.set_model_specific_task(my_model, 'None')  # we don't know the id

        if 'pool' in my_model.args.baseline:
            supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
            # inferred_subnetwork = my_model.inferred_subnetwork_eval
            inferred_subnetwork = task # when evaluataion set comes to picture, need to be careful, see "finetune_baseline.py"
            supsup_model.set_model_specific_task(my_model, inferred_subnetwork)  # in case nothing is used #ggg
            print('inferred_subnetwork_eval: ', inferred_subnetwork)


        return None,None

    # TODO: bellow for training -------------------------------------

    if 'supsup' in my_model.args.baseline:
        if 'mtl' in my_model.args.baseline:

            supsup_model.set_model_sim(my_model.model, 'both')  # if nothing
            supsup_model.set_model_specific_task(my_model, task)  # in case nothing is used
            supsup_model.set_model_share_task(my_model, 0)  # alwasy use the same, as shared knwoeldeg accorss all
        elif 'ncl' in my_model.args.baseline:
            supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
            supsup_model.set_model_specific_task(my_model, 0)  # alwasys use the same
        else:
            supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
            supsup_model.set_model_specific_task(my_model, task)  # in case nothing is used

    if 'pool' in my_model.args.baseline:
        # TODO: many wrong, why?
        inferred_subnetwork = my_model.inferred_subnetwork
        print('inferred_subnetwork: ', inferred_subnetwork)

        supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
        supsup_model.set_model_specific_task(my_model, inferred_subnetwork)  # in case nothing is used
        outputs = my_model.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask,
                                 output_hidden_states=True)
        logits = outputs.logits

        if 'entropy' in my_model.args.baseline or 'ce' in my_model.args.baseline:
            loss = outputs.loss  # TODO: consider additional regualarization for dissimilar tasks
        elif 'reconstruct' in my_model.args.baseline:
            supsup_model.set_model_sim(my_model.model, 'pool')  # if nothing
            supsup_model.set_model_pool_task(my_model, inferred_subnetwork)  # use the adapter for task t
            outputs_pool = my_model.model(input_ids=input_ids_gen, labels=labels_gen, attention_mask=attention_mask_gen,
                                          output_hidden_states=True)
            loss = outputs.loss + outputs_pool.loss  # TODO: consider additional regualarization for dissimilar tasks

    else:
        if my_model.args.is_reference:
            outputs = my_model.teacher(input_ids=input_ids, labels=labels, attention_mask=attention_mask,
                                 output_hidden_states=True)
        else:
            outputs = my_model.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask,
                                 output_hidden_states=True)

        loss = outputs.loss
        logits = outputs.logits


    if 'ewc' in my_model.args.baseline and my_model.training and self_fisher is not None:
        loss += fisher_model.ewc_loss_compute(my_model, self_fisher)


    elif ('adapter_hat' in my_model.args.baseline or 'adapter_cat' in my_model.args.baseline
            or 'adapter_bcl' in my_model.args.baseline
            or 'adapter_ctr' in my_model.args.baseline
            or 'transformer_hat' in my_model.args.baseline
            or 'adapter_classic' in my_model.args.baseline) and my_model.training and not my_model.args.is_cat: # no need for testing

        loss += hat_model.hat_loss_compute(masks, mask_pre, my_model.args)

    return loss, logits
