from pruning_methods.pruning_utils import find_layers
import torch 
import torch.nn as nn
from general_utils.config import config
from general_utils import utils
from pruning_methods.QR.qr_wrapper import WrappedQR
from models.model_utils import get_layer0_inputs
from tqdm import tqdm
from pruning_methods.pruning_utils import load_checkpoint, save_checkpoint
from pruning_methods.QR.compressed_linear import CompressedLinear, CompressedQKV
from pruning_methods.QR.qr_utils import _get_submodules
from pruning_methods.QR.preserved_linear import ProbLinear, ProbQKV
import lm_eval
from lm_eval.models.huggingface import HFLM
from measure_utils.measure_utils import DumpJSON
import os
import math
import time
import wandb
from pruning_methods.QR.scheduler import CosineAnnealingWithWarmupScheduler

class Gaussian_mask:
    def __init__(self, sparsity=0.5,layerwise_sparsity_ratios=None):
        self.layerwise_sparsity_ratios = layerwise_sparsity_ratios # shape: (32), a list of sparsity ratios for each layer

        self.layer_len = 32
        self.sparsity = sparsity
        self.mean = nn.Parameter(torch.ones(self.layer_len)*sparsity, requires_grad=False)
        self.mean.grad = torch.zeros_like(self.mean)

    def sample_ratio(self, epsilon=0.05, sigma_2 = 0.05, idx=0):
        torch.manual_seed(idx)
        # cubic 衰减gamma
        if idx < 5:
            gamma = 0.05
        elif idx < 100:
            gamma = 0.05 - (0.05 - 0.005) * ((idx-5)/95)**3
        else:
            gamma = 0.005

        # truncated normal sampling
        mu = self.mean
        sigma = sigma_2**0.5

        normal = torch.distributions.Normal(0, 1)
        lower = normal.cdf(torch.tensor(-(gamma / sigma)))
        upper = normal.cdf(torch.tensor(gamma / sigma))

        u = torch.rand(mu.shape) * (upper - lower) + lower
        z = normal.icdf(u)
        sparse_ratio = mu + sigma * z

        sparse_ratio = torch.clamp(sparse_ratio, epsilon, 1-epsilon)
        
        grad = - (self.mean - sparse_ratio) / sigma_2


        print("mean", self.mean)
        print("sparse_ratio", sparse_ratio)

        return sparse_ratio, grad

    def update_grad(self, loss_list, grad_list, K):
        for loss, grad in zip(loss_list, grad_list):
            self.mean.grad += 1/(K-1) * grad * (loss - torch.mean(torch.stack(loss_list)))

    def constrain(self, sparse_ratio = 0.5):
        sparse_ratio_all = self.mean.numel() * sparse_ratio


        self.mean.mul_(sparse_ratio_all / torch.sum(self.mean))
        

class rankratio_mask:
    def __init__(self, layer_len = 32, matrix_num = 6, rankratio = 0.25):
        # (layer_len, matrix_num) * rankratio
        self.layer_len = layer_len
        self.matrix_num = matrix_num
        rankratio_mean = torch.ones(layer_len, matrix_num)*rankratio #(32,6) * 0.25
        self.rankratio_mean = nn.Parameter(rankratio_mean, requires_grad=False)
        
        self.rankratio_mean.grad = torch.zeros_like(self.rankratio_mean)

    def sample_rank_ratio(self, epsilon=0.05, sigma_2 = 0.02, idx=0):
        torch.manual_seed(idx)
        if idx < 5:
            gamma = 0.05
        elif idx < 100:
            gamma = 0.05 - (0.05-0.005) * ((idx-5)/95)**3
        else:
            gamma = 0.005
        

        # 2. 截断正态采样
        mu = self.rankratio_mean
        sigma = sigma_2 ** 0.5  # 注意 torch.normal 用的是 std 不是方差

        normal = torch.distributions.Normal(0, 1)
        lower = normal.cdf(torch.tensor(-(gamma / sigma)))
        upper = normal.cdf(torch.tensor(gamma / sigma))

        u = torch.rand(mu.shape, device=mu.device) * (upper - lower) + lower
        z = normal.icdf(u)
        rankratio = mu + sigma * z

        rankratio = torch.clamp(rankratio, epsilon, 1.0 - epsilon)

        rankratio_grad = -(self.rankratio_mean - rankratio) / sigma_2
        # print("rankratio_mean", self.rankratio_mean)
        print("rankratio", rankratio)
        return rankratio, rankratio_grad
    
    def update_grad(self, loss_list, rankratio_grad_list, K):
        for loss, rankratio_grad in zip(loss_list, rankratio_grad_list):
            self.rankratio_mean.grad += 1/(K-1) * rankratio_grad * (loss - torch.mean(torch.stack(loss_list)))

    def constrain(self):

        self.rankratio_mean.clamp_(min=0.05, max=0.95)

        


