import time
import heapq
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import re
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

from .eval import eval_ppl
from .data import get_loaders

class LinearLSPUV(torch.nn.Module):
    """
    Custom linear layer containing U, V matrices and sparse matrix S.
    """
    def __init__(self, in_features, out_features, rank, bias=True):
        super(LinearLSPUV, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank

        self.U = torch.nn.Parameter(torch.empty((out_features, rank)))
        self.V = torch.nn.Parameter(torch.empty((rank, in_features)))
        self.S = torch.nn.Parameter(torch.empty((out_features, in_features)), requires_grad=False)
        if bias:
            self.bias = torch.nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        L = torch.matmul(self.U, self.V)
        weight = L + self.S
        return F.linear(input, weight, self.bias)

def prune_rpca(args, model, tokenizer, device):
    # Do not move the model to device here

    # Layers to process
    layers_to_process = args.linear_layers.split(",")

    # Create directories to save visualizations and decomposition info
    visualization_dir = os.path.join(args.save, 'visualization')
    os.makedirs(visualization_dir, exist_ok=True)

    decomposition_info_dir = os.path.join(args.save, 'decomposition_info')
    os.makedirs(decomposition_info_dir, exist_ok=True)

    rpca_results_dir = os.path.join(args.save, 'rpca_results')
    os.makedirs(rpca_results_dir, exist_ok=True)

    # Initialize variable to store evaluation metrics
    eval_metrics = {}

    # Collect all layer indices
    layer_indices = set()
    pattern = re.compile(r'\.(\d+)\.')
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            matches = pattern.findall(name)
            if matches:
                layer_indices.add(int(matches[0]))
    layer_indices = sorted(layer_indices)
    total_layers = len(layer_indices)

    for idx, layer_idx in enumerate(layer_indices):
        layer_name = f'.{layer_idx}.'
        print(f"Processing layer {layer_idx} ({idx + 1}/{total_layers})")
        layer_modules = []
        # Collect modules for this layer
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) and any(key in name for key in layers_to_process):
                if f'.{layer_idx}.' in name:
                    layer_modules.append((name, module))

        singular_values_dict = {}        # For plotting singular values after RPCA
        original_singular_values = {}    # For plotting original singular values before RPCA
        decomposition_info = {}          # To store decomposition error, rank, sparsity, truncated error

        layer_sparsity_list = []
        layer_rank_list = []
        layer_truncated_error_list = []
        layer_decomposition_error_list = []

        for name, module in tqdm(layer_modules, desc=f"Processing modules in layer {layer_idx}"):
            weight = module.weight.data.clone()
            bias = module.bias
            weight_device = weight.device
            weight_dtype = weight.dtype

            # Move weight to device and convert to float32
            weight = weight.to(device).to(torch.float32)

            # Compute original singular values before RPCA
            with torch.no_grad():
                U_orig, sigma_orig, Vh_orig = torch.linalg.svd(weight, full_matrices=False)
                original_singular_values[name] = sigma_orig.cpu().numpy()

            rpca_result_path = os.path.join(rpca_results_dir, f"layer_{layer_idx}_{name.replace('.', '_')}_rpca.pkl")

            if args.load_rpca and os.path.exists(rpca_result_path):
                # Load RPCA results
                with open(rpca_result_path, 'rb') as f:
                    L, S = pickle.load(f)
                print(f"Loaded RPCA results for {name} from {rpca_result_path}")
            else:
                # Perform RPCA on the weight matrix
                rpca = RPCA_gpu(
                    weight,
                    mu=args.rpca_mu,
                    lambda_=args.rpca_lambda,
                    tau_multiplier=args.tau_multiplier,  # Use tau_multiplier
                    use_nuclear_norm=args.use_nuclear_norm
                )

                # Perform RPCA
                L, S = rpca.fit(
                    tol=args.rpca_tol,
                    max_iter=args.rpca_max_iter,
                    verbose=True
                )

                if args.save_rpca:
                    # Save RPCA results
                    with open(rpca_result_path, 'wb') as f:
                        pickle.dump((L.detach().cpu(), S.detach().cpu()), f)
                    print(f"Saved RPCA results for {name} to {rpca_result_path}")

            # Move L and S back to device
            L = L.to(device)
            S = S.to(device)

            # Store singular values for plotting
            with torch.no_grad():
                U_svd_full, sigma_svd_full, Vh_svd_full = torch.linalg.svd(L, full_matrices=False)
                singular_values_dict[name] = sigma_svd_full.cpu().numpy()
                decomposition_error = torch.norm(weight - L - S, p='fro') / torch.norm(weight, p='fro')
            # Determine rank of L based on singular value threshold
            singular_value_threshold = args.singular_value_threshold
            rank_L = (sigma_svd_full >= singular_value_threshold).sum().item()

            # Truncate U, sigma, Vh based on rank_L
            U_svd = U_svd_full[:, :rank_L]
            sigma_svd = sigma_svd_full[:rank_L]
            Vh_svd = Vh_svd_full[:rank_L, :]

            # Update U and V accordingly
            sqrt_sigma = torch.sqrt(sigma_svd)
            U = U_svd * sqrt_sigma.unsqueeze(0)
            V = sqrt_sigma.unsqueeze(1) * Vh_svd

            # Enforce target sparsity on S if specified
            # Compute current sparsity of S
            sparsity_S = float((S == 0).sum().item()) / S.numel()

            if args.enforce_target_sparsity and sparsity_S < args.target_sparsity:
                # Need to further prune S to reach target sparsity
                num_elements_to_zero = int((args.target_sparsity - sparsity_S) * S.numel())
                if num_elements_to_zero > 0:
                    # Get indices of non-zero elements
                    non_zero_indices = torch.nonzero(S != 0, as_tuple=False)
                    if args.zeroing_method == 'random':
                        # Randomly select elements to set to zero
                        selected_indices = non_zero_indices[torch.randperm(non_zero_indices.size(0))[:num_elements_to_zero]]
                    elif args.zeroing_method == 'magnitude':
                        # Zero out elements with smallest magnitudes
                        non_zero_values = torch.abs(S[S != 0])
                        threshold = torch.topk(non_zero_values, num_elements_to_zero, largest=False).values.max()
                        selected_indices = torch.nonzero(torch.abs(S) <= threshold, as_tuple=False)
                    else:
                        raise ValueError(f"Unknown zeroing method: {args.zeroing_method}")

                    # Set selected elements to zero
                    S[selected_indices[:, 0], selected_indices[:, 1]] = 0

                    # Update sparsity_S
                    sparsity_S = float((S == 0).sum().item()) / S.numel()

            # Create new module
            new_module = LinearLSPUV(module.in_features, module.out_features, rank_L, bias is not None)
            new_module.U.data = U.to(device=weight_device, dtype=weight_dtype)
            new_module.V.data = V.to(device=weight_device, dtype=weight_dtype)
            new_module.S.data = S.to(device=weight_device, dtype=weight_dtype)
            if bias is not None:
                new_module.bias.data = bias.data.to(device=weight_device, dtype=weight_dtype)

            # Ensure S does not update during fine-tuning
            new_module.S.requires_grad = False

            # Move new_module to the correct device
            new_module.to(weight_device)

            # Replace module in the model
            parent_module = model
            name_parts = name.split('.')
            for n in name_parts[:-1]:
                parent_module = getattr(parent_module, n)
            setattr(parent_module, name_parts[-1], new_module)

            # Compute errors
            with torch.no_grad():
                reconstructed_weight = U @ V + new_module.S.data
                truncated_error = torch.norm(weight - reconstructed_weight, p='fro') / torch.norm(weight, p='fro')

            # Store decomposition info
            decomposition_info[name] = {
                'decomposition_error': float(decomposition_error.cpu()),
                'rank_L': rank_L,
                'sparsity_S': sparsity_S,
                'truncated_error': float(truncated_error.cpu()),
                'lambda': rpca.lambda_ if not args.load_rpca else None,
                'mu': rpca.mu if not args.load_rpca else None
            }

            # Collect sparsity and rank for this module
            layer_sparsity_list.append(sparsity_S)
            layer_rank_list.append(rank_L)
            layer_truncated_error_list.append(truncated_error.cpu())
            layer_decomposition_error_list.append(decomposition_error.cpu())

            # Print module decomposition info
            print(f"Module: {name}")
            print(f"  Decomposition Error: {decomposition_error:.6e}")
            print(f"  Rank of L: {rank_L}")
            print(f"  Sparsity of S: {sparsity_S:.4f}")
            print(f"  Truncated Error: {truncated_error:.6e}")
            if not args.load_rpca:
                print(f"  Lambda: {rpca.lambda_:.6e}")
                print(f"  Mu: {rpca.mu:.6e}")

        # Plot singular values after processing the layer
        # Create color mapping
        color_map = {
            'q_proj': 'green',
            'k_proj': 'red',
            'v_proj': 'blue',
            'o_proj': 'orange',
            'gate_proj': 'purple',
            'down_proj': 'brown',
            'up_proj': 'pink',
        }

        plt.figure(figsize=(8, 6), dpi=150)
        for module_name in singular_values_dict:
            module_short_name = module_name.split('.')[-1]
            color = color_map.get(module_short_name, None)
            plt.plot(original_singular_values[module_name], linestyle='--', label=f"{module_short_name} (original)", color=color, linewidth=2.5)
            plt.plot(singular_values_dict[module_name], linestyle='-', label=f"{module_short_name} (RPCA)", color=color, linewidth=2.0)

        plt.title(f"Singular Values of Layer {layer_idx}", fontsize=14, fontweight='bold')
        plt.xlabel("Index", fontsize=12, fontweight='bold')
        plt.ylabel("Singular Value", fontsize=12, fontweight='bold')
        plt.legend(fontsize=10)
        plt.tight_layout()
        figure_path = os.path.join(visualization_dir, f"layer_{layer_idx}_singular_values.pdf")
        plt.savefig(figure_path, format='pdf')
        plt.close()

        # Compute average sparsity, rank, truncated error for the layer
        avg_sparsity = np.mean(layer_sparsity_list)
        avg_rank = np.mean(layer_rank_list)
        avg_truncated_error = torch.mean(torch.tensor(layer_truncated_error_list)).item()
        avg_decomposition_error = torch.mean(torch.tensor(layer_decomposition_error_list)).item()

        # Print layer summary
        print(f"Layer {layer_idx} average sparsity: {avg_sparsity:.4f}")
        print(f"Layer {layer_idx} average rank: {avg_rank:.2f}")
        print(f"Layer {layer_idx} average truncated error: {avg_truncated_error:.6e}")
        print(f"Layer {layer_idx} average decomposition error: {avg_decomposition_error:.6e}")

        # Save decomposition info for the layer
        decomposition_info_path = os.path.join(decomposition_info_dir, f"layer_{layer_idx}_decomposition_info.json")
        with open(decomposition_info_path, 'w') as f:
            json.dump(decomposition_info, f, indent=4)

        # Evaluate the model after processing the layer
        if args.eval_after_each_layer:
            print(f"Evaluating the model after processing layer {layer_idx}...")
            model.eval()
            ppl_test = eval_ppl(args, model, tokenizer, device)
            print(f"Perplexity after processing layer {layer_idx}: {ppl_test}")
            eval_metrics[f"layer_{layer_idx}"] = {
                'perplexity': float(ppl_test),
                'average_sparsity': float(avg_sparsity),
                'average_rank': float(avg_rank),
                'average_truncated_error': float(avg_truncated_error),
                'average_decomposition_error': float(avg_decomposition_error)
            }

    # Save evaluation metrics
    eval_metrics_path = os.path.join(args.save, 'eval_metrics_after_each_layer.json')
    with open(eval_metrics_path, 'w') as f:
        json.dump(eval_metrics, f, indent=4)

    # Fine-tune after RPCA if specified
    if args.finetune:
        print("Starting fine-tuning after RPCA...")
        finetune_model(args, model, tokenizer, device)

