import torch
import torch.nn as nn
import argparse
from tqdm import tqdm
from transformers.models import GPT2PreTrainedModel, Gemma3PreTrainedModel
import math
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


def construct_block_diagonal_householder(vector_list, matrix_size, device="cpu"):
    """
    Given a list of unit vectors, construct a block-diagonal matrix
    where each diagonal block is a block_size×block_size Householder matrix generated by the corresponding vector.
    
    Args:
        vector_list (list of numpy.ndarray): List of unit vectors, each being a 2D array.
        
    Returns:
        block_diag_matrix (numpy.ndarray): The block-diagonal matrix.
    """
    block_diag_matrix = torch.eye(matrix_size, device=device)
    
    for i, vec in enumerate(vector_list):
        v_norm = torch.norm(vec, p=2)
        if v_norm == 0:
            raise ValueError("Input vector cannot be a zero vector.")
        v_unit = vec / v_norm 
        
        # Householder: H = I - 2 * v_unit * v_unit^T
        I = torch.eye(vec.shape[0], device=device)
        H_block = I - 2 * torch.outer(v_unit, v_unit)

        start_idx = vector_list[0].shape[0] * i
        end_idx = vector_list[0].shape[0] * (i + 1)
        block_diag_matrix[start_idx:end_idx, start_idx:end_idx] = H_block
    
    return block_diag_matrix

