from datetime import datetime
import torch 
import copy

device = torch.device("cuda:0")


def timestamp(label):

    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    label = str(label) if not isinstance(label, str) else label
    print(str(current_time) + ": " + label, flush=True)


def get_remove_mlp(model, model_name):

    
    cloned_model = copy.deepcopy(model)

    if model_name[0:4] == "gpt2":
        for layer in cloned_model.transformer.h:
            shape_c_fc = layer.mlp.c_fc.weight.shape
            shape_c_proj = layer.mlp.c_proj.weight.shape

            layer.mlp.c_fc.weight = torch.nn.Parameter(torch.zeros(shape_c_fc, device = device))
            layer.mlp.c_proj.weight = torch.nn.Parameter(torch.zeros(shape_c_proj, device = device))

            # layer.mlp.c_fc.weight = torch.nn.Parameter(torch.zeros(shape_c_fc))
            # layer.mlp.c_proj.weight = torch.nn.Parameter(torch.zeros(shape_c_proj))
        cloned_model.to(device)

    elif model_name in ["gemma-7b", "Llama-2-7b-hf"]:

        for layer in cloned_model.model.layers:
            shape_gate = layer.mlp.gate_proj.weight.shape
            shape_up = layer.mlp.up_proj.weight.shape
            shape_down = layer.mlp.down_proj.weight.shape

            layer.mlp.gate_proj.weight = torch.nn.Parameter(torch.zeros(shape_gate))
            layer.mlp.up_proj.weight = torch.nn.Parameter(torch.zeros(shape_up))
            layer.mlp.down_proj.weight = torch.nn.Parameter(torch.zeros(shape_down))

    

    return cloned_model



def get_remove_mha(model, model_name):

    cloned_model = copy.deepcopy(model)

    if model_name[0:4] == "gpt2":
        for layer in cloned_model.transformer.h:
            shape_c_attn = layer.attn.c_attn.weight.shape
            shape_c_proj = layer.attn.c_proj.weight.shape

            layer.attn.c_attn.weight = torch.nn.Parameter(torch.zeros(shape_c_attn, device = device))
            layer.attn.c_proj.weight = torch.nn.Parameter(torch.zeros(shape_c_proj, device = device))

            # layer.attn.c_attn.weight = torch.nn.Parameter(torch.zeros(shape_c_attn))
            # layer.attn.c_proj.weight = torch.nn.Parameter(torch.zeros(shape_c_proj))
        cloned_model.to(device)

    elif model_name in ["gemma-7b", "Llama-2-7b-hf"]:

        for layer in cloned_model.model.layers:
            shape_q = layer.self_attn.q_proj.weight.shape
            shape_k = layer.self_attn.k_proj.weight.shape
            shape_v = layer.self_attn.v_proj.weight.shape
            shape_o = layer.self_attn.o_proj.weight.shape

            layer.self_attn.q_proj.weight = torch.nn.Parameter(torch.zeros(shape_q))
            layer.self_attn.k_proj.weight = torch.nn.Parameter(torch.zeros(shape_k))
            layer.self_attn.v_proj.weight = torch.nn.Parameter(torch.zeros(shape_v))
            layer.self_attn.o_proj.weight = torch.nn.Parameter(torch.zeros(shape_o))

    return cloned_model


def get_remove_ln(model, model_name):

    cloned_model = copy.deepcopy(model)

    if model_name[0:4] == "gpt2":
        for layer in cloned_model.transformer.h:
            layer.ln_1 = torch.nn.Identity()
            layer.ln_2 = torch.nn.Identity()
        cloned_model.to(device)

    elif model_name in ["gemma-7b", "Llama-2-7b-hf"]:

        for layer in cloned_model.model.layers:
            layer.input_layernorm = torch.nn.Identity()
            layer.post_attention_layernorm = torch.nn.Identity()

    return cloned_model