class RPCA_gpu:
    """ Low-rank and sparse matrix decomposition via RPCA with CUDA capabilities """
    def __init__(self, D, mu=None, lambda_=None, tau_multiplier=1.0, use_nuclear_norm=True):
        self.D = D
        self.device = D.device
        self.use_nuclear_norm = use_nuclear_norm
        self.norm_D = torch.norm(D, p='fro').item()
        self.m, self.n = D.shape

        # Parameter initialization
        # Use standard mu and lambda if not provided
        self.mu = mu or (np.prod(self.D.shape) / (4 * self.norm_p(self.D, 2))).item()
        self.mu_inv = 1 / self.mu
        self.lambda_ = lambda_ or 1 / np.sqrt(np.max(self.D.shape))
        self.tau_multiplier = tau_multiplier  # Multiplier for tau

    @staticmethod
    def norm_p(M, p):
        return torch.sum(torch.pow(M.abs(), p))

    @staticmethod
    def shrink(M, tau):
        return torch.sign(M) * F.relu(torch.abs(M) - tau)

    def svd_threshold(self, M, tau):
        tau *= self.tau_multiplier  # Apply tau multiplier
        U, s, Vh = torch.linalg.svd(M, full_matrices=False)
        s_threshold = F.relu(s - tau)
        return U @ torch.diag(s_threshold) @ Vh

    def fit(self, tol=1e-7, max_iter=1000, verbose=False):
        i = 0
        S = torch.zeros_like(self.D, device=self.device)
        Y = torch.zeros_like(self.D, device=self.device)
        L = torch.zeros_like(self.D, device=self.device)
        mu_inv = self.mu_inv
        _tol = tol * self.norm_D
        converged = False

        # Track best L and S with minimum error
        best_L = None
        best_S = None
        min_error = float('inf')

        for i in range(max_iter):
            # Update L
            if self.use_nuclear_norm:
                L = self.svd_threshold(self.D - S + mu_inv * Y, mu_inv)
            else:
                L = self.D - S + mu_inv * Y

            # Update S
            S = self.shrink(self.D - L + mu_inv * Y, mu_inv * self.lambda_ * self.tau_multiplier)  # Apply tau_multiplier

            # Update Y
            Z = self.D - L - S
            Y = Y + self.mu * Z

            # Check convergence
            err = torch.norm(Z, p='fro').item() / self.norm_D

            # Track the best L and S with the minimum error so far
            if err < min_error:
                min_error = err
                best_L = L.clone()
                best_S = S.clone()

            if verbose and ((i % 50) == 0 or i == 1 or i >= max_iter - 1 or err <= tol):
                rank_L = torch.linalg.matrix_rank(L).item()
                print(f'Iteration: {i+1}; Error: {err:0.4e}; Rank of L: {rank_L}')
                
            if err < tol:
                converged = True
                break

        if not converged:
            print("RPCA did not converge within the maximum number of iterations.")
        
        # Return the best L and S with the minimum observed error
        self.L = best_L
        self.S = best_S
        return self.L, self.S

