import warnings
import functools
import math
import logging
from numpy import isin
import torch
from torch import nn
import torch.nn.functional as F
from typing import Dict, List, Tuple, Optional

from transformers import LlamaPreTrainedModel, OPTPreTrainedModel
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer

from utils import (
    HelperState,
    HelperCollectState,
    set_helper_state,
    HELPER_SUPPORT_MODEL_LIST,
    HELPER_SUPPORT_MODEL_TYPES
)

logger = logging.getLogger(__name__)

_HELPER_HOOK_KEY = "HelperHook"

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
    
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
    
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
    
def _add_collect_data_hook(model: HELPER_SUPPORT_MODEL_TYPES,
                               dest: Dict[str, Dict[str, List[torch.Tensor]]],
                               intermediate_size: int, 
                               hidden_size: int) -> int:
    set_helper_state(model, HelperState.Collecting)
    hooks = []
    last_layer = 0
    
    def forward_hook_get_XXT(layer_idx, name, module, inp, out):
        inp = inp[0].detach().float()
        if inp.dim() == 2:
            inp = inp.unsqueeze(0)
        if inp.shape[1] > 1:
            adds = torch.matmul(inp.transpose(1,2), inp)
            adds_sum = torch.sum(adds, dim=0).cpu()
            
            raw_scaling_diag_matrix = getattr(module, f'raw_scaling_diag_matrix_{layer_idx}')
            raw_scaling_diag_matrix += adds_sum
            
            inp = adds = adds_sum = out = None
            del inp, adds, adds_sum, out
            torch.cuda.empty_cache()
    
    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if isinstance(model, OPTPreTrainedModel):
            if suffix not in ["fc1", "fc2", "q_proj", "k_proj", "out_proj", "v_proj"]:
                continue
            if suffix in ['fc1', 'fc2']:
                layer_idx = int(name.split(".")[-2])
            else:
                layer_idx = int(name.split(".")[-3])
            if suffix == "fc2":
                setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(intermediate_size, intermediate_size))
            else:
                setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(hidden_size, hidden_size))
        else:
            if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj", "v_proj"]:
                continue
            layer_idx = int(name.split(".")[-3])
            if suffix == "down_proj":
                setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(intermediate_size, intermediate_size))
            else:
                setattr(module, f"raw_scaling_diag_matrix_{layer_idx}", torch.zeros(hidden_size, hidden_size))
        handle_pre_forward_collect_hook = module.register_forward_hook(
                functools.partial(
                    forward_hook_get_XXT,
                    layer_idx,
                    name
                )
            )
        hooks.append(handle_pre_forward_collect_hook)
    
    '''
    def forward_hook_get_X_mean(name, module, inp, out):
        inp = inp[0].detach().float()
        if inp.shape[1] > 1:
            if inp.dim() == 2:
                inp = inp.unsqueeze(0)
            batch_size = inp.shape[0]
            if isinstance(module, nn.Linear):
                if inp.dim() == 3:
                    inp = inp.reshape((-1, inp.shape[-1]))
                inp = inp.t()   # (dim, seqlen)
            nsamples = getattr(module, 'nsamples')
            baseline_inp = getattr(module, 'baseline_inp')
            baseline_inp += (torch.mean(inp, dim=1) - baseline_inp) / (nsamples.item() + batch_size)
            nsamples += batch_size
            # baseline_inp_cum = getattr(module, 'baseline_inp_cum')
            # baseline_inp_cum += torch.mean(inp, dim=1)
            # inp_cum = getattr(module, 'baseline_inp_cum') # playground
            # nsamples = getattr(module, 'nsamples').item()
            # inp_mean_cum = inp_cum / nsamples
    
    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in ["gate_proj", "up_proj", "down_proj", "q_proj", "k_proj", "o_proj", "v_proj"]:
            continue
        
        setattr(module, "nsamples", torch.tensor(0))
        in_dim = module.weight.data.shape[1]
        setattr(module, "baseline_inp", torch.zeros((in_dim), device=module.weight.device))
        # setattr(module, "baseline_inp_cum", torch.zeros((in_dim), device=module.weight.device))
        
        handle_pre_forward_collect_hook = module.register_forward_hook(
                functools.partial(
                    forward_hook_get_X_mean,
                    name
                )
            )
        hooks.append(handle_pre_forward_collect_hook)
    '''
    setattr(model, _HELPER_HOOK_KEY, hooks)
    return last_layer


