from networks.supsup_model import MultitaskMaskLinear
from networks import supsup_model, l2p_model
import torch
from networks import simcse
import torch.autograd as autograd
import torch.distributed as dist
import numpy as np
from tqdm import tqdm
import os
from utils import utils

# When task info is not provided we can infer it
# here we use the oneshot task inference alg detailed in the paper

def compute_subnetowrk_imp(my_model, dataloader,accelerator,args,path=''):
    # is_return is removed because it will cause inconsistency for differnt node
    unwrap_model = accelerator.unwrap_model(my_model)

    # Set task < 0 for inference mode
    # set_model_task(model, -1, verbose=False)
    num_tasks = args.ntasks
    supsup_model.set_model_oneshot(unwrap_model, True)  # in case nothing is used

    # Initialize alphas to uniform
    alphas = torch.ones(num_tasks, 1, 1) / num_tasks
    alphas = alphas.cuda()
    alphas.requires_grad_(True)
    supsup_model.set_alphas(unwrap_model, alphas) # TODO: this means all layer share the same parameters

    subnetwork_imp = torch.zeros(args.ntasks).cuda()
    tot_tokens = 0.0

    for step, inputs in enumerate(tqdm(dataloader, desc="Iteration subnetwork imp")):

        labels_gen = inputs["labels_gen"]
        input_ids_gen = inputs["input_ids_gen"]
        attention_mask_gen = inputs["attention_mask_gen"]
        input_ids =  inputs['input_ids']
        attention_mask = inputs['attention_mask']
        task = inputs['task']

        if 'entropy' in args.baseline:
            outputs = unwrap_model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) # use its own head is ok
            logits = outputs.logits
            entropy = -(logits.softmax(dim=1) * logits.log_softmax(dim=1)).sum(1).mean()
            g, = autograd.grad(entropy, alphas)

            subnetwork_imp += g.squeeze().detach()
            tot_tokens += attention_mask.float().detach().sum().data

        elif 'ce' in args.baseline:
            if args.task_name in args.classification:
                cls_labels = inputs['cls_labels']
                outputs = unwrap_model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True,labels=cls_labels,task=task)
            else:
                labels = inputs['labels']
                outputs = unwrap_model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True,labels=labels,task=task)
            loss = outputs.loss
            g, = autograd.grad(loss, alphas)

            subnetwork_imp += g.squeeze().detach()
            tot_tokens += attention_mask.float().detach().sum().data


        elif 'reconstruct' in args.baseline:

            if args.task_name in args.classification:
                outputs = unwrap_model.teacher(input_ids=input_ids_gen, labels=labels_gen, attention_mask=attention_mask_gen, output_hidden_states=True)

                loss = outputs.loss
                g, = autograd.grad(loss, alphas)

                subnetwork_imp += g.squeeze().detach()
                tot_tokens += attention_mask_gen.float().detach().sum().data


            elif args.task_name in args.generation:

                outputs = unwrap_model.model(input_ids=input_ids_gen, labels=labels_gen, attention_mask=attention_mask_gen, output_hidden_states=True)
                loss = outputs.loss
                g, = autograd.grad(loss, alphas)

                subnetwork_imp += g.squeeze().detach()
                tot_tokens += attention_mask_gen.float().detach().sum().data



    subnetwork_imp /= tot_tokens

    # Print/save matrices
    accelerator.wait_for_everyone()

    subnetwork_imp = utils.gather_mean(subnetwork_imp) #mean the gradient

    subnetwork_imp = imp_norm(subnetwork_imp)

    supsup_model.set_model_oneshot(unwrap_model, False)  # in case nothing is used
    alphas.requires_grad_(False) # return


    if accelerator.is_main_process:

        if path == '':
            subnetwork_imp_path = args.output_dir + 'subnetwork_imp.npy'
            subnetwork_imp_text_path = args.output_dir + "subnetwork_imp_text"

        else:
            subnetwork_imp_path = args.output_dir + 'subnetwork_imp_'+path+'.npy'
            subnetwork_imp_text_path = args.output_dir +path+ "_subnetwork_imp_text"

        progressive_subnetwork_imp_path  = args.output_dir + "../progressive_subnetwork_imp"
        if os.path.exists(progressive_subnetwork_imp_path):
            progressive_subnetwork_imp = np.loadtxt(progressive_subnetwork_imp_path)
        else:
            progressive_subnetwork_imp = np.zeros((args.ntasks, args.ntasks), dtype=np.float32)

        for i in range(args.ft_task+1):
            progressive_subnetwork_imp[args.ft_task][i] = subnetwork_imp.detach().cpu().numpy()[i]

        np.savetxt(progressive_subnetwork_imp_path, progressive_subnetwork_imp, '%.4f', delimiter='\t')

        np.save(subnetwork_imp_path,subnetwork_imp.detach().cpu().numpy())


        print('save subnetwork_imp: ', subnetwork_imp.detach().cpu().numpy())
        print('save path: ', subnetwork_imp_text_path )
        np.savetxt(subnetwork_imp_text_path , subnetwork_imp.detach().cpu().numpy()[:args.ft_task+1], '%.4f', delimiter='\t')

        # want to check the importance

def imp_norm(imp):
    imp = (imp - imp.mean()) / imp.std()  # 2D, we need to deal with this for each layer
    imp = torch.nn.Tanh()(imp).abs()
    return imp
