import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from vision_datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
from eval import eval_single_dataset
from modeling import ImageClassifier
from heads import get_classification_head
import utils


class Localizer(nn.Module):
    def __init__(self, trainable_params, model, pretrained_model, finetuned_model, dataset_name, args, graft_args):
        super(Localizer, self).__init__()

        self.params = trainable_params
        self.model = model
        self.pretrained_model = pretrained_model
        self.finetuned_model = finetuned_model
        self.args = args
        self.graft_args = graft_args
        self.classifier_head = get_classification_head(self.args, dataset_name)

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
        self.model.to(self.device)
        self.classifier_head.to(self.device)
        self.pretrained_model.to("cpu")
        self.finetuned_model.to("cpu")
        self.model.eval()
        self.finetuned_model.eval()
        self.pretrained_model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.finetuned_model.parameters():
            param.requires_grad = False

        self.create_binary_masks()
        self.mask = self.create_basepatch()


    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p        
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]    


    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)

            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )    
    

    def create_basepatch(self):
        num_params = self.num_params

        threshold = int(self.graft_args.sparsity_level * num_params)
        best_top = np.zeros(threshold)
        consider = self.grad_directions
        count = 0

        for p in consider:
            # print(count, len(consider))
            count+=1
            arr = np.absolute(np.ndarray.flatten(p.detach().cpu().numpy()))
            all_magnitude = np.concatenate( [np.absolute(arr), best_top] )
            best_top = -np.sort(-all_magnitude)[:threshold]  

        all_magnitude = np.asarray(best_top)  
        threshold = np.sort(all_magnitude)[0]

        basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

        for p, q in zip(consider, basepatch):
            q[torch.absolute(p) > threshold] = self.graft_args.sigmoid_bias
            q[torch.absolute(p) <= threshold] = -self.graft_args.sigmoid_bias

        print ('Total parameters in my stitch: ', sum([ torch.sum(torch.round(torch.nn.Sigmoid()(p))*torch.round(torch.nn.Sigmoid()(p))) / (1. * num_params) for p in basepatch ]))
            
        return basepatch


    def interpolate_model(self, round_=True, mask=None, return_mask=False):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0

        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)

            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        frac = sigmoid(self.mask[counter])
                        if round_:
                            frac = torch.round(frac)
                        n_graft_params += torch.sum(frac)
                        frac = frac.to(self.device)
                        p += frac * ( finetensor - pretensor ) 
                        binary_mask.append(frac)
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
        
        if return_mask:
            return binary_mask, n_graft_params / self.num_params
    

    def eval_graft(self, dataloader, dataset_name):
        classification_head = get_classification_head(self.args, dataset_name)
        model = ImageClassifier(self.model, classification_head)

        model.eval()

        with torch.no_grad():
            top1, correct, n = 0., 0., 0.
            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)

                logits = utils.get_logits(x, model)

                pred = logits.argmax(dim=1, keepdim=True).to(self.device)

                correct += pred.eq(y.view_as(pred)).sum().item()
                
                n += y.size(0)

            top1 = correct / n

        metrics = {'top1': top1}
        print(f'Grafting on {dataset_name}. Accuracy: {100*top1:.2f}%')
    
        return metrics


    def train_graft(self, dataloader, dataset_name,):
        
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.graft_args.checkpoint_location
        
        device = self.device
        lr = self.graft_args.learning_rate

        # print("Initial mask evaluation with highest movement")
        # mask = self.interpolate_model(round_=True, return_mask=True)
        # val = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        # self.reset_model()   
        
        for epoch in tqdm(range(self.graft_args.num_train_epochs), 'Training the mask'):
            print("Epoch: ", epoch)
            total_grad = []
            
            first_batch = 0
            self.interpolate_model(round_=False)

            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)
                features = self.model(x)
                outputs = self.classifier_head(features)
                loss = loss_fct(outputs, y)
                    
                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]


                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.graft_args.gradient_accumulation_steps: 
                    # lr = lr * 0.9
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, g in zip(self.mask, total_grad):
                    # print(g * sigmoid(p) * (1 - sigmoid(p)))
                    p -= g * sigmoid(p) * (1 - sigmoid(p)) + self.graft_args.l1_strength * sigmoid(p) * (1 - sigmoid(p))
                # for p, (g, s) in zip(self.trainable_parameters, zip(total_grad, self.basepatch)):
                    # p -=  ( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias))
         
            ######## Evaluation of current mask ###########
            if epoch % 5 == 0 and epoch != 0:
                mask, proportion = self.interpolate_model(round_=True, return_mask=True)
                val = self.eval_graft(dataloader, dataset_name)
                self.reset_model()   

                # print(self.mask)
                
            # if (epoch != 0 and epoch % 10 == 0):
            # # if epoch == self.graft_args.num_train_epochs-1:
            #     mask, proportion = self.interpolate_model(round_=True, return_mask=True)
            #     test = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
            #     self.reset_model()   
        
        mask, proportion = self.interpolate_model(round_=True, return_mask=True)
        test = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        self.reset_model() 
        # return 0,0,0

        return mask, proportion, test