def add_collect_data_hook(model: HELPER_SUPPORT_MODEL_TYPES,
                      dest: Dict[str, Dict[str, List[torch.Tensor]]], 
                      intermediate_size: int,
                      hidden_size: int) -> int:
    if isinstance(model, HELPER_SUPPORT_MODEL_LIST):
        return _add_collect_data_hook(model, dest, intermediate_size, hidden_size)
    else:
        raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.")


def remove_training_hook(model: HELPER_SUPPORT_MODEL_TYPES):
    hooks = getattr(model, _HELPER_HOOK_KEY)
    for handle in hooks:
        handle.remove()

    setattr(model, _HELPER_HOOK_KEY, None)


def llama_mlp_forward_svd(module, inp, **kwargs):
    if module.config.pretraining_tp > 1:
        raise NotImplementedError
    
    if module.config.pretraining_tp > 1:
        slice = module.intermediate_size // module.config.pretraining_tp
        gate_proj_slices = module.gate_proj.weight.split(slice, dim=0)
        up_proj_slices = module.up_proj.weight.split(slice, dim=0)
        down_proj_slices = module.down_proj.weight.split(slice, dim=1)

        gate_proj = torch.cat(
            [F.linear(inp, gate_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1
        )
        up_proj = torch.cat([F.linear(inp, up_proj_slices[i]) for i in range(module.config.pretraining_tp)], dim=-1)

        intermediate_states = (module.act_fn(gate_proj) * up_proj)
        down_proj = [
            F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(module.config.pretraining_tp)
        ]
        down_proj = sum(down_proj)
    else:
        if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
            down_proj = module.down_proj(module.act_fn(module.gate_proj(inp)) * module.up_proj(inp))
        else:
            if module.gate_proj_use > 0:
                h_gate = module.act_fn(module.gate_proj(inp))
            else:
                if inp.device != module.gate_weight_U_top.device:
                    module.gate_weight_U_top = module.gate_weight_U_top.to(inp.device)
                tmp = torch.nn.functional.linear(inp, module.gate_weight_U_top)
                if tmp.device != module.gate_weight_SVh_top.device:
                    module.gate_weight_SVh_top = module.gate_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    gate_bias = getattr(module, 'gate_bias').to(tmp.device)
                    h_gate = module.act_fn(torch.nn.functional.linear(tmp, module.gate_weight_SVh_top, gate_bias))
                else:
                    h_gate = module.act_fn(torch.nn.functional.linear(tmp, module.gate_weight_SVh_top))
                    
            if module.up_proj_use > 0:
                h_up = module.up_proj(inp)
            else:
                if inp.device != module.up_weight_U_top.device:
                    module.up_weight_U_top = module.up_weight_U_top.to(inp.device)
                tmp = torch.nn.functional.linear(inp, module.up_weight_U_top)
                if tmp.device != module.up_weight_SVh_top.device:
                    module.up_weight_SVh_top = module.up_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    up_bias = getattr(module, 'up_bias').to(tmp.device)
                    h_up = torch.nn.functional.linear(tmp, module.up_weight_SVh_top, up_bias)
                else:
                    h_up = torch.nn.functional.linear(tmp, module.up_weight_SVh_top)
                    
            if h_gate.device != h_up.device:
                h_gate = h_gate.to(h_up.device)
            intermediate_states = h_gate * h_up
            
            if module.down_proj_use > 0:
                down_proj = module.down_proj(intermediate_states)
            else:
                if intermediate_states.device != module.down_weight_U_top.device:
                    module.down_weight_U_top = module.down_weight_U_top.to(intermediate_states.device)
                tmp = torch.nn.functional.linear(intermediate_states, module.down_weight_U_top)
                if tmp.device != module.down_weight_SVh_top.device:
                    module.down_weight_SVh_top = module.down_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    down_bias = getattr(module, 'down_bias').to(tmp.device)
                    down_proj = torch.nn.functional.linear(tmp, module.down_weight_SVh_top, down_bias)
                else:
                    down_proj = torch.nn.functional.linear(tmp, module.down_weight_SVh_top)
                
    return down_proj

def llama_attn_forward_svd(module: torch.nn.Module, 
                           hidden_states: torch.Tensor, 
                           attention_mask: Optional[torch.Tensor] = None, 
                           position_ids: Optional[torch.LongTensor] = None, 
                           past_key_value: Optional[Tuple[torch.Tensor]] = None, 
                           output_attentions: Optional[bool] = False, 
                           use_cache: Optional[bool] = False, 
                           **kwargs):
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
        )
    
    bsz, q_len, _ = hidden_states.size()

    if module.config.pretraining_tp > 1:
        key_value_slicing = (module.num_key_value_heads * module.head_dim) // module.config.pretraining_tp
        query_slices = module.q_proj.weight.split(
            (module.num_heads * module.head_dim) // module.config.pretraining_tp, dim=0
        )
        key_slices = module.k_proj.weight.split(key_value_slicing, dim=0)
        value_slices = module.v_proj.weight.split(key_value_slicing, dim=0)

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(module.config.pretraining_tp)]
        query_states = torch.cat(query_states, dim=-1)

        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(module.config.pretraining_tp)]
        key_states = torch.cat(key_states, dim=-1)

        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(module.config.pretraining_tp)]
        value_states = torch.cat(value_states, dim=-1)

    else:
        ##### Attn q/k decomposition #####
        if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
            query_states = module.q_proj(hidden_states)
            key_states = module.k_proj(hidden_states)
            value_states = module.v_proj(hidden_states)
        else:
            if module.q_proj_use > 0:
                query_states = module.q_proj(hidden_states)
            else:
                if hidden_states.device != module.q_weight_U_top.device:
                    module.q_weight_U_top = module.q_weight_U_top.to(hidden_states.device)
                tmp = torch.nn.functional.linear(hidden_states, module.q_weight_U_top)
                if tmp.device != module.q_weight_SVh_top.device:
                    module.q_weight_SVh_top = module.q_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    q_bias = getattr(module, 'q_bias').to(tmp.device)
                    query_states = torch.nn.functional.linear(tmp, module.q_weight_SVh_top, q_bias)
                else:
                    query_states = torch.nn.functional.linear(tmp, module.q_weight_SVh_top)
                
            if module.k_proj_use > 0:
                key_states = module.k_proj(hidden_states)
            else:
                if hidden_states.device != module.k_weight_U_top.device:
                    module.k_weight_U_top = module.k_weight_U_top.to(hidden_states.device)
                tmp = torch.nn.functional.linear(hidden_states, module.k_weight_U_top)
                if tmp.device != module.k_weight_SVh_top.device:
                    module.k_weight_SVh_top = module.k_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    k_bias = getattr(module, 'k_bias').to(tmp.device)
                    key_states = torch.nn.functional.linear(tmp, module.k_weight_SVh_top, k_bias)
                else:
                    key_states = torch.nn.functional.linear(tmp, module.k_weight_SVh_top)
                
            if module.v_proj_use > 0:
                value_states = module.v_proj(hidden_states)
            else:
                if hidden_states.device != module.v_weight_U_top.device:
                    module.v_weight_U_top = module.v_weight_U_top.to(hidden_states.device)
                tmp = torch.nn.functional.linear(hidden_states, module.v_weight_U_top)
                if tmp.device != module.v_weight_SVh_top.device:
                    module.v_weight_SVh_top = module.v_weight_SVh_top.to(tmp.device)
                
                if kwargs['optim_bias']:
                    v_bias = getattr(module, 'v_bias').to(tmp.device)
                    value_states = torch.nn.functional.linear(tmp, module.v_weight_SVh_top, v_bias)
                else:
                    value_states = torch.nn.functional.linear(tmp, module.v_weight_SVh_top)
                
    query_states = query_states.view(bsz, q_len, module.num_heads, module.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, module.num_key_value_heads, module.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        if module.layer_idx is None:
            raise ValueError(
                f"The cache structure has changed since version v4.36. If you are using {module.__class__.__name__} "
                "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                "with a layer index."
            )
        kv_seq_len += past_key_value.get_usable_length(kv_seq_len, module.layer_idx)


    cos, sin = module.rotary_emb(value_states, seq_len=kv_seq_len)
    if cos.device != position_ids.device:
        position_ids = position_ids.to(cos.device)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
        key_states, value_states = past_key_value.update(key_states, value_states, module.layer_idx, cache_kwargs)
    
    key_states = repeat_kv(key_states, module.num_key_value_groups)
    value_states = repeat_kv(value_states, module.num_key_value_groups)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)

    if attn_weights.size() != (bsz, module.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, module.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        if attn_weights.device != attention_mask.device:
            attn_weights = attn_weights.to(attention_mask.device)
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.bfloat16).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
    if attn_weights.device != value_states.device:
        attn_weights = attn_weights.to(value_states.device)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, module.num_heads, q_len, module.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, module.num_heads, q_len, module.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    
    attn_output = attn_output.reshape(bsz, q_len, module.hidden_size)

    if module.config.pretraining_tp > 1:
        attn_output = attn_output.split(module.hidden_size // module.config.pretraining_tp, dim=2)
        o_proj_slices = module.o_proj.weight.split(module.hidden_size // module.config.pretraining_tp, dim=1)
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(module.config.pretraining_tp)])
    else:
        ##### Attn o decomposition #####
        if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list'] or module.o_proj_use > 0:
            attn_output = module.o_proj(attn_output)
        else:
            if attn_output.device != module.o_weight_U_top.device:
                module.o_weight_U_top = module.o_weight_U_top.to(attn_output.device)
            tmp = torch.nn.functional.linear(attn_output, module.o_weight_U_top)
            if tmp.device != module.o_weight_SVh_top.device:
                module.o_weight_SVh_top = module.o_weight_SVh_top.to(tmp.device)
            
            if kwargs['optim_bias']:
                o_bias = getattr(module, 'o_bias').to(tmp.device)
                attn_output = torch.nn.functional.linear(tmp, module.o_weight_SVh_top, o_bias)
            else:
                attn_output = torch.nn.functional.linear(tmp, module.o_weight_SVh_top)
            
    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value


