from networks.supsup_model import MultitaskMaskLinear
from networks import supsup_model
import torch
from networks import simcse
import torch.autograd as autograd
from networks.my_baselines import fisher_model, hat_model
import math
import numpy as np


def run_forward(input_ids,attention_mask,labels_gen,input_ids_gen,attention_mask_gen,task,cls_labels,my_model,self_fisher,masks=None, mask_pre=None):



        if 'supsup' in my_model.args.baseline:
            if 'mtl' in my_model.args.baseline: # these are only useful for supsup
                supsup_model.set_model_sim(my_model.model, 'both')  # if nothing
                supsup_model.set_model_specific_task(my_model, task)  # in case nothing is used
                supsup_model.set_model_share_task(my_model, 0)  # alwasy use the same, as shared knwoeldeg accorss all
            elif 'ncl' in my_model.args.baseline:
                supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
                supsup_model.set_model_specific_task(my_model, 0)  # alwasys use the same

            else:
                supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing

                if 'forward' in my_model.args.baseline:
                    task_dup = task.repeat(2)
                    supsup_model.set_model_specific_task(my_model,task_dup)  # in case nothing is used
                else:
                    supsup_model.set_model_specific_task(my_model, task)  # in case nothing is used


        if 'pool' in my_model.args.baseline:

            if my_model.training:
                inferred_subnetwork = my_model.inferred_subnetwork
                print('inferred_subnetwork: ', inferred_subnetwork)

            else:
                # inferred_subnetwork = my_model.inferred_subnetwork_eval
                inferred_subnetwork = task #TODO: you can do this because there is not evaluation set
                print('inferred_subnetwork_eval: ', inferred_subnetwork)


            # print('ppl: ', ppl)
            # my_task = torch.Tensor([inferred_subnetwork]).repeat(input_ids.size(0)).long()
            # there could be different number of classes, be careful. Let's do GGG first

            supsup_model.set_model_sim(my_model.model, 'specific')  # if nothing
            supsup_model.set_model_specific_task(my_model, inferred_subnetwork)  # in case nothing is used
            outputs = my_model.model(input_ids=input_ids, labels=cls_labels,attention_mask=attention_mask,output_hidden_states=True,task=task)
            logits = outputs.logits

            if 'entropy' in my_model.args.baseline or 'ce' in my_model.args.baseline:
                loss = outputs.loss   # TODO: consider additional regualarization for dissimilar tasks

            elif 'reconstruct' in my_model.args.baseline:
                supsup_model.set_model_sim(my_model.model, 'pool')  # if nothing
                supsup_model.set_model_pool_task(my_model, inferred_subnetwork)  # use the adapter for task t
                outputs_pool = my_model.teacher(input_ids=input_ids_gen, labels=labels_gen, attention_mask=attention_mask_gen,
                                              output_hidden_states=True)

                loss = outputs.loss + outputs_pool.loss  # TODO: consider additional regualarization for dissimilar tasks


        else:

            if my_model.args.is_reference:
                outputs = my_model.teacher(input_ids=input_ids, labels=cls_labels,attention_mask=attention_mask,output_hidden_states=True,task=task)
            else:
                outputs = my_model.model(input_ids=input_ids, labels=cls_labels,attention_mask=attention_mask,output_hidden_states=True,task=task)

            loss = outputs.loss
            logits = outputs.logits

        if 'ewc' in my_model.args.baseline and my_model.training and self_fisher is not None:  # only if we are training

            loss += fisher_model.ewc_loss_compute(my_model,self_fisher)



        elif ('adapter_hat' in my_model.args.baseline or 'adapter_cat' in my_model.args.baseline
                or 'adapter_bcl' in my_model.args.baseline
                or 'adapter_ctr' in my_model.args.baseline
                or 'transformer_hat' in my_model.args.baseline
                or 'adapter_classic' in my_model.args.baseline) and my_model.training and not my_model.args.is_cat: # no need for testing

            loss += hat_model.hat_loss_compute(masks,mask_pre,my_model.args)

        return loss, logits

