import time

import torch
from transformers import LlamaConfig, LlamaModel

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        llama_config = LlamaConfig(
            hidden_size=512,
            num_hidden_layers=6,
            num_attention_heads=2,
            max_position_embeddings=11,
            intermediate_size=400 
        )

        # self.model = LlamaModel._from_config(llama_config, attn_implementation="flash_attention_2")
        self.model = LlamaModel._from_config(llama_config, attn_implementation="flash_attention_2")

    def forward(self, x):
        return self.model(inputs_embeds=x)

model = MyNet()
model.to('cuda')

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# scaler = torch.cuda.amp.GradScaler(enabled=True)
x = torch.randn(16, 512, 512).to('cuda')

with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
    for _ in range(10):
        output = model(x)

    start = time.time()
    for i in range(20):
        output = model(x)
    end = time.time()
    print(f"FPS: {(16 * 512 * 20) / (end - start)}")

loss = output.last_hidden_state.sum()

# scaler.scale(loss).backward()
# scaler.unscale_(optimizer)
# scaler.step(optimizer)