def finetune_model(args, model, tokenizer, device):
    # Implement fine-tuning logic
    import torch
    from torch.utils.data import DataLoader
    from torch.optim import AdamW
    from transformers import get_scheduler

    # Get training data
    trainloader, _ = get_loaders(
        'wikitext2', nsamples=args.train_samples, seed=args.seed, seqlen=model.seqlen, tokenizer=tokenizer
    )

    # Prepare optimizer and scheduler
    # Only update U and V parameters
    parameters_to_update = []
    for name, param in model.named_parameters():
        if 'U' in name or 'V' in name:
            param.requires_grad = True
            parameters_to_update.append(param)
        else:
            param.requires_grad = False

    optimizer = AdamW(parameters_to_update, lr=args.learning_rate)
    num_training_steps = args.num_train_epochs * len(trainloader)
    lr_scheduler = get_scheduler(
        name='linear', optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
    )

    # Note: If the model is spread across multiple devices, do not move it
    # model.to(device)
    model.train()
    loss_fct = torch.nn.CrossEntropyLoss()

    for epoch in range(args.num_train_epochs):
        print(f"Starting epoch {epoch+1}/{args.num_train_epochs}")
        total_loss = 0
        for step, (inputs, _) in enumerate(trainloader):
            inputs = inputs.reshape(-1, model.seqlen).to(device)
            outputs = model(inputs)
            logits = outputs.logits
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = inputs[:, 1:]
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            total_loss += loss.item()
            if step % 50 == 0:
                print(f"Step {step}, Loss: {loss.item()}")
        avg_loss = total_loss / len(trainloader)
        print(f"Epoch {epoch+1} finished, Average Loss: {avg_loss}")

    # Save the fine-tuned model
    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)
        torch.save(model.state_dict(), os.path.join(args.save_model, 'pytorch_model.bin'))
        print(f"Fine-tuned model saved to {args.save_model}")

