import torch
# import bitsandbytes as bnb
import numpy as np
from jax import grad,vmap
from tqdm.notebook import tqdm
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import (
    LlamaForCausalLM, 
    LlamaTokenizer
)

from data.serialize import serialize_arr, SerializerSettings

DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

import pynvml

def get_free_gpu():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    
    max_free_mem = 0
    target_gpu = 0

    for i in range(device_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_mem = meminfo.free

        if free_mem > max_free_mem:
            max_free_mem = free_mem
            target_gpu = i

    pynvml.nvmlShutdown()
    return target_gpu

def llama2_model_string():
    model_path = "/mnt/data/llm_models/xxx"    
    return model_path



def get_tokenizer(model):
    tokenizer = AutoTokenizer.from_pretrained(
        llama2_model_string(),
        use_fast=False,
        trust_remote_code=True
    )
    
    special_tokens_to_add = []
    
    if tokenizer.eos_token is None and DEFAULT_EOS_TOKEN is not None:
        tokenizer.eos_token = DEFAULT_EOS_TOKEN
        special_tokens_to_add.append(DEFAULT_EOS_TOKEN)
    
    if tokenizer.bos_token is None and DEFAULT_BOS_TOKEN is not None:
        tokenizer.bos_token = DEFAULT_BOS_TOKEN
        special_tokens_to_add.append(DEFAULT_BOS_TOKEN)
    
    if tokenizer.unk_token is None and DEFAULT_UNK_TOKEN is not None:
        tokenizer.unk_token = DEFAULT_UNK_TOKEN
        special_tokens_to_add.append(DEFAULT_UNK_TOKEN)

    if special_tokens_to_add:
        try:
            num_added = tokenizer.add_tokens(special_tokens_to_add, special_tokens=True)
            if num_added is not None:
                print(f"Added {num_added} special tokens")
        except Exception as e:
            print(f"Warning: Could not add special tokens to vocabulary: {e}")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return tokenizer



def get_model_and_tokenizer(model):
    target_gpu = get_free_gpu()
    if torch.cuda.is_available() and torch.cuda.device_count() > target_gpu:
        torch.cuda.set_device(target_gpu)
        torch.cuda.empty_cache()
    
    tokenizer = get_tokenizer(model)
    model = AutoModelForCausalLM.from_pretrained(
        llama2_model_string(),
        device_map={"": f"cuda:{target_gpu}"},   
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    
    model.eval()

    return model, tokenizer




