import types
import torch
import torch.nn as nn

from transformers import (
    GPT2LMHeadModel,
    LlamaForCausalLM,
)
from . import register_model_initialization
from . import model_utils


def freeze_module_params(m: torch.nn.Module):
    '''
    Freeze gradient for specific layers
    '''
    if m is not None:
        for p in m.parameters:
            p.requires_grad = False


@register_model_initialization("gpt2", "config1")
def gpt2_config1(model: GPT2LMHeadModel):
    for name, p in model.named_parameters():
        if "c_attn" not in name:
            p.requires_grad = False

    model.compute = types.MethodType(model_utils.compute_gpt2, model)
    return model


@register_model_initialization("gpt2", "config2")
def gpt2_config2(model: GPT2LMHeadModel):
    config = model.config 
    config.n_ctx = 2048
    config.max_position_embeddings = 2048

    _model = GPT2LMHeadModel(config)
    model.config.n_ctx = 2048

    with torch.no_grad():
        wpe = nn.Embedding(model.config.n_ctx, model.config.n_embd)
        wpe.weight[:1024] = model.transformer.wpe.weight
        model.transformer.wpe = wpe

    _model.load_state_dict(model.state_dict())
    model = _model

    for name, p in model.named_parameters():
        if "c_attn" not in name or "wpe" in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_gpt2_with_positions, model)
    return model


@register_model_initialization("meta-llama/Llama-2-7b-hf", "config1")
def llama2_config1(model: LlamaForCausalLM):
    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    return model


@register_model_initialization("meta-llama/Llama-2-7b-hf", "config2")
def llama2_config2(_model: LlamaForCausalLM):
    config = _model.config
    config.max_position_embeddings = 8192

    model = LlamaForCausalLM(config)

    model.load_state_dict(_model.state_dict())

    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_llama2_with_positions, model)
    return model

@register_model_initialization("meta-llama/Llama-2-7b-hf", "config4096")
def llama2_config4096(_model: LlamaForCausalLM):
    config = _model.config

    model = LlamaForCausalLM(config)

    model.load_state_dict(_model.state_dict())

    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_llama2_with_positions, model)
    return model

@register_model_initialization("meta-llama/Llama-2-7b-hf", "config16384")
def llama2_config16384(_model: LlamaForCausalLM):
    config = _model.config
    config.max_position_embeddings = 16384

    model = LlamaForCausalLM(config)

    model.load_state_dict(_model.state_dict())

    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_llama2_with_positions, model)
    return model

@register_model_initialization("meta-llama/Llama-2-7b-hf", "config65536")
def llama2_config65536(_model: LlamaForCausalLM):
    config = _model.config
    config.max_position_embeddings = 65536

    model = LlamaForCausalLM(config)

    model.load_state_dict(_model.state_dict())

    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_llama2_with_positions, model)
    return model


@register_model_initialization("meta-llama/Llama-2-7b-hf", "config65536test")
def llama2_config65536(_model: LlamaForCausalLM):
    config = _model.config
    config.max_position_embeddings = 65536

    model = LlamaForCausalLM(config)

    model.load_state_dict(_model.state_dict())

    for name, p in model.named_parameters():
        if "q_proj" not in name and "k_proj" not in name:
            p.requires_grad = False
    
    model.compute = types.MethodType(model_utils.compute_llama2_test, model)
    return model


