# Load model directly
import os
import shutil
import torch

from transformers import (AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoProcessor, LlavaForConditionalGeneration, AutoModelForPreTraining, AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, Blip2Processor, BlipForConditionalGeneration, BlipProcessor)

import os
import logging
import torch

def load_pretrained_model(config, role: str):
    model_name = config[role]
    _dtype_str  = config[role + '_dtype']
    _dtype = torch.float32 if _dtype_str == 'fp32' else torch.float16 if _dtype_str=='fp16' else torch.bfloat16
    ModelClass = _get_pretrained_classes(model_name, 'model')
    local_path = _get_checkpoint_path(config, model_name, ModelClass)
    model = _load_model(ModelClass, local_path, model_name, config, _dtype)
    # text-only draft model (not multimodal)
    if role == 'captioning_model':
        pass
    elif role == 'drf' and config['is_drf_text_only'] and config['is_drf_from_mllm']:
        return model.language_model
    elif role == 'tgt' and config['is_tgt_text_only']:
        return model.language_model
    return model

def load_tokenizer(_config, max_target_length, role: str = 'drf'):
    # Determine the appropriate tokenizer class
    TokenizerClass = _get_pretrained_classes(_config[role], 'tokenizer')
          
    # Load the tokenizers
    if role == 'captioning_model' and 'lorence-2' in _config[role]: #'microsoft/Florence'
        drf_tokenizer = TokenizerClass.from_pretrained(_config[role], trust_remote_code=True)
    else:
        drf_tokenizer = TokenizerClass.from_pretrained(_config[role])    

    # if draft
    if (not _config['is_drf_from_mllm']) and role == 'drf':
        return drf_tokenizer, None, None
    
    # Get the special token IDs
    eos_token_id, pad_token_id = _get_special_token_ids(drf_tokenizer, TokenizerClass)

    # validate vocab size (drf vs. tgt models)
    if _config['decoding'] == 'sd' and role != 'captioning_model':
        tgt_tokenizer = TokenizerClass.from_pretrained(_config['tgt'])
        _validate_vocab_size(drf_tokenizer, tgt_tokenizer, TokenizerClass)

    return drf_tokenizer, eos_token_id, pad_token_id

def load_image_processor(_config, role: str):
    model_name = _config[role]
    # if not _config['is_drf_from_mllm']:
    #     # NO IMAGE PROCESSOR FOR T5, Jackfram/llama
    #     return None
    ProcessorClass = _get_pretrained_classes(model_name, 'image_processor')
    return ProcessorClass.from_pretrained(model_name).image_processor

def _get_pretrained_classes(model_name, class_type: str):
    class_mapping = {
        'llava-hf': {
            'model': LlavaForConditionalGeneration,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor
        },
        'XXXX-2': {
            'model': LlavaForConditionalGeneration,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor
        },
        'InternVL2': {
            'model': AutoModel,
            'tokenizer': AutoTokenizer,
            'image_processor': AutoProcessor
        },
        'google/t5': {
            'model': AutoModelForSeq2SeqLM,
            'tokenizer': AutoTokenizer,
            'image_processor': None
        },
        'JackFram/llama': {
            'model': AutoModelForCausalLM,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor
        },
        'double7/vicuna': {
            'model': AutoModelForCausalLM,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor
        },
        'Salesforce/blip2-opt': {
            'model': Blip2ForConditionalGeneration,
            'tokenizer': Blip2Processor,
            'image_processor': Blip2Processor
        },
        'Salesforce/blip-': {
            'model': BlipForConditionalGeneration,
            'tokenizer': BlipProcessor,
            'image_processor': BlipProcessor
        },
        'microsoft/Florence': {
            'model': AutoModelForCausalLM,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor,
        },
        'ljnlonoljpiljm/florence': {
            'model': AutoModelForCausalLM,
            'tokenizer': AutoProcessor,
            'image_processor': AutoProcessor,
        },
    }

    # Default mapping for unsupported model names
    default_mapping = {
        'model': AutoModelForCausalLM,
        'tokenizer': AutoTokenizer,
        'image_processor': None
    }

    # Get the specific class mapping for the given model_name
    model_mapping = default_mapping
    for key in class_mapping:
        if key in model_name:
            model_mapping = class_mapping[key]
            break

    # Return the requested class based on class_type
    requested_class = model_mapping.get(class_type)
    
    if requested_class is None:
        raise ValueError(f"Class type '{class_type}' not supported for model '{model_name}'.")
    
    return requested_class

