import torch
import transformers
import torch.nn as nn
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
import math 
from transformers import get_constant_schedule_with_warmup, get_constant_schedule, get_linear_schedule_with_warmup
from datasets import load_metric   
import evaluate 


class graft_Trainer_new(nn.Module):
    def __init__(self, model_trainer):
        
        super(graft_Trainer_new, self).__init__()
        self.trainer = model_trainer
        self.model   = self.trainer.model
        self.args    = self.trainer.args
        
        self.trainer.select_trainable_parameters()
        self.params  = self.trainer.params
        

    ########################################################################################################################
    #We need to store pre-trained and fine-tuned model weights as well (Inefficient, need to think more on this)
    ########################################################################################################################
    def augment_models(self, pretrained_model, finetuned_model, model_args, device):   
        self.pretrained_model = pretrained_model
        self.finetuned_model  = finetuned_model
        self.device = device
        self.model_args = model_args
    
    ########################################################################################################################
    #The following function initializes the mask for grafting optimization
    ########################################################################################################################
    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p


            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p
                    
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]        
    ########################################################################################################################

    
    ########################################################################################################################
    #The following function resets the model to pretrained model weights
    ########################################################################################################################   
    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)



            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )
    ########################################################################################################################
    
    
    ########################################################################################################################
    #The following function gets the grafted model with a given mask (or the current trainable parameters)
    ########################################################################################################################
    def interpolate_model(self, round_=False, mask=None):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0
        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)
                    # break

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)
                    # break

            # gamma = gamma_base + (1 - 2 * gamma_base) * sigmoid(S)
            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        if mask is not None:
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * mask[counter]
                        else:    
                            frac = sigmoid(self.mask[counter])
                        if round_:
                            frac = torch.round(frac)
                            binary_mask.append(frac)
                        n_graft_params += torch.sum(frac)
                        
                        p += frac * ( finetensor - pretensor ) 
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
            return n_graft_params / self.num_params, binary_mask
                        
    ########################################################################################################################                   
    

    ########################################################################################################################
    #This function creates the basepatch used for initializing the mask for optimization!
    #If mask_path == "highest_movement", we simply pick the parameters that have moved the most during training
    ########################################################################################################################
    def create_basepatch(self):
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        mask_path = self.model_args.mask_path 
        sparsity_level =  self.model_args.sparsity_level
        
        #If mask is already stored somewhere, I simply load it!
        if mask_path != "highest_movement":
            basepatch = torch.load(mask_path, map_location=self.device)

            
            total = max([ torch.amax(p) for p in basepatch ])
            #if the max value is greater than 1., it means we have received masks without sigmoid
            if total > 1.:
                basepatch[mask_counter] = [ sigmoid( p - sigmoid_bias ) for p in basepatch ]
            
            basepatch = [ torch.round( torch.clip (p, 0., 1.) )  for p in basepatch ]
            print ('Total parameters in my graft: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
            
        elif mask_path == "highest_movement":

            threshold = int(sparsity_level * num_params)  
            consider = self.grad_directions

            abs_tv = []
            for p in consider:
                abs_tv.append(torch.abs(p).view(-1))

            abs_tv = torch.cat(abs_tv)
            k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

            # Get the k largest values; returns values and their indices
            values, indices = torch.topk(abs_tv.view(-1), k)
            threshold = values.min()

            basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

            for p, q in zip(consider, basepatch):
                q[torch.absolute(p) > threshold] = 3.
                q[torch.absolute(p) <= threshold] = -3.

            print ('Total parameters in my stitch: ', sum([ torch.sum(torch.round(torch.nn.Sigmoid()(p))*torch.round(torch.nn.Sigmoid()(p))) / (1. * num_params) for p in basepatch ]))
           
        else:
            raise NotImplementedError("Not Implemented!")
            
        self.basepatch = basepatch
        self.mask = basepatch
    ########################################################################################################################
    
    
    
    ######################################################################################################################## 
    #For debugging, I re-defined evaluation here!
    ########################################################################################################################   
    def evaluate(self, dataloader, task_name, mode='dev'):
        if task_name.lower() not in [ 'qqp', 'mrpc' ]: 
            # metric = load_metric("accuracy", trust_remote_code=True)
            metric = evaluate.load("accuracy")
        else:
            # metric = load_metric("f1", trust_remote_code=True)
            metric = evaluate.load("f1")
            
        self.model.eval()
        hidden_states = []
        counter = 0 
        device = self.device
        for batch in dataloader:
            with torch.no_grad():
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), mask_pos=batch["mask_pos"].to(device), labels=batch["labels"].to(device))
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch["labels"].to(device))
                elif 'finetune' in self.model_args.few_shot_type :
                    outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device)).logits
                    
            predictions = torch.argmax(outputs, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])
            counter += 1
            if mode=='train' and counter >= self.args.gradient_accumulation_steps: break
            
        return metric
    ########################################################################################################################

    
    ########################################################################################################################
    #Main function that trains our graft!
    #We donot use an optimizer to train the graft, but compute the gradient w.r.t. the mask ourselves
    ########################################################################################################################    
    def train_graft (self, \
                     train_dataloader, \
                     valid_dataloader, \
                     eval_dataset, \
                     autoregressive, \
                     task_name, \
                    ):
        
        baseline = 0.  
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.model_args.checkpoint_location
        
        device = self.device
        lr = self.args.learning_rate
        l1_strength = 0
        
        for _ in tqdm( range(int(self.args.num_train_epochs)), 'Training the mask' ):
            total_grad = []
            
            first_batch = 0
            self.interpolate_model(round_=False)

            for batch in train_dataloader:
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               mask_pos=batch["mask_pos"].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )
                    
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )   
                    
                elif 'finetune' in self.model_args.few_shot_type :
                    loss = self.model(input_ids=batch['input_ids'].to(device), \
                                      attention_mask=batch['attention_mask'].to(device), \
                                      labels=batch['labels'].to(device), \
                                     ).loss
                    
                elif 'autoregressive' in self.model_args.few_shot_type :
                    input_ids=batch["input_ids"].to(device)
                    option_ids=batch["label_word_list"].to(device)

                    attention_mask=batch["attention_mask"].to(device)
                    token_type_ids=batch["token_type_ids"].to(device)
                    labels=batch["labels"].to(device)

                    #computing gradients for the slow weights!
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                    logits  = outputs.logits.contiguous()

                    indices = torch.where(token_type_ids[..., 1:] == 1)
                    logits = logits[indices]
                    nlogits = []
                    for i in range(len(input_ids)):
                        nlogits += [ logits[i, option_ids[i]] ]
                    logits = torch.stack(nlogits, 0)

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    loss = torch.mean(loss_fct(logits, labels.view(-1)))

                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]

                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.args.gradient_accumulation_steps: 
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, g in zip(self.mask, total_grad):
                    p -= g * sigmoid(p) * (1 - sigmoid(p)) + l1_strength * sigmoid(p) * (1 - sigmoid(p))
         
            ######### Evaluation of current mask ###########
            self.interpolate_model(round_=True)
            if task_name.lower() not in [ 'qqp', 'mrpc' ]: key = "accuracy"
            else: key = "f1"
                
            if autoregressive:
                tr  = self.trainer.evaluate(train_dataset).compute()[key] 
                val = self.trainer.evaluate(eval_dataset).compute()[key] 
            else:            
                tr  = self.evaluate(train_dataloader, task_name, mode='train').compute()[key]
                val = self.evaluate(valid_dataloader, task_name).compute()[key]
                print(tr, val, val+tr)

            #store the mask with the best train + validation score
            bs_compare = val + tr

            # if bs_compare >= baseline:
            #     torch.save(self.mask, f"/data/common/lm-bff/mask_path/1e-2/mask_{task_name}-prompt-64-0-roberta-base-2-2e-5")
            #     baseline = bs_compare
               
            self.reset_model()      

    ########################################################################################################################
    
        