def initialize(dist_type='Gaussian', sparsity = 0.5, layerwise_sparsity_ratios=None, matrix_num = 6):
    layerwise_sparse_mask = Gaussian_mask(sparsity,layerwise_sparsity_ratios)
    global_rankratio_mask = rankratio_mask()
    return layerwise_sparse_mask, global_rankratio_mask

def prune_prob(model_adapter, sparsity, layerwise_sparsity_ratios, global_rank_ratios, calib_loader,prune_hyperparams, checkpoint_path, prune_method, idx):
    model_adapter.model.to('cpu')
    model_adapter.model.eval()
    use_cache = model_adapter.model.config.use_cache

    model_adapter.model.config.use_cache = False
    
    unscaled = prune_hyperparams['unscaled']
    rank_ratio = prune_hyperparams['rank_ratio']
    prune_level = prune_hyperparams['prune_level']
    num_iters = prune_hyperparams['num_iters']
    prune_start_idx = -1
    
    inps, args, kwargs = [],  [], []

    for batch in calib_loader:
        inp_batch, args_batch, kwargs_batch = get_layer0_inputs(model_adapter, batch)
        args.append(args_batch)
        kwargs.append(kwargs_batch)
        inps.append(inp_batch)
    pruned_args = args

    layers = model_adapter.get_layers()
    for layer_idx, layer_adapter in enumerate(tqdm(layers, unit="layer", desc="Pruning using global OATS")):
        # ===== Setup Transformer Block Sparsity for prob =====
        if layerwise_sparsity_ratios is not None:
            assert(len(layerwise_sparsity_ratios) == len(layers))
            dense_alloc = 1 - layerwise_sparsity_ratios[layer_idx]
            print("Pruning Layer: " + str(layer_idx) + " using prob to sparse ratio: " + str(dense_alloc))
        else:
            dense_alloc = 1 - sparsity
        layer_rank_ratio = global_rank_ratios[layer_idx]
        # ========== Setup hooks and wrap layers ==============================
        subset = find_layers(layer_adapter.layer)
        wrapped_layers = {}

        for name in subset:
            wrapped_layers[name] = WrappedQR(subset[name])
            
        
        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp
        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        
        # =========== Precompute quantities =================================
        layer_adapter.layer.to(config.device)
        for batch_idx, (layer_args_batch, layer_kwargs_batch) in enumerate(zip(pruned_args, kwargs)):
            layer_args_batch, layer_kwargs_batch = utils.map_tensors(
                [layer_args_batch, layer_kwargs_batch], device=config.device
            )
            out = layer_adapter.layer(*layer_args_batch, **layer_kwargs_batch)
        for h in handles:
            h.remove()
        
        # =========== Prune the layer =======================================

        if subset[list(subset.keys())[0]].__class__ == nn.Linear:
            first_iterate(layer_adapter, subset, wrapped_layers, prune_hyperparams, dense_alloc, layer_rank_ratio, unscaled, num_iters, compress_type=prune_method)
        else:
            inner_iterate(layer_adapter, subset, wrapped_layers, prune_hyperparams, dense_alloc, layer_rank_ratio, unscaled, num_iters, compress_type=prune_method)

        # ============== Recalculate outputs with pruned weight ====================
        pruned_outs = []

        layer_adapter.layer.to(config.device)
        for batch_idx, (layer_args_batch, layer_kwargs_batch) in enumerate(zip(pruned_args, kwargs)):
            layer_args_batch, layer_kwargs_batch = utils.map_tensors(
                [layer_args_batch, layer_kwargs_batch], device=config.device
            )
            out = layer_adapter.layer(*layer_args_batch, **layer_kwargs_batch)
            if isinstance(out, tuple):
                out = out[layer_adapter.hidden_states_output_position]
            out = out.cpu()
            pruned_outs.append(out)
        
        for batch_idx, pruned_out in enumerate(pruned_outs):
            pruned_args[batch_idx] = layer_adapter.get_updated_args(
                pruned_out.cpu(),
                pruned_args[batch_idx],
            )
        
        layer_adapter.layer.to('cpu')

        model_adapter.model.config.use_cache = False

        # Run GC and cleanup GPU memory
        utils.cleanup_memory()
    
    model_adapter.model.config.use_cache = use_cache
    torch.cuda.empty_cache()

