from model import GPT, GPTConfig

model_configs = {
    "6xs2": dict(n_layer=4, n_head=4, n_embd=64),
    "6xs": dict(n_layer=3, n_head=4, n_embd=80),
    "5xs": dict(n_layer=3, n_head=4, n_embd=96),
    "5xs2": dict(n_layer=3, n_head=4, n_embd=112),
    "5xs1": dict(n_layer=3, n_head=4, n_embd=128),
    "4xs": dict(n_layer=4, n_head=4, n_embd=128),
    "4xs2": dict(n_layer=4, n_head=4, n_embd=144),
    "4xs1": dict(n_layer=4, n_head=4, n_embd=192),
    "3xs": dict(n_layer=5, n_head=4, n_embd=192),
    "3xs2": dict(n_layer=5, n_head=4, n_embd=224),
    "3xs1": dict(n_layer=5, n_head=4, n_embd=256),
    "xxs": dict(n_layer=6, n_head=4, n_embd=256),
    "xxs3": dict(n_layer=6, n_head=4, n_embd=288),
    "xxs2": dict(n_layer=6, n_head=8, n_embd=352),
    "xxs4": dict(n_layer=6, n_head=8, n_embd=416),
    "xxs1": dict(n_layer=6, n_head=8, n_embd=448),
    "xs": dict(n_layer=7, n_head=8, n_embd=448),
    "xs3": dict(n_layer=8, n_head=8, n_embd=448),
    "xs2": dict(n_layer=8, n_head=8, n_embd=512),
    "xs1": dict(n_layer=9, n_head=8, n_embd=512),
    "s": dict(n_layer=10, n_head=8, n_embd=512),
    "s3": dict(n_layer=10, n_head=8, n_embd=704),
    "s2": dict(n_layer=10, n_head=12, n_embd=768),
    "s1": dict(n_layer=11, n_head=12, n_embd=768),
    "base": dict(n_layer=12, n_head=12, n_embd=768),
    "m5": dict(n_layer=12, n_head=16, n_embd=896),
    "m4": dict(n_layer=12, n_head=16, n_embd=1024),
    "m3": dict(n_layer=13, n_head=16, n_embd=1024),
    "m2": dict(n_layer=14, n_head=16, n_embd=1024),
    "m1": dict(n_layer=15, n_head=16, n_embd=1024),
    "m": dict(n_layer=16, n_head=16, n_embd=1024),
    # "l": dict(n_layer=24, n_head=16, n_embd=1536),
    # "xl": dict(n_layer=32, n_head=32, n_embd=2048),
    # "xxl": dict(n_layer=40, n_head=32, n_embd=3072),
}
if __name__ == "__main__":
    num_vocab = 50304
    result = []
    for model_size in model_configs:
        config = model_configs[model_size]
        model = GPT(GPTConfig(vocab_size=num_vocab, **config))
        result.append(
            {
                "name": model_size,
                "params": model.get_params(include_lm_head=False),
                "params_with_lm_head": model.get_params(include_lm_head=True),
            }
        )

    print(result)

    import pandas as pd

    df = pd.DataFrame(result)
    df.to_csv("model_size.csv", index=False)
