import gc
import json
import torch
import time
import tqdm
import hydra
import os
from omegaconf import DictConfig

from src.model import load_model_and_tokenizer

@hydra.main(version_base=None, config_path="src/conf", config_name="latency_measure")
def main(cfg: DictConfig):
    model, tokenizer = load_model_and_tokenizer(cfg.model_name_or_path, cfg.quantizer)

    with torch.no_grad():
        results = dict()
        for batch_size in cfg.batch_sizes:
            prefill_time = 0
            context = []
            for _ in range(batch_size):
                string = 't,' * (cfg.prompt_length // 2)
                context.append(string[:-1])
            inputs = tokenizer(context, return_tensors="pt").to('cuda')
            input_ids = inputs['input_ids']
            position_ids = torch.arange(start=0, end=cfg.prompt_length, dtype=torch.long, device=model.device)
            position_ids = position_ids.repeat(batch_size, 1)
            torch.cuda.reset_peak_memory_stats()

            for i in tqdm.tqdm(range(cfg.prefill_repeat)):
                torch.cuda.synchronize()
                t0 = time.time()
                outputs = model(input_ids=input_ids, position_ids=position_ids, use_cache=True)
                torch.cuda.synchronize()
                t1 = time.time()
                prefill_time += t1-t0
            print(prefill_time / cfg.prefill_repeat)

            start_cache = outputs.past_key_values
            decode_time = 0
            for i in tqdm.tqdm(range(cfg.generation_repeat)):
                cache = start_cache
                for j in range(cfg.generation_length):
                    input_ids = input_ids[:, -1].unsqueeze(1)
                    position_ids = torch.full((batch_size, 1), cfg.prompt_length+j, dtype=torch.long, device=model.device)  # (B, 1)
                    torch.cuda.synchronize()
                    t0 = time.time()
                    outputs = model(input_ids=input_ids, past_key_values=cache, use_cache=True)
                    torch.cuda.synchronize()
                    t1 = time.time()
                    decode_time += (t1-t0)
                    cache = outputs.past_key_values
            print(decode_time / (cfg.generation_repeat * cfg.generation_length))


if __name__ == "__main__":
    print(main())
