import torch
from datasets import load_dataset
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers.models import Gemma3PreTrainedModel, Qwen3PreTrainedModel, GPT2PreTrainedModel, LlamaPreTrainedModel
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import math
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


def _obfuscate_linear_left(linear, pi, obfus_bias=False):
    W_obf = linear.weight.data @ pi.t()
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = pi @ (linear.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf
    

def _obfuscate_linear_right(linear, pi, obfus_bias=False):
    W_obf = pi.t() @ linear.weight.data
    linear.weight.data = W_obf

    if linear.bias is not None and obfus_bias:
        b_obf = (linear.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        linear.bias.data = b_obf
        
def _obfuscate_conv1d_left(conv, pi, obfus_bias=False):
    W_obf = pi @ conv.weight.data
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = pi @ (conv.bias.data.view(-1, 1))
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf
    

def _obfuscate_conv1d_right(conv, pi, obfus_bias=False):
    W_obf = conv.weight.data @ pi
    conv.weight.data = W_obf

    if conv.bias is not None and obfus_bias:
        b_obf = (conv.bias.data.view(1, -1)) @ pi
        b_obf = b_obf.squeeze()
        conv.bias.data = b_obf

def _obfuscate_gpt2_attn_left(attn, pi):
    if attn.is_cross_attention:
        w_q = attn.q_attn.weight.data
        w_k, w_v = attn.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = attn.c_attn.weight.data.chunk(3, dim=1)
    
    ob_wq = pi @ w_q
    ob_wk = pi @ w_k
    ob_wv = pi @ w_v

    if attn.is_cross_attention:
        attn.q_attn.weight.data = ob_wq
        attn.c_attn.weight.data = torch.cat([ob_wk, ob_wv], dim=1)
    else:
        attn.c_attn.weight.data = torch.cat([ob_wq, ob_wk, ob_wv], dim=1)

def _obfuscate_model(model):
    permutate_list = []
    hidden_size = model.config.hidden_size

    permutate = torch.randperm(hidden_size)
    permutate_list.append(permutate)
    
    pi = torch.eye(hidden_size)[permutate]
    
    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        _obfuscate_linear_left(attn.q_proj, pi.t(), False)
        _obfuscate_linear_left(attn.k_proj, pi.t(), False)
        _obfuscate_linear_left(attn.v_proj, pi.t(), False)
        
        _obfuscate_linear_right(attn.o_proj, pi, True)

        # obfuscate the mlp layer
        _obfuscate_linear_left(mlp.gate_proj, pi.t(), False)
        _obfuscate_linear_left(mlp.up_proj, pi.t(), False)

        _obfuscate_linear_right(mlp.down_proj, pi, True)
            

    return permutate_list

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    permutate_list = []
    hidden_size = model.config.hidden_size
    
    permutate = torch.randperm(hidden_size)
    permutate_list.append(permutate)
    
    pi = torch.eye(hidden_size)[permutate]

    for decoderLayer in model.transformer.h:
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        _obfuscate_gpt2_attn_left(attn, pi.t())

        _obfuscate_conv1d_right(attn.c_proj, pi, True)
        
        # obfuscate the mlp layer
        _obfuscate_conv1d_left(mlp.c_fc, pi.t(), False)
        
        _obfuscate_conv1d_right(mlp.c_proj, pi, True)
        
    return permutate_list


def obfuscate_model(model):

    if isinstance(model, GPT2PreTrainedModel):
        return _obfuscate_model_gpt2(model)
    else:
        return _obfuscate_model(model)

def row_restore_perm(pre_model_mat, model_mat, threshold=0.0):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu, pre_model_mat_cpu)
    perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, row in enumerate(model_mat_cpu):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[perm[i]] = model_mat_cpu[i]
            success.append(perm[i])
    for i in range(len(restored_matrix)):
        if i not in success:
            restored_matrix[i] = pre_model_mat_cpu[i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix

def col_restore_perm(pre_model_mat, model_mat, threshold=0.0, perm = None):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu.T, pre_model_mat_cpu.T)
    if perm is None:
        perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, col in enumerate(model_mat_cpu.T):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[:,perm[i]] = model_mat_cpu[:,i]
            success.append(perm[i])
    for i in range(len(restored_matrix[0])):
        if i not in success:
            restored_matrix[:,i] = pre_model_mat_cpu[:,i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix
    

def fix_factor(num, mini=1.0, max=6.0):
    if math.isnan(num):
        return 1.0
    if(num < mini):
        return mini
    elif(num > max):
        return max
    else:
        return num

def attack(model, pre_model):
    if isinstance(model, GPT2PreTrainedModel):
        for decoderLayer, pre_decoderLayer in zip(model.transformer.h, pre_model.transformer.h):
            attn = decoderLayer.attn
            mlp = decoderLayer.mlp
            
            if attn.is_cross_attention:
                pre_wq = pre_decoderLayer.attn.q_attn.weight.data
                ob_wq = attn.q_attn.weight.data
                ob_wk, ob_wv = attn.c_attn.weight.data.chunk(2, dim=1)
            else:
                pre_wq, _, _ = pre_decoderLayer.attn.c_attn.weight.data.chunk(3, dim=1)
                ob_wq, ob_wk, ob_wv = attn.c_attn.weight.data.chunk(3, dim=1)
                
            perm, _, restore_wq = row_restore_perm(pre_wq, ob_wq)
            inv_perm = torch.argsort(torch.tensor(perm))
            
            restore_wk = ob_wk[inv_perm]
            restore_wv = ob_wv[inv_perm]
            if attn.is_cross_attention:
                attn.q_attn.weight.data = restore_wq
                attn.c_attn.weight.data = torch.cat([restore_wk, restore_wv], dim=1)
            else:
                attn.c_attn.weight.data = torch.cat([restore_wq, restore_wk, restore_wv], dim=1)
            attn.c_proj.weight.data = attn.c_proj.weight.data[:, inv_perm]
            if attn.c_proj.bias is not None:
                attn.c_proj.bias.data = attn.c_proj.bias.data[inv_perm]
            
            mlp.c_fc.weight.data = mlp.c_fc.weight.data[inv_perm]
            mlp.c_proj.weight.data = mlp.c_proj.weight.data[:, inv_perm]
            if mlp.c_proj.bias is not None:
                mlp.c_proj.bias.data = mlp.c_proj.bias.data[inv_perm]

    else:
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            pre_wq = pre_decoderLayer.self_attn.q_proj.weight.data
            
            ob_wq = attn.q_proj.weight.data
            ob_wk = attn.k_proj.weight.data
            ob_wv = attn.v_proj.weight.data
            ob_wo = attn.o_proj.weight.data
            ob_gate = mlp.gate_proj.weight.data
            ob_up = mlp.up_proj.weight.data
            ob_down = mlp.down_proj.weight.data

            perm, _, restore_wq = col_restore_perm(pre_wq, ob_wq)
            inv_perm = torch.argsort(torch.tensor(perm))
            
            attn.q_proj.weight.data = ob_wq[:, inv_perm]
            attn.k_proj.weight.data = ob_wk[:, inv_perm]
            attn.v_proj.weight.data = ob_wv[:, inv_perm]
            attn.o_proj.weight.data = ob_wo[inv_perm, :]
            if attn.o_proj.bias is not None:
                attn.o_proj.bias.data = attn.o_proj.bias.data[inv_perm]
                            
            mlp.gate_proj.weight.data = ob_gate[:, inv_perm]
            mlp.up_proj.weight.data = ob_up[:, inv_perm]
            mlp.down_proj.weight.data = ob_down[inv_perm, :]
            if mlp.down_proj.bias is not None:
                mlp.down_proj.bias.data = mlp.down_proj.bias.data[inv_perm]
        
    return model
