import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
from datasets import load_dataset
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers.models import Gemma3PreTrainedModel, Qwen3PreTrainedModel, GPT2PreTrainedModel, LlamaPreTrainedModel
from torch.utils.data import DataLoader
import math
from math import sqrt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


def construct_block_diagonal_householder(vector_list, matrix_size, device="cpu"):
    """
    Given a list of unit vectors, construct a block-diagonal matrix
    where each diagonal block is a block_size×block_size Householder matrix generated by the corresponding vector.
    
    Args:
        vector_list (list of numpy.ndarray): List of unit vectors, each being a 2D array.
        
    Returns:
        block_diag_matrix (numpy.ndarray): The block-diagonal matrix.
    """
    block_diag_matrix = torch.eye(matrix_size, device=device)
    
    for i, vec in enumerate(vector_list):
        v_norm = torch.norm(vec, p=2)
        if v_norm == 0:
            raise ValueError("Input vector cannot be a zero vector.")
        v_unit = vec / v_norm 
        
        # Householder: H = I - 2 * v_unit * v_unit^T
        I = torch.eye(vec.shape[0], device=device)
        H_block = I - 2 * torch.outer(v_unit, v_unit)

        start_idx = vector_list[0].shape[0] * i
        end_idx = vector_list[0].shape[0] * (i + 1)
        block_diag_matrix[start_idx:end_idx, start_idx:end_idx] = H_block
    
    return block_diag_matrix

