import torch
from transformers.models import GPT2PreTrainedModel


def _obfuscate_layer(layer):
    w = layer.weight.data
    num_rows = w.shape[0]
    device = w.device
    coeff = torch.randint(0,5,(w.shape[0],), device=device)
    mask = torch.matmul(w.T, coeff.float())

    ratio_mask = (torch.randint(0, 11, (num_rows,), device=device)-5).float()
    ratio_w = torch.randint(1, 3, (num_rows,), device=device).float()
    for i in range(num_rows):
        w[i] *= ratio_w[i]
        mask_i = mask * ratio_mask[i]
        w[i] += mask_i

    permutation = torch.randperm(num_rows)
    w = w[permutation]
    
    layer.weight.data = w
    
    return mask, ratio_mask, ratio_w, permutation
    
def _obfuscate_layer_gpt2(layer):
    w = layer.weight.data
    num_cols = w.shape[1]
    device = w.device
    coeff = torch.randint(0,5,(w.shape[1],), device=device)
    mask = torch.matmul(w, coeff.float())
    
    ratio_mask = (torch.randint(0, 11, (num_cols,), device=device)-5).float()
    ratio_w = torch.randint(1, 3, (num_cols,), device=device).float()
    for i in range(num_cols):
        mask_i = mask * ratio_mask[i]
        w[:,i] *= ratio_w[i]
        w[:,i] += mask_i

    permutation = torch.randperm(num_cols)
    ob_w = w[:,permutation]
    layer.weight.data = ob_w

    return mask, ratio_mask, ratio_w, permutation
    
def _obfuscate_attn_gpt2(layer, is_cross_attention):
    if is_cross_attention:
        w_q = layer.q_attn.weight.data
        w_k, w_v = layer.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = layer.c_attn.weight.data.chunk(3, dim=1)
    num_cols = w_q.shape[1]
    device = w_q.device
    coeff_q, coeff_k, coeff_v = torch.randint(0,5,(w_q.shape[1],), device=device), torch.randint(0,5,(w_k.shape[1],), device=device), torch.randint(0,5,(w_v.shape[1],), device=device)
    mask_q, mask_k, mask_v = torch.matmul(w_q, coeff_q.float()), torch.matmul(w_k, coeff_k.float()), torch.matmul(w_v, coeff_v.float())

    ratio_mask_q, ratio_mask_k, ratio_mask_v = (torch.randint(0, 11, (num_cols,), device=device)-5).float(), (torch.randint(0, 11, (num_cols,), device=device)-5).float(), (torch.randint(0, 11, (num_cols,), device=device)-5).float()
    ratio_w_q, ratio_w_k, ratio_w_v = torch.randint(1, 3, (num_cols,), device=device).float(), torch.randint(1, 3, (num_cols,), device=device).float(), torch.randint(1, 3, (num_cols,), device=device).float()
    for i in range(num_cols):
        mask_qi = mask_q*ratio_mask_q[i]
        mask_ki = mask_k*ratio_mask_k[i]
        mask_vi = mask_v*ratio_mask_v[i]
        w_q[:,i] *= ratio_w_q[i]
        w_k[:,i] *= ratio_w_k[i]
        w_v[:,i] *= ratio_w_v[i]
        w_q[:,i] += mask_qi
        w_k[:,i] += mask_ki
        w_v[:,i] += mask_vi

    permutation_q = torch.randperm(num_cols)
    permutation_k = torch.randperm(num_cols)
    permutation_v = torch.randperm(num_cols)

    ob_w_q = w_q[:,permutation_q]
    ob_w_k = w_k[:,permutation_k]
    ob_w_v = w_v[:,permutation_v]
    if is_cross_attention:
        layer.c_attn.weight.data = torch.cat([ob_w_k, ob_w_v], dim=1)
        layer.q_attn.weight.data = ob_w_q
    else:
        layer.c_attn.weight.data = torch.cat([ob_w_q, ob_w_k, ob_w_v], dim=1)
        
    return {
        "q_proj": {"mask": mask_q, "ratio_mask": ratio_mask_q,"ratio_w": ratio_w_q, "permutation": permutation_q},
        "k_proj": {"mask": mask_k, "ratio_mask": ratio_mask_k,"ratio_w": ratio_w_k, "permutation": permutation_k},
        "v_proj": {"mask": mask_v, "ratio_mask": ratio_mask_v,"ratio_w": ratio_w_v, "permutation": permutation_v}
    }
    
