

from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRotaryEmbedding,LlamaRMSNorm,LlamaForCausalLM
import transformers
import tokenizers

from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union
from transformers.cache_utils import DynamicCache

import time
import os
import math
import warnings
import shutil
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from torch.cuda.amp import autocast
import wandb
import torch.nn as nn
from dataclasses import dataclass, field
import argparse
def parse_args():
    parser = argparse.ArgumentParser(description='Run an evaluation task')

    parser.add_argument('--model_path', type=str, help='Path to model/config file')

    parser.add_argument('--num_ee_block', type=int, default="4", help='stage indicator')

    
    

    args = parser.parse_args()
    return args

args = parse_args()
model_path = args.model_path


device = torch.device('cuda')
config = transformers.AutoConfig.from_pretrained(
    model_path,
)

orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and 2048 > orig_ctx_len:
    scaling_factor = float(math.ceil(2048 / orig_ctx_len))
    config.rope_scaling = {"type": "linear", "factor": scaling_factor}
config.use_cache = False

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_path,
    config=config,
)
print(model)
# cut model according to the exit position
num_layers = config.num_hidden_layers
print("***Number of layers in total: ", num_layers, "***")
num_ee_block = args.num_ee_block
block_len = num_layers // num_ee_block
left=num_layers % num_ee_block
save_path = model_path+f"-{num_ee_block}cut/"
torch.save(model.model.rotary_emb.state_dict(), save_path+"rotary_emb.pth")
torch.save(model.model.norm.state_dict(),save_path+"norm.pth")
torch.save(model.lm_head.state_dict(),save_path+"lmhead.pth")
torch.save(model.model.embed_tokens.state_dict(),save_path+"embed_tokens.pth")
layer_list = []
for i in range(num_ee_block):
    layer_list.append(block_len)
for i in range(left):
    layer_list[-i-1]+=1
print(layer_list)
sum=0
for i in range(len(layer_list)):
    decoder_layers = nn.ModuleList([LlamaDecoderLayer(config,j) for j in range(layer_list[i])])
    # get model cut
    for j in range(layer_list[i]):
        decoder_layers[j].load_state_dict(model.model.layers[sum+j].state_dict())
    sum+=layer_list[i]
    torch.save(decoder_layers.state_dict(), save_path+f"decoder_layers{i}.pth")
    print(f"successfully saved decoder_layer{i}",sum)


    # get the last layer
    #last_trm = LlamaDecoderLayer(config, (1+j)*block_len)
    #last_trm.load_state_dict(model.model.layers[-1].state_dict())
    #torch.save(last_trm.state_dict(), save_path+f"last_trm{j}.pth")

    # get the next layer
    #next_trm = LlamaDecoderLayer(config, (1+j)*block_len)
    #next_trm.load_state_dict(model.model.layers[(1+j)*block_len].state_dict())
    #torch.save(next_trm.state_dict(), save_path+f"next_trm{j}.pth")