import os
import pathlib
import warnings
from typing import Literal

import torch
import transformers
from peft import LoraConfig, PeftModel, TaskType, get_peft_model

from hip_attn.models.gemma.modeling_gemma2 import Gemma2Config, Gemma2ForCausalLM
from hip_attn.models.modeling_llama import LlamaConfig, LlamaForCausalLM
from hip_attn.models.qwen.modeling_qwen2 import Qwen2Config, Qwen2ForCausalLM
from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage
from hip_research.main.eval_args import ArgsType, eval_args
from hip_research.main.jobs.bench_single_layer import job_bench_single_layer
from hip_research.main.jobs.booksum import job_booksum
from hip_research.main.jobs.ga import job_ga
from hip_research.main.jobs.greedy_replace import job_greedy_replace
from hip_research.main.jobs.merge_lora import job_merge_lora
from hip_research.main.jobs.mmlu import job_mmlu
from hip_research.main.jobs.passkey import job_passkey
from hip_research.main.jobs.ppl import job_ppl
from hip_research.main.jobs.sample_diag import job_sample_diag
from hip_research.main.jobs.stream import job_stream
from hip_research.main.jobs.stream_demo import job_stream_demo
from hip_research.models.sglang_model import SglangModel
from hip_research.utils.seed import seed

MODELS = {
    "llama1b": "princeton-nlp/Sheared-LLaMA-1.3B",
    "llama3b": "princeton-nlp/Sheared-LLaMA-2.7B",
    "llama7b": "meta-llama/Llama-2-7b-chat-hf",
    "llama32k": "togethercomputer/LLaMA-2-7B-32K",
    "llama32k_instruct": "togethercomputer/Llama-2-7B-32K-Instruct",
    "llama13b": "meta-llama/Llama-2-13b-hf",
    "llama13b_32k": "Yukang/Llama-2-13b-longlora-32k-ft",
    "llama13b_32k_instruct": "Yukang/Llama-2-13b-chat-longlora-32k-sft",
    "llama3_8b_1m": "gradientai/Llama-3-8B-Instruct-Gradient-1048k",
    "llama3.1_8b": "meta-llama/Meta-Llama-3.1-8B",
    "llama3.1_8b_instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "llama3.2_1b": "meta-llama/Llama-3.2-1B",
    "llama3.2_3b": "meta-llama/Llama-3.2-3B",
    "llama3.2_3b_instruct": "meta-llama/Llama-3.2-3B-Instruct",
    "llama3.2_1b_instruct": "meta-llama/Llama-3.2-1B-Instruct",
    "qwen14b": "Qwen/Qwen1.5-14B-Chat",
    "qwen7b": "Qwen/Qwen1.5-7B-Chat",
    "qwen1.5b": "Qwen/Qwen1.5-1.8B-Chat",
    "qwen0.5b": "Qwen/Qwen1.5-0.5B-Chat",
    "qwen2.5_1.5b_instruct": "Qwen/Qwen2.5-1.5B-Instruct",
    "qwen2.5_3b_instruct": "Qwen/Qwen2.5-3B-Instruct",
    "qwen2.5_7b_instruct": "Qwen/Qwen2.5-7B-Instruct",
    "gemma2_2b": "google/gemma-2-2b",
    "gemma2_9b": "google/gemma-2-9b",
    "gemma2_2b_it": "google/gemma-2-2b-it",
    "gemma2_9b_it": "google/gemma-2-9b-it",
    "exaone3.5_7.8b_instruct": "LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct",
}