def first_iterate(layer_adapter, subset, wrapped_layers, prune_hyperparams, dense_alloc, layer_rank_ratio, unscaled, num_iters, compress_type="PCA"):
    matrix_idx = 0
    if len(layer_adapter.qkv_names) == 1:
        matrix_idx += 3
        qkv_name = layer_adapter.qkv_names[0]
        
        diag_approx = wrapped_layers[qkv_name].scaler_row.clone().reshape((1,-1)).to(config.device)
        if unscaled:
            diag_approx = torch.ones_like(diag_approx).float().to(config.device)
        
        qkv_weight = subset[qkv_name].weight.data.clone().detach().float()

        q_proj = qkv_weight[                                    : layer_adapter.get_qkv_partition()[0], : ].clone().detach().float()
        k_proj = qkv_weight[layer_adapter.get_qkv_partition()[0]: layer_adapter.get_qkv_partition()[1], : ].clone().detach().float()
        v_proj = qkv_weight[layer_adapter.get_qkv_partition()[1]: ,                                     : ].clone().detach().float()

        d_in = qkv_weight.shape[1]

        q_rank = int(layer_rank_ratio[0] * dense_alloc * q_proj.shape[0] * q_proj.shape[1] / (q_proj.shape[0] + q_proj.shape[1]))
        k_rank = int(layer_rank_ratio[1] * dense_alloc * k_proj.shape[0] * k_proj.shape[1] / (k_proj.shape[0] + k_proj.shape[1]))
        v_rank = int(layer_rank_ratio[2] * dense_alloc * v_proj.shape[0] * v_proj.shape[1] / (v_proj.shape[0] + v_proj.shape[1]))

        prob_module = ProbQKV(qkv_weight.shape[1],     \
                            q_rank, q_proj.shape[0], \
                            k_rank, k_proj.shape[0], \
                            v_rank, v_proj.shape[0], \
                            qkv_weight, \
                            bias = subset[qkv_name].bias is not None, dtype=config.dtype) 

        for qkv_idx, qkv_mat in enumerate([q_proj, k_proj, v_proj]):
            
            d_out, d_in = qkv_mat.shape
            target_rank = int(layer_rank_ratio[qkv_idx]  * dense_alloc * (d_out*d_in)/(d_out + d_in))
            unstruct_sparse = 1.0 - (1.0-layer_rank_ratio[qkv_idx])*dense_alloc
            if compress_type == "PCA":
                lrc_V, lrc_U, sparse_comp = altern_ls(qkv_mat, diag_approx, num_iters, target_rank, unstruct_sparse)
            elif compress_type == "WANDA":
                lrc_V, lrc_U, sparse_comp = wanda(qkv_mat, diag_approx, num_iters, target_rank, 1.0-dense_alloc)
            if qkv_idx == 0:
                prob_module.q_V.data = lrc_V.clone().to(config.dtype)
                prob_module.q_U.data = lrc_U.clone().to(config.dtype)
                prob_module.q_S.data = sparse_comp.clone().to(config.dtype)
            elif qkv_idx == 1:
                prob_module.k_V.data = lrc_V.clone().to(config.dtype)
                prob_module.k_U.data = lrc_U.clone().to(config.dtype)
                prob_module.k_S.data = sparse_comp.clone().to(config.dtype)
            elif qkv_idx == 2:
                prob_module.v_V.data = lrc_V.clone().to(config.dtype)
                prob_module.v_U.data = lrc_U.clone().to(config.dtype)
                prob_module.v_S.data = sparse_comp.clone().to(config.dtype)
            else:
                raise ValueError("qkv_idx should be 0, 1, 2")

        if subset[qkv_name].bias is not None:
            prob_module.bias.data = subset[qkv_name].bias.data.detach().clone()

        parent, target, target_name = _get_submodules(layer_adapter.layer, qkv_name)
        setattr(parent, target_name, prob_module)

        del subset[qkv_name]
    # Pruning the rest of the layers
    
    for name in subset:
        diag_approx = wrapped_layers[name].scaler_row.clone().reshape((1,-1)).to(config.device)
        if unscaled:
            diag_approx = torch.ones_like(diag_approx).float().to(config.device)
        orig_weight = subset[name].weight.data.clone().detach().float().to(config.device)
        d_out, d_in = orig_weight.shape
        target_rank = int(layer_rank_ratio[matrix_idx]  * dense_alloc * (d_out*d_in)/(d_out + d_in))
        unstruct_sparse = 1.0 - (1.0-layer_rank_ratio[matrix_idx])*dense_alloc
        matrix_idx += 1
        if compress_type == "PCA":
            lrc_V, lrc_U, sparse_comp = altern_ls(orig_weight, diag_approx, num_iters, target_rank, unstruct_sparse)
        elif compress_type == "WANDA":
            lrc_V, lrc_U, sparse_comp = wanda(orig_weight, diag_approx, num_iters, target_rank, 1.0-dense_alloc)
        prob_module = ProbLinear(d_in, target_rank, d_out,orig_weight, bias=subset[name].bias is not None, dtype=config.dtype)
        prob_module.V.data = lrc_V.clone().to(config.dtype)
        prob_module.U.data = lrc_U.clone().to(config.dtype)
        prob_module.S.data = sparse_comp.clone().to(config.dtype)
        if subset[name].bias is not None:
            prob_module.bias.data = subset[name].bias.data.detach().clone()
        
        parent, target, target_name = _get_submodules(layer_adapter.layer, name)
        setattr(parent, target_name, prob_module)

