# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import gc
import sys
import copy
import math
from typing import Optional, Tuple
from omegaconf import DictConfig

import torch
import torch.nn.functional as F
from torch import nn
from Pruner.mask.mask_module import OurMask
from Pruner.mask.utils import initialize_wanda, initialize_FLAP, initialize_random
from transformers.pytorch_utils import find_pruneable_heads_and_indices
from typing import List

DEVICE = 'cpu'

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def prune_params(self, hidden_z):
        remaining_index = torch.where(~hidden_z.eq(0))[0]
        self.weight = torch.nn.Parameter(self.weight.data.mul(hidden_z.squeeze())[remaining_index])

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x, hidden_z = None):
        output = self._norm(x.float()).type_as(x)
        output = output * self.weight
        if hidden_z is not None:
            output = output.mul(hidden_z)
        return output

    def instantation_forward(self, x):
        output = self._norm(x.float()).type_as(x)
        output = output * self.weight
        return output


def apply_scaling(freqs: torch.Tensor):
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * math.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    if use_scaled:
        freqs = apply_scaling(freqs)
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.n_kv_heads = cfg.n_kv_heads
        model_parallel_size = 1
        self.n_local_heads = cfg.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = cfg.d_model // cfg.n_heads
        self.pruned_heads = set()

        self.wq = nn.Linear(cfg.d_model, cfg.n_heads * self.head_dim, device=cfg.init_device, bias=False)
        self.wk = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, device=cfg.init_device, bias=False)
        self.wv = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, device=cfg.init_device, bias=False)
        self.wo = nn.Linear(cfg.n_heads * self.head_dim, cfg.d_model, device=cfg.init_device, bias=False)
        self.wo._is_residual = True

    def prune_params(self, zs_block):
        head_z = None
        head_layer_z = None
        hidden_z = None
        if "head_z" in zs_block:
            head_z = zs_block["head_z"].squeeze()
        if "head_layer_z" in zs_block:
            head_layer_z = zs_block["head_layer_z"].squeeze()
        if "hidden_z" in zs_block:
            hidden_z = zs_block["hidden_z"].squeeze()

        if hidden_z is not None:
            remaining_index = torch.where(~hidden_z.eq(0))[0]
            print(f"    Head hidden: {len(hidden_z)} -> {len(remaining_index)}")
            bhalf = next(self.wq.parameters()).dtype == torch.bfloat16
            self.wk = prune_linear_layer_direct_to_gpu(self.wk, remaining_index, dim=1)
            self.wq = prune_linear_layer_direct_to_gpu(self.wq, remaining_index, dim=1)
            self.wv = prune_linear_layer_direct_to_gpu(self.wv, remaining_index, dim=1)
            self.wo = prune_linear_layer_direct_to_gpu(self.wo, remaining_index)
            if bhalf:
                self.wq = self.wq.to(dtype=torch.bfloat16)
                self.wk = self.wk.to(dtype=torch.bfloat16)
                self.wv = self.wv.to(dtype=torch.bfloat16)
                self.wo = self.wo.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()

        to_prune_heads = self.turn_head_z(head_z, head_layer_z)
        len_to_prune_heads = len(to_prune_heads)
        if len_to_prune_heads == 0:
            print(f"    Heads: {self.n_kv_heads} -> {self.n_kv_heads}")
            return

        heads, index = find_pruneable_heads_and_indices(
            to_prune_heads, self.n_kv_heads, self.head_dim, self.pruned_heads
        )
        qk_index = index; vo_index = index
        if len(index) == 0:
            self.wq = None
            self.wk = None
            self.wv = None
            self.wo = None
            torch.cuda.empty_cache()
        else:
            bhalf = next(self.wq.parameters()).dtype == torch.bfloat16
            # This place the qk_index should not just repeat
            # qk_index = torch.repeat_interleave(qk_index, self.n_rep)
            # self.head_dim
            if self.n_rep != 1:
                qk_index = find_pruneable_gqa_indices(to_prune_heads, self.n_kv_heads * self.n_rep, self.head_dim, self.n_rep)
            self.wq = prune_linear_layer_direct_to_gpu(self.wq, qk_index)
            self.wk = prune_linear_layer_direct_to_gpu(self.wk, vo_index)
            self.wv = prune_linear_layer_direct_to_gpu(self.wv, vo_index)
            self.wo = prune_linear_layer_direct_to_gpu(self.wo, qk_index, dim=1)
            if bhalf:
                self.wq = self.wq.to(dtype=torch.bfloat16)
                self.wk = self.wk.to(dtype=torch.bfloat16)
                self.wv = self.wv.to(dtype=torch.bfloat16)
                self.wo = self.wo.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()
        print(f"    Heads: {self.n_kv_heads} -> {self.n_kv_heads - len(heads)}")

        self.n_kv_heads = self.n_kv_heads - len(heads)
        self.n_local_heads = self.n_kv_heads * self.n_rep
        self.pruned_heads = self.pruned_heads.union(heads)

    def turn_head_z(self, head_z, head_layer_z):
        head_z = head_z.squeeze().clone()
        if head_layer_z is not None:
            head_z *= head_layer_z
        to_prune_heads = torch.where(head_z == 0)[0].view(-1).tolist()
        return to_prune_heads

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        head_z, head_layer_z, hidden_z
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        keys = xk
        values = xv

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        if head_z is not None:
            if self.n_kv_heads != self.cfg.n_heads:
                head_z = head_z.reshape([1, -1, 1, 1])
                head_z = torch.repeat_interleave(head_z, self.n_rep, dim=1)
            output *= head_z.reshape([1, output.size(1), 1, 1])
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = self.wo(output)
        if head_layer_z is not None:
            output *= head_layer_z
        if hidden_z is not None:
            output *= hidden_z
        return output

    def instantation_forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        keys = xk
        values = xv

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = self.wo(output)
        return output


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
        cfg: DictConfig
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.hidden_dim = hidden_dim

        self.w1 = nn.Linear(dim, hidden_dim, bias=False, device=cfg.init_device) # gate
        self.w2 = nn.Linear(hidden_dim, dim, bias=False, device=cfg.init_device) # down
        self.w3 = nn.Linear(dim, hidden_dim, bias=False, device=cfg.init_device) # up

    def prune_params(self, zs_block):
        intermediate_z = zs_block.get("intermediate_z", None)
        mlp_z = zs_block.get("mlp_z", None)
        hidden_z = zs_block.get("hidden_z", None)

        if hidden_z is not None:
            remaining_index = torch.where(~hidden_z.eq(0))[0]
            print(f"    FFN hidden dim: {len(hidden_z)} -> {len(remaining_index)}")
            bhalf = next(self.w3.parameters()).dtype
            self.w3 = prune_linear_layer_direct_to_gpu(self.w3, remaining_index, dim=1)
            self.w1 = prune_linear_layer_direct_to_gpu(self.w1, remaining_index, dim=1)
            self.w2 = prune_linear_layer_direct_to_gpu(self.w2, remaining_index, dim=0)

            if bhalf == torch.bfloat16:
                self.w3 = self.w3.to(dtype=torch.bfloat16)
                self.w1 = self.w1.to(dtype=torch.bfloat16)
                self.w2 = self.w2.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()

        keep_dim = self.turn_mlp_z(intermediate_z, mlp_z)
        device = self.w3.weight.device
        if len(keep_dim) == self.w3.weight.shape[0]:
            print(f"    FFN intermediate dim: {self.hidden_dim} -> {len(keep_dim)}")
            return

        if len(keep_dim) == 0:
            self.w3 = None
            self.w2 = None
            self.w1 = None
            torch.cuda.empty_cache()
        else:
            keep_dim_index = torch.tensor(keep_dim).long().to(device)
            bhalf = next(self.w3.parameters()).dtype
            self.w3 = prune_linear_layer_direct_to_gpu(self.w3, keep_dim_index, dim=0)
            self.w1 = prune_linear_layer_direct_to_gpu(self.w1, keep_dim_index, dim=0)
            self.w2 = prune_linear_layer_direct_to_gpu(self.w2, keep_dim_index, dim=1)

            if bhalf == torch.bfloat16:
                self.w3 = self.w3.to(dtype=torch.bfloat16)
                self.w1 = self.w1.to(dtype=torch.bfloat16)
                self.w2 = self.w2.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()
        print(f"    FFN intermediate dim: {self.hidden_dim} -> {len(keep_dim)}")

    def turn_mlp_z(self, intermediate_z, mlp_z):
        intermediate_z_layer = intermediate_z.squeeze().clone()
        if mlp_z is not None:
            intermediate_z_layer *= mlp_z
        keep_intermediate_dims = torch.where(intermediate_z_layer != 0)[0].tolist()
        return keep_intermediate_dims

    def forward(self, x, intermediate_z, mlp_z, hidden_z):
        gate = F.silu(self.w1(x))
        up_v = self.w3(x)
        if intermediate_z is not None:
            up_v *= intermediate_z
        down_v = self.w2(gate * up_v)
        if mlp_z is not None:
            down_v *= mlp_z
        if hidden_z is not None:
            down_v *= hidden_z
        return down_v

    def instantation_forward(self, x):
        gate = F.silu(self.w1(x))
        up_v = self.w3(x)
        down_v = self.w2(gate * up_v)
        return down_v


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, cfg: DictConfig):
        super().__init__()
        self.n_heads = cfg.n_heads
        self.dim = cfg.d_model
        self.head_dim = cfg.d_model // cfg.n_heads
        self.attention = Attention(cfg)
        self.feed_forward = FeedForward(
            dim=cfg.d_model,
            hidden_dim=4 * cfg.d_model,
            multiple_of=cfg.multiple_of,
            ffn_dim_multiplier=getattr(cfg, "ffn_dim_multiplier", None),
            cfg=cfg
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(cfg.d_model, eps=cfg.rms_norm_eps)
        self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_norm_eps)

    def prune_params(self, zs_block):
        self.attention.prune_params(zs_block)
        self.feed_forward.prune_params(zs_block)
        if 'hidden_z' in zs_block:
            hidden_z = zs_block['hidden_z']
            self.attention_norm.prune_params(hidden_z)
            self.ffn_norm.prune_params(hidden_z)
            torch.cuda.empty_cache()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        head_z: Optional[torch.Tensor] = None,
        head_layer_z: Optional[torch.Tensor] = None,
        intermediate_z: Optional[torch.Tensor] = None,
        mlp_z: Optional[torch.Tensor] = None,
        hidden_z: Optional[torch.Tensor] = None,
    ):
        h = x + self.attention(
            self.attention_norm(x, hidden_z), start_pos, freqs_cis, mask,
            head_z, head_layer_z, hidden_z
        )
        out = h + self.feed_forward(self.ffn_norm(h, hidden_z),
                                    intermediate_z, mlp_z, hidden_z)
        return out

    def instantation_forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention.instantation_forward(
            self.attention_norm.instantation_forward(x), start_pos, freqs_cis, mask
        )
        out = h + self.feed_forward.instantation_forward(self.ffn_norm.instantation_forward(h))
        return out