class Stitcher(nn.Module):
    def __init__(self, trainable_params, model, pretrained_model, finetuned_models, masks, use_avg_mask=True):
        super(Stitcher, self).__init__()
        self.params = trainable_params
        self.pretrained_model = pretrained_model
        self.finetuned_models = finetuned_models
        self.model = model

        self.masks = masks
        if use_avg_mask:
            self.masks = self.get_average_masks()
        # self.masks = self.get_sum_masks()

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        self.pretrained_model.eval()
        self.model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.model.parameters():
            param.requires_grad = False   
        for finetuned_model in self.finetuned_models:
            finetuned_model.eval()
            for param in finetuned_model.parameters():
                param.requires_grad = False
    

    def get_average_masks(self):
            
        def reciprocal_with_zero(tensor):
            mask = tensor == 0
            reciprocal = torch.reciprocal(tensor)
            reciprocal = reciprocal.masked_fill(mask, 0)
            return reciprocal

        output_masks = []
        for i in range(len(self.masks)):
            output_mask = self.masks[i].copy()
            # every other mask
            for j in range(len(self.masks)):
                if i == j: continue
                # every layer
                for k in range(len(self.masks[i])):
                    intersect = torch.logical_and(self.masks[i][k], self.masks[j][k])
                    output_mask[k] = output_mask[k] + intersect
            
            for k in range(len(self.masks[i])):
                output_mask[k] = reciprocal_with_zero(output_mask[k])
            output_masks.append(output_mask)

        return output_masks
    
    def get_sum_masks(self):
        output_masks = []
        for i in range(len(self.masks)):
            output_mask = self.masks[i].copy()
            for j in range(len(self.masks)):
                if i == j: continue
                for k in range(len(self.masks[i])):
                    output_mask[k] = torch.logical_or(output_mask[k], self.masks[j][k])
            output_masks.append(output_mask)

        return output_masks
            

    def interpolate_models(self):
        n_graft_params, n_total_params = 0, 0

        trainable_name = []
        for n in self.params: 
            trainable_name += [n]

        self.pretrained_model.to(self.device)
        self.model.to(self.device)
        for finetuned_model, mask in zip(self.finetuned_models, self.masks):
            finetuned_model.to(self.device)
            for counter in range(len(trainable_name)):
                for pre_n, pre_p in self.pretrained_model.named_parameters():
                    if pre_n == trainable_name[counter]: 
                        pretensor = pre_p.to(self.device)

                for fine_n, fine_p in finetuned_model.named_parameters():
                    if fine_n == trainable_name[counter]: 
                        finetensor = fine_p.to(self.device)

                with torch.no_grad():            
                    for n, p in self.model.named_parameters():  
                        n_total_params += p.numel()
                        if n == trainable_name[counter]: 
                            mask[counter].to(self.device)
                            p += mask[counter] * ( finetensor - pretensor ) 
        
            print ('Proportion in my graft: ', n_graft_params / n_total_params)
        return self.model
                        