def inner_iterate(layer_adapter, subset, wrapped_layers, prune_hyperparams, dense_alloc, layer_rank_ratio, unscaled, num_iters, compress_type="PCA"):
    # 对于不是第一次的prune，subset中的是ProbLinear或者ProbQKV，所以直接做交替优化就行了
    matrix_idx = 0
    if len(layer_adapter.qkv_names) == 1:
        matrix_idx += 3
        qkv_name = layer_adapter.qkv_names[0]
        prob_module = subset[qkv_name]
        
        diag_approx = wrapped_layers[qkv_name].scaler_row.clone().reshape((1,-1)).to(config.device)
        if unscaled:
            diag_approx = torch.ones_like(diag_approx).float().to(config.device)
        
        # q_proj = prob_module.q_W.data.clone().detach().float()
        # k_proj = prob_module.k_W.data.clone().detach().float()
        # v_proj = prob_module.v_W.data.clone().detach().float()
        qkv_weight = prob_module.weight.data.clone().detach().float()
        q_proj = qkv_weight[                                    : layer_adapter.get_qkv_partition()[0], : ].clone().detach().float()
        k_proj = qkv_weight[layer_adapter.get_qkv_partition()[0]: layer_adapter.get_qkv_partition()[1], : ].clone().detach().float()
        v_proj = qkv_weight[layer_adapter.get_qkv_partition()[1]: ,                                     : ].clone().detach().float()

        
        d_in = prob_module.in_features
        
        for qkv_idx, qkv_mat in enumerate([q_proj, k_proj, v_proj]):
            
            d_out, d_in = qkv_mat.shape
            target_rank = int(layer_rank_ratio[qkv_idx]  * dense_alloc * (d_out*d_in)/(d_out + d_in))
            unstruct_sparse = 1.0 - (1.0-layer_rank_ratio[qkv_idx])*dense_alloc
            qkv_mat = qkv_mat.to(config.device)
            if compress_type == "PCA":
                lrc_V, lrc_U, sparse_comp = altern_ls(qkv_mat, diag_approx, num_iters, target_rank, unstruct_sparse)
            elif compress_type == "WANDA":
                lrc_V, lrc_U, sparse_comp = wanda(qkv_mat, diag_approx, num_iters, target_rank, 1.0-dense_alloc)
            if qkv_idx == 0:
                prob_module.adjust_rank('q_proj', target_rank, lrc_V, lrc_U, sparse_comp, config.device, config.dtype)
            elif qkv_idx == 1:
                prob_module.adjust_rank('k_proj', target_rank, lrc_V, lrc_U, sparse_comp, config.device, config.dtype)
            elif qkv_idx == 2:
                prob_module.adjust_rank('v_proj', target_rank, lrc_V, lrc_U, sparse_comp, config.device, config.dtype)
            else:
                raise ValueError("qkv_idx should be 0, 1, 2")

        
    # Pruning the rest of the layers
    
    for name in subset:
        if len(layer_adapter.qkv_names) == 1 and name in layer_adapter.qkv_names:
            continue
        diag_approx = wrapped_layers[name].scaler_row.clone().reshape((1,-1)).to(config.device)
        if unscaled:
            diag_approx = torch.ones_like(diag_approx).float().to(config.device)
        orig_weight = subset[name].weight.data.clone().detach().float().to(config.device)
        d_out, d_in = orig_weight.shape
        target_rank = int(layer_rank_ratio[matrix_idx]  * dense_alloc * (d_out*d_in)/(d_out + d_in))
        unstruct_sparse = 1.0 - (1.0-layer_rank_ratio[matrix_idx])*dense_alloc
        matrix_idx += 1
        if compress_type == "PCA":
            lrc_V, lrc_U, sparse_comp = altern_ls(orig_weight, diag_approx, num_iters, target_rank, unstruct_sparse)
        elif compress_type == "WANDA":
            lrc_V, lrc_U, sparse_comp = wanda(orig_weight, diag_approx, num_iters, target_rank, 1.0-dense_alloc)
        subset[name].adjust_rank(target_rank, lrc_V, lrc_U, sparse_comp, config.device, config.dtype)
        
    # replace matrix for layer_adapter.layer
    for name in subset:
        parent, target, target_name = _get_submodules(layer_adapter.layer, name)
        new_module = subset[name]
        setattr(parent, target_name, new_module)


