import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
from datasets import load_dataset
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers.models import Gemma3PreTrainedModel, Qwen3PreTrainedModel, GPT2PreTrainedModel, LlamaPreTrainedModel
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import random
from math import sqrt
import math


def _obfuscate_layer(layer):
    scale = 1 + 5 * torch.rand(1).item()
    W_obf = scale * layer.weight.data
    layer.weight.data = W_obf
    
def _obfuscate_layer_attn(attn):
    if attn.is_cross_attention:
        w_q = attn.q_attn.weight.data
        w_k, w_v = attn.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = attn.c_attn.weight.data.chunk(3, dim=1)
    scale_q = 1 + 5 * torch.rand(1).item()
    scale_k = 1 + 5 * torch.rand(1).item()
    scale_v = 1 + 5 * torch.rand(1).item()
    if random.random() < 0.2:
        w_q *= scale_q
    if random.random() < 0.2:
        w_k *= scale_k
    if random.random() < 0.2:
        w_v *= scale_v
    
    if attn.is_cross_attention:
        attn.q_attn.weight.data = w_q
        attn.c_attn.weight.data = torch.cat([w_k, w_v], dim=1)
    else:
        attn.c_attn.weight.data = torch.cat([w_q, w_k, w_v], dim=1)


def _obfuscate_model(model):

    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        
        # obfuscate the attention layer
        if random.random() < 0.2:
            _obfuscate_layer(attn.q_proj)
        if random.random() < 0.2:
            _obfuscate_layer(attn.k_proj)
        if random.random() < 0.2:
            _obfuscate_layer(attn.v_proj)
        
        if random.random() < 0.2:
            _obfuscate_layer(attn.o_proj)

        # obfuscate the mlp layer
        if random.random() < 0.2:
            _obfuscate_layer(mlp.gate_proj)
        if random.random() < 0.2:
            _obfuscate_layer(mlp.up_proj)
        if random.random() < 0.2:
            _obfuscate_layer(mlp.down_proj)
            

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    hidden_size = model.config.hidden_size

    for decoderLayer in model.transformer.h:
        
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        _obfuscate_layer_attn(attn)
        
        if random.random() < 0.2:
            _obfuscate_layer(attn.c_proj)
        
        # obfuscate the mlp layer
        if random.random() < 0.2:
            _obfuscate_layer(mlp.c_fc)
        if random.random() < 0.2:
            _obfuscate_layer(mlp.c_proj)
        

def obfuscate_model(model):

    if isinstance(model, GPT2PreTrainedModel):
        _obfuscate_model_gpt2(model)
    else:
        _obfuscate_model(model)


def fix_factor(num, mini=1.0, max=6.0):
    if math.isnan(num):
        return 1.0
    if(num < mini):
        return mini
    elif(num > max):
        return max
    else:
        return num

def attack(model, pre_model):
    if isinstance(model, GPT2PreTrainedModel):
        for decoderLayer, pre_decoderLayer in zip(model.transformer.h, pre_model.transformer.h):
            attn = decoderLayer.attn
            mlp = decoderLayer.mlp
            
            pre_attn = pre_decoderLayer.attn
            pre_mlp = pre_decoderLayer.mlp
            
            if attn.is_cross_attention:
                pre_wq = pre_attn.q_attn.weight.data
                pre_wk, pre_wv = pre_attn.c_attn.weight.data.chunk(2, dim=1)
                ob_wq = attn.q_attn.weight.data
                ob_wk, ob_wv = attn.c_attn.weight.data.chunk(2, dim=1)
            else:
                pre_wq, pre_wk, pre_wv = pre_decoderLayer.attn.c_attn.weight.data.chunk(3, dim=1)
                ob_wq, ob_wk, ob_wv = attn.c_attn.weight.data.chunk(3, dim=1)
                
            ob_wc = attn.c_proj.weight.data
            pre_wc = pre_attn.c_proj.weight.data
            
            ob_fc = mlp.c_fc.weight.data
            ob_proj = mlp.c_proj.weight.data
            pre_fc = pre_mlp.c_fc.weight.data
            pre_proj = pre_mlp.c_proj.weight.data
            
            restore_wq = ob_wq / fix_factor(sqrt(torch.var(ob_wq).item()/torch.var(pre_wq).item()))
            restore_wk = ob_wk / fix_factor(sqrt(torch.var(ob_wk).item()/torch.var(pre_wk).item()))
            restore_wv = ob_wv / fix_factor(sqrt(torch.var(ob_wv).item()/torch.var(pre_wv).item()))

            if attn.is_cross_attention:
                attn.q_attn.weight.data = restore_wq
                attn.c_attn.weight.data = torch.cat([restore_wk, restore_wv], dim=1)
            else:
                attn.c_attn.weight.data = torch.cat([restore_wq, restore_wk, restore_wv], dim=1)
            attn.c_proj.weight.data = ob_wc / fix_factor(sqrt(torch.var(ob_wc).item()/torch.var(pre_wc).item()))
            mlp.c_fc.weight.data = ob_fc / fix_factor(sqrt(torch.var(ob_fc).item()/torch.var(pre_fc).item()))
            mlp.c_proj.weight.data = ob_proj / fix_factor(sqrt(torch.var(ob_proj).item()/torch.var(pre_proj).item()))

    else:
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            pre_attn = pre_decoderLayer.self_attn
            pre_mlp = pre_decoderLayer.mlp
            
            pre_wq = pre_attn.q_proj.weight.data
            pre_wk = pre_attn.k_proj.weight.data
            pre_wv = pre_attn.v_proj.weight.data
            pre_wo = pre_attn.o_proj.weight.data
            pre_gate = pre_mlp.gate_proj.weight.data
            pre_up = pre_mlp.up_proj.weight.data
            pre_down = pre_mlp.down_proj.weight.data
            
            ob_wq = attn.q_proj.weight.data
            ob_wk = attn.k_proj.weight.data
            ob_wv = attn.v_proj.weight.data
            ob_wo = attn.o_proj.weight.data
            ob_gate = mlp.gate_proj.weight.data
            ob_up = mlp.up_proj.weight.data
            ob_down = mlp.down_proj.weight.data

            attn.q_proj.weight.data = ob_wq / fix_factor(sqrt(torch.var(ob_wq).item()/torch.var(pre_wq).item()))
            attn.k_proj.weight.data = ob_wk / fix_factor(sqrt(torch.var(ob_wk).item()/torch.var(pre_wk).item()))
            attn.v_proj.weight.data = ob_wv / fix_factor(sqrt(torch.var(ob_wv).item()/torch.var(pre_wv).item()))
            attn.o_proj.weight.data = ob_wo / fix_factor(sqrt(torch.var(ob_wo).item()/torch.var(pre_wo).item()))
            mlp.gate_proj.weight.data = ob_gate / fix_factor(sqrt(torch.var(ob_gate).item()/torch.var(pre_gate).item()))
            mlp.up_proj.weight.data = ob_up / fix_factor(sqrt(torch.var(ob_up).item()/torch.var(pre_up).item()))
            mlp.down_proj.weight.data = ob_down / fix_factor(sqrt(torch.var(ob_down).item()/torch.var(pre_down).item()))
        
    return model
