import gc
import json
import torch
import time
import tqdm
import hydra
import os
from omegaconf import DictConfig

from src.model import load_model_and_tokenizer

@hydra.main(version_base=None, config_path="src/conf", config_name="spd_mem_bench_config")
def main(cfg: DictConfig):
    model, tokenizer = load_model_and_tokenizer(cfg.model_name_or_path, cfg.quantizer)

    prompt_lenth = cfg.prompt_length
    output_length = cfg.output_length
    num_repeats = cfg.num_repeats

    throughputs = []
    mems = []
    for batch_size in cfg.batch_size:
        gc.collect()
        torch.cuda.empty_cache()
        context = []
        for _ in range(batch_size):
            string = 't,' * (prompt_lenth // 2)
            context.append(string[:-1])
        inputs = tokenizer(context, return_tensors="pt").to('cuda')
        input_ids = inputs['input_ids']
        print(f"bs: {batch_size}, seqlen: {input_ids.shape[1]}+{output_length}\nmodel:{cfg.model_name_or_path}")
        torch.cuda.reset_peak_memory_stats()
        try:
            with torch.no_grad():
                gc.collect()
                torch.cuda.synchronize()
                st = time.time()
                for i in tqdm.tqdm(range(num_repeats)):
                    outputs = model.generate(**inputs, max_new_tokens=output_length)
                torch.cuda.synchronize()
                used_time = (time.time() - st) / num_repeats * 1000
                print(f'used time: {used_time} ms')
                used_mem = torch.cuda.max_memory_allocated() / 1024 ** 3
                print(f'peak mem: {used_mem} GB')
                throughputs.append((batch_size*output_length) / (used_time / 1000))
                mems.append(used_mem)
            torch.cuda.empty_cache()
        except:
            pass

    results = dict()
    for i in range(len(cfg.batch_size)):
        results[cfg.batch_size[i]] = {
            "throughput": throughputs[i],
            "peak_memory": mems[i]
        }

    os.makedirs(cfg.output_dir, exist_ok=True)
    with open(os.path.join(cfg.output_dir, f"results_{cfg.model_name_or_path.split('/')[-1]}_{cfg.quantizer.save_postfix}.txt"), "w") as f:
        json.dump(results, f)

    print(throughputs, mems)


if __name__ == "__main__":
    print(main())
