import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
from datasets import load_dataset
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from transformers.models import Gemma3PreTrainedModel, Qwen3PreTrainedModel, GPT2PreTrainedModel, LlamaPreTrainedModel
from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from math import sqrt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import math


def _obfuscate_layer(layer):
    w = layer.weight.data
    num_rows = w.shape[1]

    for i in range(num_rows):
        ratio = 1 + 5 * torch.rand(1).item()
        w[:,i] *= ratio

    permutation = torch.randperm(num_rows)
    ob_w = w[:, permutation]
    
    layer.weight.data = ob_w
    
def _obfuscate_layer_gpt2(layer):
    w = layer.weight.data
    num_rows = w.shape[0]

    for i in range(num_rows):
        ratio = 1 + 5 * torch.rand(1).item()
        w[i] *= ratio

    permutation = torch.randperm(num_rows)
    ob_w = w[permutation]
    
    layer.weight.data = ob_w
    
def _obfuscate_attn_gpt2(layer, is_cross_attention):
    if is_cross_attention:
        w_q = layer.q_attn.weight.data
        w_k, w_v = layer.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = layer.c_attn.weight.data.chunk(3, dim=1)
    num_rows = w_q.shape[0]
    ratios_q = []
    ratios_k = []
    ratios_v = []
    for i in range(num_rows):
        ratio_q = 1 + 5 * torch.rand(1).item()
        ratio_k = 1 + 5 * torch.rand(1).item()
        ratio_v = 1 + 5 * torch.rand(1).item()
        w_q[i] *= ratio_q
        w_k[i] *= ratio_k
        w_v[i] *= ratio_v
        ratios_q.append(ratio_q)
        ratios_k.append(ratio_k)
        ratios_v.append(ratio_v)

    permutation_q = torch.randperm(num_rows)
    permutation_k = torch.randperm(num_rows)
    permutation_v = torch.randperm(num_rows)

    ob_w_q = w_q[permutation_q]
    ob_w_k = w_k[permutation_k]
    ob_w_v = w_v[permutation_v] 
    if is_cross_attention:
        layer.c_attn.weight.data = torch.cat([ob_w_k, ob_w_v], dim=1)
        layer.q_attn.weight.data = ob_w_q
    else:
        layer.c_attn.weight.data = torch.cat([ob_w_q, ob_w_k, ob_w_v], dim=1)
    
    
def _obfuscate_model(model):

    hidden_size = model.config.hidden_size

    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        
        # obfuscate the attention layer
        _obfuscate_layer(attn.q_proj)
        _obfuscate_layer(attn.k_proj)
        _obfuscate_layer(attn.v_proj)
        
        _obfuscate_layer(attn.o_proj)

        # obfuscate the mlp layer
        _obfuscate_layer(mlp.gate_proj)
        _obfuscate_layer(mlp.up_proj)

        _obfuscate_layer(mlp.down_proj)
            

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    hidden_size = model.config.hidden_size

    for decoderLayer in model.transformer.h:
        
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        _obfuscate_attn_gpt2(attn, attn.is_cross_attention)

        _obfuscate_layer_gpt2(attn.c_proj)
        
        # obfuscate the mlp layer
        _obfuscate_layer_gpt2(mlp.c_fc)
        _obfuscate_layer_gpt2(mlp.c_proj)
        

def obfuscate_model(model):

    if isinstance(model, GPT2PreTrainedModel):
        _obfuscate_model_gpt2(model)
    else:
        _obfuscate_model(model)

