import torch
from datasets import load_dataset
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers.models import Gemma3PreTrainedModel, Qwen3PreTrainedModel, GPT2PreTrainedModel, LlamaPreTrainedModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from math import sqrt
import math


def _obfuscate_layer(layer):
    scale = 1 + 5 * torch.rand(1).item()
    W_obf = scale * layer.weight.data
    layer.weight.data = W_obf

def _obfuscate_layer_attn(attn):
    if attn.is_cross_attention:
        w_q = attn.q_attn.weight.data
        w_k, w_v = attn.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = attn.c_attn.weight.data.chunk(3, dim=1)
    scale_q = 1 + 5 * torch.rand(1).item()
    scale_k = 1 + 5 * torch.rand(1).item()
    scale_v = 1 + 5 * torch.rand(1).item()
    w_q *= scale_q
    w_k *= scale_k
    w_v *= scale_v
    
    if attn.is_cross_attention:
        attn.q_attn.weight.data = w_q
        attn.c_attn.weight.data = torch.cat([w_k, w_v], dim=1)
    else:
        attn.c_attn.weight.data = torch.cat([w_q, w_k, w_v], dim=1)

def _obfuscate_model(model):

    hidden_size = model.config.hidden_size

    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        
        # obfuscate the attention layer
        _obfuscate_layer(attn.q_proj)
        _obfuscate_layer(attn.k_proj)
        _obfuscate_layer(attn.v_proj)
        
        _obfuscate_layer(attn.o_proj)

        # obfuscate the mlp layer
        _obfuscate_layer(mlp.gate_proj)
        _obfuscate_layer(mlp.up_proj)

        _obfuscate_layer(mlp.down_proj)
            

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    hidden_size = model.config.hidden_size

    for decoderLayer in model.transformer.h:
        
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        _obfuscate_layer_attn(attn)
        
        _obfuscate_layer(attn.c_proj)
        
        # obfuscate the mlp layer
        _obfuscate_layer(mlp.c_fc)
        _obfuscate_layer(mlp.c_proj)
        

def obfuscate_model(model):

    if isinstance(model, GPT2PreTrainedModel):
        _obfuscate_model_gpt2(model)
    else:
        _obfuscate_model(model)


def row_restore_perm(pre_model_mat, model_mat, threshold=0.0):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu, pre_model_mat_cpu)
    perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, row in enumerate(model_mat_cpu):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[perm[i]] = model_mat_cpu[i]
            success.append(perm[i])
    for i in range(len(restored_matrix)):
        if i not in success:
            restored_matrix[i] = pre_model_mat_cpu[i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix

def col_restore_perm(pre_model_mat, model_mat, threshold=0.0, perm = None):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu.T, pre_model_mat_cpu.T)
    if perm is None:
        perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, col in enumerate(model_mat_cpu.T):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[:,perm[i]] = model_mat_cpu[:,i]
            success.append(perm[i])
    for i in range(len(restored_matrix[0])):
        if i not in success:
            restored_matrix[:,i] = pre_model_mat_cpu[:,i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix
    

def fix_factor(num, mini=1.0, max=6.0):
    if math.isnan(num):
        return 1.0
    if(num < mini):
        return mini
    elif(num > max):
        return max
    else:
        return num

def attack(model, pre_model):
    if isinstance(model, GPT2PreTrainedModel):
        for decoderLayer, pre_decoderLayer in zip(model.transformer.h, pre_model.transformer.h):
            attn = decoderLayer.attn
            mlp = decoderLayer.mlp
            
            pre_attn = pre_decoderLayer.attn
            pre_mlp = pre_decoderLayer.mlp
            
            if attn.is_cross_attention:
                pre_wq = pre_attn.q_attn.weight.data
                pre_wk, pre_wv = pre_attn.c_attn.weight.data.chunk(2, dim=1)
                ob_wq = attn.q_attn.weight.data
                ob_wk, ob_wv = attn.c_attn.weight.data.chunk(2, dim=1)
            else:
                pre_wq, pre_wk, pre_wv = pre_decoderLayer.attn.c_attn.weight.data.chunk(3, dim=1)
                ob_wq, ob_wk, ob_wv = attn.c_attn.weight.data.chunk(3, dim=1)
                
            ob_wc = attn.c_proj.weight.data
            pre_wc = pre_attn.c_proj.weight.data
            
            ob_fc = mlp.c_fc.weight.data
            ob_proj = mlp.c_proj.weight.data
            pre_fc = pre_mlp.c_fc.weight.data
            pre_proj = pre_mlp.c_proj.weight.data
            
            restore_wq = ob_wq / fix_factor(sqrt(torch.var(ob_wq).item()/torch.var(pre_wq).item()))
            restore_wk = ob_wk / fix_factor(sqrt(torch.var(ob_wk).item()/torch.var(pre_wk).item()))
            restore_wv = ob_wv / fix_factor(sqrt(torch.var(ob_wv).item()/torch.var(pre_wv).item()))

            if attn.is_cross_attention:
                attn.q_attn.weight.data = restore_wq
                attn.c_attn.weight.data = torch.cat([restore_wk, restore_wv], dim=1)
            else:
                attn.c_attn.weight.data = torch.cat([restore_wq, restore_wk, restore_wv], dim=1)
            attn.c_proj.weight.data = ob_wc / fix_factor(sqrt(torch.var(ob_wc).item()/torch.var(pre_wc).item()))
            mlp.c_fc.weight.data = ob_fc / fix_factor(sqrt(torch.var(ob_fc).item()/torch.var(pre_fc).item()))
            mlp.c_proj.weight.data = ob_proj / fix_factor(sqrt(torch.var(ob_proj).item()/torch.var(pre_proj).item()))

    else:
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            pre_attn = pre_decoderLayer.self_attn
            pre_mlp = pre_decoderLayer.mlp
            
            pre_wq = pre_attn.q_proj.weight.data
            pre_wk = pre_attn.k_proj.weight.data
            pre_wv = pre_attn.v_proj.weight.data
            pre_wo = pre_attn.o_proj.weight.data
            pre_gate = pre_mlp.gate_proj.weight.data
            pre_up = pre_mlp.up_proj.weight.data
            pre_down = pre_mlp.down_proj.weight.data
            
            ob_wq = attn.q_proj.weight.data
            ob_wk = attn.k_proj.weight.data
            ob_wv = attn.v_proj.weight.data
            ob_wo = attn.o_proj.weight.data
            ob_gate = mlp.gate_proj.weight.data
            ob_up = mlp.up_proj.weight.data
            ob_down = mlp.down_proj.weight.data

            attn.q_proj.weight.data = ob_wq / fix_factor(sqrt(torch.var(ob_wq).item()/torch.var(pre_wq).item()))
            attn.k_proj.weight.data = ob_wk / fix_factor(sqrt(torch.var(ob_wk).item()/torch.var(pre_wk).item()))
            attn.v_proj.weight.data = ob_wv / fix_factor(sqrt(torch.var(ob_wv).item()/torch.var(pre_wv).item()))
            attn.o_proj.weight.data = ob_wo / fix_factor(sqrt(torch.var(ob_wo).item()/torch.var(pre_wo).item()))
            mlp.gate_proj.weight.data = ob_gate / fix_factor(sqrt(torch.var(ob_gate).item()/torch.var(pre_gate).item()))
            mlp.up_proj.weight.data = ob_up / fix_factor(sqrt(torch.var(ob_up).item()/torch.var(pre_up).item()))
            mlp.down_proj.weight.data = ob_down / fix_factor(sqrt(torch.var(ob_down).item()/torch.var(pre_down).item()))
        
    return model