class Localizer_og(nn.Module):
    def __init__(self, trainable_params, model, pretrained_model, finetuned_model, dataset_name, args, graft_args):
        super(Localizer_og, self).__init__()

        self.params = trainable_params
        self.model = model
        self.pretrained_model = pretrained_model
        self.finetuned_model = finetuned_model
        self.args = args
        self.graft_args = graft_args
        self.classifier_head = get_classification_head(self.args, dataset_name)

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
        self.model.to(self.device)
        self.classifier_head.to(self.device)
        self.pretrained_model.to("cpu")
        self.finetuned_model.to("cpu")
        self.model.eval()
        self.finetuned_model.eval()
        self.pretrained_model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.finetuned_model.parameters():
            param.requires_grad = False

        self.create_binary_masks()
        self.create_basepatch()


    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p        
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]    


    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)

            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )    
    

    def create_basepatch(self):
        num_params = self.num_params

        threshold = int(self.graft_args.sparsity_level * num_params)
        best_top = np.zeros(threshold)
        consider = self.grad_directions
        count = 0

        for p in consider:
            # print(count, len(consider))
            count+=1
            arr = np.absolute(np.ndarray.flatten(p.detach().cpu().numpy()))
            all_magnitude = np.concatenate( [np.absolute(arr), best_top] )
            best_top = -np.sort(-all_magnitude)[:threshold]  

        all_magnitude = np.asarray(best_top)  
        threshold = np.sort(all_magnitude)[0]

        basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

        for p, q in zip(consider, basepatch):
            q[torch.absolute(p) > threshold] = 1.

        print ('Total parameters in my stitch: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
            
        self.basepatch = basepatch


    def interpolate_model(self, round_=True, mask=None, return_mask=False):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0

        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)

            # gamma = gamma_base + (1 - 2 * gamma_base) * sigmoid(S - sigmoid_bias)
            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    p = p.to(self.device)
                    if n == self.trainable_name[counter]: 
                        if mask is not None:
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * mask[counter]
                        else:    
                            frac = self.basepatch[counter] + (1. - 2. * self.basepatch[counter]) * sigmoid(self.trainable_parameters[counter] - self.graft_args.sigmoid_bias) 
                        if round_:
                            frac = torch.round(frac)
                        n_graft_params += torch.sum(frac)
                        frac = frac.to(self.device)
                        p += frac * ( finetensor - pretensor ) 
                        binary_mask.append(frac)
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
        
        if return_mask:
            return binary_mask, n_graft_params / self.num_params

    
    def eval_graft(self, dataloader, dataset_name):
        classification_head = get_classification_head(self.args, dataset_name)
        model = ImageClassifier(self.model, classification_head)

        model.eval()

        with torch.no_grad():
            top1, correct, n = 0., 0., 0.
            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)

                logits = utils.get_logits(x, model)

                pred = logits.argmax(dim=1, keepdim=True).to(self.device)

                correct += pred.eq(y.view_as(pred)).sum().item()
                
                n += y.size(0)

            top1 = correct / n

        metrics = {'top1': top1}
        print(f'Grafting on {dataset_name}. Accuracy: {100*top1:.2f}%')
    
        return metrics


    def train_graft(self, dataloader, dataset_name,):

        if self.graft_args.num_train_epochs == 0:
            mask = self.interpolate_model(round_=True, return_mask=True)
            eval_single_dataset(self.model, dataset_name, self.args)["top1"]
            self.reset_model()

            return mask
        
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.graft_args.checkpoint_location
        
        device = self.device
        lr = self.graft_args.learning_rate
        sigmoid_bias = self.graft_args.sigmoid_bias

        # print("Initial mask evaluation with highest movement")
        # mask = self.interpolate_model(round_=True, return_mask=True)
        # val = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        # self.reset_model()   
        
        for epoch in tqdm(range(self.graft_args.num_train_epochs), 'Training the mask'):
            print("Epoch: ", epoch)
            total_grad = []
            
            first_batch = 0
            self.interpolate_model(round_=False)

            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)
                features = self.model(x)
                outputs = self.classifier_head(features)
                loss = loss_fct(outputs, y)
                    
                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]


                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.graft_args.gradient_accumulation_steps: 
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, (g, s) in zip(self.trainable_parameters, zip(total_grad, self.basepatch)):
                    # print(( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias)) ))
                    p -=  ( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias)) )# + 1 * sigmoid(p) * (1 - sigmoid(p))
            ######### Evaluation of current mask ###########
            if epoch % 5 == 0 and epoch != 0:
                mask = self.interpolate_model(round_=True, return_mask=True)
                val = self.eval_graft(dataloader, dataset_name)
                self.reset_model()   

                # print(self.trainable_parameters)
                
        mask, proportion = self.interpolate_model(round_=True, return_mask=True)
        test = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        self.reset_model() 
        # return 0,0,0

        return mask, proportion, test