def altern_ls(weight, diag_approx, num_iters, target_rank, unstruct_sparse, prune_level = "row", prune_n=0, prune_m=0):
    
    if diag_approx.isnan().any():
        print("Outliers have NaN. Exiting!")
        raise ValueError
    
    scaled_weight = weight * torch.sqrt(diag_approx) # d_out x d_in
    sparse_component = torch.zeros_like(scaled_weight).to(config.device)
    d_out, d_in = weight.shape
    initial_rank = target_rank
    V = torch.zeros((initial_rank, d_in)).to(config.device)
    for iter_idx in range(num_iters): 
        B = scaled_weight - sparse_component
        Q, R = torch.qr(B@V.t())
        U = Q
        V = Q.t() @ B
        low_rank_component = U @ V
        sparse_component = scaled_weight - low_rank_component

        # Prune the weight
        W_metric = sparse_component.clone()
        W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
        
        if prune_n != 0:
            print("Applying N:M Sparsity")
            # structured n:m sparsity
            W_metric = torch.abs(W_metric)
            for ii in range(W_metric.shape[1]):
                if ii % prune_m == 0:
                    tmp = W_metric[:,ii:(ii+prune_m)].float()
                    W_mask.scatter_(1,ii+torch.topk(tmp, prune_m - prune_n,dim=1, largest=False)[1], True)
        elif prune_level == "row":
            sort_res = torch.sort(torch.abs(W_metric), dim=-1, stable=True)
            # unstructured pruning
            indices = sort_res[1][:,:int(W_metric.shape[1]* unstruct_sparse)]
            W_mask.scatter_(1, indices, True)
        elif prune_level == "global":
            sort_res = torch.sort(torch.flatten(torch.abs(W_metric)), stable=True)
            indices = sort_res[1][:int(W_metric.numel()* unstruct_sparse)]

            W_mask = torch.flatten(W_mask)
            W_mask[indices] = True
            W_mask = torch.unflatten(W_mask, 0 , (W_metric.shape[0], W_metric.shape[1]))
        elif prune_level == "structured":
            # structured pruning
            # sum the absolute value of the weight along the rows as the metric
            W_metric = torch.abs(W_metric)
            W_metric_structured = torch.sum(W_metric, dim=1)
            sort_res = torch.sort(W_metric_structured, dim=-1, stable=True)
            indices = sort_res[1][:int(W_metric.shape[0]* unstruct_sparse)]
            W_mask.scatter_(0, indices.unsqueeze(1).expand(-1, W_metric.shape[1]), True)
        else:
            assert ValueError
        sparse_component[W_mask] = 0
        # if V.shape[0] < target_rank:
        #     V = update_V(scaled_weight, U, V, sparse_component, delta_r=10)
    
    low_rank_compressed_V = V.detach().clone() * (1/torch.sqrt(diag_approx))
    low_rank_compressed_U = U.detach().clone()
    sparse_comp = sparse_component * (1/torch.sqrt(diag_approx))

    return low_rank_compressed_V, low_rank_compressed_U, sparse_comp


