import argparse
import torch
import torch.nn as nn
from transformers import (
    GPT2Tokenizer, GPT2LMHeadModel,
    AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, Qwen2ForCausalLM
)
import os
import copy


def count_parameters(model):
    """计算模型参数量"""
    return sum(p.numel() for p in model.parameters())


def build_model(args):
    """
    根据指定 method 构建新模型（auto-learngene 或 van-learngene）
    """
    # ----------------
    # 1. 加载模型 & tokenizer
    # ----------------
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
    model = AutoModelForCausalLM.from_pretrained(args.model_path)
    # if args.model_name.lower() == 'llama3-8b':
    #     model = LlamaForCausalLM.from_pretrained(args.model_path)
    # elif args.model_name.lower() == 'qwen3-4b':
    #     model = Qwen2ForCausalLM.from_pretrained(args.model_path)
        

    # ----------------
    # 2. 确定层数和基因层选择
    # ----------------
    model_name = args.model_name.lower()
    if model_name == "gpt2-xl":
        total_layers = 48
        auto_layers = 3
        van_layers = 3
        layer_attr = model.transformer.h
    elif model_name == "llama3-8b":
        total_layers = 32
        auto_layers = 1
        van_layers = 1
        layer_attr = model.model.layers
    elif model_name == "qwen3-4b":
        total_layers = 40
        auto_layers = 5
        van_layers = 3
        layer_attr = model.model.layers
    else:
        raise ValueError(f"Unsupported model_name: {args.model_name}")

    # ----------------
    # 3. 提取基因层
    # ----------------
    if args.method == "auto-learngene":
        gene_layers = layer_attr[:auto_layers]
    elif args.method == "van-learngene":
        gene_layers = layer_attr[-van_layers:]
    else:
        raise ValueError(f"Unsupported method: {args.method}")

    # ----------------
    # 4. 随机初始化新层
    # ----------------
    new_layers = []
    if args.extra_layers > 0:
        sample_block = layer_attr[0]  # 取第一个block当模版
        for _ in range(args.extra_layers):
            new_block = copy.deepcopy(sample_block)  # 克隆结构
            # 随机初始化参数
            new_block.apply(model._init_weights)
            new_layers.append(new_block)

    # ----------------
    # 5. 拼接新结构
    # ----------------
    if args.method == "auto-learngene":
        final_layers = list(gene_layers) + new_layers
    else:  # van-learngene
        final_layers = new_layers + list(gene_layers)

    # 替换模型的层
    if model_name == "gpt2-xl":
        model.transformer.h = nn.ModuleList(final_layers)
        model.config.n_layer = len(final_layers)   # 修改 config
    else:  # llama3 / qwen
        model.model.layers = nn.ModuleList(final_layers)
        model.config.num_hidden_layers = len(final_layers)  # 修改 config

    # ----------------
    # 6. 打印参数量
    # ----------------
    total_params = count_parameters(model)
    print(f"New model built with method={args.method}, "
          f"extra_layers={args.extra_layers}")
    print(f"Total parameters: {total_params / 1e6:.2f}M")

    return model, tokenizer


def main():
    parser = argparse.ArgumentParser(description="Build LLM with learngene methods")
    parser.add_argument("--model_name", type=str, required=True,
                        choices=["gpt2-xl", "llama3-8b", "qwen3-4b"])
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to pretrained model weights")
    parser.add_argument("--tokenizer_path", type=str, required=True,
                        help="Path to tokenizer")
    parser.add_argument("--method", type=str, required=True,
                        choices=["auto-learngene", "van-learngene"])
    parser.add_argument("--extra_layers", type=int, default=0,
                        help="Number of randomly initialized layers to add")
    parser.add_argument("--save_path", type=str, required=True,
                        help="Where to save the new model")

    args = parser.parse_args()

    model, tokenizer = build_model(args)
    args.save_path = os.path.join(args.save_path, args.model_name+'-'+args.method)

    # ----------------
    # 7. 保存模型和分词器
    # ----------------
    model.save_pretrained(args.save_path)
    tokenizer.save_pretrained(args.save_path)
    print(f"New model saved at {args.save_path}")


if __name__ == "__main__":
    main()
