import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm


class Localizer(nn.Module):
    def __init__(self, trainable_params, model, pretrained_model, finetuned_model, graft_args, run_name):
        super(Localizer, self).__init__()

        self.params = trainable_params
        self.model = model
        self.pretrained_model = pretrained_model
        self.finetuned_model = finetuned_model
        self.graft_args = graft_args
        self.run_name = run_name

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
        self.model.to(self.device)
        # self.model.to("cpu")
        self.pretrained_model.to("cpu")
        self.finetuned_model.to("cpu")
        self.model.train()
        self.finetuned_model.eval()
        self.pretrained_model.eval()
        for n, p in self.model.named_parameters():
            frozen = ['model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight']
            if n in frozen:
                p.requires_grad = False
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.finetuned_model.parameters():
            param.requires_grad = False

        self.create_binary_masks()
        self.mask = self.create_basepatch()


    def create_binary_masks(self):
        
        self.trainable_name = []
        # S
        self.trainable_parameters = []
       
        for n in self.params: 
            self.trainable_name += [n]
            p = self.params[n]
            self.trainable_parameters += [ torch.rand_like( p.data, device="cpu", requires_grad=False) ] 
        
        self.num_params = sum([p.numel() for p in self.trainable_parameters])  

        self.grad_directions = []
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: finetensor = fine_p        
                    
            self.grad_directions += [ (finetensor - pretensor).detach() ]    


    def reset_model(self):
        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: pretensor = pre_p.to(self.device)

            with torch.no_grad():   
                for n, p in self.model.named_parameters():    
                    if n == self.trainable_name[counter]: 
                    #    frac = sigmoid(trainable_parameters[counter] - sigmoid_bias)
                        p += ( pretensor - p )    
    

    def create_basepatch(self):
        print("Creating initialization with largest movement")
        
        num_params = self.num_params
        sparsity_level =  self.graft_args.sparsity
        
        threshold = int(sparsity_level * num_params)  
        consider = self.grad_directions

        abs_tv = []
        for p in consider:
            abs_tv.append(torch.abs(p).view(-1))

        abs_tv = torch.cat(abs_tv)
        k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

        # Get the k largest values; returns values and their indices
        values, indices = torch.topk(abs_tv.view(-1), k)
        threshold = values.min()

        # basepatch = [torch.zeros_like(p, requires_grad=False) for p in self.trainable_parameters]
        basepatch = [torch.zeros_like(p, requires_grad=True, dtype=torch.bfloat16).to("cpu") for p in self.trainable_parameters]

        with torch.no_grad():
            for p, q in zip(consider, basepatch):
                q[torch.absolute(p) > threshold] = 3.
                q[torch.absolute(p) <= threshold] = -3.

        print ('Total parameters in my stitch: ', sum([ torch.sum(torch.round(torch.nn.Sigmoid()(p))*torch.round(torch.nn.Sigmoid()(p))) / (1. * num_params) for p in basepatch ]))
            
        self.basepatch = basepatch
        
        return basepatch


    def interpolate_model(self, round_=True, mask=None, return_mask=False):  
        sigmoid = torch.nn.Sigmoid()

        n_graft_params, n_total_params = 0, 0

        binary_mask = []

        self.model.to("cpu")

        for counter in range(len(self.trainable_name)):
            for pre_n, pre_p in self.pretrained_model.named_parameters():
                if pre_n == self.trainable_name[counter]: 
                    pretensor = pre_p

            for fine_n, fine_p in self.finetuned_model.named_parameters():
                if fine_n == self.trainable_name[counter]: 
                    finetensor = fine_p

            with torch.no_grad():            
                for n, p in self.model.named_parameters():  
                    n_total_params += p.numel()
                    if n == self.trainable_name[counter]: 
                        frac = sigmoid(self.mask[counter])
                        if round_:
                            frac = torch.round(frac)
                        frac = frac.to("cpu")
                        # frac = frac.to(self.device)
                        n_graft_params += torch.sum(frac)
                        p += frac * ( finetensor - pretensor ) 
                        binary_mask.append(frac)
        
        self.model.to(self.device)
        
        if round_:
            print(n_graft_params)
            print ('Proportion in my graft: ', n_graft_params / self.num_params)
        
        if return_mask:
            return binary_mask, n_graft_params / self.num_params



    def train_graft(self, dataloader):
        first_batch = 0
        sigmoid = torch.nn.Sigmoid()
        
        device = self.device
        lr = self.graft_args.learning_rate

        # print("Initial mask evaluation with highest movement")
        # mask = self.interpolate_model(round_=True, return_mask=True)
        # val = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        # self.reset_model()   

        for epoch in tqdm(range(self.graft_args.num_train_epochs), 'Training the mask'):
            print("Epoch: ", epoch)
            first_batch = 0
            self.interpolate_model(round_=False)

            for batch in dataloader:
                # with torch.autocast(device_type=device, dtype=torch.float16):
                loss = self.model(input_ids=batch['input_ids'].to(device), \
                                            attention_mask=batch['attention_mask'].to(device), \
                                            labels=batch["labels"].to(device), \
                                            )[0]
                
                # loss.requires_grad = True
                print(loss)
                loss.backward()

                for n, p in self.model.named_parameters() :
                    if n in self.trainable_name :
                        if p.grad is None: print (n)

                first_batch += 1

                # actual batch size = 1
                grad = [p.grad.detach().clone() for n, p in self.model.named_parameters() if n in self.trainable_name]
                self.model.zero_grad()
                grad = [ g * p.to(device) for (g, p) in zip(grad, self.grad_directions) ]
                total_grad = [lr * g for g in grad]
                del grad
                with torch.no_grad():
                    for p, g in zip(self.mask, total_grad):
                        p_cuda = p.to(device)
                        # print(g * sigmoid(p_cuda) * (1 - sigmoid(p_cuda)))
                        p -= (g * sigmoid(p_cuda) * (1 - sigmoid(p_cuda)) + self.graft_args.l1_strength * sigmoid(p_cuda) * (1 - sigmoid(p_cuda))).cpu()
                del total_grad

                #restrict the number of loops
                if first_batch >= self.graft_args.gradient_accumulation_steps: 
                    break

            # total_grad = [ p / (1. * first_batch) for p in total_grad ]    
            self.reset_model()
            mask, proportion = self.interpolate_model(round_=True, return_mask=True)

            if (epoch+1) % 10 == 0:
                torch.save(mask, f"/data/common/mergekit/masks/mask_{self.run_name}_epoch_{epoch}_{self.graft_args.sparsity}_{self.graft_args.learning_rate}.pt")
       
            # #Take the gradient step
            # with torch.no_grad():
            #     for p, g in zip(self.mask, total_grad):
            #         p = p.to(device)
            #         # print(g * sigmoid(p) * (1 - sigmoid(p)))
            #         p -= g * sigmoid(p) * (1 - sigmoid(p)) + self.graft_args.l1_strength * sigmoid(p) * (1 - sigmoid(p))
         
            ######## Evaluation of current mask ###########
            # self.model.to("cpu")
            # del grad, total_grad
        # test = eval_single_dataset(self.model, dataset_name, self.args)["top1"]
        self.reset_model() 
        # return 0,0,0

        return mask, proportion