class graft_Trainer(nn.Module):
    def __init__(self, model_trainer):
        
        super(graft_Trainer, self).__init__()
        self.trainer = model_trainer
        self.model   = self.trainer.model
        self.args    = self.trainer.args
        
        self.trainer.select_trainable_parameters()
        self.params  = self.trainer.params
        

    ########################################################################################################################
    #We need to store pre-trained and fine-tuned model weights as well (Inefficient, need to think more on this)
    ########################################################################################################################
    def augment_models(self, pretrained_model, finetuned_model, model_args, device):   
        self.pretrained_model = pretrained_model
        self.finetuned_model  = finetuned_model
        self.device = device
        self.model_args = model_args
    
    ########################################################################################################################
    #The following function initializes the mask for grafting optimization
    ########################################################################################################################
    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p


            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p
                    
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]        
    ########################################################################################################################

    
    ########################################################################################################################
    #The following function resets the model to pretrained model weights
    ########################################################################################################################   
    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)



            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )
    ########################################################################################################################
    
    
    ########################################################################################################################
    #The following function gets the grafted model with a given mask (or the current trainable parameters)
    ########################################################################################################################
    def interpolate_model(self, round_=False, mask=None):  
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias

        n_graft_params, n_total_params = 0, 0
        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)
                    # break

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)
                    # break

            # gamma = gamma_base + (1 - 2 * gamma_base) * sigmoid(S)
            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        if mask is not None:
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * mask[counter]
                        else:    
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * sigmoid(self.trainable_parameters[counter] - sigmoid_bias) 
                        if round_:
                            frac = torch.round(frac)
                            binary_mask.append(frac)
                        n_graft_params += torch.sum(frac)
                        
                        p += frac * ( finetensor - pretensor ) 
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
            return n_graft_params / self.num_params, binary_mask
                        
    ########################################################################################################################                   
    

    ########################################################################################################################
    #This function creates the basepatch used for initializing the mask for optimization!
    #If mask_path == "highest_movement", we simply pick the parameters that have moved the most during training
    ########################################################################################################################
    def create_basepatch(self):
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        mask_path = self.model_args.mask_path 
        sparsity_level =  self.model_args.sparsity_level
        
        #If mask is already stored somewhere, I simply load it!
        if mask_path != "highest_movement":
            basepatch = torch.load(mask_path, map_location=self.device)

            
            total = max([ torch.amax(p) for p in basepatch ])
            #if the max value is greater than 1., it means we have received masks without sigmoid
            if total > 1.:
                basepatch[mask_counter] = [ sigmoid( p - sigmoid_bias ) for p in basepatch ]
            
            basepatch = [ torch.round( torch.clip (p, 0., 1.) )  for p in basepatch ]
            print ('Total parameters in my graft: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
            
        elif mask_path == "highest_movement":

            threshold = int(sparsity_level * num_params)  
            consider = self.grad_directions

            abs_tv = []
            for p in consider:
                abs_tv.append(torch.abs(p).view(-1))

            abs_tv = torch.cat(abs_tv)
            k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

            # Get the k largest values; returns values and their indices
            values, indices = torch.topk(abs_tv.view(-1), k)
            threshold = values.min()

            basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

            for p, q in zip(consider, basepatch):
                q[torch.absolute(p) > threshold] = 1.

            print ('Total parameters in my stitch: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
        else:
            raise NotImplementedError("Not Implemented!")
            
        self.basepatch = basepatch
    ########################################################################################################################
    
    
    
    ######################################################################################################################## 
    #For debugging, I re-defined evaluation here!
    ########################################################################################################################   
    def evaluate(self, dataloader, task_name, mode='dev'):
        if task_name.lower() not in [ 'qqp', 'mrpc' ]: 
            # metric = load_metric("accuracy", trust_remote_code=True)
            metric = evaluate.load("accuracy")
        else:
            # metric = load_metric("f1", trust_remote_code=True)
            metric = evaluate.load("f1")
            
        self.model.eval()
        hidden_states = []
        counter = 0 
        device = self.device
        for batch in dataloader:
            with torch.no_grad():
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), mask_pos=batch["mask_pos"].to(device), labels=batch["labels"].to(device))
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch["labels"].to(device))
                elif 'finetune' in self.model_args.few_shot_type :
                    outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device)).logits
                    
            predictions = torch.argmax(outputs, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])
            counter += 1
            if mode=='train' and counter >= self.args.gradient_accumulation_steps: break
            
        return metric
    ########################################################################################################################

    
    ########################################################################################################################
    #Main function that trains our graft!
    #We donot use an optimizer to train the graft, but compute the gradient w.r.t. the mask ourselves
    ########################################################################################################################    
    def train_graft (self, \
                     train_dataloader, \
                     valid_dataloader, \
                     eval_dataset, \
                     autoregressive, \
                     task_name, \
                    ):
        
        baseline = 0.  
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.model_args.checkpoint_location
        
        device = self.device
        lr = self.args.learning_rate
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        
        for _ in tqdm( range(int(self.args.num_train_epochs)), 'Training the mask' ):
            total_grad = []
            
            first_batch = 0
            self.interpolate_model()

            for batch in train_dataloader:
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               mask_pos=batch["mask_pos"].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )
                    
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )   
                    
                elif 'finetune' in self.model_args.few_shot_type :
                    loss = self.model(input_ids=batch['input_ids'].to(device), \
                                      attention_mask=batch['attention_mask'].to(device), \
                                      labels=batch['labels'].to(device), \
                                     ).loss
                    
                elif 'autoregressive' in self.model_args.few_shot_type :
                    input_ids=batch["input_ids"].to(device)
                    option_ids=batch["label_word_list"].to(device)

                    attention_mask=batch["attention_mask"].to(device)
                    token_type_ids=batch["token_type_ids"].to(device)
                    labels=batch["labels"].to(device)

                    #computing gradients for the slow weights!
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                    logits  = outputs.logits.contiguous()

                    indices = torch.where(token_type_ids[..., 1:] == 1)
                    logits = logits[indices]
                    nlogits = []
                    for i in range(len(input_ids)):
                        nlogits += [ logits[i, option_ids[i]] ]
                    logits = torch.stack(nlogits, 0)

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    loss = torch.mean(loss_fct(logits, labels.view(-1)))

                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]

                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.args.gradient_accumulation_steps: 
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, (g, s) in zip(self.trainable_parameters, zip(total_grad, self.basepatch)):
                    p -=  ( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias)) )
         
            ######### Evaluation of current mask ###########
            self.interpolate_model(round_=True)
            if task_name.lower() not in [ 'qqp', 'mrpc' ]: key = "accuracy"
            else: key = "f1"
                
            if autoregressive:
                tr  = self.trainer.evaluate(train_dataset).compute()[key] 
                val = self.trainer.evaluate(eval_dataset).compute()[key] 
            else:            
                tr  = self.evaluate(train_dataloader, task_name, mode='train').compute()[key]
                val = self.evaluate(valid_dataloader, task_name).compute()[key]
                print(val)

            #store the mask with the best train + validation score
            bs_compare = val + tr

            # if bs_compare > baseline:
            #     torch.save(self.trainable_parameters, checkpoint_location)
            #     baseline = bs_compare
               
            self.reset_model()      

    ########################################################################################################################
    

