"""
Script to benchmark latency and throughput of generation workloads for different language models.
"""

import sys
import gc
import os
from itertools import product

import torch
from torch.utils.benchmark import Compare

if __name__ == "__main__":
    SAFARI_PATH = os.environ.get("SAFARI_PATH", "/mnt/safari-internal/")
    sys.path.append(SAFARI_PATH)

    from src.models.sequence.simple_lm import SimpleLMHeadModel
    from src.utils.benchmark import benchmark_direction, benchmark_forward
    from src.utils.profiling import handle_oom

    results = []
    b_size_range = [2**x for x in range(0, 7)]
    seqlen_range = [2**x for x in range(10, 17)]
    d_model_range = [768]
    directions = ["forward"]
    n_heads = 1
    n_tokens_range = [256]
    vocab_size = 50257
    n_layers = 12

    # ==================================
    # Token Generation
    # ==================================
    b_size_range = [2**x for x in range(0, 9)]
    seqlen_range = [2**x for x in range(10, 11)]
    for b_size, d_model, seqlen, n_tokens in product(
        b_size_range, d_model_range, seqlen_range, n_tokens_range
    ):
        models = [
            # (
            #     SimpleLMHeadModel(d_model, n_layers, d_model, vocab_size).to("cuda"),
            #     "MHA",
            # ),
            # (
            #     SimpleLMHeadModel(
            #         d_model, n_layers, d_model, vocab_size, kv_caching=True
            #     ).to("cuda"),
            #     "MHA (KV)",
            # ),
            # (
            #     SimpleLMHeadModel(
            #         d_model, n_layers, d_model, vocab_size, layer="laughing-hyena"
            #     ).to("cuda"),
            #     "Laughing Hyena (conv)",
            # ),
            (
                SimpleLMHeadModel(
                    d_model,
                    n_layers,
                    d_model,
                    vocab_size,
                    layer="laughing-hyena",
                    recurrence=True,
                ).to("cuda"),
                "Laughing Hyena (recurrent)",
            ),
        ]
        for model, model_name in models:
            print(f"Running {model_name} with batch size {b_size}")
            model.eval()
            if model_name == "Laughing Hyena (recurrent)":
                filter_cfg = {
                    "num_order": 2,
                    "den_order": 2,
                    "num_filters": 768,
                    "heads": 16,
                    "decay_rate": 1e-2,
                    "real_fft": False,
                    "train_mixer": True,
                }

                model.setup_recurrence(b_size, filter_cfg)
            try:
                input_ids = torch.randint(vocab_size, (b_size, 256)).to("cuda")
            except torch.cuda.OutOfMemoryError:
                gc.collect()
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            sub_label = f"[b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}, n_tokens: {n_tokens}]"
            try:
                results.append(
                    benchmark_forward(
                        model.generate,
                        input_ids,
                        seqlen,
                        n_tokens,
                        desc=model_name,
                        repeats=1,
                        label="Token Generation",
                        sub_label=sub_label,
                    )
                )
            except torch.cuda.OutOfMemoryError:
                handle_oom(
                    f"WARNING! Ran out of memory, skipping {model_name}, foward : [b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}]"
                )
            except Exception as e:
                print(e)

    compare = Compare(results)

    with open(
        os.path.join(os.getcwd(), "assets/benchmarking/results_minimal.txt"), mode="w"
    ) as f:
        f.write(str(compare))

    compare.colorize()
    compare.print()
