import multiprocessing as mp
import time
import torch
import argparse


def clear_l2(size_mb=128):
    size = size_mb * 1024 * 1024 // 4  # float32
    a = torch.randn(size, dtype=torch.float32, device="cuda")
    a = a + 1.0  # prevent optimization
    del a
    torch.cuda.empty_cache()
    torch.cuda.synchronize()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--warmup-iter", type=int, default=2, help="Number of warmup iterations"
    )
    parser.add_argument(
        "--test-iter", type=int, default=5, help="Number of test iterations"
    )
    parser.add_argument(
        "--model", type=str, default="mamba_fused", help="Model to benchmark"
    )
    parser.add_argument(
        "--batch-size", type=int, default=4, help="Batch size for benchmarking"
    )
    parser.add_argument(
        "--seq-len", type=int, default=256, help="Sequence length for benchmarking"
    )
    parser.add_argument(
        "--input-dim", type=int, default=256, help="Input dimension for benchmarking"
    )
    parser.add_argument(
        "--forward-only", action="store_true", help="Whether to benchmark forward only"
    )
    args = parser.parse_args()
    
    if 'mamba' in args.model:
        args.input_dim=928
    elif 'kla' in args.model:
        args.input_dim=960

    if args.model == "kla_recurrent":
        from kla_recurrent import KLABlock as ModelClass

        model = ModelClass(
            dim=args.input_dim,
            d_state=32,
        )
    elif args.model == "kla_torch":
        from kla_torch import KLABlock as ModelClass

        model = ModelClass(
            dim=args.input_dim,
            d_state=32,
        )
    elif args.model == "kla_triton":
        from kla_triton import KLABlock as ModelClass

        model = ModelClass(
            dim=args.input_dim,
            d_state=32,
        )
    elif args.model == "mamba_fused":
        from mamba_fused import Mamba as ModelClass

        model = ModelClass(
            dim=args.input_dim,
        )
    elif args.model == "mamba_torch":
        from mamba_torch import Mamba as ModelClass

        model = ModelClass(
            d_model=args.input_dim,
        )
    elif args.model == "mamba_triton_fused":
        from mamba_triton_fused import Mamba as ModelClass

        model = ModelClass(
            input_dim=args.input_dim,
            qk_dim=16,
            v_dim=2 * args.input_dim,
        )
    elif args.model == "mamba_triton":
        from mamba_triton import Mamba as ModelClass

        model = ModelClass(
            input_dim=args.input_dim,
            qk_dim=16,
            v_dim=2 * args.input_dim,
        )
    else:
        raise ValueError(f"Unknown model: {args.model}")

    model.to("cuda")
    model.compile()

    if args.forward_only:
        model.eval()
        with torch.no_grad():
            input_tensor = torch.randn(
                args.batch_size, args.seq_len, args.input_dim, device="cuda"
            )
            for _ in range(args.warmup_iter):
                output = model(input_tensor)
            total_time = 0.0
            for _ in range(args.test_iter):
                clear_l2()
                start_time = time.perf_counter()
                output = model(input_tensor)
                torch.cuda.synchronize()
                end_time = time.perf_counter()
                total_time += end_time - start_time
            avg_time = total_time / args.test_iter
            # print(f"Average forward time: {avg_time * 1000:.2f} ms")
            print(f"{args.model}, {args.batch_size}, {args.seq_len}, 0, 0, 0, {args.forward_only}, {avg_time * 1000:.2f}")
    else:
        model.train()
        input_tensor = torch.randn(
            args.batch_size, args.seq_len, args.input_dim, device="cuda"
        )
        for _ in range(args.warmup_iter):
            model.zero_grad()
            output = model(input_tensor)
            loss = output.sum()
            loss.backward()
        total_time = 0.0
        for _ in range(args.test_iter):
            clear_l2()
            model.zero_grad()
            start_time = time.perf_counter()
            output = model(input_tensor)
            loss = output.sum()
            loss.backward()
            torch.cuda.synchronize()
            end_time = time.perf_counter()
            total_time += end_time - start_time
        avg_time = total_time / args.test_iter
        # print(f"Average forward + backward time: {avg_time * 1000:.2f} ms")
        # model,batch_size,seq_len,input_dim,qk_dim,v_dim,forward_only,time_ms

        print(f"{args.model}, {args.batch_size}, {args.seq_len}, 0, 0, 0, {args.forward_only}, {avg_time * 1000:.2f}")