#################################################################################


def find_layers(module, layers=[nn.Linear], name=''):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def check_sparsity(model):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    layers = model.model.layers
    count = 0 
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        sub_count = 0
        sub_params = 0
        for name in subset:
            W = subset[name].weight.data
            count += (W==0).sum().item()
            total_params += W.numel()

            sub_count += (W==0).sum().item()
            sub_params += W.numel()

        print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")

    model.config.use_cache = use_cache 
    return float(count)/total_params 

def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    # dev = model.hf_device_map["model.embed_tokens"]
    if "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(device))
        except ValueError:
            pass 
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids 

def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
    thres_cumsum = sum_before * alpha 
    sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1))
    thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1)
    W_mask = (W_metric <= thres)
    cur_sparsity = (W_mask==True).sum() / W_mask.numel()
    return W_mask, cur_sparsity

def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    layers = model.model.layers 

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        for name in subset:
            W = subset[name].weight.data 
            W_metric = torch.abs(W)
            if prune_n != 0:
                W_mask = (torch.zeros_like(W)==1)
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu()
                W_mask = (W_metric<=thresh)

            W[W_mask] = 0

def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0):
    use_cache = model.config.use_cache 
    model.config.use_cache = False 

    print("loading calibdation data")
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)
    print("dataset loading complete")
    with torch.no_grad():
        inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)

    layers = model.model.layers
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        if f"model.layers.{i}" in model.hf_device_map:   ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{i}"]
            # print(f"inps: {inps}, outs: {outs}, attention_mask: {attention_mask}, position_ids: {position_ids}")
            inps, outs, position_ids = inps.to(dev), outs.to(dev), position_ids.to(dev)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()

        for name in subset:
            print(f"pruning layer {i} name {name}")
            W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1)))

            W_mask = (torch.zeros_like(W_metric) == 1)  ## initialize a mask to be all False
            if prune_n != 0:
                # structured n:m sparsity
                for ii in range(W_metric.shape[1]):
                    if ii % prune_m == 0:
                        tmp = W_metric[:,ii:(ii+prune_m)].float()
                        W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True)
            else:
                sort_res = torch.sort(W_metric, dim=-1, stable=True)

                if args.use_variant:
                    # wanda variant 
                    tmp_metric = torch.cumsum(sort_res[0], dim=1)
                    sum_before = W_metric.sum(dim=1)

                    alpha = 0.4
                    alpha_hist = [0., 0.8]
                    W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001):
                        if cur_sparsity > args.sparsity_ratio:
                            alpha_new = (alpha + alpha_hist[0]) / 2.0
                            alpha_hist[1] = alpha
                        else:
                            alpha_new = (alpha + alpha_hist[1]) / 2.0
                            alpha_hist[0] = alpha

                        alpha = alpha_new 
                        W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before)
                    print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
                else:
                    # unstructured pruning
                    indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)]
                    W_mask.scatter_(1, indices, True)

            subset[name].weight.data[W_mask] = 0  ## set weights to zero 

        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        inps, outs = outs, inps

    model.config.use_cache = use_cache 
    torch.cuda.empty_cache()