class Stitcher(nn.Module):
    def __init__(self, trainable_params, model, pretrained_model, finetuned_models, masks, use_avg_mask=True):
        super(Stitcher, self).__init__()
        self.params = trainable_params
        self.pretrained_model = pretrained_model
        self.finetuned_models = finetuned_models
        self.model = model

        self.masks = masks
        if use_avg_mask:
            self.masks = self.get_average_masks()
        # self.masks = self.get_sum_masks()

        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        # self.device = torch.device("cpu")

        self.pretrained_model.eval()
        self.model.eval()
        for param in self.pretrained_model.parameters():
            param.requires_grad = False   
        for param in self.model.parameters():
            param.requires_grad = False   
        for finetuned_model in self.finetuned_models:
            finetuned_model.eval()
            for param in finetuned_model.parameters():
                param.requires_grad = False
    

    def get_average_masks(self):
            
        def reciprocal_with_zero(tensor):
            mask = tensor == 0
            reciprocal = torch.reciprocal(tensor)
            reciprocal = reciprocal.masked_fill(mask, 0)
            return reciprocal

        output_masks = []
        for i in range(len(self.masks)):
            output_mask = self.masks[i].copy()
            # every other mask
            for j in range(len(self.masks)):
                if i == j: continue
                # every layer
                for k in range(len(self.masks[i])):
                    intersect = torch.logical_and(self.masks[i][k], self.masks[j][k])
                    output_mask[k] = output_mask[k] + intersect
            
            for k in range(len(self.masks[i])):
                output_mask[k] = reciprocal_with_zero(output_mask[k])
            output_masks.append(output_mask)

        return output_masks
    
    def get_sum_masks(self):
        output_masks = []
        for i in range(len(self.masks)):
            output_mask = self.masks[i].copy()
            for j in range(len(self.masks)):
                if i == j: continue
                for k in range(len(self.masks[i])):
                    output_mask[k] = torch.logical_or(output_mask[k], self.masks[j][k])
            output_masks.append(output_mask)

        return output_masks
            

    def interpolate_models(self):
        n_graft_params, n_total_params = 0, 0

        trainable_name = []
        for n in self.params: 
            trainable_name += [n]

        self.pretrained_model.to(self.device)
        self.model.to(self.device)
        for finetuned_model, mask in zip(self.finetuned_models, self.masks):
            finetuned_model.to(self.device)
            for counter in range(len(trainable_name)):
                for pre_n, pre_p in self.pretrained_model.named_parameters():
                    if pre_n == trainable_name[counter]: 
                        pretensor = pre_p.to(self.device)

                for fine_n, fine_p in finetuned_model.named_parameters():
                    if fine_n == trainable_name[counter]: 
                        finetensor = fine_p.to(self.device)

                with torch.no_grad():            
                    for n, p in self.model.named_parameters():  
                        n_total_params += p.numel()
                        if n == trainable_name[counter]: 
                            mask[counter] = mask[counter].to(self.device)
                            p += mask[counter] * ( finetensor - pretensor ) 
        
            print ('Proportion in my graft: ', n_graft_params / n_total_params)
        return self.model