def create_obfus_matrix(dim, block_size, device="cpu", layernorm=False):
    if dim % block_size != 0:
        raise ValueError(f"dim({dim}) must be a multiple of block_size({block_size})")

    h = None
    indices_list = []
    v_list = []
    p = 0.99

    for _ in range(math.ceil(math.log(dim, block_size) + math.log(math.log(dim**2/(1-p)), block_size))):
        vectors = []
        random_indices = torch.randperm(dim, device=device)
        indices_list.append(random_indices)
        for _ in range(dim // block_size):
            v = torch.randn(block_size, device=device)
            if layernorm:
                v[-1] = -v[:-1].sum()
            vectors.append(v)
        v_list.append(vectors)
        if h is None:
            h = construct_block_diagonal_householder(vectors, dim, device)[:, random_indices]
        else:
            h = (h @ construct_block_diagonal_householder(vectors, dim, device))[:, random_indices]
    return h, v_list, indices_list

def _obfuscate_linear_left(linear, pi, obfus_bias=False):
    """
    linear(x) = x @ W^T + b
    W' = W @ pi^T
    b' = pi @ b
    linear_obf(x) = x @ pi @ W^T + pi @ b
    """
    W_obf = linear.weight.data @ pi.t()
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = pi @ (linear.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf
    

def _obfuscate_linear_right(linear, pi, obfus_bias=False):
    """
    linear(x) = x @ W^T + b
    W' = pi^T @ W
    b' = b @ pi
    linear_obf(x) = x @ W^T @ pi + b @ pi
    """
    W_obf = pi.t() @ linear.weight.data
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = (linear.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf

def _obfuscate_conv1d_left(conv, pi, obfus_bias=False):
    """
    conv1d(x) = x @ W + b
    W' = pi @ W
    b' = pi @ b
    """
    W_obf = pi @ conv.weight.data
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = pi @ (conv.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf
    

def _obfuscate_conv1d_right(conv, pi, obfus_bias=False):
    """
    conv1d(x) = x @ W + b
    W' = W @ pi
    b' = b @ pi
    """
    W_obf = conv.weight.data @ pi
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = (conv.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf

def _obfuscate_model(model, block_size, obf_score):
    # Why 5e-8? For float16, float32, float64, and bfloat16, 5e-8 will not be truncated to 0.
    for name, m in model.named_modules():
        if "norm" in name:
            assert torch.tensor(5e-8, dtype=m.weight.data.dtype).item() != 0.
            m.weight.data = torch.where(m.weight.data==0., 5e-8, m.weight.data)
    
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    
    decoderLayer0 = model.model.layers[0]
    mlp0 = decoderLayer0.mlp

    _obfuscate_linear_right(mlp0.down_proj, pi0, True)


    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)

    for decoderLayer in model.model.layers:
        if decoderLayer is model.model.layers[0]:
            continue
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        input_layernorm = decoderLayer.input_layernorm
        post_attention_layernorm = decoderLayer.post_attention_layernorm

        input_norm_weights = input_layernorm.weight.data
        D1 = torch.diag(input_norm_weights)
        D1_inv = torch.diag(torch.reciprocal(input_norm_weights))
        
        post_attn_norm_weights = post_attention_layernorm.weight.data
        D2 = torch.diag(post_attn_norm_weights)
        D2_inv = torch.diag(torch.reciprocal(post_attn_norm_weights))

        # obfuscate the attention layer
        _obfuscate_linear_left(attn.q_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.k_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.v_proj, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_linear_right(attn.o_proj, pi, True)

        # obfuscate the mlp layer
        _obfuscate_linear_left(mlp.gate_proj, D2_inv @ pi.t() @ D2, False)
        _obfuscate_linear_left(mlp.up_proj, D2_inv @ pi.t() @ D2, False)

        _obfuscate_linear_right(mlp.down_proj, pi, True)
    
    if obf_score:
        norm = model.model.norm
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")

        norm_weights = norm.weight.data
        D = torch.diag(norm_weights)
        D_inv = torch.diag(torch.reciprocal(norm_weights))
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
    
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def _obfuscate_model_gpt2(model, block_size, obf_score):
    # Why 5e-8? For float16, float32, float64, and bfloat16, 5e-8 will not be truncated to 0.
    for name, m in model.named_modules():
        if "ln_" in name:
            assert torch.tensor(5e-8, dtype=m.weight.data.dtype).item() != 0.
            m.weight.data = torch.where(m.weight.data==0., 5e-8, m.weight.data)
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    decoderLayer0 = model.transformer.h[0]
    mlp0 = decoderLayer0.mlp
    
    _obfuscate_conv1d_right(mlp0.c_proj, pi0, True)

    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device, layernorm=True)

    for decoderLayer in model.transformer.h:
        if decoderLayer is model.transformer.h[0]:
            continue
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp
        
        ln_1 = decoderLayer.ln_1
        ln_2 = decoderLayer.ln_2
        
        ln_1_weights = ln_1.weight.data
        D1 = torch.diag(ln_1_weights)
        D1_inv = torch.diag(torch.reciprocal(ln_1_weights))
        if hasattr(ln_1, 'bias') and ln_1.bias is not None:
            ln_1.bias.data = ((ln_1.bias.data.view(1, -1)) @ D1_inv @ pi @ D1).squeeze()
        
        ln_2_weights = ln_2.weight.data
        D2 = torch.diag(ln_2_weights)
        D2_inv = torch.diag(torch.reciprocal(ln_2_weights))
        if hasattr(ln_2, 'bias') and ln_2.bias is not None:
            ln_2.bias.data = ((ln_2.bias.data.view(1, -1)) @ D2_inv @ pi @ D2).squeeze()
        
        # obfuscate the attention layer
        _obfuscate_conv1d_left(attn.c_attn, D1_inv @ pi.t() @ D1, False)
        if attn.is_cross_attention:
            _obfuscate_conv1d_left(attn.q_attn, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_conv1d_right(attn.c_proj, pi, True)
        
        # obfuscate the mlp layer
        _obfuscate_conv1d_left(mlp.c_fc, D2_inv @ pi.t() @ D2, False)
        
        _obfuscate_conv1d_right(mlp.c_proj, pi, True)
    
    norm = model.transformer.ln_f
    norm_weights = norm.weight.data
    D = torch.diag(norm_weights)
    D_inv = torch.diag(torch.reciprocal(norm_weights))
    if hasattr(norm, 'bias') and norm.bias is not None:
        norm.bias.data = ((norm.bias.data.view(1, -1)) @ D_inv @ pi @ D).squeeze()
    
    if obf_score:
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")  
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
        
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def _obfuscate_model_gemma3(model, block_size, obf_score):
    """ 
    Gemma3's RMSNorm.weight contains -1 elements (its forward function first adds 1.0 before computation); 
    we replace them with the number closest to -1.
    
    It can be observed that float16 and bfloat16 have poor precision. 
    The most recommended approach is to modify Gemma3's RMSNorm forward function: 
    avoid multiplying by (1. + RMSNorm.weight).
    
    Instead, set RMSNorm.weight = RMSNorm.weight + 1., 
    then replace any zero values with the number closest to zero 
    (since floating-point numbers can represent values near zero more precisely), 
    and then perform the multiplication.
    """
    for name, m in model.named_modules():
        if "norm" in name:
            if m.weight.data.dtype == torch.float32:
                assert (1. + torch.tensor(-1.00000011920928955078125, dtype=torch.float32)).item() != 0.
                m.weight.data = torch.where(m.weight.data==-1., -1.00000011920928955078125, m.weight.data)
            elif m.weight.data.dtype == torch.float64:
                assert (1. + torch.tensor(-1.0000000000000002220446049250313, dtype=torch.float64)).item() != 0.
                m.weight.data = torch.where(m.weight.data==-1, -1.0000000000000002220446049250313, m.weight.data)
            elif m.weight.data.dtype == torch.float16:
                # m.weight.data = torch.where(m.weight.data==-1., -1.0009765625, m.weight.data)
                raise ValueError("The precision of float16 for numbers near 1.0 is too poor; please use another method.")
            elif m.weight.data.dtype == torch.bfloat16:
                # m.weight.data = torch.where(m.weight.data==-1., -1.0078125, m.weight.data)
                raise ValueError("The precision of bfloat16 for numbers near 1.0 is too poor; please use another method.")
    
    
    def obf_linear(linear):
        n = linear.out_features
        Q = torch.eye(n, device=linear.weight.device)
        for _ in range(math.ceil(math.log(n, 2) + math.log(100, 2))):
            Q0 = torch.eye(n, device=linear.weight.device)
            for i in range(0, n, 2):
                if i + 1 < n:
                    v = torch.randn(2, device=linear.weight.device)
                    v = v / torch.norm(v, p=2)
                    H = torch.eye(2, device=linear.weight.device) - 2 * torch.outer(v, v)
                    Q0[i:i+2, i:i+2] = H
                elif i < n: 
                    pass
            
            Q0 = Q0[:, torch.randperm(linear.out_features, device=linear.weight.device)]
            Q0 = Q0[torch.randperm(linear.out_features, device=linear.weight.device)]
            Q = Q0 @ Q
        
        new_linear = nn.Linear(linear.in_features, linear.out_features, bias=linear.bias is not None)
        new_linear.weight.data = Q @ linear.weight.data
        if linear.bias is not None:
            new_linear.bias.data = (linear.bias.data.view(1, -1) @ Q).squeeze()
        return new_linear
    
    def init_branchs(m):
        m.o_proj_branchs = nn.ModuleList([
            None if i == 0 else obf_linear(m.model.layers[i].self_attn.o_proj)
            for i in range(m.model.config.num_hidden_layers)
        ])
        
        m.down_proj_branchs = nn.ModuleList([
            obf_linear(m.model.layers[i].mlp.down_proj)
            for i in range(m.model.config.num_hidden_layers)
        ])

    init_branchs(model)
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    
    decoderLayer0 = model.model.layers[0]
    mlp0 = decoderLayer0.mlp
    
    post_feedforward_layernorm = decoderLayer0.post_feedforward_layernorm

    D0 = torch.diag(1.0 + post_feedforward_layernorm.weight.data)

    _obfuscate_linear_right(mlp0.down_proj, D0 @ pi0, True)
    post_feedforward_layernorm.weight.data = torch.zeros_like(post_feedforward_layernorm.weight.data,
                                                               dtype=post_feedforward_layernorm.weight.data.dtype,
                                                               device=post_feedforward_layernorm.weight.data.device)

    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)

    for decoderLayer in model.model.layers:
        if decoderLayer is model.model.layers[0]:
            continue
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        input_layernorm = decoderLayer.input_layernorm
        post_attention_layernorm = decoderLayer.post_attention_layernorm
        pre_feedforward_layernorm = decoderLayer.pre_feedforward_layernorm
        post_feedforward_layernorm = decoderLayer.post_feedforward_layernorm

        input_norm_weights = input_layernorm.weight.data + 1.0
        D1 = torch.diag(input_norm_weights)
        D1_inv = torch.diag(torch.reciprocal(input_norm_weights))
        if torch.any(torch.isnan(D1_inv)):
            raise ValueError()
        
        D2 = torch.diag(1.0 + post_attention_layernorm.weight.data)

        pre_ffn_norm_weights = pre_feedforward_layernorm.weight.data + 1.0
        D3 = torch.diag(pre_ffn_norm_weights)
        D3_inv = torch.diag(torch.reciprocal(pre_ffn_norm_weights))
        if torch.any(torch.isnan(D3_inv)):
            raise ValueError()

        D4 = torch.diag(1.0 + post_feedforward_layernorm.weight.data)

        # obfuscate the attention layer
        _obfuscate_linear_left(attn.q_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.k_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.v_proj, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_linear_right(attn.o_proj, D2 @ pi, True)
        post_attention_layernorm.weight.data = torch.zeros_like(post_attention_layernorm.weight.data,
                                                               dtype=post_attention_layernorm.weight.data.dtype,
                                                               device=post_attention_layernorm.weight.data.device)

        # obfuscate the mlp layer
        _obfuscate_linear_left(mlp.gate_proj, D3_inv @ pi.t() @ D3, False)
        _obfuscate_linear_left(mlp.up_proj, D3_inv @ pi.t() @ D3, False)

        _obfuscate_linear_right(mlp.down_proj, D4 @ pi, True)
        post_feedforward_layernorm.weight.data = torch.zeros_like(post_feedforward_layernorm.weight.data,
                                                               dtype=post_feedforward_layernorm.weight.data.dtype,
                                                               device=post_feedforward_layernorm.weight.data.device)
    
    if obf_score:
        norm = model.model.norm
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")

        norm_weights = norm.weight.data + 1.0
        D = torch.diag(norm_weights)
        D_inv = torch.diag(torch.reciprocal(norm_weights))
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
    
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def obfuscate_model(model, block_size=2, obf_score=False):
    if isinstance(model, GPT2PreTrainedModel):
        return _obfuscate_model_gpt2(model, block_size, obf_score)
    elif isinstance(model, Gemma3PreTrainedModel):
        return _obfuscate_model_gemma3(model, block_size, obf_score)
    else:
        return _obfuscate_model(model, block_size, obf_score)
    

def row_restore_perm(pre_model_mat, model_mat, threshold=0.0):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu, pre_model_mat_cpu)
    perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, row in enumerate(model_mat_cpu):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[perm[i]] = model_mat_cpu[i]
            success.append(perm[i])
    for i in range(len(restored_matrix)):
        if i not in success:
            restored_matrix[i] = pre_model_mat_cpu[i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix

def col_restore_perm(pre_model_mat, model_mat, threshold=0.0, perm = None):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu.T, pre_model_mat_cpu.T)
    if perm is None:
        perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, col in enumerate(model_mat_cpu.T):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[:,perm[i]] = model_mat_cpu[:,i]
            success.append(perm[i])
    for i in range(len(restored_matrix[0])):
        if i not in success:
            restored_matrix[:,i] = pre_model_mat_cpu[:,i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix
    

def fix_factor(num, mini=1.0, max=6.0):
    return num

def rescale_weight(ob_weight, pre_weight, is_row=False):
    if is_row:
        for i in range(ob_weight.data.shape[0]):
            ratio_q = fix_factor(math.sqrt(torch.var(ob_weight.data[i]).item()/torch.var(pre_weight.data[i]).item()))
            ob_weight.data[i] /= ratio_q
    else:
        for i in range(ob_weight.data.shape[1]):
            ratio_q = fix_factor(math.sqrt(torch.var(ob_weight.data[:,i]).item()/torch.var(pre_weight.data[:,i]).item()))
            ob_weight.data[:,i] /= ratio_q
    return ob_weight

def attack(model, pre_model, obf_score=False):
    if isinstance(model, GPT2PreTrainedModel):
        perm = None
        inv_perm = None
        for decoderLayer, pre_decoderLayer in zip(model.transformer.h, pre_model.transformer.h):
            attn = decoderLayer.attn
            mlp = decoderLayer.mlp
            ln_1 = decoderLayer.ln_1
            ln_2 = decoderLayer.ln_2
            
            D1 = torch.diag(ln_1.weight.data)
            D1_inv = torch.diag(1.0 / (ln_1.weight.data))
            
            D2 = torch.diag(ln_2.weight.data)
            D2_inv = torch.diag(1.0 / (ln_2.weight.data))
            
            if decoderLayer is model.transformer.h[0]:
                perm, _, _ = col_restore_perm(pre_decoderLayer.mlp.c_proj.weight.data, mlp.c_proj.weight.data)
                inv_perm = torch.argsort(torch.tensor(perm))
                
                mlp.c_proj.weight.data = mlp.c_proj.weight.data[:, inv_perm]
                
                continue
            
            if attn.is_cross_attention:
                pre_wq = pre_decoderLayer.attn.q_attn.weight.data
                ob_wq = attn.q_attn.weight.data 
                ob_wk, ob_wv = attn.c_attn.weight.data.chunk(2, dim=1)
            else:
                pre_wq, pre_wk, pre_wv = pre_decoderLayer.attn.c_attn.weight.data.chunk(3, dim=1)
                ob_wq, ob_wk, ob_wv = attn.c_attn.weight.data.chunk(3, dim=1)
                
            perm, _, _ = col_restore_perm(pre_decoderLayer.attn.c_proj.weight.data, attn.c_proj.weight.data)
            inv_perm = torch.argsort(torch.tensor(perm))
            
            if hasattr(ln_1, 'bias') and ln_1.bias is not None:
                ln_1.bias.data = (((ln_1.bias.data.view(1, -1)) @ D1_inv)[:,inv_perm] @ D1).squeeze()
            if hasattr(ln_2, 'bias') and ln_2.bias is not None:
                ln_2.bias.data = (((ln_2.bias.data.view(1, -1)) @ D2_inv)[:, inv_perm] @ D2).squeeze()
            
            restore_wq = D1_inv @ (D1 @ ob_wq)[inv_perm]
            restore_wk = D1_inv @ (D1 @ ob_wk)[inv_perm]
            restore_wv = D1_inv @ (D1 @ ob_wv)[inv_perm]
            if attn.is_cross_attention:
                attn.q_attn.weight.data = restore_wq
                attn.c_attn.weight.data = torch.cat([restore_wk, restore_wv], dim=1)
            else:
                attn.c_attn.weight.data = torch.cat([restore_wq, restore_wk, restore_wv], dim=1)
            attn.c_proj.weight.data = attn.c_proj.weight.data[:, inv_perm]
            
            if attn.c_proj.bias is not None:
                attn.c_proj.bias.data = attn.c_proj.bias.data[inv_perm]
            
            mlp.c_fc.weight.data = D2_inv @ (D2 @ mlp.c_fc.weight.data)[inv_perm]
            
            mlp.c_proj.weight.data = mlp.c_proj.weight.data[:, inv_perm]
            if mlp.c_proj.bias is not None:
                mlp.c_proj.bias.data = mlp.c_proj.bias.data[inv_perm]
        
        norm = model.transformer.ln_f
        norm_weights = norm.weight.data
        D = torch.diag(norm_weights)
        D_inv = torch.diag(torch.reciprocal(norm_weights))
        if hasattr(norm, 'bias') and norm.bias is not None:
            norm.bias.data = (((norm.bias.data.view(1, -1)) @ D_inv)[inv_perm] @ D).squeeze()
        if obf_score:
            model.score.weight.data = (model.score.weight.data @ D)[:, inv_perm] @ D1_inv
    elif isinstance(model, Gemma3PreTrainedModel):
        perm = None
        inv_perm = None
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            if decoderLayer is model.model.layers[0]:
                ob_down = mlp.down_proj.weight.data
                
                pre_down = pre_decoderLayer.mlp.down_proj.weight.data
                
                perm, _, _ = row_restore_perm(pre_down, ob_down)
                inv_perm = torch.argsort(torch.tensor(perm))
                
                mlp.down_proj.weight.data = ob_down[inv_perm]
                
                continue
            
            input_layernorm = decoderLayer.input_layernorm
            pre_feedforward_layernorm = decoderLayer.pre_feedforward_layernorm
            
            D1 = torch.diag(input_layernorm.weight.data + 1.0)
            D1_inv = torch.diag(torch.reciprocal(input_layernorm.weight.data + 1.0))
            if torch.any(torch.isnan(D1_inv)):
                raise ValueError()
            
            D2 = torch.diag(pre_feedforward_layernorm.weight.data + 1.0)
            D2_inv = torch.diag(torch.reciprocal(pre_feedforward_layernorm.weight.data + 1.0))
            if torch.any(torch.isnan(D2_inv)):
                raise ValueError()
            
            pre_wo = pre_decoderLayer.self_attn.o_proj.weight.data
            
            ob_wq = attn.q_proj.weight.data @ D1
            ob_wk = attn.k_proj.weight.data @ D1
            ob_wv = attn.v_proj.weight.data @ D1
            ob_wo = attn.o_proj.weight.data
            ob_gate = mlp.gate_proj.weight.data @ D2
            ob_up = mlp.up_proj.weight.data @ D2
            ob_down = mlp.down_proj.weight.data

            perm, _, _ = row_restore_perm(pre_wo, ob_wo)
            inv_perm = torch.argsort(torch.tensor(perm))
            
            attn.q_proj.weight.data = ob_wq[:, inv_perm] @ D1_inv
            attn.k_proj.weight.data = ob_wk[:, inv_perm] @ D1_inv
            attn.v_proj.weight.data = ob_wv[:, inv_perm] @ D1_inv
            attn.o_proj.weight.data = ob_wo[inv_perm]
            if attn.o_proj.bias is not None:
                attn.o_proj.bias.data = attn.o_proj.bias.data[inv_perm]
            
            mlp.gate_proj.weight.data = ob_gate[:, inv_perm] @ D2_inv
            mlp.up_proj.weight.data = ob_up[:, inv_perm] @ D2_inv
            mlp.down_proj.weight.data = ob_down[inv_perm]
            if mlp.down_proj.bias is not None:
                mlp.down_proj.bias.data = mlp.down_proj.bias.data[inv_perm]
            
        if obf_score:
            norm = model.model.norm
            norm_weights = norm.weight.data
            D = torch.diag(norm_weights + 1.0)
            D_inv = torch.diag(torch.reciprocal(norm_weights + 1.0))
            if torch.any(torch.isnan(D_inv)):
                raise ValueError()
            
            model.score.weight.data = (model.score.weight.data @ D)[:, inv_perm] @ D1_inv
    else:
        perm = None
        inv_perm = None
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            if decoderLayer is model.model.layers[0]:
                ob_down = mlp.down_proj.weight.data
                
                pre_down = pre_decoderLayer.mlp.down_proj.weight.data
                
                perm, _, _ = row_restore_perm(pre_down, ob_down)
                inv_perm = torch.argsort(torch.tensor(perm))
                
                mlp.down_proj.weight.data = ob_down[inv_perm]
                
                continue
            
            input_layernorm = decoderLayer.input_layernorm
            post_attention_layernorm = decoderLayer.post_attention_layernorm

            D1 = torch.diag(input_layernorm.weight.data)
            D1_inv = torch.diag(1.0 / (input_layernorm.weight.data))
            
            D2 = torch.diag(post_attention_layernorm.weight.data)
            D2_inv = torch.diag(1.0 / (post_attention_layernorm.weight.data))
            
            pre_wo = pre_decoderLayer.self_attn.o_proj.weight.data
            
            ob_wq = attn.q_proj.weight.data @ D1
            ob_wk = attn.k_proj.weight.data @ D1
            ob_wv = attn.v_proj.weight.data @ D1
            ob_wo = attn.o_proj.weight.data
            ob_gate = mlp.gate_proj.weight.data @ D2
            ob_up = mlp.up_proj.weight.data @ D2
            ob_down = mlp.down_proj.weight.data

            perm, _, _ = row_restore_perm(pre_wo, ob_wo)
            inv_perm = torch.argsort(torch.tensor(perm))
            
            attn.q_proj.weight.data = ob_wq[:, inv_perm] @ D1_inv
            attn.k_proj.weight.data = ob_wk[:, inv_perm] @ D1_inv
            attn.v_proj.weight.data = ob_wv[:, inv_perm] @ D1_inv
            attn.o_proj.weight.data = ob_wo[inv_perm]
            if attn.o_proj.bias is not None:
                attn.o_proj.bias.data = attn.o_proj.bias.data[inv_perm]
            
            mlp.gate_proj.weight.data = ob_gate[:, inv_perm] @ D2_inv
            mlp.up_proj.weight.data = ob_up[:, inv_perm] @ D2_inv
            mlp.down_proj.weight.data = ob_down[inv_perm]
            if mlp.down_proj.bias is not None:
                mlp.down_proj.bias.data = mlp.down_proj.bias.data[inv_perm]
            
        if obf_score:
            norm = model.model.norm
            norm_weights = norm.weight.data
            D = torch.diag(norm_weights)
            D_inv = torch.diag(torch.reciprocal(norm_weights))
            
            model.score.weight.data = (model.score.weight.data @ D)[:, inv_perm] @ D1_inv
            
    return model