def _obfuscate_model(model):

    hidden_size = model.config.hidden_size
    
    obf_param = {
        "q_proj": [],
        "k_proj": [],
        "v_proj": [],
        "o_proj": [],
        "gate_proj": [],
        "up_proj": [],
        "down_proj": []
    }
    
    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        
        # Obfuscate each layer and store parameters
        q_mask, q_ratio_mask, q_ratio_w, q_permutation = _obfuscate_layer(attn.q_proj)
        k_mask, k_ratio_mask, k_ratio_w, k_permutation = _obfuscate_layer(attn.k_proj)
        v_mask, v_ratio_mask, v_ratio_w, v_permutation = _obfuscate_layer(attn.v_proj)
        o_mask, o_ratio_mask, o_ratio_w, o_permutation = _obfuscate_layer(attn.o_proj)
        gate_mask, gate_ratio_mask, gate_ratio_w, gate_permutation = _obfuscate_layer(mlp.gate_proj)
        up_mask, up_ratio_mask, up_ratio_w, up_permutation = _obfuscate_layer(mlp.up_proj)
        down_mask, down_ratio_mask, down_ratio_w, down_permutation = _obfuscate_layer(mlp.down_proj)
        
        # Store parameters in the obf_param dictionary
        obf_param["q_proj"].append({
            "mask": q_mask,
            "ratio_mask": q_ratio_mask,
            "ratio_w": q_ratio_w,
            "permutation": q_permutation
        })
        obf_param["k_proj"].append({
            "mask": k_mask,
            "ratio_mask": k_ratio_mask,
            "ratio_w": k_ratio_w,
            "permutation": k_permutation
        })
        obf_param["v_proj"].append({
            "mask": v_mask,
            "ratio_mask": v_ratio_mask,
            "ratio_w": v_ratio_w,
            "permutation": v_permutation
        })
        obf_param["o_proj"].append({
            "mask": o_mask,
            "ratio_mask": o_ratio_mask,
            "ratio_w": o_ratio_w,
            "permutation": o_permutation
        })
        obf_param["gate_proj"].append({
            "mask": gate_mask,
            "ratio_mask": gate_ratio_mask,
            "ratio_w": gate_ratio_w,
            "permutation": gate_permutation
        })
        obf_param["up_proj"].append({
            "mask": up_mask,
            "ratio_mask": up_ratio_mask,
            "ratio_w": up_ratio_w,
            "permutation": up_permutation
        })
        obf_param["down_proj"].append({
            "mask": down_mask,
            "ratio_mask": down_ratio_mask,
            "ratio_w": down_ratio_w,
            "permutation": down_permutation
        })
    return obf_param
            

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    obf_param = {
        "q_proj": [],
        "k_proj": [],
        "v_proj": [],
        "o_proj": [],
        "c_fc": [],
        "c_proj": [],
    }

    for decoderLayer in model.transformer.h:
        
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        params = _obfuscate_attn_gpt2(attn, attn.is_cross_attention)

        o_mask, o_ratio_mask, o_ratio_w, o_permutation = _obfuscate_layer_gpt2(attn.c_proj)
        
        # obfuscate the mlp layer
        c_fc_mask, c_fc_ratio_mask, c_fc_ratio_w, c_fc_permutation = _obfuscate_layer_gpt2(mlp.c_fc)
        c_proj_mask, c_proj_ratio_mask, c_proj_ratio_w, c_proj_permutation = _obfuscate_layer_gpt2(mlp.c_proj)
        
        obf_param["q_proj"].append(params["q_proj"])
        obf_param["k_proj"].append(params["k_proj"])
        obf_param["v_proj"].append(params["v_proj"])
        obf_param["o_proj"].append({
            "mask": o_mask,
            "ratio_mask": o_ratio_mask,
            "ratio_w": o_ratio_w,
            "permutation": o_permutation
        })
        obf_param["c_fc"].append({
            "mask": c_fc_mask,
            "ratio_mask": c_fc_ratio_mask,
            "ratio_w": c_fc_ratio_w,
            "permutation": c_fc_permutation
        })
        obf_param["c_proj"].append({
            "mask": c_proj_mask,
            "ratio_mask": c_proj_ratio_mask,
            "ratio_w": c_proj_ratio_w,
            "permutation": c_proj_permutation
        })
        
    return obf_param
        

def obfuscate_model(model):
    if isinstance(model, GPT2PreTrainedModel):
        return _obfuscate_model_gpt2(model)
    else:
        return _obfuscate_model(model)