def _get_checkpoint_path(config, model_name, ModelClass):
    if config['ckpt_dir'] is not None:
        logging.info(f"Loading selected model checkpoint from {config['ckpt_dir']}...")
        return f"{config['ckpt_dir']}/{model_name}"
    else:
        local_path = f"{config['root']}/data/MSD/checkpoint/{model_name}"
        if not os.path.exists(local_path) and "lorence-2" not in model_name:
            _download_and_save_model(ModelClass, model_name, local_path)
        return local_path

def _download_and_save_model(ModelClass, model_name, local_path):
    logging.info(f"No model found in local path! Downloading & Saving {model_name}...")
    model = ModelClass.from_pretrained(model_name)
    model.save_pretrained(local_path)

def _load_model(ModelClass, local_path, model_name, _config, _dtype):
    if 'llava-hf' in model_name or 'XXXX-2' in model_name:
        return ModelClass.from_pretrained(
            local_path,
            torch_dtype=_dtype,
            low_cpu_mem_usage=True,
            # attn_implementation="sdpa", # Todo
            )
    elif 'Salesforce/blip2' in model_name:
        return ModelClass.from_pretrained(
            local_path, 
            # load_in_8bit=True, 
            # device_map={"": 0}, 
            torch_dtype=torch.float16
        )
    elif 'Salesforce/blip-' in model_name:
        return ModelClass.from_pretrained(
            local_path, 
            torch_dtype=torch.float16
        )
    elif 'lorence-2' in model_name: # 'microsoft/Florence'
        return  ModelClass.from_pretrained(
            # local_path, 
            model_name,
            torch_dtype=torch.float16, 
            trust_remote_code=True,
        )
         
    # elif 'InternVL2' in model_name:
    #     override_mllm()
    #     return ModelClass.from_pretrained(
    #         local_path,
    #         torch_dtype=torch.bfloat16,
    #         trust_remote_code=True,
    #     )
    # elif 'google/t5' in model_name:
    #     return ModelClass.from_pretrained(
    #         local_path,
    #         torch_dtype=_dtype,
    #         trust_remote_code=True
    #     )
    else:
        raise ValueError(f"Invalid model name: {model_name}")

def _get_special_token_ids(tokenizer, tokenizer_class):
	if tokenizer_class in [AutoProcessor, Blip2Processor, BlipProcessor]:
		eos_token_id = tokenizer.tokenizer.eos_token_id
		pad_token_id = tokenizer.tokenizer.pad_token_id
	else:
		eos_token_id = tokenizer.eos_token_id
		pad_token_id = tokenizer.pad_token_id
	return eos_token_id, pad_token_id

def _validate_vocab_size(drf_tokenizer, tgt_tokenizer, tokenizer_class):
	message = "Vocab mismatch between drf and tgt models"
	if tokenizer_class == AutoTokenizer:
		assert drf_tokenizer.vocab_size == tgt_tokenizer.vocab_size, message
	elif tokenizer_class == AutoProcessor:
		assert drf_tokenizer.tokenizer.vocab_size == tgt_tokenizer.tokenizer.vocab_size, message

def override_mllm():
    hash_internvl2 = "26287f7a0814001a6f17d6cbb639711d96f67788"
    src_dir = f"/XXXX-5/home-XXXX-3/workspace/MLLMSD/mllmsd/utils/custom_files"
    dst_dir = f"/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-2B/{hash_internvl2}"

    assert os.path.exists(dst_dir), f"{dst_dir} does not exist."
    
    filename = "modeling_internvl_chat.py"
    dst = os.path.join(dst_dir, filename)
    src = os.path.join(src_dir, filename)

    shutil.copyfile(src, dst)