class Transformer(nn.Module):
    def __init__(self, cfg: DictConfig):

        super().__init__()
        self.cfg = cfg
        self.vocab_size = cfg.vocab_size
        self.n_layers = cfg.n_layers

        self.tok_embeddings = nn.Embedding(cfg.vocab_size, cfg.d_model, device=cfg.init_device)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(cfg.n_layers):
            self.layers.append(TransformerBlock(layer_id, cfg))

        self.norm = RMSNorm(cfg.d_model, eps=cfg.rms_norm_eps)
        self.output = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False, device=cfg.init_device)

        self.freqs_cis = precompute_freqs_cis(
            # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
            # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
            cfg.d_model // cfg.n_heads, cfg.max_seq_len * 2, 500000.0 if 'llama3' in cfg.name else 10000.0, True if '3.1' in cfg.name or '3.2' in cfg.name else False
        )

    def prune_params(self, zs=None):
        if "hidden_z" in zs:
            hidden_z = zs["hidden_z"]
            remaining_index = torch.where(~hidden_z.eq(0))[0]
            self.norm.prune_params(hidden_z)
            self.tok_embeddings.weight.data = self.tok_embeddings.weight.data.mul(hidden_z)
            self.tok_embeddings.weight = nn.parameter.Parameter(self.tok_embeddings.weight.index_select(1, remaining_index).clone())
            # self.tok_embeddings.weight.requires_grad = False
            self.tok_embeddings.embedding_dim = len(remaining_index)
            self.output.weight.data = self.output.weight.data.mul(hidden_z)
            bhalf = self.output.weight.data.dtype == torch.bfloat16
            self.output = prune_linear_layer_direct_to_gpu(self.output, remaining_index, dim=1)
            # self.output.weight.requires_grad = False
            if bhalf:
                self.output = self.output.to(dtype=torch.bfloat16)
            torch.cuda.empty_cache()
        for i, block in enumerate(self.layers):
            zs_block = self.get_zs_block(zs, i)
            block.prune_params(zs_block)
            torch.cuda.empty_cache()

    def get_zs_block(self, zs, block_idx):
        zs_block = {}
        if zs is not None:
            for key in zs:
                if key == "hidden_z": zs_block["hidden_z"] = zs["hidden_z"]
                else: zs_block[key] = zs[key][block_idx]
        return zs_block

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int, mask_module: OurMask, ppl_during_train=False):
        if mask_module is not None and not ppl_during_train:
            zs, grads = mask_module()
        elif ppl_during_train:
            zs, _ = mask_module(ppl_during_train=ppl_during_train)

        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        if "hidden_z" in zs:
            h = h.mul(zs["hidden_z"])
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (seqlen, seqlen), float("-inf"), device=tokens.device
            )

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack([
                torch.zeros((seqlen, start_pos), device=tokens.device),
                mask
            ]).type_as(h)

        for b_idx, layer in enumerate(self.layers):
            zs_block = self.get_zs_block(zs, b_idx)
            h = layer(h, start_pos, freqs_cis, mask, **zs_block)
        h = self.norm(h, hidden_z=None)
        output = self.output(h).float()
        return {"logits": output, "grads": grads} if not ppl_during_train else {"logits": output}

    @torch.inference_mode()
    def instantation_forward(self, tokens: torch.Tensor, start_pos: int, mask_module, outdated_zs):
        def calculate_grad(zs, score):
            return (zs - score) / torch.sqrt((score + 1e-8) * (1 - score + 1e-8))
        
        if mask_module is not None and outdated_zs is not None:
            grads = {f"{pruning_module}_grad": [] for pruning_module in mask_module.pruning_modules}
            for pruning_module in mask_module.pruning_modules:
                mask = mask_module.masks[pruning_module]
                if pruning_module == 'layer':
                    zs = outdated_zs['head_layer_z']
                else:
                    zs = outdated_zs[f'{pruning_module}_z']
                zs = zs.reshape(mask.mask_shape)
                grad = calculate_grad(zs, mask.score.data)
                grads[f"{pruning_module}_grad"] = grad
        _, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (seqlen, seqlen), float("-inf"), device=tokens.device
            )

            mask = torch.triu(mask, diagonal=1)

            mask = torch.hstack([
                torch.zeros((seqlen, start_pos), device=tokens.device),
                mask
            ]).type_as(h)

        for _, layer in enumerate(self.layers):
            h = layer.instantation_forward(h, start_pos, freqs_cis, mask)
        h = self.norm.instantation_forward(h)
        output = self.output(h).float()

        return {"logits": output, "grads": grads} if mask_module is not None and outdated_zs is not None else {"logits": output}

    @torch.inference_mode()
    def instantation_test_forward(self, tokens: torch.Tensor, start_pos: int, zs):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        if "hidden_z" in zs:
            h = h.mul(zs["hidden_z"])
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (seqlen, seqlen), float("-inf"), device=tokens.device
            )

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack([
                torch.zeros((seqlen, start_pos), device=tokens.device),
                mask
            ]).type_as(h)

        for b_idx, layer in enumerate(self.layers):
            zs_block = self.get_zs_block(zs, b_idx)
            h = layer(h, start_pos, freqs_cis, mask, **zs_block)
        h = self.norm(h, hidden_z=None)
        output = self.output(h).float()
        return {"logits": output}


