import os, glob, sys, random, pickle, copy, resource, logging
import torch
import torch.multiprocessing as mp
import numpy as np
from datetime import datetime
import yaml
from contextlib import contextmanager

map_name_task = {
    # model
    "google/t5-small-lm-adapt": "T5lm-small",
    "google/t5-base-lm-adapt": "T5lm-base",
    "google/t5-large-lm-adapt": "T5lm-large",
    "google/t5-xl-lm-adapt": "T5lm-xl",
    "google/t5-xxl-lm-adapt": "T5lm-xxl",

    # llava-hf repo
    "llava-hf/llava-1.5-7b-hf": "llava-llama-7b",
    "llava-hf/llava-1.5-13b-hf": "llava-llama-13b",
    "llava-hf/llava-v1.6-vicuna-7b-hf": "llava-vicuna-7b",
    "llava-hf/llava-v1.6-vicuna-13b-hf": "llava-vicuna-13b",
    "llava-hf/llava-interleave-qwen-0.5b-hf": "llava-qwen-0.5b",
    "llava-hf/llava-interleave-qwen-7b-hf": "llava-qwen-7b",

    "XXXX-2/lvlm68m": "llava-68m",
    "XXXX-2/lvlm290m": "llava-290m",
    "XXXX-2/lm68m": "llama-68m",
    "XXXX-2/lm290m": "llama-290m",
    
    "XXXX-2/lvlm68m-pretrain-only": "llava-68m-pto",
    
    "XXXX-2/lvlm68m-pool-0-ft": "llava-68m-pool-0-ft",
    "XXXX-2/lvlm68m-pool-1-ft": "llava-68m-pool-1-ft",
    "XXXX-2/lvlm68m-pool-4-ft": "llava-68m-pool-4-ft",
    "XXXX-2/lvlm68m-pool-9-ft": "llava-68m-pool-9-ft",
    "XXXX-2/lvlm68m-pool-36-ft": "llava-68m-pool-36-ft",
    "XXXX-2/lvlm68m-pool-144-ft": "llava-68m-pool-144-ft",

    "JackFram/llama-68m": "llama-68m",
    "JackFram/llama-160m": "llama-160m",
    "double7/vicuna-68m": "vicuna-68m",
    "double7/vicuna-160m": "vicuna-160m",

    # Captioning model
    "Salesforce/blip-image-captioning-base": "blip-base",
    "Salesforce/blip-image-captioning-large": "blip-large",
    "Salesforce/blip2-opt-2.7b-coco": "blip2-2.7b",
    "Salesforce/blip2-opt-2.7b": "blip2-2.7b-noft",
    "Salesforce/blip2-opt-6.7b-coco": "blip2-6.7b",
    "microsoft/Florence-2-large": "florence2-0.77b-noft",
    "microsoft/Florence-2-large-ft": "florence2-0.77b",
    "ljnlonoljpiljm/florence-2-large-llava-recap-cc3m": "florence2-0.77b-cc3m",

    "<CAPTION>": "C",
    "<DETAILED_CAPTION>": "DC",
    "<MORE_DETAILED_CAPTION>": "MDC",
    "<OCR>": "OCR",

    # dataset
    "cnn_dailymail": ("cnndm", "summarization"),
    "xsum": ("xsum", "summarization"),
    "wmt14": ("wmt", "translation"),
    
    "LLaVA-Instruct-150K": ("llava-inst", "summarization"),
    "COCO2014": ("coco", "summarization"),
}

def get_short_name(name_obj: str):
    if name_obj not in map_name_task:
        return name_obj
    mapped = map_name_task[name_obj]
    # (Dataset, Task)
    if isinstance(mapped, tuple):
        return map_name_task[name_obj][0]
    return mapped


def get_cascade_name(_config):
    sorted_list = sorted(_config['drafting'])

    if 'caption' in _config['drafting']:
        index_caption_drafting = sorted_list.index('caption')
        sorted_list[index_caption_drafting] = f"caption-{get_short_name(_config['captioning_model'])}"     
        if 'lorence-2' in _config['captioning_model']: # 'microsoft/Florence'
            sorted_list[index_caption_drafting] += f"-{get_short_name(_config['caption_type'])}"

    if 'image-pool' in _config['drafting']:
        index_image_pool_drafting = sorted_list.index('image-pool')
        sorted_list[index_image_pool_drafting] = f"image-pool-{_config['target_dim_image_pooling']}d-{_config['image_pool_type']}"
    
    if _config['cascade_rule'] == 'mm-weight':
        cascade_option_list = [_config['cascade_rule'], str(_config['mm_weight_policy'])]
    else:
        cascade_option_list = [_config['cascade_rule']]

    if _config['mm_weight_k'] is not None:
        cascade_option_list.append(f"{_config['mm_weight_k']}x")
    
    # Sort the list
    cascade_option_list.append('cascade')
    sorted_list = sorted_list + cascade_option_list
    
    # Merge the sorted list with a hyphen
    name_drafting = '-'.join(sorted_list)
    
    return name_drafting