def add_inference_hook_to_llama(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, 
                                        model: HELPER_SUPPORT_MODEL_TYPES):
    set_helper_state(model, HelperState.Inference)
    hooks = []

    for name, module in model.named_modules():
        if not isinstance(module, (LlamaMLP, LlamaAttention)):
            continue
        layer_idx = int(name.split(".")[-2])
        
        if isinstance(module, (LlamaMLP)):
            module.forward = functools.partial(
                llama_mlp_forward_svd,
                module,
                layer_idx=layer_idx,
                pruned_layer_idx_list=pruned_layer_idx_list,
                module_name=name,
                use_trunc=use_trunc,
                use_bias=use_bias,
                optim_bias=optim_bias
            )
        elif isinstance(module, (LlamaAttention)):
            hitter_dict_cur_layer_q = None
            hitter_dict_cur_layer_k = None
            
            module.forward = functools.partial(
                llama_attn_forward_svd,
                module,
                layer_idx=layer_idx,
                pruned_layer_idx_list=pruned_layer_idx_list,
                hitter_dict_q=hitter_dict_cur_layer_q,
                hitter_dict_k=hitter_dict_cur_layer_k,
                use_trunc=use_trunc,
                use_bias=use_bias,
                optim_bias=optim_bias
            )

def proj_svd_llm(input, A, B, b=None):
    if input.device != A.device:
        A = A.to(input.device)
    tmp = torch.nn.functional.linear(input, A)
    if tmp.device != B.device:
        B = B.to(tmp.device)
    if b is not None:
        b = b.to(tmp.device)
        res = torch.nn.functional.linear(tmp, B, b)
    else:
        res = torch.nn.functional.linear(tmp, B)
    return res