class graft_Trainer_new_disjoint(nn.Module):
    def __init__(self, model_trainer):
        
        super(graft_Trainer_new_disjoint, self).__init__()
        self.trainer = model_trainer
        self.model   = self.trainer.model
        self.args    = self.trainer.args
        
        self.trainer.select_trainable_parameters()
        self.params  = self.trainer.params
    
    def create_union_mask(self, task_name):
        all_tasks = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
        prev_masks = []

        if task_name.lower() == all_tasks[0].lower():
            self.union_mask = [torch.zeros_like(p, requires_grad=False).cpu() for p in self.trainable_parameters]
        else:
            for task in all_tasks:
                if task.lower() == task_name.lower():
                    break
                cur_mask = torch.load(f"/data/common/lm-bff/mask_path/disjoint/1e-2/mask_{task.lower()}-prompt-64-0-roberta-base-2-2e-5", map_location=torch.device('cpu'))
                prev_masks.append(cur_mask)
            
            self.union_mask = self.get_union_mask(prev_masks)


    def get_union_mask(self, masks):
        union_mask = []
        for i in range(len(masks[0])):
            for j in range(len(masks)):
                if j == 0:
                    union_mask.append(masks[j][i])
                else:
                    union_mask[i] = torch.logical_or(union_mask[i], masks[j][i])

        return union_mask


    ########################################################################################################################
    #We need to store pre-trained and fine-tuned model weights as well (Inefficient, need to think more on this)
    ########################################################################################################################
    def augment_models(self, pretrained_model, finetuned_model, model_args, device):   
        self.pretrained_model = pretrained_model
        self.finetuned_model  = finetuned_model
        self.device = device
        self.model_args = model_args
    
    ########################################################################################################################
    #The following function initializes the mask for grafting optimization
    ########################################################################################################################
    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p


            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p
                    
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]        
    ########################################################################################################################

    
    ########################################################################################################################
    #The following function resets the model to pretrained model weights
    ########################################################################################################################   
    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)



            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )
    ########################################################################################################################
    
    
    ########################################################################################################################
    #The following function gets the grafted model with a given mask (or the current trainable parameters)
    ########################################################################################################################
    def interpolate_model(self, round_=False, mask=None):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0
        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)
                    # break

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)
                    # break

            # gamma = gamma_base + (1 - 2 * gamma_base) * sigmoid(S)
            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        if mask is not None:
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * mask[counter]
                        else:    
                            frac = sigmoid(self.mask[counter])
                        if round_:
                            frac = torch.round(frac)
                            binary_mask.append(frac)
                        n_graft_params += torch.sum(frac)
                        
                        p += frac * ( finetensor - pretensor ) 
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
            return n_graft_params / self.num_params, binary_mask
                        
    ########################################################################################################################                   
    

    ########################################################################################################################
    #This function creates the basepatch used for initializing the mask for optimization!
    #If mask_path == "highest_movement", we simply pick the parameters that have moved the most during training
    ########################################################################################################################
    def create_basepatch(self):
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        mask_path = self.model_args.mask_path 
        sparsity_level =  self.model_args.sparsity_level
        
        #If mask is already stored somewhere, I simply load it!
        if mask_path != "highest_movement":
            basepatch = torch.load(mask_path, map_location=self.device)

            
            total = max([ torch.amax(p) for p in basepatch ])
            #if the max value is greater than 1., it means we have received masks without sigmoid
            if total > 1.:
                basepatch[mask_counter] = [ sigmoid( p - sigmoid_bias ) for p in basepatch ]
            
            basepatch = [ torch.round( torch.clip (p, 0., 1.) )  for p in basepatch ]
            print ('Total parameters in my graft: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
            
        elif mask_path == "highest_movement":

            threshold = int(sparsity_level * num_params)  
            consider = self.grad_directions

            abs_tv = []
            for (p, m) in zip(consider, self.union_mask):
                masked_p = p * torch.logical_not(m)
                # masked_p = p * (1 - m) # enforce disjoint initialization
                abs_tv.append(torch.abs(masked_p).view(-1))

            abs_tv = torch.cat(abs_tv)
            k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

            # Get the k largest values; returns values and their indices
            values, indices = torch.topk(abs_tv.view(-1), k)
            threshold = values.min()

            basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]
            
            print(self.args.sigmoid_bias)
            for p, q in zip(consider, basepatch):
                q[torch.absolute(p) > threshold] = self.args.sigmoid_bias
                q[torch.absolute(p) <= threshold] = -self.args.sigmoid_bias

            print ('Total parameters in my stitch: ', sum([ torch.sum(torch.round(torch.nn.Sigmoid()(p))*torch.round(torch.nn.Sigmoid()(p))) / (1. * num_params) for p in basepatch ]))
           
        else:
            raise NotImplementedError("Not Implemented!")
            
        self.basepatch = basepatch
        self.mask = basepatch
    ########################################################################################################################
    
    
    
    ######################################################################################################################## 
    #For debugging, I re-defined evaluation here!
    ########################################################################################################################   
    def evaluate(self, dataloader, task_name, mode='dev'):
        if task_name.lower() not in [ 'qqp', 'mrpc' ]: 
            # metric = load_metric("accuracy", trust_remote_code=True)
            metric = evaluate.load("accuracy")
        else:
            # metric = load_metric("f1", trust_remote_code=True)
            metric = evaluate.load("f1")
            
        self.model.eval()
        hidden_states = []
        counter = 0 
        device = self.device
        for batch in dataloader:
            with torch.no_grad():
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), mask_pos=batch["mask_pos"].to(device), labels=batch["labels"].to(device))
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch["labels"].to(device))
                elif 'finetune' in self.model_args.few_shot_type :
                    outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device)).logits
                    
            predictions = torch.argmax(outputs, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])
            counter += 1
            if mode=='train' and counter >= self.args.gradient_accumulation_steps: break
            
        return metric
    ########################################################################################################################

    
    ########################################################################################################################
    #Main function that trains our graft!
    #We donot use an optimizer to train the graft, but compute the gradient w.r.t. the mask ourselves
    ########################################################################################################################    
    def train_graft (self, \
                     train_dataloader, \
                     valid_dataloader, \
                     eval_dataset, \
                     autoregressive, \
                     task_name, \
                    ):
        
        baseline = 0.  
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.model_args.checkpoint_location
        
        device = self.device
        lr = self.args.learning_rate
        l1_strength = 0
        
        for _ in tqdm( range(int(self.args.num_train_epochs)), 'Training the mask' ):
            total_grad = []
            
            first_batch = 0
            self.interpolate_model(round_=False)

            for batch in train_dataloader:
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               mask_pos=batch["mask_pos"].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )
                    
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )   
                    
                elif 'finetune' in self.model_args.few_shot_type :
                    loss = self.model(input_ids=batch['input_ids'].to(device), \
                                      attention_mask=batch['attention_mask'].to(device), \
                                      labels=batch['labels'].to(device), \
                                     ).loss
                    
                elif 'autoregressive' in self.model_args.few_shot_type :
                    input_ids=batch["input_ids"].to(device)
                    option_ids=batch["label_word_list"].to(device)

                    attention_mask=batch["attention_mask"].to(device)
                    token_type_ids=batch["token_type_ids"].to(device)
                    labels=batch["labels"].to(device)

                    #computing gradients for the slow weights!
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                    logits  = outputs.logits.contiguous()

                    indices = torch.where(token_type_ids[..., 1:] == 1)
                    logits = logits[indices]
                    nlogits = []
                    for i in range(len(input_ids)):
                        nlogits += [ logits[i, option_ids[i]] ]
                    logits = torch.stack(nlogits, 0)

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    loss = torch.mean(loss_fct(logits, labels.view(-1)))

                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]

                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.args.gradient_accumulation_steps: 
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, g, m in zip(self.mask, total_grad, self.union_mask):
                    m = m.to(device)
                    # p -= (g * sigmoid(p) * (1 - sigmoid(p)) + l1_strength * sigmoid(p) * (1 - sigmoid(p))) * (1 - m)
                    p -= (g * sigmoid(p) * (1 - sigmoid(p)) + l1_strength * sigmoid(p) * (1 - sigmoid(p))) * torch.logical_not(m)
         
            ######### Evaluation of current mask ###########
            self.interpolate_model(round_=True)
            if task_name.lower() not in [ 'qqp', 'mrpc' ]: key = "accuracy"
            else: key = "f1"
                
            if autoregressive:
                tr  = self.trainer.evaluate(train_dataset).compute()[key] 
                val = self.trainer.evaluate(eval_dataset).compute()[key] 
            else:            
                tr  = self.evaluate(train_dataloader, task_name, mode='train').compute()[key]
                val = self.evaluate(valid_dataloader, task_name).compute()[key]
                print(tr, val, val+tr)

            #store the mask with the best train + validation score
            bs_compare = val + tr

            # if bs_compare >= baseline:
            #     torch.save(self.mask, f"/data/common/lm-bff/mask_path/1e-2/mask_{task_name}-prompt-64-0-roberta-base-2-2e-5")
            #     baseline = bs_compare
               
            self.reset_model()      

    ########################################################################################################################


