import torch
from transformers.models import GPT2PreTrainedModel


def _obfuscate_layer(layer):
    w = layer.weight.data
    num_rows = w.shape[0]
    device = w.device
    
    block_size = 2
    num_blocks = num_rows // block_size
    assert num_rows % block_size == 0, "num_rows must be divisible by block_size"
   
    blocks = []
    for i in range(num_blocks):
        blocks.append(torch.randn((block_size, block_size), device=device))

    w = torch.block_diag(*blocks).T @ w

    permutation = torch.randperm(num_rows)
    w = w[permutation]
    
    layer.weight.data = w
    
    return blocks, permutation
    
def _obfuscate_layer_gpt2(layer):
    w = layer.weight.data
    num_cols = w.shape[1]
    device = w.device

    block_size = 2
    num_blocks = num_cols // block_size
    assert num_cols % block_size == 0, "num_rows must be divisible by block_size"
   
    blocks = []
    for i in range(num_blocks):
        blocks.append(torch.randn((block_size, block_size), device=device))

    w = w @ torch.block_diag(*blocks)

    permutation = torch.randperm(num_cols)
    ob_w = w[:,permutation]
    layer.weight.data = ob_w

    return blocks, permutation
    
def _obfuscate_attn_gpt2(layer, is_cross_attention):
    if is_cross_attention:
        w_q = layer.q_attn.weight.data
        w_k, w_v = layer.c_attn.weight.data.chunk(2, dim=1)
    else:
        w_q, w_k, w_v = layer.c_attn.weight.data.chunk(3, dim=1)
    num_cols = w_q.shape[1]
    device = w_q.device
    
    # Generate diagonal block matrices for q, k, v
    block_size = 2
    num_blocks_q = w_q.shape[1] // block_size
    num_blocks_k = w_k.shape[1] // block_size
    num_blocks_v = w_v.shape[1] // block_size
    
    blocks_q = []
    for _ in range(num_blocks_q):
        block = torch.randn((block_size, block_size), device=device)
        blocks_q.append(block)
    coeff_mat_q = torch.block_diag(*blocks_q)[:w_q.shape[1], :w_q.shape[1]]
    
    blocks_k = []
    for _ in range(num_blocks_k):
        block = torch.randn((block_size, block_size), device=device)
        blocks_k.append(block)
    coeff_mat_k = torch.block_diag(*blocks_k)[:w_k.shape[1], :w_k.shape[1]]
    
    blocks_v = []
    for _ in range(num_blocks_v):
        block = torch.randn((block_size, block_size), device=device)
        blocks_v.append(block)
    coeff_mat_v = torch.block_diag(*blocks_v)[:w_v.shape[1], :w_v.shape[1]]
    
    w_q = w_q @ coeff_mat_q
    w_k = w_k @ coeff_mat_k
    w_v = w_v @ coeff_mat_v

    permutation_q = torch.randperm(num_cols)
    permutation_k = torch.randperm(num_cols)
    permutation_v = torch.randperm(num_cols)

    ob_w_q = w_q[:,permutation_q]
    ob_w_k = w_k[:,permutation_k]
    ob_w_v = w_v[:,permutation_v]
    if is_cross_attention:
        layer.c_attn.weight.data = torch.cat([ob_w_k, ob_w_v], dim=1)
        layer.q_attn.weight.data = ob_w_q
    else:
        layer.c_attn.weight.data = torch.cat([ob_w_q, ob_w_k, ob_w_v], dim=1)
        
    return {
        "q_proj": {"blocks": blocks_q, "permutation": permutation_q},
        "k_proj": {"blocks": blocks_k, "permutation": permutation_k},
        "v_proj": {"blocks": blocks_v, "permutation": permutation_v}
    }
    
def _obfuscate_model(model):

    hidden_size = model.config.hidden_size
    
    obf_param = {
        "q_proj": [],
        "k_proj": [],
        "v_proj": [],
        "o_proj": [],
        "gate_proj": [],
        "up_proj": [],
        "down_proj": []
    }
    
    for decoderLayer in model.model.layers:
        attn = decoderLayer.self_attn
        mlp = decoderLayer.mlp
        
        # Obfuscate each layer and store parameters
        q_blocks, q_permutation = _obfuscate_layer(attn.q_proj)
        k_blocks, k_permutation = _obfuscate_layer(attn.k_proj)
        v_blocks, v_permutation = _obfuscate_layer(attn.v_proj)
        o_blocks, o_permutation = _obfuscate_layer(attn.o_proj)
        gate_blocks, gate_permutation = _obfuscate_layer(mlp.gate_proj)
        up_blocks, up_permutation = _obfuscate_layer(mlp.up_proj)
        down_blocks, down_permutation = _obfuscate_layer(mlp.down_proj)
        
        # Store parameters in the obf_param dictionary
        obf_param["q_proj"].append({
            "blocks": q_blocks,
            "permutation": q_permutation
        })
        obf_param["k_proj"].append({
            "blocks": k_blocks,
            "permutation": k_permutation
        })
        obf_param["v_proj"].append({
            "blocks": v_blocks,
            "permutation": v_permutation
        })
        obf_param["o_proj"].append({
            "blocks": o_blocks,
            "permutation": o_permutation
        })
        obf_param["gate_proj"].append({
            "blocks": gate_blocks,
            "permutation": gate_permutation
        })
        obf_param["up_proj"].append({
            "blocks": up_blocks,
            "permutation": up_permutation
        })
        obf_param["down_proj"].append({
            "blocks": down_blocks,
            "permutation": down_permutation
        })
    return obf_param
            

def _obfuscate_model_gpt2(model: GPT2PreTrainedModel):
    obf_param = {
        "q_proj": [],
        "k_proj": [],
        "v_proj": [],
        "o_proj": [],
        "c_fc": [],
        "c_proj": [],
    }

    for decoderLayer in model.transformer.h:
        
        attn = decoderLayer.attn
        mlp = decoderLayer.mlp

        # obfuscate the attention layer
        params = _obfuscate_attn_gpt2(attn, attn.is_cross_attention)

        o_blocks, o_permutation = _obfuscate_layer_gpt2(attn.c_proj)
        
        # obfuscate the mlp layer
        c_fc_blocks, c_fc_permutation = _obfuscate_layer_gpt2(mlp.c_fc)
        c_proj_blocks, c_proj_permutation = _obfuscate_layer_gpt2(mlp.c_proj)
        
        obf_param["q_proj"].append(params["q_proj"])
        obf_param["k_proj"].append(params["k_proj"])
        obf_param["v_proj"].append(params["v_proj"])
        obf_param["o_proj"].append({
            "blocks": o_blocks,
            "permutation": o_permutation
        })
        obf_param["c_fc"].append({
            "blocks": c_fc_blocks,
            "permutation": c_fc_permutation
        })
        obf_param["c_proj"].append({
            "blocks": c_proj_blocks,
            "permutation": c_proj_permutation
        })
        
    return obf_param
        

def obfuscate_model(model):
    if isinstance(model, GPT2PreTrainedModel):
        return _obfuscate_model_gpt2(model)
    else:
        return _obfuscate_model(model)