import gc
import math
import os
from typing import Any, Dict, Literal, Optional, Union

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.utils.data import Subset
from transformers import (DataCollatorForSeq2Seq, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments)
from transformers.loss.loss_utils import ForCausalLMLoss
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

import wandb
from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage
from hip_attn.v1_3.models.llama import (LlamaAttention, LlamaConfig,
                                        LlamaForCausalLM)
from hip_research.dataset.openwebtext import OpenWebTextDataset
from hip_research.dataset.pg19 import PG19Dataset
from hip_research.dataset.pg19_long_qa import PG19LongQA
from hip_research.main.jobs.ppl import PplArgs, job_ppl
from hip_research.main.long_eval_experimental import (get_hip_config,
                                                      init_model, parse_args)
from hip_research.utils.long_train import Config, get_logger

log = get_logger()

TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"
WANDB_DISABLED = os.getenv("WANDB_MODE", "none") == "disabled"


def cuda_gc():
    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()


def evaluate(config, model, tokenizer) -> Dict[str, float]:

    setattr(model, "no_lm_head", True)
    model.eval()
    
    cuda_gc()
    
    layer = model.model.layers[len(model.model.layers) // 2]
    layer.mlp = nn.Identity()
    layer.input_layernorm = nn.Identity()
    layer.post_attention_layernorm = nn.Identity()
    layer.self_attn.q_proj = nn.Identity()
    layer.self_attn.k_proj = nn.Identity()
    layer.self_attn.v_proj = nn.Identity()
    layer.self_attn.o_proj = nn.Identity()
    layer = layer.cuda()

    gc.collect()
    torch.cuda.empty_cache()

    layer = model.model.layers[len(model.model.layers) // 2]
    layer.mlp = nn.Identity()
    layer.input_layernorm = nn.Identity()
    layer.post_attention_layernorm = nn.Identity()
    layer.self_attn.q_proj = nn.Identity()
    layer.self_attn.k_proj = nn.Identity()
    layer.self_attn.v_proj = nn.Identity()
    layer.self_attn.o_proj = nn.Identity()
    layer = layer.cuda()

    metrics = {"hip": [], "window": [], "fa2": []}
    with torch.no_grad():
        # TODO:
        # - make inputs with tokenizer
        # - swap in attention types
        # - run warmup foir kernel
        # - benchmark a forward pass on one attention layer
        # - evaluate:
        #  - hip: w 64, 128, 256, 512, 1024, 2048
        #  - window: w 64, 128, 256, 512, 1024, 2048
        #  - fa2

        # env vars attention method: [postfixes to run]
        env_vars = {
            "flash_attention_2": ["none"],
            "hip_attention": [
                "recompute_dense-window_2048-diff_1-w_64-decode_dense",
                "recompute_dense-window_2048-diff_1-w_128-decode_dense",
                "recompute_dense-window_2048-diff_1-w_256-decode_dense",
                "recompute_dense-window_2048-diff_1-w_512-decode_dense",
                "recompute_dense-window_2048-diff_1-w_1024-decode_dense",
                "recompute_dense-window_2048-diff_1-w_2048-decode_dense",
                "recompute_dense-window_0-diff_1-w_64-decode_dense",
                "recompute_dense-window_0-diff_1-w_128-decode_dense",
                "recompute_dense-window_0-diff_1-w_256-decode_dense",
                "recompute_dense-window_0-diff_1-w_512-decode_dense",
                "recompute_dense-window_0-diff_1-w_1024-decode_dense",
                "recompute_dense-window_0-diff_1-w_2048-decode_dense",
                "recompute_dense-window_0-diff_1-w_64-decode_dense-JUST_RETURN",
                "recompute_dense-window_2048-diff_1-w_64-decode_dense-JUST_RETURN",
                "recompute_dense-window_4096-diff_1-w_64-decode_dense-JUST_RETURN",
            ],
        }
        
        outputs = ['ctx,attn,postfix,latency']

        for ctx_len in [2**i for i in range(15, 21)]:
            cuda_gc()

            x = torch.randn(1, ctx_len, 4096, dtype=torch.bfloat16, device="cuda:0")
            position_embeddings = (
                torch.randn(1, ctx_len, 128, device=x.device, dtype=x.dtype),
                torch.randn(1, ctx_len, 128, device=x.device, dtype=x.dtype),
            )
            position_ids = torch.arange(ctx_len, device=x.device).unsqueeze(0)
            
            torch.cuda.synchronize()

            torch.cuda.synchronize()

            for attn in env_vars.keys():
                for postfix in env_vars[attn]:
                    cuda_gc()
                    os.environ["ATTENTION_IMPLEMENTATION"] = attn
                    os.environ["USE_ATTN_POSTFIX"] = postfix

                    prev_attn = model.config._attn_implementation
                    if "hip" in attn:
                        model.config._attn_implementation = "recompute_dense"
                    elif "flash_attention_2" in attn:
                        model.config._attn_implementation = "flash_attention_2"
                    elif "minference" in attn:
                        print("minference")
                    else:
                        raise Exception()

                    times = []
                    for i in range(4 if 'flash_attention_2' in attn else 7):
                        start_event = torch.cuda.Event(enable_timing=True)
                        end_event = torch.cuda.Event(enable_timing=True)

                        # Record time
                        start_event.record()
                        layer(
                            x,
                            position_ids=position_ids,
                            position_embeddings=position_embeddings,
                        )
                        end_event.record()

                        # Wait for the events to be recorded
                        torch.cuda.synchronize()

                        # Compute elapsed time in milliseconds
                        if i >= 3:
                            times.append(start_event.elapsed_time(end_event))

                    model.config._attn_implementation = prev_attn

                    latency = sum(times) / len(times)
                    line = f'{ctx_len},{attn},{postfix},{latency}'
                    print(
                        line
                        # f"{ctx_len=} {attn=} {postfix=}: elapsed time: {latency:.3f} ms"
                    )
                    outputs.append(line)
        
        print()
        print('-' * 40)
        print()

        print('\n'.join(outputs))

if __name__ == "__main__":
    config = parse_args()
    model, tokenizer = init_model(config=config)
    tokenizer.pad_token = tokenizer.eos_token

    model = model.cpu()
    metrics = evaluate(config, model, tokenizer)
