MAX_MASK_VAL = 1.0
import torch
import copy
import sys
sys.path.append("..")
from accelerate.utils.other import extract_model_from_parallel
from .copy_weight import (grow_embedding, grow_dim_self_att, grow_head_num, grow_dim_block_ln,
                          grow_dim_intermediate, grow_block, stack_block, grow_dim_lm_head, vanilla_copy, copy_att_ffn)
from .utils import transfer_states, count_parameters

class grow_ops(object):
    def __init__(self, model):
        model = extract_model_from_parallel(model)
        self.config = model.config
        self.available_to_grow = True
        self.curr_mode = None

        self.step_size = 0

        self.temp_mask = 0
        self.layer_mask=0
        self.layer_mask_step = 0

    def mask_to_gpu(self, model):
        model = extract_model_from_parallel(model)
        for m in model.modules():
            if hasattr(m, "grow_mask_head") and m.grow_mask_head is not None:
                m.grow_mask_head = m.grow_mask_head.to(model.transformer.wte.weight.device)

            if hasattr(m, "grow_mask_vec") and m.grow_mask_vec is not None:

                m.grow_mask_vec = m.grow_mask_vec.to(model.transformer.wte.weight.device)

            if hasattr(m, "grow_mask_vec_ffn") and m.grow_mask_vec_ffn is not None:
                m.grow_mask_vec_ffn = m.grow_mask_vec_ffn.to(model.transformer.wte.weight.device)

    def set_grow(self, m1, m2, mode, target, steps, args):

        
        m1 = extract_model_from_parallel(m1)

        m1.requires_grad_(False)

        self.step_size = 768
        self.layer_mask_step = MAX_MASK_VAL / steps

        if mode == "hidden_size":
            grow_embedding(m1.transformer, m2.transformer, target, args)
            for id in range(len(m1.transformer.h)):
                grow_dim_self_att(m1.transformer.h[id].attn.attention, m2.transformer.h[int(id*2)].attn.attention, target, args)
                grow_dim_intermediate(m1.transformer.h[id].mlp, m2.transformer.h[int(id*2)].mlp, m2.transformer.h[int(id*2)].mlp.intermediate_size, target, args)
                grow_dim_block_ln(m1.transformer.h[id], m2.transformer.h[int(id*2)], target, args)
                m2.transformer.h[int(id * 2)].is_new_block=False
                m2.transformer.h[int(id * 2+1)].is_new_block = True
            grow_dim_lm_head(m1, m2, target, args)

            
        elif mode == "heads":
            for id in range(len(m1.transformer.h)):
                grow_head_num(m1.transformer.h[id].attn.attention, m2.transformer.h[int(id*2)].attn.attention, target, args)

        elif mode == "layers":
            m2.transformer.in_growth_layer=True
            m2.transformer.set_mask_layer(0)
            for i in range(len(m2.transformer.h)):
                if i % 2 == 0:
                    m2.transformer.h[i].is_new_block=False
                else:

                    with torch.no_grad():
                        copy_att_ffn(m2.transformer.h[i-1],m2.transformer.h[i])
                        m2.transformer.h[i].ln_1.weight[:]=0
                        m2.transformer.h[i].ln_1.bias[:] = 0
                        m2.transformer.h[i].ln_2.weight[:] = 0
                        m2.transformer.h[i].ln_2.bias[:] = 0


                    m2.transformer.h[i].is_new_block = True
        m2.requires_grad_(True)
        self.available_to_grow = False
        self.curr_mode = mode

    def grow_opt(self,m1,m2,opt1,opt2,args):
        m1 = extract_model_from_parallel(m1)
        m2 = extract_model_from_parallel(m2)
        param_list_1 = list([n, p] for n, p in m1.named_parameters() if "h" not in n)
        param_list_2 = list([n, p] for n, p in m2.named_parameters() if "h" not in n)

        assert len(param_list_1) == len(param_list_2)
        for p1, p2 in zip(param_list_1, param_list_2):
            transfer_states(p1, p2, opt1, opt2)
        for i in range(len(m1.transformer.h)):
            param_list_1 = list([n, p] for n, p in m1.named_parameters() if "h."+str(i) in n)
            param_list_2 = list([n, p] for n, p in m2.named_parameters() if "h."+str(int(i*2)) in n)
        for p1, p2 in zip(param_list_1, param_list_2):
            transfer_states(p1, p2, opt1, opt2)

    def end_grow(self, model):
        model = extract_model_from_parallel(model)
        self.step_size = 1024
        self.temp_mask = 0.
        model.transformer.in_growth = False
        model.transformer.set_mask(None, self.step_size)
        model.transformer.in_growth_layer = True
        model.transformer.set_mask_layer(1)
        for single_layer in model.transformer.h:
            if single_layer.is_new_block:
                single_layer.is_new_block=False
                continue
            single_layer.attn.attention.in_grow_dim = False
            single_layer.attn.attention.set_mask(None, self.step_size)
            single_layer.mlp.in_grow_dim=False
            single_layer.mlp.set_mask(None, self.step_size)
            single_layer.mlp.set_mask_ffn(None, self.step_size)
            single_layer.in_grow_dim = False
            single_layer.set_mask(None, self.step_size)
            single_layer.mlp.in_grow_fnn=False
            single_layer.attn.attention.in_grow_head = False
            single_layer.attn.attention.set_mask_head(None, self.step_size)

            single_layer.is_new_block = False

        model.transformer.in_growth_layer = False
        model.transformer.set_mask_layer(1)

        
        self.available_to_grow = True
        self.curr_mode = None
        
    def increase_mask(self, model,per_step_grow_dim):

        model = extract_model_from_parallel(model)
        self.step_size += per_step_grow_dim
        new_val = self.step_size
        model.transformer.set_mask(None, new_val)
        self.layer_mask+=self.layer_mask_step
        model.transformer.set_mask_layer(self.layer_mask)
        for single_layer in model.transformer.h:
            if single_layer.is_new_block:
                continue
            single_layer.attn.attention.set_mask(None, new_val)
            single_layer.mlp.set_mask(None, new_val)
            single_layer.set_mask(None, new_val)
            single_layer.mlp.set_mask_ffn(None, new_val)
            single_layer.attn.attention.set_mask_head(None, new_val)

    def print_all_masks(self, model):
        def n(tensor):
            return sum(tensor).item() if tensor is not None else None

        model = extract_model_from_parallel(model)
        if model.transformer.wte.weight.device == torch.device(type="cuda",index=0):
            all_masks = []
            all_masks.append([model.transformer.grow_mask_vec,model.transformer.grow_mask_vec.size()])
            for layer in model.transformer.h:
                all_masks.append([n(layer.attn.attention.grow_mask_head),
                n(layer.attn.attention.grow_mask_vec),
                n(layer.mlp.grow_mask_vec_ffn), n(layer.mlp.grow_mask_vec), n(layer.grow_mask_vec)])
            print(all_masks)

    def count_parameters(self, model):
        return count_parameters(model)