def opt_attn_forward_svd(module: torch.nn.Module, 
                         hidden_states: torch.Tensor,
                         key_value_states: Optional[torch.Tensor] = None,
                         past_key_value: Optional[Tuple[torch.Tensor]] = None,
                         attention_mask: Optional[torch.Tensor] = None,
                         layer_head_mask: Optional[torch.Tensor] = None,
                         output_attentions: bool = False,
                         **kwargs):
    """Input shape: Batch x Time x Channel"""

    # if key_value_states are provided this layer is used as a cross-attention layer
    # for the decoder
    is_cross_attention = key_value_states is not None

    bsz, tgt_len, _ = hidden_states.size()

    if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list']:
        # get query proj
        query_states = module.q_proj(hidden_states) * module.scaling
        # get key, value proj
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = module._shape(module.k_proj(key_value_states), -1, bsz)
            value_states = module._shape(module.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = module._shape(module.k_proj(hidden_states), -1, bsz)
            value_states = module._shape(module.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = module._shape(module.k_proj(hidden_states), -1, bsz)
            value_states = module._shape(module.v_proj(hidden_states), -1, bsz)
    else:
        # get query proj
        if module.q_proj_use > 0:
            query_states = module.q_proj(hidden_states) * module.scaling
        else:
            if kwargs['optim_bias']:
                q_bias = getattr(module, 'q_proj_bias')
                query_states = proj_svd_llm(hidden_states, module.q_proj_weight_U_top, module.q_proj_weight_SVh_top, q_bias) * module.scaling
            else:
                query_states = proj_svd_llm(hidden_states, module.q_proj_weight_U_top, module.q_proj_weight_SVh_top) * module.scaling
            
        # get key, value proj
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            # k
            if module.k_proj_use > 0:
                key_states = module._shape(module.k_proj(key_value_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    k_bias = getattr(module, 'k_proj_bias')
                    key_states = module._shape(proj_svd_llm(key_value_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top, k_bias), -1, bsz)
                else:
                    key_states = module._shape(proj_svd_llm(key_value_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top), -1, bsz)
            # v
            if module.v_proj_use > 0:
                value_states = module._shape(module.v_proj(key_value_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    v_bias = getattr(module, 'v_proj_bias')
                    value_states = module._shape(proj_svd_llm(key_value_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top, v_bias), -1, bsz)
                else:
                    value_states = module._shape(proj_svd_llm(key_value_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            # k
            if module.k_proj_use > 0:
                key_states = module._shape(module.k_proj(hidden_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    k_bias = getattr(module, 'k_proj_bias')
                    key_states = module._shape(proj_svd_llm(hidden_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top, k_bias), -1, bsz)
                else:
                    key_states = module._shape(proj_svd_llm(hidden_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top), -1, bsz)
            # v
            if module.v_proj_use > 0:
                value_states = module._shape(module.v_proj(hidden_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    v_bias = getattr(module, 'v_proj_bias')
                    value_states = module._shape(proj_svd_llm(hidden_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top, v_bias), -1, bsz)
                else:
                    value_states = module._shape(proj_svd_llm(hidden_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            # k
            if module.k_proj_use > 0:
                key_states = module._shape(module.k_proj(hidden_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    k_bias = getattr(module, 'k_proj_bias')
                    key_states = module._shape(proj_svd_llm(hidden_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top, k_bias), -1, bsz)
                else:
                    key_states = module._shape(proj_svd_llm(hidden_states, module.k_proj_weight_U_top, module.k_proj_weight_SVh_top), -1, bsz)
            # v
            if module.v_proj_use > 0:
                value_states = module._shape(module.v_proj(hidden_states), -1, bsz)
            else:
                if kwargs['optim_bias']:
                    v_bias = getattr(module, 'v_proj_bias')
                    value_states = module._shape(proj_svd_llm(hidden_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top, v_bias), -1, bsz)
                else:
                    value_states = module._shape(proj_svd_llm(hidden_states, module.v_proj_weight_U_top, module.v_proj_weight_SVh_top), -1, bsz)

    if module.is_decoder:
        # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
        # Further calls to cross_attention layer can then reuse all cross-attention
        # key/value_states (first "if" case)
        # if uni-directional module-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
        # all previous decoder key/value_states. Further calls to uni-directional module-attention
        # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
        # if encoder bi-directional module-attention `past_key_value` is always `None`
        past_key_value = (key_states, value_states)

    proj_shape = (bsz * module.num_heads, -1, module.head_dim)
    query_states = module._shape(query_states, tgt_len, bsz).view(*proj_shape)
    key_states = key_states.view(*proj_shape)
    value_states = value_states.view(*proj_shape)

    src_len = key_states.size(1)
    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

    if attn_weights.size() != (bsz * module.num_heads, tgt_len, src_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz * module.num_heads, tgt_len, src_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, tgt_len, src_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights.view(bsz, module.num_heads, tgt_len, src_len) + attention_mask
        attn_weights = torch.max(
            attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
        )
        attn_weights = attn_weights.view(bsz * module.num_heads, tgt_len, src_len)

    # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
    if attn_weights.dtype == torch.float16:
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
    else:
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

    if layer_head_mask is not None:
        if layer_head_mask.size() != (module.num_heads,):
            raise ValueError(
                f"Head mask for a single layer should be of size {(module.num_heads,)}, but is"
                f" {layer_head_mask.size()}"
            )
        attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, module.num_heads, tgt_len, src_len)
        attn_weights = attn_weights.view(bsz * module.num_heads, tgt_len, src_len)

    if output_attentions:
        # this operation is a bit awkward, but it's required to
        # make sure that attn_weights keeps its gradient.
        # In order to do so, attn_weights have to be reshaped
        # twice and have to be reused in the following
        attn_weights_reshaped = attn_weights.view(bsz, module.num_heads, tgt_len, src_len)
        attn_weights = attn_weights_reshaped.view(bsz * module.num_heads, tgt_len, src_len)
    else:
        attn_weights_reshaped = None

    attn_probs = nn.functional.dropout(attn_weights, p=module.dropout, training=module.training)

    attn_output = torch.bmm(attn_probs, value_states)

    if attn_output.size() != (bsz * module.num_heads, tgt_len, module.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, module.num_heads, tgt_len, module.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.view(bsz, module.num_heads, tgt_len, module.head_dim)
    attn_output = attn_output.transpose(1, 2)

    # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
    # partitioned aross GPUs when using tensor-parallelism.
    attn_output = attn_output.reshape(bsz, tgt_len, module.embed_dim)

    ##### Attn out decomposition #####
    if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list'] or module.out_proj_use > 0:
        attn_output = module.out_proj(attn_output)
    else:
        if kwargs['optim_bias']:
            out_bias = getattr(module, 'out_proj_bias')
            attn_output = proj_svd_llm(attn_output, module.out_proj_weight_U_top, module.out_proj_weight_SVh_top, out_bias)
        else:
            attn_output = proj_svd_llm(attn_output, module.out_proj_weight_U_top, module.out_proj_weight_SVh_top)

    return attn_output, attn_weights_reshaped, past_key_value

def opt_decoder_forward_svd(module, 
                                hidden_states: torch.Tensor,
                                attention_mask: Optional[torch.Tensor] = None,
                                layer_head_mask: Optional[torch.Tensor] = None,
                                past_key_value: Optional[Tuple[torch.Tensor]] = None,
                                output_attentions: Optional[bool] = False,
                                use_cache: Optional[bool] = False,
                                **kwargs):
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
            `(encoder_attention_heads,)`.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states

    # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
    if module.do_layer_norm_before:
        hidden_states = module.self_attn_layer_norm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = module.self_attn(
        hidden_states=hidden_states,
        past_key_value=past_key_value,
        attention_mask=attention_mask,
        layer_head_mask=layer_head_mask,
        output_attentions=output_attentions,
    )
    hidden_states = nn.functional.dropout(hidden_states, p=module.dropout, training=module.training)
    hidden_states = residual + hidden_states

    # 350m applies layer norm AFTER attention
    if not module.do_layer_norm_before:
        hidden_states = module.self_attn_layer_norm(hidden_states)

    # Fully Connected
    hidden_states_shape = hidden_states.shape
    hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
    residual = hidden_states

    # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
    if module.do_layer_norm_before:
        hidden_states = module.final_layer_norm(hidden_states)

    if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list'] or module.fc1_use > 0:
        hidden_states = module.fc1(hidden_states)
    else:
        if kwargs['optim_bias']:
            fc1_bias = getattr(module, 'fc1_bias')
            hidden_states = proj_svd_llm(hidden_states, module.fc1_weight_U_top, module.fc1_weight_SVh_top, fc1_bias)
        else:
            hidden_states = proj_svd_llm(hidden_states, module.fc1_weight_U_top, module.fc1_weight_SVh_top)
    hidden_states = module.activation_fn(hidden_states)

    if kwargs['layer_idx'] not in kwargs['pruned_layer_idx_list'] or module.fc2_use > 0:
        hidden_states = module.fc2(hidden_states)
    else:
        if kwargs['optim_bias']:
            fc2_bias = getattr(module, 'fc2_bias')
            hidden_states = proj_svd_llm(hidden_states, module.fc2_weight_U_top, module.fc2_weight_SVh_top, fc2_bias)
        else:
            hidden_states = proj_svd_llm(hidden_states, module.fc2_weight_U_top, module.fc2_weight_SVh_top)
    hidden_states = nn.functional.dropout(hidden_states, p=module.dropout, training=module.training)

    hidden_states = (residual + hidden_states).view(hidden_states_shape)

    # 350m applies layer norm AFTER attention
    if not module.do_layer_norm_before:
        hidden_states = module.final_layer_norm(hidden_states)

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs


def add_inference_hook_to_opt(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, 
                                        model: HELPER_SUPPORT_MODEL_TYPES):
    set_helper_state(model, HelperState.Inference)
    hooks = []

    for name, module in model.named_modules():
        if not isinstance(module, (OPTDecoderLayer, OPTAttention)):
            continue
        
        if isinstance(module, OPTDecoderLayer):
            layer_idx = int(name.split(".")[-1])
            module.forward = functools.partial(
                opt_decoder_forward_svd,
                module,
                layer_idx=layer_idx,
                pruned_layer_idx_list=pruned_layer_idx_list,
                module_name=name,
                optim_bias=optim_bias
            )
        elif isinstance(module, OPTAttention):
            layer_idx = int(name.split(".")[-2])
            module.forward = functools.partial(
                opt_attn_forward_svd,
                module,
                layer_idx=layer_idx,
                pruned_layer_idx_list=pruned_layer_idx_list,
                optim_bias=optim_bias
            )


def add_inference_hook(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, 
                       model: HELPER_SUPPORT_MODEL_TYPES):
    if isinstance(model, HELPER_SUPPORT_MODEL_LIST):
        if isinstance(model, LlamaPreTrainedModel):
            add_inference_hook_to_llama(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, model)
        elif isinstance(model, OPTPreTrainedModel):
            add_inference_hook_to_opt(optim_bias, use_trunc, use_bias, pruned_layer_idx_list, model)
    else:
        raise NotImplementedError(f"Only support {HELPER_SUPPORT_MODEL_LIST}.")

