import torch
import copy
import torch.utils.checkpoint
from torch import nn

from .utils import expand_linear, expand_embedding, expand_norm, init_weight_ex
from .modeling_gpt_neo import (GPTNeoModel, GPTNeoSelfAttention, GPTNeoMLP, GPTNeoBlock,
                               GPTNeoForCausalLM, GPTNeoModel)

CONSTANT_ATTENTION_HEAD_SIZE = 64
CONSTANT_MIN_MASK_VAL = 0.

def grow_embedding(old_module, new_module, target_dim, args):
    assert isinstance(old_module, GPTNeoModel) and isinstance(new_module, GPTNeoModel)

    new_module.in_growth = True
    # expand weights
    word_embeddings_new = expand_embedding(old_module.wte, target_dim, old_module.config.vocab_size, args)
    position_embeddings_new = expand_embedding(old_module.wpe, target_dim,
                                                old_module.config.max_position_embeddings, args)
    LayerNorm_new = expand_norm(old_module.ln_f, target_dim, old_module.config.layer_norm_epsilon, args)
    new_module.wte.load_state_dict(word_embeddings_new.state_dict(),strict=True)
    new_module.wpe.load_state_dict(position_embeddings_new.state_dict(),strict=True)
    new_module.ln_f.load_state_dict(LayerNorm_new.state_dict(),strict=True)
    new_module.old_hidden_size = old_module.hidden_size
    new_module.hidden_size = target_dim

    new_module.set_mask(target_dim, val=CONSTANT_MIN_MASK_VAL)


def grow_dim_self_att(old_module, new_module, target_size, args):
    assert isinstance(old_module, GPTNeoSelfAttention) and isinstance(new_module, GPTNeoSelfAttention)
    new_module.in_grow_dim = True
    assert target_size > old_module.hidden_size
    query_new = expand_linear(old_module.q_proj, target_size, target_size, args)
    key_new = expand_linear(old_module.k_proj, target_size, target_size, args)
    value_new = expand_linear(old_module.v_proj, target_size, target_size, args)
    out_new = expand_linear(old_module.out_proj, target_size, target_size, args)

    new_module.q_proj.load_state_dict(query_new.state_dict(),strict=True)
    new_module.k_proj.load_state_dict(key_new.state_dict(),strict=True)
    new_module.v_proj.load_state_dict(value_new.state_dict(),strict=True)
    new_module.out_proj.load_state_dict(out_new.state_dict(), strict=True)

    new_module.old_hidden_size = old_module.hidden_size
    new_module.hidden_size = target_size
    new_module.set_mask(target_size, val=CONSTANT_MIN_MASK_VAL)


def grow_head_num(old_module, new_module, target_head_num, args):
    assert isinstance(old_module, GPTNeoSelfAttention) and isinstance(new_module, GPTNeoSelfAttention)
    new_module.in_grow_head = True

    new_head_count = target_head_num - old_module.num_heads
    assert new_head_count > 0

    new_module.num_heads = target_head_num
    new_module.new_head_count = new_head_count
    
    new_module.all_head_size = target_head_num * new_module.head_dim
    new_module.set_mask_head(target_head_num, val=CONSTANT_MIN_MASK_VAL)


def grow_dim_intermediate(old_module, new_module, target_intermediate_size, target_input_size, args):
    assert isinstance(old_module, GPTNeoMLP) and isinstance(new_module, GPTNeoMLP)

    if target_intermediate_size > old_module.intermediate_size:
        new_module.in_grow_ffn = True
    if target_input_size > old_module.hidden_size:
        new_module.in_grow_dim = True

    c_fc_new = expand_linear(old_module.c_fc, target_intermediate_size, target_input_size, args)
    c_proj_new = expand_linear(old_module.c_proj, target_input_size, target_intermediate_size, args)
    new_module.c_fc.load_state_dict(c_fc_new.state_dict(),strict=True)
    new_module.c_proj.load_state_dict(c_proj_new.state_dict(), strict=True)
    if new_module.in_grow_ffn:
        new_module.old_intermediate_size = old_module.intermediate_size
        new_module.intermediate_size = target_intermediate_size
        new_module.set_mask_ffn(target_intermediate_size, val=CONSTANT_MIN_MASK_VAL)
    if new_module.in_grow_dim:
        new_module.old_hidden_size = old_module.hidden_size
        new_module.hidden_size = target_input_size
        new_module.set_mask(target_input_size, val=CONSTANT_MIN_MASK_VAL)

def grow_dim_block_ln(old_module, new_module,target_size, args):
    assert isinstance(old_module, GPTNeoBlock) and isinstance(new_module, GPTNeoBlock)
    assert target_size > old_module.hidden_size

    ln1_new = expand_norm(old_module.ln_1, target_size, old_module.config.layer_norm_epsilon, args)
    ln2_new = expand_norm(old_module.ln_2, target_size, old_module.config.layer_norm_epsilon, args)

    new_module.ln_1.load_state_dict(ln1_new.state_dict(), strict=True)
    new_module.ln_2.load_state_dict(ln2_new.state_dict(), strict=True)

    new_module.old_hidden_size = old_module.hidden_size
    new_module.set_mask(target_size, val=CONSTANT_MIN_MASK_VAL)
    new_module.in_grow_dim=True



def grow_dim_lm_head(old_module, new_module, target_hidden_size, args):
    assert isinstance(old_module, GPTNeoForCausalLM) and isinstance(new_module, GPTNeoForCausalLM)
    lm_head_new = expand_linear(old_module.lm_head, new_module.config.vocab_size, target_hidden_size, args)
    new_module.lm_head.load_state_dict(lm_head_new.state_dict(),strict=True)

def copy_att_ffn(l1,l2):
    l2.attn.attention.q_proj.load_state_dict(copy.deepcopy(l1.attn.attention.q_proj.state_dict()), strict=True)
    l2.attn.attention.k_proj.load_state_dict(copy.deepcopy(l1.attn.attention.k_proj.state_dict()), strict=True)
    l2.attn.attention.v_proj.load_state_dict(copy.deepcopy(l1.attn.attention.v_proj.state_dict()), strict=True)
    l2.attn.attention.out_proj.load_state_dict(copy.deepcopy(l1.attn.attention.out_proj.state_dict()), strict=True)
    l2.attn.attention.out_proj.bias[:] = 0
    l2.mlp.c_proj.load_state_dict(copy.deepcopy(l1.mlp.c_proj.state_dict()), strict=True)
    l2.mlp.c_fc.load_state_dict(copy.deepcopy(l1.mlp.c_fc.state_dict()), strict=True)
    l2.mlp.c_proj.bias[:] = 0
    l2.mlp.c_fc.bias[:]=0

def vanilla_copy(old_module, new_module):
    assert type(old_module) == type(new_module)
    new_module.load_state_dict(old_module.state_dict(), strict=True)