import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pickle

from typing import Optional, List 

class RankAllocator(object):
    """
    The RankAllocator for AdaLoRA Model that will be called every training step. 
    Paper: https://openreview.net/pdf?id=lq62uWRJjiY

    Args:
        model: the model that we apply AdaLoRA to.
        lora_r (`int`): The initial rank for each incremental matrix.
        target_rank (`int`): The target average rank of incremental matrix.
        init_warmup (`int`): The steps of initial fine-tuning warmup.
        final_warmup (`int`): The step of final fine-tuning.
        mask_interval (`int`): The time internval between two budget allocations.
        beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
        beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
        total_step (`int`): The total training steps, correctly configured before training.
        target_total_rank (`Optinal[int]`): The speficified final total rank. 
        tb_writter (`SummaryWriter`): Tensorboard SummaryWriter. 
        tb_writter_loginterval (`int`): The logging interval of SummaryWriter. 
    """
    def __init__(
        self, model, 
        lora_r:int,
        target_rank:int, 
        init_warmup:int, 
        final_warmup:int,
        mask_interval:int,
        beta1:float, 
        beta2:float, 
        total_step:Optional[int]=None, 
        target_total_rank:Optional[int]=None,
        tb_writter=None,
        tb_writter_loginterval:int=500, 
        safe_H_num:int=5,
    ):
        self.ave_target_rank = target_rank 
        self.target_rank = target_total_rank
        self.lora_init_rank = lora_r 
        self.initial_warmup = init_warmup
        self.final_warmup = final_warmup 
        self.mask_interval = mask_interval
        self.save_dir = "/mnt/data1/zkyjss/DyN/checkpoint/ipt/OPT_125_40_10"
        self.beta1 = beta1
        self.beta2 = beta2
        self.total_step = total_step

        self.model = model
        self.safe_H = {}
        self.safe_H_num = safe_H_num
        self.ipt = {} 
        self.exp_avg_ipt = {}
        self.exp_avg_unc = {}
        self.cat_ipt = {}
        self.rank_pattern = {} 
        self.get_lora_param_name()

        self.tb_writter = tb_writter
        self.log_interval = tb_writter_loginterval 

        assert (self.beta1<1 and self.beta1>0)
        assert (self.beta2<1 and self.beta2>0)

    def set_total_step(self, total_step:int): 
        # Set total step number 
        self.total_step = total_step
        assert self.total_step>self.initial_warmup+self.final_warmup

    def get_rank_pattern(self):
        # Return rank pattern 
        return self.rank_pattern

    def get_lora_param_name(self):
        # Prepare the budget scheduler 
        self.name_set = set() 
        self.total_rank = 0
        self.total_linear_rank = 0
        self.total_attention_rank = 0
        self.shape_dict = {}
        self.name_set_attention = set()
        self.name_set_linear = set()
        for n,p in self.model.named_parameters():
            if "In_Qs" in n: 
                name_mat = n.replace("In_Qs", "%s")
                self.name_set.add(name_mat)
                if "self_attn" in n:
                    self.name_set_attention.add(name_mat)
                    self.total_attention_rank += p.size(0)
                else:
                    self.name_set_linear.add(name_mat)
                    self.total_linear_rank += p.size(0)
                self.total_rank += p.size(0) 
                self.shape_dict[n] = p.shape
            if "Out_Qs" in n:
                self.shape_dict[n] = p.shape
        self.name_set = list(sorted(self.name_set)) 
        self.attention_ratio = float(self.total_attention_rank) / float(self.total_rank)
        #print("attention ratio: ", self.attention_ratio)
        if self.target_rank is None:
            self.target_rank = self.ave_target_rank * len(self.name_set) 

    def schedule_threshold(self, step:int):
        # Global budget schedule
        mask_ind = False 
        target_rank = self.target_rank 
        initial_warmup = self.initial_warmup 
        final_warmup = self.final_warmup 
        total_step = self.total_step 
        self.global_step = step
        if step <= initial_warmup: 
            # Initial warmup 
            curr_rank = self.total_rank 
            mask_ind = False 
        elif step > total_step - final_warmup: 
            # Final fine-tuning 
            curr_rank = self.target_rank 
            # Fix the rank pattern by 
            # always masking the same unimportant singluar values 
            mask_ind = True 
        else: 
            # Budget decreasing 
            mul_coeff = 1-(step-initial_warmup)/(total_step-final_warmup-initial_warmup)
            curr_rank = target_rank + (self.total_rank-target_rank)*(mul_coeff**3)
            curr_rank = int(curr_rank)
            mask_ind = True if step % self.mask_interval == 0 else False 
        return curr_rank, mask_ind 

    def count_mask(self, model, curr_step):
        total_mask = 0
        print("Training steps: ", curr_step)
        for n,p in model.named_parameters():
            if "_coeff" in n:
                num_zeros = (p == 0).sum().item()
                total_mask += num_zeros
                print(n, ": (", num_zeros, ",", p.shape[0], ")")
        print("Total masks: ", total_mask)

    def update_ipt(self, model): 
        for n,p in model.named_parameters():
            if "_Qs" in n or "_coeff" in n: 
            #if "_Qs" in n: 
                if n not in self.ipt:
                    self.ipt[n] = torch.zeros_like(p)
                    self.exp_avg_ipt[n] = torch.zeros_like(p) 
                    self.exp_avg_unc[n] = torch.zeros_like(p) 
                with torch.no_grad():
                    # Calculate sensitivity 
                    #norm_factor = p.abs().max()  
                    #normalized_p = p / (norm_factor + 1e-8)  

                    #self.ipt[n] = (normalized_p * p.grad).abs().detach()
                    self.ipt[n] = (p * p.grad).abs().detach()
                    # Update sensitivity 
                    self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + \
                                        (1-self.beta1)*self.ipt[n]
                    # Update uncertainty 
                    self.exp_avg_unc[n] = self.beta2 * self.exp_avg_unc[n] + \
                                        (1-self.beta2)*(self.ipt[n]-self.exp_avg_ipt[n]).abs()

    def calculate_score(self, n, p=None, metric="ipt"):
        if metric == "ipt":
            # Combine the senstivity and uncertainty 
            ipt_score = self.exp_avg_ipt[n] * self.exp_avg_unc[n]
        elif metric == "mag":
            ipt_score = p.abs().detach().clone() 
        else:
            raise ValueError("Unexcptected Metric: %s"%metric)
        return ipt_score 

    def _combine_ipt(self, ipt_E, ipt_AB):
        #print("ipt_AB origin: ", ipt_AB.shape)
        ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
        #print("ipt_AB: ", ipt_AB.shape)
        sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1)
        return sum_ipt
    
    def _combine_ipt_noH(self, ipt_AB):
        ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
        sum_ipt = ipt_AB.view(-1)
        return sum_ipt

    def mask_to_target_rank(self, model, curr_rank, global_step): 
        is_dict = {}
        combine_dict = {} 
        singular_dict = {}
        # Calculate the importance score for each sub matrix 
        for n,p in model.named_parameters(): 
            if "In_Qs" in n: 
                #rdim, hdim_a = p.shape
                ipt_score = self.calculate_score(n, metric="ipt")
                comb_ipt = torch.mean(ipt_score, dim=(1,2), keepdim=True)
                comb_ipt = comb_ipt.squeeze(-1)
                name_mat = n.replace("In_Qs", "%s")
                if name_mat not in combine_dict: 
                    combine_dict[name_mat] = [comb_ipt]
                else:
                    combine_dict[name_mat].append(comb_ipt)
            if "Out_Qs" in n: 
                #hdim_b, rdim = p.shape 
                ipt_score = self.calculate_score(n, metric="ipt")
                #comb_ipt = torch.mean(ipt_score, dim=0, keepdim=False).view(-1, 1)
                comb_ipt = torch.mean(ipt_score, dim=(1,2), keepdim=True)
                comb_ipt = comb_ipt.squeeze(-1)
                name_mat = n.replace("Out_Qs", "%s")
                if name_mat not in combine_dict: 
                    combine_dict[name_mat] = [comb_ipt]
                else:
                    combine_dict[name_mat].append(comb_ipt)
            
            if "shared_coeff" in n:
                ipt_score = self.calculate_score(n, p=p, metric="ipt")                
                name_mat = n.replace("shared_coeff", "%s")
                singular_dict[name_mat] = ipt_score
            

        # Combine the importance scores 
        all_is = []
        linear_score = []
        attention_score = []
        saved_data = {}
        save_path = os.path.join(self.save_dir, f"step_{global_step}.pkl")
        for name_mat in combine_dict: 
            ipt_E = singular_dict[name_mat] 
            ipt_AB = torch.cat(combine_dict[name_mat], dim=1)
            #TODO: No H
            sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
            #sum_ipt = self._combine_ipt_noH(ipt_AB)
            name_E = name_mat%"shared_coeff"
            is_dict[name_E] = sum_ipt.view(-1, 1)
            """if "self_attn" in name_E:
                attention_score.append(sum_ipt.view(-1))
            else:
                linear_score.append(sum_ipt.view(-1))"""
            # TODO: set the min rank
            if name_mat not in self.safe_H:
                H_score = sum_ipt.view(-1)
                max_k = torch.kthvalue(H_score, H_score.shape[0] - self.safe_H_num)[0].item()
                self.safe_H[name_E] = H_score <= max_k
                #print(self.safe_H[name_mat])
            all_is.append(sum_ipt.view(-1))
            if global_step % self.mask_interval == 0:
                saved_data[name_E] = {
                    #"ipt_E": ipt_E.clone().detach(),
                    "ipt_AB": ipt_AB.clone().detach(),
                    "sum_ipt": sum_ipt.clone().detach(),
                }
                with open(save_path, "wb") as f:
                    pickle.dump(saved_data, f)

        # Calculate the masking threshold 
        mask_threshold = torch.kthvalue(torch.cat(all_is), (self.total_rank-curr_rank))[0].item()
        """
        print("attention ratio: ", self.attention_ratio)
        print("total linear rank: ", self.total_linear_rank)
        print("current linear rank: ", curr_linear_rank)
        print("total attention rank: ", self.total_attention_rank)
        print("current attention rank: ", curr_attention_rank)
        """
        with torch.no_grad():
            #print(self.safe_H)
            curr_sum_rank = 0
            sum_param = 0
            for n,p in model.named_parameters():
                if "shared_coeff" in n: 
                    #print(f"Shape of p.data: {p.data.shape}, Shape of is_dict[n]: {is_dict[n].shape}")
                    mask = is_dict[n].squeeze(1)
                    """
                    if "self_attn" in n:
                        p.data.masked_fill_(mask<=mask_attention, 0.0)
                    else:
                        p.data.masked_fill_(mask<=mask_linear, 0.0)
                    """
                    mask_val = mask <= mask_threshold
                    mask_bool = mask_val * self.safe_H[n]
                    p.data.masked_fill_(mask_bool, 0.0)
                    ranknum = (is_dict[n]>mask_threshold).sum().item() 

                    if self.tb_writter is not None and self.global_step%self.log_interval==0:
                        self.tb_writter.add_scalar("Ranknum/%s"%(n,), ranknum, self.global_step) 
                        self.rank_pattern[n] = ranknum 
                        curr_sum_rank += ranknum 
                        sum_param += ranknum*self.shape_dict[n.replace("shared_coeff", "In_Qs")][1]  
                        sum_param += ranknum*self.shape_dict[n.replace("shared_coeff", "Out_Qs")][0]  

            if self.tb_writter is not None and self.global_step%self.log_interval==0:
                self.tb_writter.add_scalar("Budget/total_rank", curr_sum_rank, self.global_step)
                self.tb_writter.add_scalar("Budget/mask_threshold", mask_threshold, self.global_step)
                self.tb_writter.add_scalar("Budget/sum_param", sum_param, self.global_step)

        return mask_threshold


    def update_and_mask(self, model, global_step):
        if global_step<self.total_step-self.final_warmup:
            # Update importance scores element-wise 
            self.update_ipt(model)
            # do not update ipt during final fine-tuning 
        # Budget schedule
        curr_rank, mask_ind = self.schedule_threshold(global_step)
        if mask_ind:
            # Mask to target budget 
            mask_threshold = self.mask_to_target_rank(model, curr_rank, global_step) 
        else:
            mask_threshold = None 
        self._maybe_tb_writter_log(model)
        return curr_rank, mask_threshold

    def _maybe_tb_writter_log(self, model):
        if self.tb_writter is not None and self.global_step%self.log_interval==0:
            with torch.no_grad():
                regu_loss = []
                for n,p in model.named_parameters():
                    if "lora_A" in n or "lora_B" in n:
                        mat = p.data.detach().clone()
                        mat_cov = mat @ mat.T if "lora_A" in n else mat.T @ mat 
                        I = torch.eye(*mat_cov.size(), out=torch.empty_like(mat_cov))
                        I.requires_grad = False
                        orth_regu = torch.norm(mat_cov-I, p="fro")
                        regu_loss.append(orth_regu.item())
                        self.tb_writter.add_scalar(
                            "Orth_regu_loss/%s"%n, orth_regu.item(), self.global_step
                        )
                self.tb_writter.add_scalar(
                    "train/orth_regu_loss", sum(regu_loss)/len(regu_loss), self.global_step
                )


def compute_orth_regu(model, regu_weight=0.1):
    # The function to compute orthongonal regularization for SVDLinear in `model`. 
    regu_loss, num_param = 0., 0
    for n,p in model.named_parameters():
        if "In_Qs" in n or "Out_Qs" in n:
            para_cov = p @ p.T if "lora_A" in n else p.T @ p 
            I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov))
            I.requires_grad = False
            regu_loss += torch.norm(para_cov-I, p="fro")
            num_param += 1
    return regu_weight*regu_loss/num_param