def row_restore_perm(pre_model_mat, model_mat, threshold=0.0):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu, pre_model_mat_cpu)
    perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, row in enumerate(model_mat_cpu):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[perm[i]] = model_mat_cpu[i]
            success.append(perm[i])
    for i in range(len(restored_matrix)):
        if i not in success:
            restored_matrix[i] = pre_model_mat_cpu[i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix

def col_restore_perm(pre_model_mat, model_mat, threshold=0.0, perm = None):
    model_mat_cpu = model_mat.cpu().numpy()
    pre_model_mat_cpu = pre_model_mat.cpu().numpy()
    similarty_matrix = cosine_similarity(model_mat_cpu.T, pre_model_mat_cpu.T)
    if perm is None:
        perm = np.argmax(similarty_matrix, axis=1)
    restored_matrix = np.empty_like(model_mat_cpu)
    success = []
    for i, col in enumerate(model_mat_cpu.T):
        max_similarity = similarty_matrix[i, perm[i]]
        if max_similarity >= threshold:
            restored_matrix[:,perm[i]] = model_mat_cpu[:,i]
            success.append(perm[i])
    for i in range(len(restored_matrix[0])):
        if i not in success:
            restored_matrix[:,i] = pre_model_mat_cpu[:,i]
    restored_matrix = torch.from_numpy(restored_matrix).to(model_mat.device)
    return perm, success, restored_matrix
    

def fix_factor(num, mini=1.0, max=6.0):
    if math.isnan(num):
        return 1.0
    if(num < mini):
        return mini
    elif(num > max):
        return max
    else:
        return num
    
def fix_layer(obf_weight, pre_weight):
    ob_w = obf_weight.data
    pre_w = pre_weight.data
    _, _, restore_w = col_restore_perm(pre_w, ob_w)
    for i in range(ob_w.shape[1]):
        ratio_q = fix_factor(sqrt(torch.var(restore_w[:,i]).item()/torch.var(pre_w[:,i]).item()))
        restore_w[:,i] /= ratio_q
    obf_weight.data = restore_w
    
def fix_layer_gpt2(obf_weight, pre_weight):
    ob_w = obf_weight.data
    pre_w = pre_weight.data
    _, _, restore_w = row_restore_perm(pre_w, ob_w)
    for i in range(ob_w.shape[0]):
        ratio_q = fix_factor(sqrt(torch.var(restore_w[i]).item()/torch.var(pre_w[i]).item()))
        restore_w[i] /= ratio_q
    obf_weight.data = restore_w

def fix_layer_gpt_attn(attn, pre_attn):
    
    if attn.is_cross_attention:
        pre_wq = pre_attn.q_attn.weight.data
        ob_wq = attn.q_attn.weight.data
        pre_wk, pre_wv = pre_attn.c_attn.weight.data.chunk(2, dim=1)
        ob_wk, ob_wv = attn.c_attn.weight.data.chunk(2, dim=1)
    else:
        pre_wq, pre_wk, pre_wv = pre_attn.c_attn.weight.data.chunk(3, dim=1)
        ob_wq, ob_wk, ob_wv = attn.c_attn.weight.data.chunk(3, dim=1)
        
    _, _, restore_wq = row_restore_perm(pre_wq, ob_wq)
    for i in range(ob_wq.shape[0]):
        ratio_q = fix_factor(sqrt(torch.var(restore_wq[i]).item()/torch.var(pre_wq[i]).item()))
        restore_wq[i] /= ratio_q
    _, _, restore_wk = row_restore_perm(pre_wk, ob_wk)
    for i in range(ob_wk.shape[0]):
        ratio_k = fix_factor(sqrt(torch.var(restore_wk[i]).item()/torch.var(pre_wk[i]).item()))
        restore_wk[i] /= ratio_k
    _, _, restore_wv = row_restore_perm(pre_wv, ob_wv)
    for i in range(ob_wv.shape[0]):
        ratio_v = fix_factor(sqrt(torch.var(restore_wv[i]).item()/torch.var(pre_wv[i]).item()))
        restore_wv[i] /= ratio_v
        
    if attn.is_cross_attention:
        attn.q_attn.weight.data = restore_wq
        attn.c_attn.weight.data = torch.cat([restore_wk, restore_wv], dim=1)
    else:
        attn.c_attn.weight.data = torch.cat([restore_wq, restore_wk, restore_wv], dim=1)
    
    fix_layer_gpt2(attn.c_proj.weight, pre_attn.c_proj.weight)

def attack(model, pre_model):
    if isinstance(model, GPT2PreTrainedModel):
        for decoderLayer, pre_decoderLayer in zip(model.transformer.h, pre_model.transformer.h):
            attn = decoderLayer.attn
            mlp = decoderLayer.mlp
            pre_attn = pre_decoderLayer.attn
            pre_mlp = pre_decoderLayer.mlp
                
            fix_layer_gpt_attn(attn, pre_attn)
            
            fix_layer_gpt2(mlp.c_fc.weight, pre_mlp.c_fc.weight)
            fix_layer_gpt2(mlp.c_proj.weight, pre_mlp.c_proj.weight)
            
            
    else:
        for decoderLayer, pre_decoderLayer in zip(model.model.layers, pre_model.model.layers):
            attn = decoderLayer.self_attn
            mlp = decoderLayer.mlp
            
            pre_attn = pre_decoderLayer.self_attn
            pre_mlp = pre_decoderLayer.mlp
            
            fix_layer(attn.q_proj.weight, pre_attn.q_proj.weight)
            fix_layer(attn.k_proj.weight, pre_attn.k_proj.weight)
            fix_layer(attn.v_proj.weight, pre_attn.v_proj.weight)
            fix_layer(attn.o_proj.weight, pre_attn.o_proj.weight)
            fix_layer(mlp.gate_proj.weight, pre_mlp.gate_proj.weight)
            fix_layer(mlp.up_proj.weight, pre_mlp.up_proj.weight)
            fix_layer(mlp.down_proj.weight, pre_mlp.down_proj.weight)
            
        
    return model