def create_obfus_matrix(dim, block_size, device="cpu", layernorm=False):
    if dim % block_size != 0:
        raise ValueError(f"dim({dim}) must be a multiple of block_size({block_size})")

    h = None
    indices_list = []
    v_list = []
    p = 0.99

    for _ in range(math.ceil(math.log(dim, block_size) + math.log(math.log(dim**2/(1-p)), block_size))):
        vectors = []
        random_indices = torch.randperm(dim, device=device)
        indices_list.append(random_indices)
        for _ in range(dim // block_size):
            v = torch.randn(block_size, device=device)
            if layernorm:
                v[-1] = -v[:-1].sum()
            vectors.append(v)
        v_list.append(vectors)
        if h is None:
            h = construct_block_diagonal_householder(vectors, dim, device)[:, random_indices]
        else:
            h = (h @ construct_block_diagonal_householder(vectors, dim, device))[:, random_indices]
    return h, v_list, indices_list

def _obfuscate_linear_left(linear, pi, obfus_bias=False):
    """
    linear(x) = x @ W^T + b
    W' = W @ pi^T
    b' = pi @ b
    linear_obf(x) = x @ pi @ W^T + pi @ b
    """
    W_obf = linear.weight.data @ pi.t()
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = pi @ (linear.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf
    

def _obfuscate_linear_right(linear, pi, obfus_bias=False):
    """
    linear(x) = x @ W^T + b
    W' = pi^T @ W
    b' = b @ pi
    linear_obf(x) = x @ W^T @ pi + b @ pi
    """
    W_obf = pi.t() @ linear.weight.data
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = (linear.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf

def _obfuscate_conv1d_left(conv, pi, obfus_bias=False):
    """
    conv1d(x) = x @ W + b
    W' = pi @ W
    b' = pi @ b
    """
    W_obf = pi @ conv.weight.data
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = pi @ (conv.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf
    

def _obfuscate_conv1d_right(conv, pi, obfus_bias=False):
    """
    conv1d(x) = x @ W + b
    W' = W @ pi
    b' = b @ pi
    """
    W_obf = conv.weight.data @ pi
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = (conv.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf

def _obfuscate_model(model, block_size, obf_score):
    # Why 5e-8? For float16, float32, float64, and bfloat16, 5e-8 will not be truncated to 0.
    for name, m in model.named_modules():
        if "norm" in name:
            assert torch.tensor(5e-8, dtype=m.weight.data.dtype).item() != 0.
            m.weight.data = torch.where(m.weight.data==0., 5e-8, m.weight.data)
    
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    
    decoderLayer0 = model.model.layers[0]
    mlp0 = decoderLayer0.mlp

    _obfuscate_linear_right(mlp0.down_proj, pi0, True)


    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)

    for decoderLayer in model.model.layers:
        if decoderLayer is model.model.layers[0]:
            continue
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        input_layernorm = decoderLayer.input_layernorm
        post_attention_layernorm = decoderLayer.post_attention_layernorm

        input_norm_weights = input_layernorm.weight.data
        D1 = torch.diag(input_norm_weights)
        D1_inv = torch.diag(torch.reciprocal(input_norm_weights))
        
        post_attn_norm_weights = post_attention_layernorm.weight.data
        D2 = torch.diag(post_attn_norm_weights)
        D2_inv = torch.diag(torch.reciprocal(post_attn_norm_weights))

        # obfuscate the attention layer
        _obfuscate_linear_left(attn.q_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.k_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.v_proj, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_linear_right(attn.o_proj, pi, True)

        # obfuscate the mlp layer
        _obfuscate_linear_left(mlp.gate_proj, D2_inv @ pi.t() @ D2, False)
        _obfuscate_linear_left(mlp.up_proj, D2_inv @ pi.t() @ D2, False)

        _obfuscate_linear_right(mlp.down_proj, pi, True)
    
    if obf_score:
        norm = model.model.norm
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")

        norm_weights = norm.weight.data
        D = torch.diag(norm_weights)
        D_inv = torch.diag(torch.reciprocal(norm_weights))
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
    
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def _obfuscate_model_gpt2(model, block_size, obf_score):
    # Why 5e-8? For float16, float32, float64, and bfloat16, 5e-8 will not be truncated to 0.
    for name, m in model.named_modules():
        if "ln_" in name:
            assert torch.tensor(5e-8, dtype=m.weight.data.dtype).item() != 0.
            m.weight.data = torch.where(m.weight.data==0., 5e-8, m.weight.data)
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    decoderLayer0 = model.transformer.h[0]
    mlp0 = decoderLayer0.mlp
    
    _obfuscate_conv1d_right(mlp0.c_proj, pi0, True)

    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device, layernorm=True)

    for decoderLayer in model.transformer.h:
        if decoderLayer is model.transformer.h[0]:
            continue
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp
        
        ln_1 = decoderLayer.ln_1
        ln_2 = decoderLayer.ln_2
        
        ln_1_weights = ln_1.weight.data
        D1 = torch.diag(ln_1_weights)
        D1_inv = torch.diag(torch.reciprocal(ln_1_weights))
        if hasattr(ln_1, 'bias') and ln_1.bias is not None:
            ln_1.bias.data = ((ln_1.bias.data.view(1, -1)) @ D1_inv @ pi @ D1).squeeze()
        
        ln_2_weights = ln_2.weight.data
        D2 = torch.diag(ln_2_weights)
        D2_inv = torch.diag(torch.reciprocal(ln_2_weights))
        if hasattr(ln_2, 'bias') and ln_2.bias is not None:
            ln_2.bias.data = ((ln_2.bias.data.view(1, -1)) @ D2_inv @ pi @ D2).squeeze()
        
        # obfuscate the attention layer
        _obfuscate_conv1d_left(attn.c_attn, D1_inv @ pi.t() @ D1, False)
        if attn.is_cross_attention:
            _obfuscate_conv1d_left(attn.q_attn, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_conv1d_right(attn.c_proj, pi, True)
        
        # obfuscate the mlp layer
        _obfuscate_conv1d_left(mlp.c_fc, D2_inv @ pi.t() @ D2, False)
        
        _obfuscate_conv1d_right(mlp.c_proj, pi, True)
    
    norm = model.transformer.ln_f
    norm_weights = norm.weight.data
    D = torch.diag(norm_weights)
    D_inv = torch.diag(torch.reciprocal(norm_weights))
    if hasattr(norm, 'bias') and norm.bias is not None:
        norm.bias.data = ((norm.bias.data.view(1, -1)) @ D_inv @ pi @ D).squeeze()
    
    if obf_score:
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")  
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
        
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def _obfuscate_model_gemma3(model, block_size, obf_score):
    """ 
    Gemma3's RMSNorm.weight contains -1 elements (its forward function first adds 1.0 before computation); 
    we replace them with the number closest to -1.
    
    It can be observed that float16 and bfloat16 have poor precision. 
    The most recommended approach is to modify Gemma3's RMSNorm forward function: 
    avoid multiplying by (1. + RMSNorm.weight).
    
    Instead, set RMSNorm.weight = RMSNorm.weight + 1., 
    then replace any zero values with the number closest to zero 
    (since floating-point numbers can represent values near zero more precisely), 
    and then perform the multiplication.
    """
    for name, m in model.named_modules():
        if "norm" in name:
            if m.weight.data.dtype == torch.float32:
                assert (1. + torch.tensor(-1.00000011920928955078125, dtype=torch.float32)).item() != 0.
                m.weight.data = torch.where(m.weight.data==-1., -1.00000011920928955078125, m.weight.data)
            elif m.weight.data.dtype == torch.float64:
                assert (1. + torch.tensor(-1.0000000000000002220446049250313, dtype=torch.float64)).item() != 0.
                m.weight.data = torch.where(m.weight.data==-1, -1.0000000000000002220446049250313, m.weight.data)
            elif m.weight.data.dtype == torch.float16:
                # m.weight.data = torch.where(m.weight.data==-1., -1.0009765625, m.weight.data)
                raise ValueError("The precision of float16 for numbers near 1.0 is too poor; please use another method.")
            elif m.weight.data.dtype == torch.bfloat16:
                # m.weight.data = torch.where(m.weight.data==-1., -1.0078125, m.weight.data)
                raise ValueError("The precision of bfloat16 for numbers near 1.0 is too poor; please use another method.")
    
    
    def obf_linear(linear):
        n = linear.out_features
        Q = torch.eye(n, device=linear.weight.device)
        for _ in range(math.ceil(math.log(n, 2) + math.log(100, 2))):
            Q0 = torch.eye(n, device=linear.weight.device)
            for i in range(0, n, 2):
                if i + 1 < n:
                    v = torch.randn(2, device=linear.weight.device)
                    v = v / torch.norm(v, p=2)
                    H = torch.eye(2, device=linear.weight.device) - 2 * torch.outer(v, v)
                    Q0[i:i+2, i:i+2] = H
                elif i < n: 
                    pass
            
            Q0 = Q0[:, torch.randperm(linear.out_features, device=linear.weight.device)]
            Q0 = Q0[torch.randperm(linear.out_features, device=linear.weight.device)]
            Q = Q0 @ Q
        
        new_linear = nn.Linear(linear.in_features, linear.out_features, bias=linear.bias is not None)
        new_linear.weight.data = Q @ linear.weight.data
        if linear.bias is not None:
            new_linear.bias.data = (linear.bias.data.view(1, -1) @ Q).squeeze()
        return new_linear
    
    def init_branchs(m):
        m.o_proj_branchs = nn.ModuleList([
            None if i == 0 else obf_linear(m.model.layers[i].self_attn.o_proj)
            for i in range(m.model.config.num_hidden_layers)
        ])
        
        m.down_proj_branchs = nn.ModuleList([
            obf_linear(m.model.layers[i].mlp.down_proj)
            for i in range(m.model.config.num_hidden_layers)
        ])

    init_branchs(model)
    
    # obfuscate for the first decoder layer
    pi0, v_list0, indices_list0 = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)
    
    decoderLayer0 = model.model.layers[0]
    mlp0 = decoderLayer0.mlp
    
    post_feedforward_layernorm = decoderLayer0.post_feedforward_layernorm

    D0 = torch.diag(1.0 + post_feedforward_layernorm.weight.data)

    _obfuscate_linear_right(mlp0.down_proj, D0 @ pi0, True)
    post_feedforward_layernorm.weight.data = torch.zeros_like(post_feedforward_layernorm.weight.data,
                                                               dtype=post_feedforward_layernorm.weight.data.dtype,
                                                               device=post_feedforward_layernorm.weight.data.device)

    # obfuscate for the remaining layers
    pi, v_list, indices_list = create_obfus_matrix(model.config.hidden_size, block_size, device=model.device)

    for decoderLayer in model.model.layers:
        if decoderLayer is model.model.layers[0]:
            continue
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        input_layernorm = decoderLayer.input_layernorm
        post_attention_layernorm = decoderLayer.post_attention_layernorm
        pre_feedforward_layernorm = decoderLayer.pre_feedforward_layernorm
        post_feedforward_layernorm = decoderLayer.post_feedforward_layernorm

        input_norm_weights = input_layernorm.weight.data + 1.0
        D1 = torch.diag(input_norm_weights)
        D1_inv = torch.diag(torch.reciprocal(input_norm_weights))
        
        D2 = torch.diag(1.0 + post_attention_layernorm.weight.data)

        pre_ffn_norm_weights = pre_feedforward_layernorm.weight.data + 1.0
        D3 = torch.diag(pre_ffn_norm_weights)
        D3_inv = torch.diag(torch.reciprocal(pre_ffn_norm_weights))

        D4 = torch.diag(1.0 + post_feedforward_layernorm.weight.data)

        # obfuscate the attention layer
        _obfuscate_linear_left(attn.q_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.k_proj, D1_inv @ pi.t() @ D1, False)
        _obfuscate_linear_left(attn.v_proj, D1_inv @ pi.t() @ D1, False)
        
        _obfuscate_linear_right(attn.o_proj, D2 @ pi, True)
        post_attention_layernorm.weight.data = torch.zeros_like(post_attention_layernorm.weight.data,
                                                               dtype=post_attention_layernorm.weight.data.dtype,
                                                               device=post_attention_layernorm.weight.data.device)

        # obfuscate the mlp layer
        _obfuscate_linear_left(mlp.gate_proj, D3_inv @ pi.t() @ D3, False)
        _obfuscate_linear_left(mlp.up_proj, D3_inv @ pi.t() @ D3, False)

        _obfuscate_linear_right(mlp.down_proj, D4 @ pi, True)
        post_feedforward_layernorm.weight.data = torch.zeros_like(post_feedforward_layernorm.weight.data,
                                                               dtype=post_feedforward_layernorm.weight.data.dtype,
                                                               device=post_feedforward_layernorm.weight.data.device)
    
    if obf_score:
        norm = model.model.norm
        if hasattr(model, "score") and model.score is not None:
            score = model.score
        else:
            raise ValueError("The model does not have a score layer.")

        norm_weights = norm.weight.data + 1.0
        D = torch.diag(norm_weights)
        D_inv = torch.diag(torch.reciprocal(norm_weights))
        _obfuscate_linear_left(score, D_inv @ pi.t() @ D, False)
    
    return {
            "pi": pi,
            "pi0": pi0,
            "v_list": v_list,
            "v_list0": v_list0,
            "indices_list": indices_list,
            "indices_list0": indices_list0
        }

def obfuscate_model(model, block_size=2, obf_score=False):
    if isinstance(model, GPT2PreTrainedModel):
        return _obfuscate_model_gpt2(model, block_size, obf_score)
    elif isinstance(model, Gemma3PreTrainedModel):
        return _obfuscate_model_gemma3(model, block_size, obf_score)
    else:
        return _obfuscate_model(model, block_size, obf_score)