def wanda(weight, diag_approx, num_iters, target_rank, unstruct_sparse, prune_level = "row", prune_n=0, prune_m=0):
    if diag_approx.isnan().any():
        print("Outliers have NaN. Exiting!")
        raise ValueError

    scaled_weight = weight * torch.sqrt(diag_approx) # d_out x d_in
    sparse_component = torch.zeros_like(scaled_weight).to(config.device)
    d_out, d_in = weight.shape
    initial_rank = target_rank
    V = torch.zeros((initial_rank, d_in)).to(config.device)
    U = torch.zeros((d_out, initial_rank)).to(config.device)

    sparse_component = scaled_weight

    # Prune the weight
    W_metric = sparse_component.clone()
    W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
    
    if prune_n != 0:
        print("Applying N:M Sparsity")
        # structured n:m sparsity
        W_metric = torch.abs(W_metric)
        for ii in range(W_metric.shape[1]):
            if ii % prune_m == 0:
                tmp = W_metric[:,ii:(ii+prune_m)].float()
                W_mask.scatter_(1,ii+torch.topk(tmp, prune_m - prune_n,dim=1, largest=False)[1], True)
    elif prune_level == "row":
        sort_res = torch.sort(torch.abs(W_metric), dim=-1, stable=True)
        # unstructured pruning
        indices = sort_res[1][:,:int(W_metric.shape[1]* unstruct_sparse)]
        W_mask.scatter_(1, indices, True)
    elif prune_level == "global":
        sort_res = torch.sort(torch.flatten(torch.abs(W_metric)), stable=True)
        indices = sort_res[1][:int(W_metric.numel()* unstruct_sparse)]

        W_mask = torch.flatten(W_mask)
        W_mask[indices] = True
        W_mask = torch.unflatten(W_mask, 0 , (W_metric.shape[0], W_metric.shape[1]))
    elif prune_level == "structured":
        # structured pruning
        # sum the absolute value of the weight along the rows as the metric
        W_metric = torch.abs(W_metric)
        W_metric_structured = torch.sum(W_metric, dim=1)
        sort_res = torch.sort(W_metric_structured, dim=-1, stable=True)
        indices = sort_res[1][:int(W_metric.shape[0]* unstruct_sparse)]
        W_mask.scatter_(0, indices.unsqueeze(1).expand(-1, W_metric.shape[1]), True)
    else:
        assert ValueError
    sparse_component[W_mask] = 0
    
    low_rank_compressed_V = V.detach().clone() * (1/torch.sqrt(diag_approx))
    low_rank_compressed_U = U.detach().clone()
    sparse_comp = sparse_component * (1/torch.sqrt(diag_approx))

    return low_rank_compressed_V, low_rank_compressed_U, sparse_comp