class Masked_Llama(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = Transformer(cfg)

        state_dict = torch.load(cfg.path)
        self.model.load_state_dict(state_dict, strict=False)

        self.mask_module = None
        if getattr(self.cfg, "mask", None) is not None:
            self.mask_module = OurMask(cfg, device=cfg.init_device)

    def prune_params(self, zs=None):
        self.model.prune_params(zs)

    def initialize_score(self, method, tokenizer, device, model_name):
        if method == 'wanda':
            initialize_wanda(self.model, tokenizer, device, self.mask_module, model_name)
        elif method == 'flap':
            initialize_FLAP(self.model, tokenizer, device, self.mask_module, model_name)
        elif method == 'random':
            initialize_random(self.cfg, self.mask_module)
        elif method == 'mean':
            return

    def instantation(self, origin_model, device, test_mask, test_batch): # TODO DEBUG
        print('================== START INSTANTATION! ==================')

        self.model = copy.deepcopy(origin_model)
        torch.cuda.empty_cache()
        self.model = self.model.to(device=device, dtype=torch.bfloat16)
        
        if test_mask:
            test_batch = test_batch.to(device)
            zs_cache = []
            zs_loss = []
            sample_num = 5
            for _ in range(sample_num):
                zs_cache.append(self.mask_module()[0])
                output = self.model.instantation_test_forward(tokens=test_batch, start_pos=0, zs=zs_cache[-1])
                zs_loss.append(self.loss(output, test_batch))
            print(zs_loss)
            min_loss_index = zs_loss.index(min(zs_loss))
            zs = zs_cache[min_loss_index]
        else:
            zs, _ = self.mask_module()
        self.model.prune_params(zs)

        print('================== INSTANTATION OVER! ==================')
        return copy.deepcopy(self.mask_module.masks), zs
    
    def sim_instantation(self, device, test_mask, test_batch):
        if test_mask:
            test_batch = test_batch.to(device)
            zs_cache = []
            zs_loss = []
            sample_num = 5
            for _ in range(sample_num):
                zs_cache.append(self.mask_module()[0])
                output = self.model.instantation_test_forward(tokens=test_batch, start_pos=0, zs=zs_cache[-1])
                zs_loss.append(self.loss(output, test_batch))
            print(zs_loss)
            min_loss_index = zs_loss.index(min(zs_loss))
            zs = zs_cache[min_loss_index]
        else:
            zs, _ = self.mask_module()

        return copy.deepcopy(self.mask_module.masks), zs

    def get_targets(self, batch):
        targets = torch.roll(batch, shifts=-1)
        targets[:, -1] = -100
        return targets

    def forward(self, input_ids, instantation_model, outdated_zs=None): #TODO : DEBUG
        if instantation_model:
            input_ids = input_ids
            model_output = self.model.instantation_forward(tokens = input_ids, start_pos = 0, mask_module=self.mask_module, outdated_zs=outdated_zs)
            return model_output
        else:
            input_ids = input_ids
            model_output = self.model.forward(tokens = input_ids, start_pos = 0, mask_module=self.mask_module)
            return model_output
        
    def sim_forward(self, input_ids, outdated_zs):
        output = self.model.instantation_test_forward(tokens=input_ids, start_pos=0, zs=outdated_zs)
        def calculate_grad(zs, score):
            return (zs - score) / ((score + 1e-8) * (1 - score + 1e-8))

        grads = {f"{pruning_module}_grad": [] for pruning_module in self.mask_module.pruning_modules}
        for pruning_module in self.mask_module.pruning_modules:
            mask = self.mask_module.masks[pruning_module]
            if pruning_module == 'layer':
                zs = outdated_zs['head_layer_z']
            else:
                zs = outdated_zs[f'{pruning_module}_z']
            zs = zs.reshape(mask.mask_shape)
            grad = calculate_grad(zs, mask.score.data)
            grads[f"{pruning_module}_grad"] = grad

        output["grads"] = grads
        return output

    def loss(self, outputs, batch):
        logits = outputs["logits"]
        targets = self.get_targets(batch)

        loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                   targets.view(-1),
                                   ignore_index=-100)
        return loss


def prune_linear_layer_direct_to_gpu(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
    """
    Copyed from transformers.pytorch_utils.prune_linear_layer

    but directly put the nn to gpu and delete the original nn,

    since nn.Linear(device='cuda') is faster than nn.Linear().to(device='cuda')
    """
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None, device=layer.weight.device)
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer


def find_pruneable_gqa_indices(
    heads: List[int], n_heads: int, head_size: int, n_rep: int
):
    """
    Copyed from transformers.pytorch_utils.find_pruneable_heads_and_indices

    due to the GQA, we rewrite this function to find the q_index
    """
    import itertools
    heads = [[t for t in range(i * n_rep, (i + 1) * n_rep)] for i in heads]
    heads = list(itertools.chain(*heads))
    mask = torch.ones(n_heads, head_size)
    heads = set(heads)
    for head in heads:
        mask[head] = 0
    mask = mask.view(-1).contiguous().eq(1)
    index: torch.LongTensor = torch.arange(len(mask))[mask].long()
    return index

