#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
#  adalora+pissa
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .layers import LoRALayer
from typing import Optional, List
import numpy as np
import matplotlib.pyplot as plt

class SVDLinear(nn.Linear, LoRALayer):
    # SVD-based adaptation implemented in a dense layer
    def __init__(
            self,
            in_features: int,
            out_features: int,
            r: int = 0,
            lora_alpha: int = 1,
            lora_dropout: float = 0.,
            fan_in_fan_out: bool = False,
            merge_weights: bool = True,
            **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.last_input = None
        self.fan_in_fan_out = fan_in_fan_out
        self.register_buffer('scaler_A', torch.zeros(self.in_features))
        self.register_buffer('scaler_B', torch.zeros(self.out_features))
        self.register_buffer('scaler_E', torch.zeros(1))

        if r > 0:
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r, in_features))
            )
            self.lora_E = nn.Parameter(
                self.weight.new_zeros(r, 1)
            )
            self.lora_B = nn.Parameter(
                self.weight.new_zeros((out_features, r))
            )
            self.ranknum = nn.Parameter(
                self.weight.new_zeros(1), requires_grad=False
            )
            self.ranknum.data.fill_(float(self.r))
            self.scaling = self.lora_alpha if self.lora_alpha > 0 else float(self.r)
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            self.ranknum.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def update_scaler(self, x: torch.Tensor):
        with torch.no_grad():
            if x.dim() == 3:
                x = x.reshape(-1, x.shape[-1])  # (batch * seq_len, hidden_dim)
            val_A = x.pow(2).mean(dim=0)  # shape: (hidden_dim,)
            self.scaler_A = 0.9 * self.scaler_A + 0.1 * val_A

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A,B the same way as the default for nn.Linear
            # and E (singular values) for zero
            nn.init.zeros_(self.lora_E)
            nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
            nn.init.normal_(self.lora_B, mean=0.0, std=0.02)

    def train(self, mode: bool = True):
        def T(w):
            return w.T if self.fan_in_fan_out else w

        nn.Linear.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            if self.r > 0:
                self.weight.data -= T(
                    self.lora_B @ (self.lora_A * self.lora_E)
                ) * self.scaling / (self.ranknum + 1e-5)
            self.merged = False

    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w

        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            if self.r > 0:
                self.weight.data += T(
                    self.lora_B @ (self.lora_A * self.lora_E)
                ) * self.scaling / (self.ranknum + 1e-5)
            self.merged = True

    def forward(self, x: torch.Tensor):
        self.last_input = x.detach()
        self.update_scaler(x)
        def T(w):
            return w.T if self.fan_in_fan_out else w

        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                if (
                        self.lora_A.shape[0] == 0
                        or self.lora_B.shape[1] == 0
                        or self.lora_E.shape[0] == 0
                ):
                    return result
                result += (
                                  self.lora_dropout(x) @ (self.lora_A * self.lora_E).T @ self.lora_B.T
                          ) * self.scaling / (self.ranknum + 1e-5)
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)