@torch.no_grad()
def prune_QR_rank(model_adapter, tokenizer, sparsity, layerwise_sparsity_ratios, train_loader,train_loader_prob, prune_hyperparams, checkpoint_path):
    results_path = "/home/Leo/Robust_PCA/oats_code/prune_results"
    results = DumpJSON(read_path=(results_path+'.json'),
                    write_path=(results_path+'.json'))
    
    method = 'Gaussian'
    lr = 2e-3
    prune_method = 'PCA'

    name = f"num_iters{prune_hyperparams['num_iters']}_lr{lr}_sparsity{sparsity}_prune_method{prune_method}"
    run = wandb.init(
        project="Robust_PCA",
        notes="PG_sparse_alloc",
        config=prune_hyperparams,
        name=name,
    )
    model_adapter.model.eval()
    use_cache = model_adapter.model.config.use_cache

    model_adapter.model.config.use_cache = False

        # inps, args, kwargs = [],  [], []


    
    K = 2
    layerwise_sparse_mask, global_rankratio_mask = initialize(dist_type=method, sparsity=sparsity,layerwise_sparsity_ratios=layerwise_sparsity_ratios, matrix_num=6)
    if method == 'Gaussian':
        optimizer = torch.optim.Adamw([
            {'params': layerwise_sparse_mask.mean, 'lr': 2e-3},
            {'params': global_rankratio_mask.rankratio_mean, 'lr': 2e-3},
        ])
        

    
    idx = 0 # 记录当前轮数
    for epoch in range(1):
        # for batch in train_loader_probmask:
        for batch in train_loader_prob:
            # forward pass
            
            loss_list = []
            grad_list = []
            sparse_grad_list = []
            rankratio_grad_list = []
            ppl_list = []
            batch = utils.map_tensors(batch, device=config.device)
            for k in range(K):
                idx += 1
                print(f"Iteration:{idx}")
                if method == 'Gaussian':
                    sparse_ratio, sparse_grad = layerwise_sparse_mask.sample_ratio(idx=idx)

                    rank_ratio, rankratio_grad = global_rankratio_mask.sample_rank_ratio( idx=idx)

                sparse_grad_list.append(sparse_grad)
                rankratio_grad_list.append(rankratio_grad)

                prune_prob(model_adapter, sparsity, sparse_ratio, rank_ratio, train_loader, prune_hyperparams, checkpoint_path, prune_method, idx)
                model_adapter.model.to(config.device)
                if idx % 5 == 0:
                    ppl_eval(model_adapter, tokenizer, ppl_list)
                loss = model_adapter.model(**batch)['loss']
                loss = loss.to('cpu')
                loss_list.append(loss)
            loss_mean = torch.mean(torch.stack(loss_list))
            wandb.log({'loss': loss_mean.item()})
                
            optimizer.zero_grad(set_to_none=False)
            layerwise_sparse_mask.update_grad(loss_list, sparse_grad_list, K)
            global_rankratio_mask.update_grad(loss_list, rankratio_grad_list, K)
            if method == 'Gaussian':
                torch.nn.utils.clip_grad_norm_(layerwise_sparse_mask.mean, max_norm=3.0)
                torch.nn.utils.clip_grad_norm_(global_rankratio_mask.rankratio_mean, max_norm=3.0)
            
            optimizer.step()
            layerwise_sparse_mask.constrain(sparse_ratio=sparsity)
            global_rankratio_mask.constrain()

    model_adapter.model.config.use_cache = use_cache
    torch.cuda.empty_cache()

def ppl_eval(model_adapter, tokenizer, ppl_list):
    eval_batch_size = "auto"
    hflm = HFLM(pretrained=model_adapter.model, tokenizer=tokenizer, batch_size=eval_batch_size) 
    with torch.no_grad():
        ppl_tasks = ["wikitext"]
        ppl_results = lm_eval.simple_evaluate(hflm, tasks=ppl_tasks, num_fewshot=None, batch_size=eval_batch_size)[
                'results'
            ]
        print(ppl_results)
        ppl_test = torch.tensor(ppl_results['wikitext']['word_perplexity,none'])
        wandb.log({'ppl_wikitext': ppl_test.item()})
        ppl_list.append(ppl_test)
        ppl_vals = {task: round(result.get('word_perplexity,none', result['word_perplexity,none']), 4) for task, result in ppl_results.items()}

        
    
    
    
        