class graft_Trainer_disjoint(nn.Module):
    def __init__(self, model_trainer):
        
        super(graft_Trainer_disjoint, self).__init__()
        self.trainer = model_trainer
        self.model   = self.trainer.model
        self.args    = self.trainer.args
        
        self.trainer.select_trainable_parameters()
        self.params  = self.trainer.params


    def create_union_mask(self, task_name, desired_sparsity):
        # all_tasks = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
        all_tasks = ["trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP", "SST-2", "cr", "mr", "mpqa"]
        prev_masks = []

        if task_name.lower() == all_tasks[0].lower():
            self.union_mask = [torch.zeros_like(p, requires_grad=False).cpu() for p in self.trainable_parameters]
        else:
            for task in all_tasks:
                if task.lower() == task_name.lower():
                    break
                cur_mask = torch.load(f"/data/common/lm-bff/mask_path/disjoint/reverse_{desired_sparsity}/mask_{task.lower()}-prompt-64-0-roberta-base-2-2e-5", map_location=torch.device('cpu'))
                prev_masks.append(cur_mask)
            
            self.union_mask = self.get_union_mask(prev_masks)


    def get_union_mask(self, masks):
        union_mask = []
        for i in range(len(masks[0])):
            for j in range(len(masks)):
                if j == 0:
                    union_mask.append(masks[j][i])
                else:
                    union_mask[i] = torch.logical_or(union_mask[i], masks[j][i])

        return union_mask
        

    ########################################################################################################################
    #We need to store pre-trained and fine-tuned model weights as well (Inefficient, need to think more on this)
    ########################################################################################################################
    def augment_models(self, pretrained_model, finetuned_model, model_args, device):   
        self.pretrained_model = pretrained_model
        self.finetuned_model  = finetuned_model
        self.device = device
        self.model_args = model_args
    
    ########################################################################################################################
    #The following function initializes the mask for grafting optimization
    ########################################################################################################################
    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p


            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p
                    
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]        
    ########################################################################################################################

    
    ########################################################################################################################
    #The following function resets the model to pretrained model weights
    ########################################################################################################################   
    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)



            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )
    ########################################################################################################################
    
    
    ########################################################################################################################
    #The following function gets the grafted model with a given mask (or the current trainable parameters)
    ########################################################################################################################
    def interpolate_model(self, round_=False, mask=None):  
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias

        n_graft_params, n_total_params = 0, 0
        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)
                    # break

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)
                    # break

            # gamma = gamma_base + (1 - 2 * gamma_base) * sigmoid(S)
            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        if mask is not None:
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * mask[counter]
                        else:    
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * sigmoid(self.trainable_parameters[counter] - sigmoid_bias) 
                        if round_:
                            frac = torch.round(frac)
                            binary_mask.append(frac)
                        n_graft_params += torch.sum(frac)
                        
                        p += frac * ( finetensor - pretensor ) 
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
            return n_graft_params / self.num_params, binary_mask
                        
    ########################################################################################################################                   
    

    ########################################################################################################################
    #This function creates the basepatch used for initializing the mask for optimization!
    #If mask_path == "highest_movement", we simply pick the parameters that have moved the most during training
    ########################################################################################################################
    def create_basepatch(self):
        sigmoid = torch.nn.Sigmoid()
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        mask_path = self.model_args.mask_path 
        sparsity_level =  self.model_args.sparsity_level
        
        #If mask is already stored somewhere, I simply load it!
        if mask_path != "highest_movement":
            basepatch = torch.load(mask_path, map_location=self.device)

            
            total = max([ torch.amax(p) for p in basepatch ])
            #if the max value is greater than 1., it means we have received masks without sigmoid
            if total > 1.:
                basepatch[mask_counter] = [ sigmoid( p - sigmoid_bias ) for p in basepatch ]
            
            basepatch = [ torch.round( torch.clip (p, 0., 1.) )  for p in basepatch ]
            print ('Total parameters in my graft: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
            
        elif mask_path == "highest_movement":

            threshold = int(sparsity_level * num_params)  
            consider = self.grad_directions

            abs_tv = []
            for p, m in zip(consider, self.union_mask):
                masked_p = p * torch.logical_not(m)
                # masked_p = p * (1 - m) # enforce disjoint initialization
                abs_tv.append(torch.abs(masked_p).view(-1))

            abs_tv = torch.cat(abs_tv)
            k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

            # Get the k largest values; returns values and their indices
            values, indices = torch.topk(abs_tv.view(-1), k)
            threshold = values.min()

            basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

            for p, q in zip(consider, basepatch):
                q[torch.absolute(p) > threshold] = 1.

            print ('Total parameters in my stitch: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
        else:
            raise NotImplementedError("Not Implemented!")
            
        self.basepatch = basepatch
    ########################################################################################################################
    
    
    
    ######################################################################################################################## 
    #For debugging, I re-defined evaluation here!
    ########################################################################################################################   
    def evaluate(self, dataloader, task_name, mode='dev'):
        if task_name.lower() not in [ 'qqp', 'mrpc' ]: 
            # metric = load_metric("accuracy", trust_remote_code=True)
            metric = evaluate.load("accuracy")
        else:
            # metric = load_metric("f1", trust_remote_code=True)
            metric = evaluate.load("f1")
            
        self.model.eval()
        hidden_states = []
        counter = 0 
        device = self.device
        for batch in dataloader:
            with torch.no_grad():
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), mask_pos=batch["mask_pos"].to(device), labels=batch["labels"].to(device))
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch["labels"].to(device))
                elif 'finetune' in self.model_args.few_shot_type :
                    outputs = self.model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device)).logits
                    
            predictions = torch.argmax(outputs, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])
            counter += 1
            if mode=='train' and counter >= self.args.gradient_accumulation_steps: break
            
        return metric
    ########################################################################################################################

    
    ########################################################################################################################
    #Main function that trains our graft!
    #We donot use an optimizer to train the graft, but compute the gradient w.r.t. the mask ourselves
    ########################################################################################################################    
    def train_graft (self, \
                     train_dataloader, \
                     valid_dataloader, \
                     eval_dataset, \
                     autoregressive, \
                     task_name, \
                    ):
        
        baseline = 0.  
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.model_args.checkpoint_location
        
        device = self.device
        lr = self.args.learning_rate
        sigmoid_bias = self.args.sigmoid_bias
        num_params = self.num_params
        
        for _ in tqdm( range(int(self.args.num_train_epochs)), 'Training the mask' ):
            total_grad = []
            
            first_batch = 0
            self.interpolate_model()

            for batch in train_dataloader:
                if 'prompt' in self.model_args.few_shot_type :
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               mask_pos=batch["mask_pos"].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )
                    
                elif ('finetune' in self.model_args.few_shot_type and  self.model_args.use_CLS_linearhead == 1) : 
                    loss, outputs = self.model(input_ids=batch['input_ids'].to(device), \
                                               attention_mask=batch['attention_mask'].to(device), \
                                               labels=batch["labels"].to(device), \
                                              )   
                    
                elif 'finetune' in self.model_args.few_shot_type :
                    loss = self.model(input_ids=batch['input_ids'].to(device), \
                                      attention_mask=batch['attention_mask'].to(device), \
                                      labels=batch['labels'].to(device), \
                                     ).loss
                    
                elif 'autoregressive' in self.model_args.few_shot_type :
                    input_ids=batch["input_ids"].to(device)
                    option_ids=batch["label_word_list"].to(device)

                    attention_mask=batch["attention_mask"].to(device)
                    token_type_ids=batch["token_type_ids"].to(device)
                    labels=batch["labels"].to(device)

                    #computing gradients for the slow weights!
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                    logits  = outputs.logits.contiguous()

                    indices = torch.where(token_type_ids[..., 1:] == 1)
                    logits = logits[indices]
                    nlogits = []
                    for i in range(len(input_ids)):
                        nlogits += [ logits[i, option_ids[i]] ]
                    logits = torch.stack(nlogits, 0)

                    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                    loss = torch.mean(loss_fct(logits, labels.view(-1)))

                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]

                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.args.gradient_accumulation_steps: 
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, (g, s), m in zip(self.trainable_parameters, zip(total_grad, self.basepatch), self.union_mask):
                    m = m.to(device)
                    p -=  ( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias)) ) * torch.logical_not(m)
         
            ######### Evaluation of current mask ###########
            self.interpolate_model(round_=True)
            if task_name.lower() not in [ 'qqp', 'mrpc' ]: key = "accuracy"
            else: key = "f1"
                
            if autoregressive:
                tr  = self.trainer.evaluate(train_dataset).compute()[key] 
                val = self.trainer.evaluate(eval_dataset).compute()[key] 
            else:            
                tr  = self.evaluate(train_dataloader, task_name, mode='train').compute()[key]
                val = self.evaluate(valid_dataloader, task_name).compute()[key]
                print(val)

            #store the mask with the best train + validation score
            bs_compare = val + tr

            # if bs_compare > baseline:
            #     torch.save(self.trainable_parameters, checkpoint_location)
            #     baseline = bs_compare
               
            self.reset_model()      

    ########################################################################################################################