@torch.no_grad()
def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0):
    ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
    print('Starting ...')
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    if "model.embed_tokens" in model.hf_device_map:
        dev = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']

    print('Ready.')

    for i in range(len(layers)):
        layer = layers[i]
        if f"model.layers.{i}" in model.hf_device_map:
            dev = model.hf_device_map[f"model.layers.{i}"]
            print(f"layer {i} device {dev}")
            inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)

        subset = find_layers(layer)

        gpts = {}
        for name in subset:
            gpts[name] = SparseGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')

            gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
            gpts[name].free()

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

        layers[i] = layer 
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()



@torch.no_grad()
def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0):
    ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
    print('Starting ...')
    dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    if "model.embed_tokens" in model.hf_device_map:
        dev = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']

    print('Ready.')

    for i in range(len(layers)):
        layer = layers[i]
        if f"model.layers.{i}" in model.hf_device_map:
            dev = model.hf_device_map[f"model.layers.{i}"]
            print(f"layer {i} device {dev}")
            inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev)

        subset = find_layers(layer)

        gpts = {}
        for name in subset:
            gpts[name] = AblateGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')

            if args.prune_method == "ablate_wanda_seq":
                prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m)
            elif args.prune_method == "ablate_mag_seq":
                prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m)
            elif "iter" in args.prune_method:
                prune_mask = None 

            gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
            gpts[name].free()

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

        layers[i] = layer 
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()