OBSOLATED_VLLM_MODELS = {
    "vllm_llama32k": "togethercomputer/LLaMA-2-7B-32K",
    "vllm_llama32k_instruct": "togethercomputer/Llama-2-7B-32K-Instruct",
    "vllm_llama128k": "NousResearch/Yarn-Llama-2-7b-128k",
    "vllm_llama13b_128k": "NousResearch/Yarn-Llama-2-13b-128k",
    "vllm_llama13b_32k": "Yukang/Llama-2-13b-longlora-32k-ft",
    "vllm_llama13b_32k_instruct": "Yukang/Llama-2-13b-chat-longlora-32k-sft",
    "vllm_llama100k": "Yukang/Llama-2-7b-longlora-100k-ft",
    "vllm_llama1b": "princeton-nlp/Sheared-LLaMA-1.3B",
    "vllm_llama7b": "meta-llama/Llama-2-7b-hf",
    "vllm_llama13b": "meta-llama/Llama-2-13b-hf",
    # 'vllm_qwen14b': 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int4',
    "vllm_qwen14b_local": "./Qwen1.5-14B-Chat-GPTQ-Int4",
    "vllm_qwen14b_int8_local": "./Qwen1.5-14B-Chat-GPTQ-Int8",
    "vllm_qwen14b_noquant_local": "./Qwen1.5-14B-Chat",
    "vllm_qwen7b": "Qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
    "vllm_qwen7b_pt": "Qwen/Qwen1.5-7B",
    "vllm_qwen14b": "Qwen/Qwen1.5-14B-Chat",
    "vllm_qwen14b_gptq": "Qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
    "vllm_qwen0.5b": "Qwen/Qwen1.5-0.5B-Chat",
    "vllm_pythia70m": "EleutherAI/pythia-70m",
    "vllm_yi6b": "01-ai/Yi-6B-200K",
    "vllm_yi34b": "brucethemoose/Yi-34B-200K-RPMerge",
    "vllm_mixtral8x7b": "TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
    "vllm_gemma2b": "google/gemma-2b-it",
    "vllm_gemma7b": "google/gemma-7b-it",
    "vllm_luxia21.4b": "saltlux/luxia-21.4b-alignment-v1.1",
    "vllm_llama3_8b": "unsloth/llama-3-8b-Instruct",
    "vllm_yi1.5_9b_32k": "01-ai/Yi-1.5-9B-32K",
    "vllm_llama3.1_8b_instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "vllm_llama3.1_8b_instruct_awq": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4",
}