class RankAllocator(object):
    """
    The RankAllocator for AdaLoRA Model that will be called every training step.
    Paper: https://openreview.net/pdf?id=lq62uWRJjiY

    Args:
        model: the model that we apply AdaLoRA to.
        lora_r (`int`): The initial rank for each incremental matrix.
        target_rank (`int`): The target average rank of incremental matrix.
        init_warmup (`int`): The steps of initial fine-tuning warmup.
        final_warmup (`int`): The step of final fine-tuning.
        mask_interval (`int`): The time internval between two budget allocations.
        beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
        beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
        total_step (`int`): The total training steps, correctly configured before training.
        target_total_rank (`Optinal[int]`): The speficified final total rank.
        tb_writter (`SummaryWriter`): Tensorboard SummaryWriter.
        tb_writter_loginterval (`int`): The logging interval of SummaryWriter.
    """

    def __init__(
            self, model,
            lora_r: int,
            target_rank: int,
            init_warmup: int,
            final_warmup: int,
            mask_interval: int,
            beta1: float,
            beta2: float,
            total_step: Optional[int] = None,
            target_total_rank: Optional[int] = None,
            tb_writter=None,
            tb_writter_loginterval: int = 500,
            grad_log_interval: int = 100,
            vis_interval: int = 1000,
            ema_beta: float = 0.9
    ):
        self.layer_importance = {}
        self.ave_target_rank = target_rank
        self.target_rank = target_total_rank
        self.lora_init_rank = lora_r
        self.initial_warmup = init_warmup
        self.final_warmup = final_warmup
        self.mask_interval = mask_interval
        self.beta1 = beta1
        self.beta2 = beta2
        self.total_step = total_step
        self.ema_beta = ema_beta
        self.ema_importance = {}
        self.importance_buffer = {}

        self.budget_pool = 0
        self.layer_ranks = {}
        self.model = model
        self.ipt = {}
        self.exp_avg_ipt = {}
        self.exp_avg_unc = {}
        self.cat_ipt = {}
        self.rank_pattern = {}
        self.get_lora_param_name()

        self.tb_writter = tb_writter
        self.log_interval = tb_writter_loginterval


        self.last_prune_thresholds = {}

        assert (self.beta1 < 1 and self.beta1 > 0)
        assert (self.beta2 < 1 and self.beta2 > 0)

        self.grad_log_interval = grad_log_interval
        self.vis_interval = vis_interval
        self.grad_stats = {
            'grad_E': [], 'grad_A': [], 'grad_B': [],
            'mean_E': [], 'var_E': [],
            'mean_A': [], 'var_A': [],
            'mean_B': [], 'var_B': [],
            'corr_E_acc': []
        }
        self.alpha_history = []
        self.beta_history = []
        self.gamma_history = []
        self.val_accuracies = []

    def _record_grad_stats(self, grad_E, grad_A, grad_B):
        self.grad_stats['grad_E'].append(grad_E.detach().cpu().item())
        self.grad_stats['grad_A'].append(grad_A.detach().cpu().item())
        self.grad_stats['grad_B'].append(grad_B.detach().cpu().item())

    def _update_statistics(self, window_size=100):

        def get_stats(data):
            t = torch.tensor(data[-window_size:])
            return t.mean().item(), t.var().item()

        if len(self.grad_stats['grad_E']) >= window_size:
            mean, var = get_stats(self.grad_stats['grad_E'])
            self.grad_stats['mean_E'].append(mean)
            self.grad_stats['var_E'].append(var)

        if len(self.grad_stats['grad_A']) >= window_size:
            mean, var = get_stats(self.grad_stats['grad_A'])
            self.grad_stats['mean_A'].append(mean)
            self.grad_stats['var_A'].append(var)

        if len(self.grad_stats['grad_B']) >= window_size:
            mean, var = get_stats(self.grad_stats['grad_B'])
            self.grad_stats['mean_B'].append(mean)
            self.grad_stats['var_B'].append(var)

    def _log_correlations(self):
        if len(self.val_accuracies) < 10:
            return

        min_len = min(len(self.grad_stats['grad_E']), len(self.val_accuracies))
        grad_E = torch.tensor(self.grad_stats['grad_E'][-min_len:])
        acc = torch.tensor(self.val_accuracies[-min_len:])

        window_size = 100
        corr_values = []
        for i in range(min_len - window_size + 1):
            grad_window = grad_E[i:i + window_size]
            acc_window = acc[i:i + window_size]
            corr = torch.corrcoef(torch.stack([grad_window, acc_window]))[0, 1].item()
            corr_values.append(corr)

        self.grad_stats['corr_E_acc'].extend(corr_values)

        if self.tb_writter and len(corr_values) > 0:
            fig, ax = plt.subplots(figsize=(10, 4))
            ax.plot(range(self.global_step - len(corr_values) + 1, self.global_step + 1), corr_values,
                    label="Correlation (E vs Acc)")
            ax.axhline(0, color='gray', linestyle='--', linewidth=0.8)
            ax.set_xlabel("Training Step")
            ax.set_ylabel("Correlation")
            ax.set_title("Gradient-Validation Accuracy Correlation Over Time")
            ax.legend()
            ax.grid(True)

            self.tb_writter.add_figure("Correlation/E_Accuracy_Trend", fig, self.global_step)
            plt.close(fig)

    def _visualize(self):
        if not self.tb_writter:
            return

        fig, axs = plt.subplots(3, figsize=(12, 8))
        axs[0].plot(self.grad_stats['grad_E'], label='Core Matrix (E)')
        axs[1].plot(self.grad_stats['grad_A'], label='Factor A')
        axs[2].plot(self.grad_stats['grad_B'], label='Factor B')
        [ax.legend() for ax in axs]
        self.tb_writter.add_figure("Gradient_Trends", fig, self.global_step)

        fig, ax = plt.subplots(figsize=(10, 6))
        ax.plot(self.grad_stats['mean_E'], label='E Mean')
        ax.plot(self.grad_stats['var_E'], label='E Variance', linestyle='--')
        ax.set_title("Core Matrix Gradient Statistics")
        ax.legend()
        self.tb_writter.add_figure("Gradient_Stats/E", fig, self.global_step)

        if len(self.alpha_history) >= 100:
            fig, ax = plt.subplots(figsize=(10, 4))
            data = np.array([
                self.alpha_history[-100:],
                self.beta_history[-100:],
                self.gamma_history[-100:]
            ])
            im = ax.imshow(data, cmap='viridis', aspect='auto')

            ax.set_xticks(np.arange(0, 100, 10))
            ax.set_yticks([0, 1, 2])
            ax.set_yticklabels(['Alpha', 'Beta', 'Gamma'])
            plt.colorbar(im)

            self.tb_writter.add_figure(
                "Weight_Distribution_Heatmap",
                fig,
                self.global_step
            )
            plt.close(fig)

    def set_total_step(self, total_step: int):
        # Set total step number
        self.total_step = total_step
        assert self.total_step > self.initial_warmup + self.final_warmup

    def get_rank_pattern(self):
        # Return rank pattern
        return self.rank_pattern

    def get_lora_param_name(self):
        # Prepare the budget scheduler
        self.name_set = set()
        self.total_rank = 0
        self.shape_dict = {}

        self.layer_ranks = {}

        for n, p in self.model.named_parameters():
            if "lora_A" in n:
                name_mat = n.replace("lora_A", "%s")
                self.name_set.add(name_mat)

                param_name = name_mat % "lora_E"
                self.layer_ranks[param_name] = self.lora_init_rank

                self.total_rank += p.size(0)
                self.shape_dict[n] = p.shape
            if "lora_B" in n:
                self.shape_dict[n] = p.shape

        self.name_set = list(sorted(self.name_set))
        if self.target_rank is None:
            self.target_rank = self.ave_target_rank * len(self.name_set)

    def _compute_layer_importance(self, model: str, step):
        is_dict = {}
        combine_dict = {}
        singular_dict = {}
        w_metric_dict = {}
        total_step = self.total_step

        with torch.no_grad():
            for module_name, module in model.named_modules():
                if isinstance(module, SVDLinear):
                    if hasattr(module, 'last_input'):
                        x_abs = module.last_input.abs().mean(dim=0)  # (in_features,)

                        sqrt_scaler_A = torch.sqrt(module.scaler_A).reshape(1, -1)
                        W_metric_A = (module.lora_A.abs() * sqrt_scaler_A).sum(dim=1)
                        b_col_norms = torch.norm(module.lora_B, p=2, dim=0)  # (r,)
                        combined_metric = module.lora_E.abs().view(-1) * W_metric_A * b_col_norms
                        param_name = f"{module_name}.lora_E"
                        w_metric_dict[param_name] = combined_metric

        # Calculate the importance score for each sub matrix
        for n, p in model.named_parameters():
            if "lora_A" in n:
                rdim, hdim_a = p.shape
                ipt_score = self.calculate_score(n, metric="ipt")
                comb_ipt = torch.mean(ipt_score, dim=1, keepdim=True)
                name_mat = n.replace("lora_A", "%s")
                if name_mat not in combine_dict:
                    combine_dict[name_mat] = [comb_ipt]
                else:
                    combine_dict[name_mat].append(comb_ipt)
            if "lora_B" in n:
                hdim_b, rdim = p.shape
                ipt_score = self.calculate_score(n, metric="ipt")
                comb_ipt = torch.mean(ipt_score, dim=0, keepdim=False).view(-1, 1)
                name_mat = n.replace("lora_B", "%s")
                if name_mat not in combine_dict:
                    combine_dict[name_mat] = [comb_ipt]
                else:
                    combine_dict[name_mat].append(comb_ipt)
            if "lora_E" in n:
                ipt_score = self.calculate_score(n, p=p, metric="ipt")
                name_mat = n.replace("lora_E", "%s")
                singular_dict[name_mat] = ipt_score

        name_paramter = {}
        for n, p in model.named_parameters():
            if "lora_" in n:
                name_paramter[n] = p
        # Combine the importance scores
        all_is = []
        for name_mat in combine_dict:
            ipt_E = singular_dict[name_mat]
            ipt_AB = torch.cat(combine_dict[name_mat], dim=1)
            # sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
            name_E = name_mat % "lora_E"
            name_A = name_mat % "lora_A"
            name_B = name_mat % "lora_B"
            p_A_grad = name_paramter[name_A].grad
            p_B_grad = name_paramter[name_B].grad
            p_E_grad = name_paramter[name_E].grad
            if p_A_grad is None or p_B_grad is None or p_E_grad is None:
                continue

            grad_E = torch.norm(p_E_grad, p=2) if p_E_grad is not None else 0.0
            grad_A = torch.norm(p_A_grad, p=2) if p_A_grad is not None else 0.0
            grad_B = torch.norm(p_B_grad, p=2) if p_B_grad is not None else 0.0
            grad_E = grad_E / math.sqrt(p_E_grad.numel() + 1e-8) if p_E_grad is not None else 0.0
            grad_A = grad_A / math.sqrt(p_A_grad.numel() + 1e-8) if p_A_grad is not None else 0.0
            grad_B = grad_B / math.sqrt(p_B_grad.numel() + 1e-8) if p_B_grad is not None else 0.0

            if self.tb_writter and self.global_step % self.grad_log_interval == 0:
                self.tb_writter.add_scalar(f"Gradients/E/{name_E}", grad_E, self.global_step)
                self.tb_writter.add_scalar(f"Gradients/A/{name_A}", grad_A, self.global_step)
                self.tb_writter.add_scalar(f"Gradients/B/{name_B}", grad_B, self.global_step)

            self._record_grad_stats(grad_E, grad_A, grad_B)

            self._update_statistics(window_size=100)

            if self.global_step % self.vis_interval == 0:
                self._visualize()
                self._log_correlations()

            total_grad = grad_E + grad_A + grad_B
            if total_grad == 0:
                alpha, beta, gamma = 0.0, 0.0, 0.0
            else:
                alpha = grad_E / total_grad
                beta = grad_A / total_grad
                gamma = grad_B / total_grad
            sum_ipt = self._combine_ipt(ipt_E, ipt_AB, alpha, beta, gamma)

            if name_E in w_metric_dict:
                progress = step / total_step
                grad_weight = 1.0 - 0.3 * progress
                wanda_weight = 0.3 * progress

                grad_mean, grad_std = sum_ipt.mean(), sum_ipt.std()
                wanda_metric = w_metric_dict[name_E].to(sum_ipt.device)
                wanda_mean, wanda_std = wanda_metric.mean(), wanda_metric.std()

                norm_grad = (sum_ipt - grad_mean) / (grad_std + 1e-8)
                norm_wanda = (wanda_metric - wanda_mean) / (wanda_std + 1e-8)

                sum_ipt = grad_weight * norm_grad + wanda_weight * norm_wanda
                sum_ipt = sum_ipt * (grad_std + 1e-8) + grad_mean
            for element_idx, importance_value in enumerate(sum_ipt):
                all_is.append((
                    importance_value.item(),
                    name_E,
                    element_idx
                ))
            scalar_mean = torch.mean(sum_ipt)
            self.layer_importance[name_E] = scalar_mean.item()
            is_dict[name_E] = sum_ipt.view(-1, 1)

        return self.layer_importance, all_is

    def _prune_ranks(self, param_name: str, prune_num: int):
        a_name = param_name.replace("lora_E", "lora_A")
        b_name = param_name.replace("lora_E", "lora_B")
        e_name = param_name
        weight_name = param_name.replace("lora_E", "weight")

        param_A = self.model.get_parameter(a_name)  # (r, in)
        param_B = self.model.get_parameter(b_name)  # (out, r)
        param_E = self.model.get_parameter(e_name)  # (r, 1)
        orig_weight = self.model.get_parameter(weight_name)
        current_rank = param_A.shape[0]
        if current_rank == 0:
            return (param_A, param_A), (param_B, param_B), (param_E, param_E)

        ipt_E = self.calculate_score(e_name)
        ipt_A = self.calculate_score(a_name)
        ipt_B = self.calculate_score(b_name)

        grad_E = torch.norm(param_E.grad, p=2) / math.sqrt(param_E.grad.numel() + 1e-8)
        grad_A = torch.norm(param_A.grad, p=2) / math.sqrt(param_A.grad.numel() + 1e-8)
        grad_B = torch.norm(param_B.grad, p=2) / math.sqrt(param_B.grad.numel() + 1e-8)

        total_grad = grad_E + grad_A + grad_B
        alpha = grad_E / total_grad
        beta = grad_A / total_grad
        gamma = grad_B / total_grad

        assert ipt_A.mean(dim=1, keepdim=True).shape[0] == current_rank, f"ipt_A The length after aggregation is incorrect."
        assert ipt_B.mean(dim=0, keepdim=True).T.shape[0] == current_rank, f"ipt_B The length after aggregation is incorrect."
        combined_ipt = (
                alpha * ipt_E +
                beta * ipt_A.mean(dim=1, keepdim=True) +
                gamma * ipt_B.mean(dim=0, keepdim=True).T
        ).view(-1)
        assert combined_ipt.shape[0] == current_rank, \
            f"combined_ipt length {combined_ipt.shape[0]} compared with the current rank {current_rank} mismatch"

        _, prune_indices = torch.topk(-combined_ipt, prune_num)
        keep_mask = torch.ones_like(combined_ipt, dtype=torch.bool)
        keep_mask[prune_indices] = False

        threshold = combined_ipt[prune_indices[-1]].item()
        self.last_prune_thresholds[param_name] = threshold

        with torch.no_grad():
            pruned_A = param_A.data[prune_indices, :]
            pruned_B = param_B.data[:, prune_indices]
            pruned_E = param_E.data[prune_indices, :]

            lora_layer = self.model.get_submodule(param_name.rsplit(".", 1)[0])

        def update_ema(param_name: str, keep_mask: torch.Tensor, dim: int):
            if param_name in self.exp_avg_ipt:
                if dim == 0:
                    self.exp_avg_ipt[param_name] = self.exp_avg_ipt[param_name][keep_mask]
                    self.exp_avg_unc[param_name] = self.exp_avg_unc[param_name][keep_mask]
                elif dim == 1:
                    self.exp_avg_ipt[param_name] = self.exp_avg_ipt[param_name][:, keep_mask]
                    self.exp_avg_unc[param_name] = self.exp_avg_unc[param_name][:, keep_mask]

            if param_name in self.ipt:
                if dim == 0:
                    self.ipt[param_name] = self.ipt[param_name][keep_mask]
                elif dim == 1:
                    self.ipt[param_name] = self.ipt[param_name][:, keep_mask]

        update_ema(a_name, keep_mask, dim=0)
        update_ema(b_name, keep_mask, dim=1)
        update_ema(e_name, keep_mask, dim=0)
        new_A = nn.Parameter(param_A.data[keep_mask, :])
        new_B = nn.Parameter(param_B.data[:, keep_mask])
        new_E = nn.Parameter(param_E.data[keep_mask, :])

        if param_A.grad is not None:
            new_A.grad = param_A.grad[keep_mask, :].detach().clone()
        if param_B.grad is not None:
            new_B.grad = param_B.grad[:, keep_mask].detach().clone()
        if param_E.grad is not None:
            new_E.grad = param_E.grad[keep_mask, :].detach().clone()

        module = self.model.get_submodule(param_name.rsplit(".", 1)[0])
        module.lora_A = new_A
        module.lora_B = new_B
        module.lora_E = new_E

        lora_layer.ranknum.data.fill_(new_A.size(0))
        return (param_A, new_A), (param_B, new_B), (param_E, new_E)


    def _expand_ranks(self, param_name: str, expand_num: int):
        a_name = param_name.replace("lora_E", "lora_A")
        b_name = param_name.replace("lora_E", "lora_B")
        e_name = param_name

        param_A = self.model.get_parameter(a_name)  # (r, in)
        param_B = self.model.get_parameter(b_name)  # (out, r)
        param_E = self.model.get_parameter(e_name)  # (r, 1)

        init_scale = 0.02
        new_A = torch.randn((expand_num, param_A.shape[1]),
                            device=param_A.device) * init_scale
        new_B = torch.randn((param_B.shape[0], expand_num),
                            device=param_B.device) * init_scale
        new_E = torch.randn((expand_num, 1),
                            device=param_E.device) * init_scale


        if param_A.shape[0] > 0:
            last_A = param_A[0].detach().clone().expand(expand_num, -1) * 0.1
            last_B = param_B[:, 0].detach().clone().unsqueeze(1).expand(-1, expand_num) * 0.1
            last_E = param_E[0].detach().clone().expand(expand_num, -1) * 0.1
        else:
            last_A = new_A
            last_B = new_B
            last_E = new_E

        def expand_ema(param_name: str, new_size: int, dim: int):
            if param_name in self.exp_avg_ipt:
                old_ema = self.exp_avg_ipt[param_name]
                old_unc = self.exp_avg_unc[param_name]
                expand_num = new_size - old_ema.size(dim)
                if expand_num <= 0:
                    return

                if dim == 0:
                    new_part = torch.ones((expand_num, old_ema.size(1)), device=old_ema.device) * old_ema[0].mean()
                    new_ema = torch.cat([old_ema, new_part], dim=0)
                    new_unc = torch.cat([old_unc, new_part.clone()], dim=0)

                elif dim == 1:
                    new_part = torch.ones((old_ema.size(0), expand_num), device=old_ema.device) * old_ema[:, 0].mean()
                    new_ema = torch.cat([old_ema, new_part], dim=1)
                    new_unc = torch.cat([old_unc, new_part.clone()], dim=1)

                self.exp_avg_ipt[param_name] = new_ema
                self.exp_avg_unc[param_name] = new_unc

        expand_ema(a_name, new_size=param_A.shape[0] + expand_num, dim=0)
        expand_ema(b_name, new_size=param_B.shape[1] + expand_num, dim=1)
        expand_ema(e_name, new_size=param_E.shape[0] + expand_num, dim=0)
        expanded_A = nn.Parameter(torch.cat([param_A.data, last_A], dim=0))
        expanded_B = nn.Parameter(torch.cat([param_B.data, last_B], dim=1))
        expanded_E = nn.Parameter(torch.cat([param_E.data, last_E], dim=0))

        if param_A.grad is not None:
            expanded_A.grad = torch.cat([
                param_A.grad,
                torch.zeros_like(last_A, device=param_A.device)
            ], dim=0)
        if param_B.grad is not None:
            expanded_B.grad = torch.cat([
                param_B.grad,
                torch.zeros_like(last_B, device=param_B.device)
            ], dim=1)
        if param_E.grad is not None:
            expanded_E.grad = torch.cat([
                param_E.grad,
                torch.zeros_like(last_E, device=param_E.device)
            ], dim=0)

        module = self.model.get_submodule(param_name.rsplit(".", 1)[0])
        module.lora_A = expanded_A
        module.lora_B = expanded_B
        module.lora_E = expanded_E

        return (param_A, expanded_A), (param_B, expanded_B), (param_E, expanded_E)

    def adjust_ranks(self, step):
        self.layer_importance = {}
        self.layer_importance, all_is = self._compute_layer_importance(self.model, step)
        all_triples = []
        updated_params = []
        for n in self.layer_ranks:
            if self.layer_ranks[n] > 1:
                importance = self.layer_importance.get(n, 0)
                all_triples.append((n, importance))

        if len(all_triples) < 4:
            return []
        prune_candidates = []
        temp_ranks = self.layer_ranks.copy()
        for item in sorted(all_is, key=lambda x: x[0]):
            param_name = item[1]
            current_rank = temp_ranks.get(param_name, 0)
            if current_rank > 1:
                prune_candidates.append(item)
                temp_ranks[param_name] = current_rank - 1
            if len(prune_candidates) >= 12:
                break

        for (_, param_name, _) in prune_candidates:
            if self.layer_ranks[param_name] >= 1:
                a_pair, b_pair, e_pair = self._prune_ranks(param_name, prune_num=1)
                updated_params.extend([a_pair, b_pair, e_pair])
                self.layer_ranks[param_name] -= 1

        expand_candidates = sorted(all_triples, key=lambda x: -x[1])[:12]

        for i, (param_name, _) in enumerate(expand_candidates):
            if i < 12:
                a_pair, b_pair, e_pair = self._expand_ranks(param_name, expand_num=1)
                updated_params.extend([a_pair, b_pair, e_pair])
                self.layer_ranks[param_name] += 1

        total_rank = sum(self.layer_ranks.values())
        assert total_rank == self.target_rank, f"Rank mismatch {total_rank} vs {self.target_rank}"

        for n, p in self.model.named_parameters():
            if "lora_" in n:
                if p.grad is not None:
                    p.grad = None

        if self.tb_writter:
            for n in self.layer_ranks:
                self.tb_writter.add_scalar(f"Ranknum/{n}", self.layer_ranks[n], self.global_step)
                self.rank_pattern[n] = self.layer_ranks[n]

        return updated_params

    def schedule_threshold(self, step: int):
        mask_ind = False
        initial_warmup = self.initial_warmup
        final_warmup = self.final_warmup
        total_step = self.total_step
        self.global_step = step
        if step <= initial_warmup:
            mask_ind = False
        elif step > total_step - final_warmup:
            mask_ind = False
        else:
            mask_ind = (step % self.mask_interval == 0)
        return mask_ind

    def update_ipt(self, model):
        for n, p in model.named_parameters():
            if "lora_" in n:
                if n in self.ipt:
                    if self.ipt[n].shape != p.shape:
                        del self.ipt[n]
                        del self.exp_avg_ipt[n]
                        del self.exp_avg_unc[n]
                if n not in self.ipt:
                    self.ipt[n] = torch.zeros_like(p)
                    self.exp_avg_ipt[n] = torch.zeros_like(p)
                    self.exp_avg_unc[n] = torch.zeros_like(p)
                if p.grad is None:
                    continue
                with torch.no_grad():
                    # Calculate sensitivity
                    self.ipt[n] = (p * p.grad).abs().detach()
                    # Update sensitivity
                    self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + \
                                          (1 - self.beta1) * self.ipt[n]
                    # Update uncertainty
                    self.exp_avg_unc[n] = self.beta2 * self.exp_avg_unc[n] + \
                                          (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()

    def calculate_score(self, n, p=None, metric="ipt"):
        if metric == "ipt":
            # Combine the senstivity and uncertainty
            ipt_score = self.exp_avg_ipt[n] * self.exp_avg_unc[n]
        elif metric == "mag":
            ipt_score = p.abs().detach().clone()
        else:
            raise ValueError("Unexcptected Metric: %s" % metric)
        return ipt_score

    def _combine_ipt(self, ipt_E, ipt_AB, alpha, beta, gamma):
        weight_E = alpha
        weight_AB = 1 - weight_E
        ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
        sum_ipt = weight_E * ipt_E.view(-1) + weight_AB * ipt_AB.view(-1)
        return sum_ipt


    def update_validation_metrics(self, accuracy):
        self.val_accuracies.append(accuracy)

    def update_and_mask(self, model, global_step):
        if global_step < self.total_step - self.final_warmup:
            # Update importance scores element-wise
            self.update_ipt(model)
        mask_ind = self.schedule_threshold(global_step)
        curr_rank = 0
        mask_threshold = None
        updated_params = None

        if mask_ind:
            # Mask to target budget
            updated_params = self.adjust_ranks(global_step)
            curr_rank = sum(self.layer_ranks.values())
            if hasattr(self, 'last_prune_thresholds') and len(self.last_prune_thresholds) > 0:
                mask_threshold = np.mean(list(self.last_prune_thresholds.values()))
        self._maybe_tb_writter_log(model)
        return updated_params, mask_ind

    def _maybe_tb_writter_log(self, model):
        if self.tb_writter is not None and self.global_step % self.log_interval == 0:
            with torch.no_grad():
                regu_loss = []
                for n, p in model.named_parameters():
                    if "lora_A" in n or "lora_B" in n:
                        mat = p.data.detach().clone()
                        mat_cov = mat @ mat.T if "lora_A" in n else mat.T @ mat
                        I = torch.eye(*mat_cov.size(), out=torch.empty_like(mat_cov))
                        I.requires_grad = False
                        orth_regu = torch.norm(mat_cov - I, p="fro")
                        regu_loss.append(orth_regu.item())
                        self.tb_writter.add_scalar(
                            "Orth_regu_loss/%s" % n, orth_regu.item(), self.global_step
                        )
                self.tb_writter.add_scalar(
                    "train/orth_regu_loss", sum(regu_loss) / len(regu_loss), self.global_step
                )


def compute_orth_regu(model, regu_weight=0.1):
    # The function to compute orthongonal regularization for SVDLinear in `model`.
    regu_loss, num_param = 0., 0
    for n, p in model.named_parameters():
        if "lora_A" in n or "lora_B" in n:
            para_cov = p @ p.T if "lora_A" in n else p.T @ p
            I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov))
            I.requires_grad = False
            regu_loss += torch.norm(para_cov - I, p="fro")
            num_param += 1
    return regu_weight * regu_loss / num_param