def get_ckpt_name(_config, is_phase_2=True, dataset=None):
    ckpt_dataset = dataset or _config['dataset']

    if isinstance(_config['drafting'], list):
        name_drafting = get_cascade_name(_config)

    elif _config['drafting'] == 'caption':
        name_drafting = f"{_config['drafting']}-{get_short_name(_config['captioning_model'])}" 
        if 'lorence-2' in _config['captioning_model']: # 'microsoft/Florence'
            name_drafting += f"-{get_short_name(_config['caption_type'])}"
    elif _config['drafting'] in ['image-pool']:
        name_drafting = f"{_config['drafting']}-{_config['target_dim_image_pooling']}d-{_config['image_pool_type']}"
    else:
        name_drafting = f"{_config['drafting']}"
    
    if _config['image_top_k_attention']:
        name_drafting += f"-top-{_config['image_top_k_attention']}"
    
    name_drafting += '-drafting'

    factors = [
        _config['decoding'],
        get_short_name(_config['drf']),
        get_short_name(_config['tgt']),
        name_drafting,
        get_short_name(ckpt_dataset),
        f"mtl-{_config['max_target_length']}",
        f"gamma-{_config['max_chunk_length']}",
        f"t{int(_config['temperature'])}",
        f"fp{_config['drf_dtype'][-2:]}-{_config['tgt_dtype'][-2:]}",
        _config['seed'],
    ]
    if _config['is_time_factorized']:
        factors.insert(-1, 'time-factorized')
    if _config['is_tgt_text_only']:
        factors.insert(-1, 'text-verify')
    if _config['tiny_data']:
        factors.append('tiny_data')
    
    return "_".join(map(str, factors))

def get_image_escape_token_num(model_name: str):
    if any(x in model_name for x in ["XXXX-2/lvlm", "XXXX-2/lm", "llava-hf/llava-1.5-7b-hf"]):
        return 2, 3
    elif "llava-hf/llava-interleave-qwen" in model_name:
        return 1, 1
    else:
        raise NotImplementedError(f"get_image_escape_token_num not implemented for {model_name}")


def get_caption_prefix_ids(model_name: str):
    if any(x in model_name for x in ["XXXX-2/lvlm", "XXXX-2/lm", "llava-hf/llava-1.5-7b-hf"]):
        return torch.LongTensor([1967, 29901,29871])
    elif "llava-hf/llava-interleave-qwen" in model_name:
        return torch.LongTensor([1805, 25, 220])
    else:
        raise NotImplementedError(f"get_image_escape_token_num not implemented for {model_name}")

def get_pseudo_image_text_token_ids(model_name: str):
    if any(x in model_name for x in ["XXXX-2/lvlm", "XXXX-2/lm", "llava-hf/llava-1.5-7b-hf"]):
        return torch.LongTensor([529, 3027, 29958])
    elif "llava-hf/llava-interleave-qwen" in model_name:
        return torch.LongTensor([27, 1805, 29])
    else:
        raise NotImplementedError(f"get_pseudo_image_text_token_ids not implemented for {model_name}")
    
def avg(l: list):
    return sum(l)/len(l)

def _save(model, optimizer, lr_scheduler, metric, save_dir, config):
    # Todo: optimzier state, scheduler state 
    os.makedirs(save_dir, exist_ok=True)
    state_dict = model.state_dict()
    
    # save model
    model.save_pretrained(save_dir, state_dict=state_dict, safe_serialization=True)
    
    # save optimizer, scheduler
    torch.save({
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
    }, os.path.join(save_dir, "optimizers.pt"))
    
    # save metric
    torch.save(metric.state_dict(), os.path.join(save_dir, "metric.pt"))
    
    # save expt config
    with open(os.path.join(save_dir, 'config_sacred.yaml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def set_seed(seed):
	mp.set_sharing_strategy('file_system')
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False 


@contextmanager
def noop_context():
    yield

@contextmanager
def patch_function(target_object, function_name, custom_function, *args, **kwargs):
    """
    A context manager to patch a function dynamically in any target object or module.
    
    Args:
        target_object: The object or module where the function is located (e.g., torch.nn.functional).
        function_name: The name of the function to patch as a string (e.g., 'scaled_dot_product_attention').
        custom_function: The custom function to replace the original function.
        *args, **kwargs: Additional arguments or configurations to be passed to the custom function.
    
    Yields:
        None
    """
    
    # Save the original function
    original_function = getattr(target_object, function_name)

    # Set the custom function's arguments
    custom_function.kwargs = kwargs

    # Replace the original function with the custom function
    setattr(target_object, function_name, custom_function)

    try:
        # Yield control back to the context (the 'with' block)
        yield
    finally:
        # Revert the original function after the 'with' block is done
        setattr(target_object, function_name, original_function)
    
    """
    # exmaple usage
    with patch_function(F, 'scaled_dot_product_attention', scaled_dot_product_attention_top_k_image_tokens):
        # Set the current instance to use the parameters within the custom function
        scaled_dot_product_attention_top_k_image_tokens.model_instance = llama_model

        # Run your model's forward pass here, only this will use the patched function
        # Assuming `input_ids` and other required inputs are prepared
        input_ids = torch.randint(0, llama_model.config.vocab_size, (1, 128))  # Example input
        outputs = llama_model(input_ids)
        
        # Any other computations that involve the patched function
        # These will only affect this particular llama_model instance
        print(outputs)
    """