import torch
import gc
import os
import sys
from itertools import product
from torch.utils.benchmark import Compare

from flash_attn.modules.mha import MHA
import torch.nn.functional as F


def handle_oom(message):
    print(message)
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()


class Object:
    def __init__(self, **attributes):
        self.__dict__.update(attributes)


if __name__ == "__main__":
    SAFARI_PATH = os.environ.get("SAFARI_PATH", "/mnt/safari-internal/")
    sys.path.append(SAFARI_PATH)

    from src.models.sequence.hyena import HyenaOperator, MultiHeadHyenaOperator
    from src.models.sequence.laughing_hyena import (
        LaughingHyenaOperator,
        LaughingHyenaFilter,
    )
    from src.models.sequence.h3 import H3
    from src.models.sequence.simple_lm import SimpleLMHeadModel
    from src.utils.benchmark import benchmark_direction, benchmark_forward

    results = []
    b_size_range = [1]
    seqlen_range = [128]
    d_model_range = [768]
    directions = ["forward"]
    n_heads = 1
    n_tokens_range = [256]
    vocab_size = 50257
    n_layers = 12

    # ==========================================================
    # Benchmark Forward, Backward, and Forward + Backward Passes
    # ==========================================================
    for b_size, d_model, seqlen, direction in product(
        b_size_range, d_model_range, seqlen_range, directions
    ):
        laughingHyena = LaughingHyenaOperator(
            d_model=d_model,
            num_heads=n_heads,
            order=2,
            num_blocks=1,
            l_max=seqlen,
            return_state=False,
        )
        filter = LaughingHyenaFilter(
            num_order=2,
            den_order=2,
            num_filters=768,
            heads=16,
            decay_rate=1e-2,
            real_fft=False,
            train_mixer=True,
        )
        a, b, w = filter.get_params()
        laughingHyena.setup_recurrence(a, b, w, bsz=1)
        # the layer must be moved to device after initializing the recurrence, which sets some new parameters and buffers
        laughingHyena = laughingHyena.to("cuda")
        layers = [
            (MHA(d_model, d_model // 64, use_flash_attn=False).to("cuda"), "MHA"),
            (MHA(d_model, d_model // 64, use_flash_attn=True).to("cuda"), "Flash MHA"),
            (
                H3(d_model, head_dim=d_model // 64, use_fast_fftconv=False).to("cuda"),
                "H3",
            ),
            (
                H3(d_model, head_dim=d_model // 64, use_fast_fftconv=True).to("cuda"),
                "H3 Fast FFTConv",
            ),
            (
                MultiHeadHyenaOperator(
                    d_model, l_max=seqlen, num_heads=d_model // 64, fused_bias_fc=False
                ).to("cuda"),
                "Hyena",
            ),
            (
                LaughingHyenaOperator(
                    d_model=d_model,
                    num_heads=n_heads,
                    order=2,
                    num_blocks=1,
                    l_max=seqlen,
                    return_state=False,
                ).to("cuda"),
                "Laughing Hyena (conv)",
            ),
            (laughingHyena, "Laughing Hyena (recurrent)"),
        ]

        try:
            x = torch.randn(b_size, seqlen, d_model, requires_grad=True).to("cuda")
        except torch.cuda.OutOfMemoryError:
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        for layer, layer_name in layers:
            sub_label = f"[b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}]"
            try:
                results.append(
                    benchmark_direction(
                        layer,
                        x,
                        direction=direction,
                        desc=layer_name,
                        label=direction,
                        sub_label=sub_label,
                    )
                )
            except torch.cuda.OutOfMemoryError:
                handle_oom(
                    f"WARNING! ran out of memory, skipping {layer_name}, foward : [b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}]"
                )
            except Exception as e:
                print(e)

    # ==================================
    # Token Generation
    # ==================================
    b_size_range = [1, 8]
    seqlen_range = [2**x for x in range(10, 11)]
    for b_size, d_model, seqlen, n_tokens in product(
        b_size_range, d_model_range, seqlen_range, n_tokens_range
    ):
        models = [
            (
                SimpleLMHeadModel(d_model, n_layers, d_model, vocab_size).to("cuda"),
                "MHA",
            ),
            (
                SimpleLMHeadModel(
                    d_model, n_layers, d_model, vocab_size, kv_caching=True
                ).to("cuda"),
                "MHA (KV)",
            ),
            (
                SimpleLMHeadModel(
                    d_model, n_layers, d_model, vocab_size, layer="h3"
                ).to("cuda"),
                "H3 ",
            ),
            (
                SimpleLMHeadModel(
                    d_model, n_layers, d_model, vocab_size, layer="h3", kv_caching=True
                ).to("cuda"),
                "H3 (KV)",
            ),
            (
                SimpleLMHeadModel(
                    d_model,
                    n_layers,
                    d_model,
                    vocab_size,
                    layer="multihyena",
                    fused_bias_fc=False,
                ).to("cuda"),
                "Hyena",
            ),
            (
                SimpleLMHeadModel(
                    d_model, n_layers, d_model, vocab_size, layer="laughing-hyena"
                ).to("cuda"),
                "Laughing Hyena (conv)",
            ),
            (
                SimpleLMHeadModel(
                    d_model,
                    n_layers,
                    d_model,
                    vocab_size,
                    layer="laughing-hyena",
                    recurrence=True,
                ).to("cuda"),
                "Laughing Hyena (recurrent)",
            ),
        ]
        for model, model_name in models:
            model.eval()
            try:
                input_ids = torch.randint(vocab_size, (b_size, 256)).to("cuda")
            except torch.cuda.OutOfMemoryError:
                gc.collect()
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            sub_label = f"[b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}, n_tokens: {n_tokens}]"
            try:
                results.append(
                    benchmark_forward(
                        model.generate,
                        input_ids,
                        seqlen,
                        n_tokens,
                        desc=model_name,
                        label="Token Generation",
                        sub_label=sub_label,
                    )
                )
            except torch.cuda.OutOfMemoryError:
                handle_oom(
                    f"WARNING! ran out of memory, skipping {model_name}, foward : [b_size: {b_size}, d_model: {d_model}, seqlen: {seqlen}]"
                )
            except Exception as e:
                print(e)

    compare = Compare(results)

    with open(
        os.path.join(os.getcwd(), "assets/benchmarking/results.txt"), mode="w"
    ) as f:
        f.write(str(compare))

    compare.colorize()
    compare.print()