def load_vllm_model(args: ArgsType):
    from vllm import LLM

    if int(os.getenv("HIP_K", "512")) != args.k:
        warnings.warn(
            f'WARN!!! your command line argument of hip_k is {args.k} but environment variable is {os.getenv("HIP_K", "512")}. OS environment is higher priority.'
        )

    device = "cuda:0"
    if args.model.replace("vllm_", "") in MODELS:
        model_id = MODELS[args.model.replace("vllm_", "")]
    elif args.model in OBSOLATED_VLLM_MODELS:
        model_id = OBSOLATED_VLLM_MODELS[args.model]
    else:
        model_id = args.model.replace("vllm_", "")
    print(f"Loading model {model_id}")

    assert args.checkpoint is None

    seq_len = args.stride
    assert seq_len > 0
    # seq_len = 10600
    model = LLM(
        model_id,
        max_num_seqs=args.batch_size,
        max_seq_len_to_capture=seq_len,
        max_model_len=seq_len,
        swap_space=0,
        kv_cache_dtype=os.getenv("KV_CACHE_DTYPE", "fp8_e5m2"),
        dtype="half",
        gpu_memory_utilization=float(os.getenv("MEM_UTIL", "0.9")),
        tensor_parallel_size=torch.cuda.device_count(),
        enforce_eager=os.environ.get("ENFORCE_EAGER", "0") == "1",
        trust_remote_code=True,
        max_num_batched_tokens=seq_len,
        enable_chunked_prefill=False,
        # observability_config=ObservabilityConfig(
        #     collect_model_forward_time=True,
        #     collect_model_execute_time=True
        # )
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

    if tokenizer.bos_token_id is None:
        tokenizer.bos_token = tokenizer.eos_token
        tokenizer.bos_token_id = tokenizer.eos_token_id

    if tokenizer.eos_token_id is None:
        tokenizer.eos_token = tokenizer.bos_token
        tokenizer.eos_token_id = tokenizer.bos_token_id

    return model, tokenizer, device


def load_sglang_model(args: ArgsType):
    model_name = args.model.replace("sglang_", "")

    assert model_name in MODELS, f"Available Models: {list(MODELS.keys())}"

    model_id = MODELS[model_name]

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

    if tokenizer.bos_token_id is None:
        tokenizer.bos_token = tokenizer.eos_token
        tokenizer.bos_token_id = tokenizer.eos_token_id

    if tokenizer.eos_token_id is None:
        tokenizer.eos_token = tokenizer.bos_token
        tokenizer.eos_token_id = tokenizer.bos_token_id

    model = SglangModel(args.endpoint, tokenizer)

    return model, tokenizer, torch.device("cpu")


def load_model(args):
    if args.model.startswith("vllm"):
        return load_vllm_model(args)
    if args.model.startswith("sglang"):
        return load_sglang_model(args)

    device = "cuda:0"
    if args.model in MODELS:
        model_id = MODELS[args.model]
    else:
        model_id = args.model

    if args.model.startswith("qwen"):
        config = Qwen2Config.from_pretrained(model_id)
        config._attn_implementation = config.attn_implementation = "sdpa"
        ModelClass = Qwen2ForCausalLM
    elif args.model.startswith("gemma2") or ("gemma" in args.model):
        config = Gemma2Config.from_pretrained(model_id)
        config._attn_implementation = config.attn_implementation = "sdpa"
        ModelClass = Gemma2ForCausalLM
    else:
        config = LlamaConfig.from_pretrained(model_id)
        config._attn_implementation = config.attn_implementation = "sdpa"
        ModelClass = LlamaForCausalLM

    print(f"Loading model {model_id} {ModelClass} {type(config)}")

    if torch.cuda.is_bf16_supported():
        infer_dtype = torch.bfloat16
    else:
        infer_dtype = torch.float16

    if os.getenv("FORCE_FP32", "0") == "1":
        infer_dtype = torch.float32

    if args.method in ["h2o", "h2o_stream"]:
        from hip_research.models.h2o.h2o_llama import H2OLlamaForCausalLM

        ModelClass = H2OLlamaForCausalLM

        if args.method == "h2o_stream":
            args.h2o_streaming = True
        config.attention_method = args.method
        config.hh_size = args.k // 2
        config.recent_size = args.k // 2
        config._attn_implementation = config.attn_implementation = "eager"
        config.h2o_shift_q_pos = args.h2o_shift_q_pos
        config.h2o_streaming = args.h2o_streaming
        config.reduction_for_gqa = args.h2o_reduce_for_gqa
        config.tree_dense_layers = list(range(args.dense_layers))
        config.tree_k = args.k

        if args.job not in ["stream", "passkey"]:  # TODO ga?
            config.is_decoding = False
        else:
            config.is_decoding = True

    if args.method == "tova":
        from transformers.models.llama.modeling_llama import (
            LlamaForCausalLM as OriginalLlamaForCausalLM,
        )

        ModelClass = OriginalLlamaForCausalLM

    model = ModelClass.from_pretrained(
        model_id,
        config=config,
        device_map={"": device},
        quantization_config=(
            transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                llm_int8_skip_modules=[
                    "tree_avgpool_scaler",
                    "lm_head",
                ],
                bnb_4bit_compute_dtype=infer_dtype,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
            if not args.no_quantize
            else None
        ),
        torch_dtype=infer_dtype,
        # torch_dtype=torch.float32,
        trust_remote_code=True,
    )

    if args.method == "tova":
        from hip_research.models.tova.convert_tova import enable_tova_caching

        enable_tova_caching(model)

    layer_idx = 0
    for m in model.modules():
        if hasattr(m, "attention_method"):

            class SimpleTokenPooler(torch.nn.Module):
                def __init__(self):
                    super().__init__()

                def forward(
                    self,
                    x: torch.Tensor,
                    dim: int,
                    i_stage: int,
                    tensor_type: Literal["query", "key", "value"],
                ):
                    assert tensor_type in ["query", "key", "value"]
                    x_pooled = x.mean(dim=dim)
                    return x_pooled

            class SimpleOutputUnpooler(torch.nn.Module):
                def __init__(self, dim=2):
                    super().__init__()
                    self.pooling_dim = dim

                def forward(self, x: torch.Tensor, rate: int):
                    return x.repeat_interleave(rate, dim=self.pooling_dim)

            token_pooler_fn = SimpleTokenPooler()
            output_unpooler_fn = SimpleOutputUnpooler()
            random_gate_probs = (
                torch.randn(
                    # NOTE: [BSZ, N_Q, HEAD, N_GATES]
                    (
                        1,
                        1,
                        1,
                        3 + 1,
                    ),
                    device=device,
                    dtype=torch.bfloat16,
                )
                - 100.0
            )
            random_gate_probs[..., -1] = 100

            m.args = args
            m.hip_attn_args = HiPAttentionArgs(
                block_size_k=64,
                sliding_window_size=1024,
                sink_token_size=256,
                using_extend=True,
                need_apply_rope=True,
                second_stage_k=2048,
                stages=(
                    [
                        ScanStage(
                            stage_block_size_q=64,
                            stage_block_stride_q=4,
                            stage_chunk_size=128,
                            stage_k=None,
                            stage_stride=1,
                        ),
                        ScanStage(
                            stage_block_size_q=64,
                            stage_block_stride_q=4,
                            stage_chunk_size=32,
                            stage_k=32768,
                            stage_stride=1,
                        ),
                        ScanStage(
                            stage_block_size_q=64,
                            stage_block_stride_q=1,
                            stage_chunk_size=8,
                            stage_k=8192,
                            stage_stride=1,
                        ),
                    ]
                ),
                block_sparse_block_size_q=64,
                model_context_length=131072,
                scan_extend_backend=("streaming" if layer_idx < 3 else "relative"),
                sa_extend_backend="streaming",
                require_stage_caches=False,
                require_cache_statistics=False,
                enable_hip_tune=False,
                token_pooler_fn=token_pooler_fn,
                output_unpooler_fn=output_unpooler_fn,
                gate_probs=random_gate_probs.sigmoid().to(torch.bfloat16),
            )
            layer_idx += 1
            m.attention_method = args.method if layer_idx > 3 else 'fa2'
            m.tree_dense_layers = list(range(len(model.model.layers)))

            # TODO: not sure what is needed here because many of these
            # seemt to be convered by the stages and HiPAttentionArgs
            # m.tree_k = args.k
            # m.tree_block_size_q = args.block_size_q
            # m.tree_block_stride_q = args.block_stride_q
            # m.tree_block_size_k = args.block_size_k
            # m.tree_block_stride_k = args.block_stride_k
            # m.tree_using_context_avg = False
            # m.tree_dense_queries = args.dense_queries
            # m.tree_rope_method = args.rope_method
            # m.tree_enable_sparq = not args.disable_sparq
            # m.tree_enable_flash = not args.disable_flash
            # m.tree_use_sliding_window = not args.disable_sliding_window
            # m.tree_sampling_method = args.sampling_method

    if args.method != "none" and args.checkpoint is not None:
        if pathlib.Path(args.checkpoint).is_dir():
            # is peft checkpoint
            # Load peft pretrained
            print(f"Loading peft model from {args.checkpoint}")
            model = PeftModel.from_pretrained(model, args.checkpoint)

        else:
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=True,
                r=args.lora_r,
                lora_alpha=args.lora_r // 2,
                lora_dropout=0.0,
                target_modules=[
                    "q_proj",
                    "k_proj",
                    "v_proj",
                    "o_proj",
                    "gate_proj",
                    "up_proj",
                    "down_proj",
                    # 'input_layernorm', 'post_attention_layernorm'
                ],
                modules_to_save=[
                    "tree_avgpool_scaler",
                    "input_layernorm",
                    "post_attention_layernorm",
                ],
            )

            model = get_peft_model(model, peft_config)

            state_dict = torch.load(args.checkpoint, map_location="cpu")
            if "state_dict" in state_dict:
                state_dict = state_dict["state_dict"]
            keys = list(state_dict.keys())
            for key in keys:
                x = state_dict[key]
                state_dict[key.strip("model.")] = x
                del state_dict[key]
            try:
                result = model.load_state_dict(state_dict, strict=False)
                print("load result", result)
            except RuntimeError as e:
                pass

        # model = model.to(infer_dtype)
        print("lora checkpoint loaded from", args.checkpoint)

    elif args.method != "none":
        for m in model.modules():
            if hasattr(m, "attention_method"):
                m.tree_using_context_avg = False

    model = model.eval()

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

    return model, tokenizer, device


