import os
import subprocess
import time
import torch
from torch.nn import functional as F
from transformers import (
    LlamaTokenizerFast,
    LlamaTokenizer,
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Always import device to register correct backend
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)

# README: you should adjust based on your hardware
# NVIDIA GPUs Ampere uarch and after support BF16 (better precision than IEEE FP16)
# M2 Apple Silicon and after also support BF16 (CPU and GPU)
# Don't attempt to use FP16 on CPU, as it's not supported for GEMM
if device.type == "cuda":
    # detect if GPU is capable of BF16
    if torch.cuda.get_device_capability(0)[0] >= 8:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
elif device.type == "mps":
    command = 'sysctl -a | grep "hw.optional.arm.FEAT_BF16"'
    process = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
    output, error = process.communicate()

    if output.decode("utf-8").strip().endswith("1"):
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
else:
    # default to FP32 on CPU, because PyTorch doesn't support HGEMM on any CPU architecture
    dtype = torch.float32


def get_gpu_name():
    """Get the device name"""
    if device.type == "cuda":
        gpu_name = torch.cuda.get_device_name(0)
        return gpu_name.replace(" ", "_")
    elif device.type == "mps":
        # TODO: Systematically get the device name
        return "m1_gpu"


def touch():
    """Synchronization primitives for the respective backends when timing routines"""
    if device.type == "mps":
        torch.mps.synchronize()
    elif device.type == "cuda":
        torch.cuda.synchronize()


def torch_timer():
    """Timer for the respective backends"""
    if device.type == "mps":
        torch.mps.synchronize()
    elif device.type == "cuda":
        torch.cuda.synchronize()
    return time.perf_counter()


def warmup():
    """Warmup the GPU device"""
    A = torch.randn(2048, 2048, device=device, dtype=dtype)
    B = torch.randn(2048, 2048, device=device, dtype=dtype)
    touch()
    for _ in range(10):
        C = A @ B
        del C
    touch()


def max_fn(x):
    """Max function"""
    x_max = torch.where(x > 0, x, 0)
    return x_max / torch.sum(x_max)


def sample(p, deterministic=True):
    """Sample from the distribution"""
    if deterministic:
        return torch.argmax(p).unsqueeze(0).unsqueeze(0)
    return torch.multinomial(p, 1)


def norm_logits(logits, temperature, eps=1e-10):
    if temperature == 0.0:
        return F.softmax(logits, dim=1)
    logits = logits / (temperature + eps)
    logits = F.softmax(logits, dim=1)
    return logits


def load_models(
    draft_model_path,
    target_model_path,
    load_one=False,
    use_fa=False,
    use_4bit=False,
    use_8bit=False,
):
    tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    opt_args = {"torch_dtype": dtype}
    if use_8bit:
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True, llm_int8_threshold=200.0
        )
        opt_args["quantization_config"] = quantization_config
        opt_args["torch_dtype"] = torch.float32
    if use_4bit:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            llm_int4_threshold=200.0,
        )
        opt_args["quantization_config"] = quantization_config
        opt_args["torch_dtype"] = torch.float32
    if use_fa:
        opt_args["attn_implementation"] = "flash_attention_2"

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_path,
        device_map=device,
        low_cpu_mem_usage=True,
        # local_files_only=True,
        trust_remote_code=True,
        **opt_args,
    ).eval()
    # target_model = torch.compile(target_model)

    if load_one:
        return tokenizer, None, target_model

    draft_model = AutoModelForCausalLM.from_pretrained(
        draft_model_path,
        device_map=device,
        low_cpu_mem_usage=True,
        # local_files_only=True,
        trust_remote_code=True,
        # **opt_args,
    ).eval()
    # draft_model = torch.compile(draft_model)

    if device.type == "cuda":
        # print out how much memory is consumed by GPU out of total
        print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    return tokenizer, draft_model, target_model