class Localizer_disjoint(nn.Module):
    def __init__(self, trainable_params, prev_masks, model, pretrained_model, finetuned_model, dataset_name, args, graft_args,):
        super(Localizer_disjoint, self).__init__()

        self.params = trainable_params
        self.model = model
        self.pretrained_model = pretrained_model
        self.finetuned_model = finetuned_model
        self.args = args
        self.graft_args = graft_args
        self.classifier_head = get_classification_head(self.args, dataset_name)

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
        self.model.to(self.device)
        self.classifier_head.to(self.device)
        self.pretrained_model.to("cpu")
        self.finetuned_model.to("cpu")
        self.model.eval()
        self.finetuned_model.eval()
        self.pretrained_model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.finetuned_model.parameters():
            param.requires_grad = False
        
        self.union_mask = self.get_union_mask(prev_masks)

        self.create_binary_masks()
        self.mask = self.create_basepatch()
    

    def get_union_mask(self, masks):
        union_mask = []
        for i in range(len(masks[0])):
            for j in range(len(masks)):
                if j == 0:
                    union_mask.append(masks[j][i])
                else:
                    union_mask[i] = torch.logical_or(union_mask[i], masks[j][i])

        return union_mask


    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device=self.device, requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p        
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]    


    def reset_model(self):
        sigmoid = torch.nn.Sigmoid()
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)

            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )    
    

    def create_basepatch(self):
        num_params = self.num_params

        threshold = int(self.graft_args.sparsity_level * num_params)
        best_top = np.zeros(threshold)
        consider = self.grad_directions

        abs_tv = []
        for (p, m) in zip(consider, self.union_mask):
            masked_p = p * (1 - m) # enforce disjoint initialization
            abs_tv.append(torch.abs(masked_p).view(-1))

        abs_tv = torch.cat(abs_tv)
        k = int(self.graft_args.sparsity_level * abs_tv.numel())  # 1% of the total number of elements

        # Get the k largest values; returns values and their indices
        values, indices = torch.topk(abs_tv.view(-1), k)
        threshold = values.min()

        basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]

        for p, q in zip(consider, basepatch):
            q[torch.absolute(p) > threshold] = self.graft_args.sigmoid_bias
            q[torch.absolute(p) <= threshold] = -self.graft_args.sigmoid_bias

        print ('Total parameters in my stitch: ', sum([ torch.sum(torch.round(torch.nn.Sigmoid()(p))*torch.round(torch.nn.Sigmoid()(p))) / (1. * num_params) for p in basepatch ]))
            
        return basepatch


    def interpolate_model(self, round_=True, mask=None, return_mask=False):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0

        binary_mask = []

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p.to(self.device)

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p.to(self.device)

            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        frac = sigmoid(self.mask[counter])
                        if round_:
                            frac = torch.round(frac)
                        n_graft_params += torch.sum(frac)
                        frac = frac.to(self.device)
                        p += frac * ( finetensor - pretensor ) 
                        binary_mask.append(frac)
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
        
        if return_mask:
            return binary_mask, n_graft_params / self.num_params
    

    def eval_graft(self, dataloader, dataset_name):
        classification_head = get_classification_head(self.args, dataset_name)
        model = ImageClassifier(self.model, classification_head)

        model.eval()

        with torch.no_grad():
            top1, correct, n = 0., 0., 0.
            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)

                logits = utils.get_logits(x, model)

                pred = logits.argmax(dim=1, keepdim=True).to(self.device)

                correct += pred.eq(y.view_as(pred)).sum().item()
                
                n += y.size(0)

            top1 = correct / n

        metrics = {'top1': top1}
        print(f'Grafting on {dataset_name}. Accuracy: {100*top1:.2f}%')
    
        return metrics


    def train_graft(self, dataloader, dataset_name,):
        
        loss_fct = torch.nn.CrossEntropyLoss()
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        checkpoint_location = self.graft_args.checkpoint_location
        
        device = self.device
        lr = self.graft_args.learning_rate

        # print("Initial mask evaluation with highest movement")
        # mask = self.interpolate_model(round_=True, return_mask=True)
        # val = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        # self.reset_model()   
        
        for epoch in tqdm(range(self.graft_args.num_train_epochs), 'Training the mask'):
            print("Epoch: ", epoch)
            total_grad = []
            
            first_batch = 0
            self.interpolate_model(round_=False)

            for i, data in enumerate(tqdm(dataloader)):
                data = maybe_dictionarize(data)
                x = data['images'].to(self.device)
                y = data['labels'].to(self.device)
                features = self.model(x)
                outputs = self.classifier_head(features)
                loss = loss_fct(outputs, y)
                    
                loss.backward()
                
                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]


                if first_batch == 0:
                    total_grad = [lr * g for g in grad]
                else:
                    total_grad = [ p + lr * g for (p, g) in zip( total_grad, grad ) ]
                first_batch += 1 
                #restrict the number of loops
                if first_batch >= self.graft_args.gradient_accumulation_steps: 
                    # lr = lr * 0.9
                    break

            total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
       
            #Take the gradient step
            with torch.no_grad():
                for p, g in zip(self.mask, total_grad, self.union_mask):
                    # print(g * sigmoid(p) * (1 - sigmoid(p)))
                    p -= (g * sigmoid(p) * (1 - sigmoid(p)) + self.graft_args.l1_strength * sigmoid(p) * (1 - sigmoid(p))) * (1 - self.union_mask)
                # for p, (g, s) in zip(self.trainable_parameters, zip(total_grad, self.basepatch)):
                    # p -=  ( (1. - 2.*s) * g * sigmoid(p - sigmoid_bias) * (1. - sigmoid(p - sigmoid_bias))
         
            ######## Evaluation of current mask ###########
            if epoch % 5 == 0 and epoch != 0:
                mask, proportion = self.interpolate_model(round_=True, return_mask=True)
                val = self.eval_graft(dataloader, dataset_name)
                self.reset_model()   
        
        mask, proportion = self.interpolate_model(round_=True, return_mask=True)
        test = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        self.reset_model() 
        # return 0,0,0

        return mask, proportion, test