def load_model_hip13(args: ArgsType):
    device = 'cuda:0'

    import transformers
    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

    from hip_attn.v1_3.models.llama import (
        LlamaAttention, 
        LlamaForCausalLM, 
        LlamaConfig
    )

    ALL_ATTENTION_FUNCTIONS.update({"hip_attention": (lambda x: x)})

    if args.model in MODELS:
        model_id = MODELS[args.model]
    else:
        model_id = args.model

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
    model_config = LlamaConfig.from_pretrained(
        model_id,
        attn_implementation="hip_attention",
    )
    model_config.pooler_method = args.pooler_method
    model_config.pooler_config = args.pooler_config
    infer_dtype = torch.bfloat16
    model = LlamaForCausalLM.from_pretrained(
        model_id,
        config=model_config,
        quantization_config=(
            transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                llm_int8_skip_modules=[
                    "tree_avgpool_scaler",
                    "lm_head",
                ],
                bnb_4bit_compute_dtype=infer_dtype,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
            if not args.no_quantize
            else None
        ),
        torch_dtype=infer_dtype,
    )

    # model.hip_train_config = args

    layer_idx = 0
    for m in model.modules():
        if isinstance(m, LlamaAttention):
            m.args = args
            
            stages = [
                ScanStage(
                    stage_block_size_q=64,
                    stage_block_stride_q=4,
                    stage_chunk_size=128,
                    stage_k=None,
                    stage_stride=1,
                ),
                ScanStage(
                    stage_block_size_q=64,
                    stage_block_stride_q=4,
                    stage_chunk_size=32,
                    stage_k=32768,
                    stage_stride=1,
                ),
                ScanStage(
                    stage_block_size_q=64,
                    stage_block_stride_q=4,
                    stage_chunk_size=8,
                    stage_k=8192,
                    stage_stride=1,
                ),
            ]

            hip_attn_args = HiPAttentionArgs(
                sliding_window_size=1024,
                sink_token_size=256,
                using_extend=True,
                need_apply_rope=True,
                second_stage_k=2048,
                stages=stages,
                model_context_length=65536,
                # scan_extend_backend="relative",
                scan_extend_backend=("streaming" if layer_idx < 3 else "relative"),
                sa_extend_backend="streaming",
                block_sparse_block_size_q=stages[-1].stage_block_size_q,
                enable_hip_tune=False,
                disable_gate_prob=args.checkpoint is None,
                hip_tune_first_stage_apply_rope=False,
            )
            m.hip_attn_args = hip_attn_args

            layer_idx += 1

    if args.lora_r > 0:
        from peft import LoraConfig, PeftModel, TaskType, get_peft_model

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=True,
            r=args.lora_r,
            lora_alpha=args.lora_r // 2,
            lora_dropout=0.15,
            target_modules=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
                # "gate_proj",
                # "up_proj",
                # "down_proj",
                # 'input_layernorm', 'post_attention_layernorm'
            ],
            modules_to_save=[
                "token_pooler_fn",
                "output_unpooler_fn",
                "gate_prob_estimater",
                "input_layernorm",
                "post_attention_layernorm",
            ],
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

        if args.checkpoint is not None:
            model.load_adapter(args.checkpoint, adapter_name='default')
            print('checkpoint loaded', args.checkpoint)
    
    model.to(device)

    return model, tokenizer, device


JOBS = {
    "ppl": job_ppl,
    "stream": job_stream,
    "mmlu": job_mmlu,
    "bench_single_layer": job_bench_single_layer,
    "booksum": job_booksum,
    "merge_lora": job_merge_lora,
    "stream_demo": job_stream_demo,
    "greedy_replace": job_greedy_replace,
    "passkey": job_passkey,
    "ga": job_ga,
    "sample_diag": job_sample_diag,
}


def main():
    args = eval_args()
    seed(seed=args.seed)

    assert args.job in JOBS.keys()

    if args.method == 'hip13':
        model, tokenizer, device = load_model_hip13(args)
    else:
        model, tokenizer, device = load_model(args)

    JOBS[args.job](args, model, tokenizer, device)


if __name__ == "__main__":
    main()
