import os
import torch
import yaml
import sys
from typing import List, Tuple, Dict, Optional, Any
from collections import namedtuple
from mmengine import Config 

from constants import (
    IMAGE_TOKEN_INDEX,
    IMAGE_TOKEN_LENGTH,
    MODEL_PATHS,
    SYSTEM_MESSAGE,
    INSTRUCTION_TEMPLATE,
    SHIKRA_IMG_START_TOKEN,
    SHIKRA_IMG_END_TOKEN,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_model_args_from_yaml(yaml_path):
    with open(yaml_path, "r") as file:
        data = yaml.safe_load(file)
    
    ModelArgs = namedtuple("ModelArgs", data["ModelArgs"].keys())
    TrainingArgs = namedtuple("TrainingArgs", data["TrainingArgs"].keys())

    model_args = ModelArgs(**data["ModelArgs"])
    training_args = TrainingArgs(**data["TrainingArgs"])

    return model_args, training_args

def load_llava_model(model_path: str):
    from llava.mm_utils import get_model_name_from_path
    from llava.model.builder import load_pretrained_model
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, model_base=None, model_name=model_name,
        load_8bit=False, load_4bit=False, device=DEVICE, device_map="auto"
    )
    return {"tokenizer": tokenizer, "model": model, "image_processor": image_processor, "llm_model": model}

def load_minigpt4_model(cfg_path):
    from minigpt4.common.config import Config as MiniGPT4ConfigCls
    from minigpt4.common.registry import registry
    class MiniGPT4Config:
        def __init__(self, cfg_path):
            self.cfg_path = cfg_path
            self.options = None
    args = MiniGPT4Config(cfg_path)
    cfg = MiniGPT4ConfigCls(args)
    model_config = cfg.model_cfg
    model_cls = registry.get_model_class(model_config.arch)
    model = model_cls.from_config(model_config).to('cuda:0')
    key = list(cfg.datasets_cfg.keys())[0]
    vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train
    vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
    return {"tokenizer": model.llama_tokenizer, "model": model, "image_processor": vis_processor, "llm_model": model.llama_model}


def load_shikra_model(yaml_path):
    from mllm.models import load_pretrained
    from mllm.dataset.process_function import PlainBoxFormatter

    model_args, training_args = load_model_args_from_yaml(yaml_path)
    model, preprocessor = load_pretrained(model_args, training_args)

    preprocessor['conv'] = {
        'image_token_len': model_args.image_token_len,
        'sep_image_conv_front': model_args.sep_image_conv_front,
        'use_im_start_end': model_args.mm_use_im_start_end,
    }
    preprocessor['target'] = {'boxes': PlainBoxFormatter()}

    return {
        "tokenizer": preprocessor["text"],
        "model": model.to("cuda"),
        "image_processor": preprocessor["image"],
        "llm_model": model.to("cuda"),
        "full_preprocessor": preprocessor,
    }

def load_qwen_vl_model(model_path):
    qwen_module_path = os.path.abspath("./Qwen_VL")
    if qwen_module_path not in sys.path:
        sys.path.append(qwen_module_path)
    
    from transformers import AutoTokenizer
    from Qwen_VL.modeling_qwen import QWenLMHeadModel

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token_id = tokenizer.eod_id
    
    model = QWenLMHeadModel.from_pretrained(
        model_path,
        device_map="cuda",
        trust_remote_code=True
    ).eval()
    
    image_processor = model.transformer.visual.image_transform
    
    return {
        "tokenizer": tokenizer,
        "model": model,
        "image_processor": image_processor,
        "llm_model": model
    }

# ==================== Input Preparers ====================

def prepare_llava_inputs(template, query, image_tensor, tokenizer):
    qu = [template.replace("<question>", q) for q in query]
    batch_size = len(query)
    chunks = [q.split("<ImageHere>") for q in qu]
    chunk_before = [chunk[0] for chunk in chunks]
    chunk_after = [chunk[1] for chunk in chunks]
    token_before = tokenizer(chunk_before, return_tensors="pt", padding="longest", add_special_tokens=False).to(DEVICE).input_ids
    token_after = tokenizer(chunk_after, return_tensors="pt", padding="longest", add_special_tokens=False).to(DEVICE).input_ids
    bos = torch.ones([batch_size, 1], dtype=torch.int64, device=DEVICE) * tokenizer.bos_token_id
    img_start_idx = len(token_before[0]) + 1
    image_token_length = IMAGE_TOKEN_LENGTH
    if isinstance(image_token_length, dict):
        image_token_length = image_token_length.get("llava-1.5", 576)
    img_end_idx = img_start_idx + image_token_length
    image_token = torch.ones([batch_size, 1], dtype=torch.int64, device=DEVICE) * IMAGE_TOKEN_INDEX
    input_ids = torch.cat([bos, token_before, image_token, token_after], dim=1)
    kwargs = {"images": image_tensor.half().to(DEVICE)}
    return qu, input_ids, img_start_idx, img_end_idx, kwargs

def prepare_minigpt4_inputs(template, query, image, model):
    qu = [template.replace("<question>", q) for q in query]
    for q in qu:
        if "<ImageHere>" not in q: raise ValueError("Template must contain '<ImageHere>'")
    with torch.no_grad():
        image = image.to("cuda")
        img_output = model.encode_img(image)
        if isinstance(img_output, tuple):
            img_embeds = img_output[0]
            atts_img = img_output[1] if len(img_output) > 1 else torch.ones(img_embeds.shape[:-1], dtype=torch.long).to(img_embeds.device)
        else:
            img_embeds = img_output
            atts_img = torch.ones(img_embeds.shape[:-1], dtype=torch.long).to(img_embeds.device)
    wrap_output = model.prompt_wrap(img_embeds=img_embeds, atts_img=atts_img, prompts=qu)
    if isinstance(wrap_output, tuple):
        inputs_embeds = wrap_output[0]
        attention_mask = wrap_output[1] if len(wrap_output) > 1 else torch.ones(inputs_embeds.shape[:-1], dtype=torch.long).to(inputs_embeds.device)
    else:
        inputs_embeds = wrap_output
        attention_mask = torch.ones(inputs_embeds.shape[:-1], dtype=torch.long).to(inputs_embeds.device)
    bos = torch.ones([len(query), 1], dtype=torch.int64, device=inputs_embeds.device) * model.llama_tokenizer.bos_token_id
    bos_embeds = model.embed_tokens(bos)
    atts_bos = torch.ones([len(query), 1], dtype=torch.long, device=inputs_embeds.device)
    inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
    attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
    chunks = [q.split("<ImageHere>") for q in qu]
    chunk_before = [chunk[0] for chunk in chunks]
    token_before = model.llama_tokenizer(chunk_before[0], return_tensors="pt", add_special_tokens=False).input_ids
    img_start_idx = token_before.shape[1] + 1
    img_end_idx = img_start_idx + IMAGE_TOKEN_LENGTH["minigpt4"]
    kwargs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
    return qu, None, img_start_idx, img_end_idx, kwargs

def prepare_qwen_vl_inputs(template, query, image_path, tokenizer, image_tensor):
    """
    Prepare Qwen-VL inputs (fixed).
    """
    prompts = []
    for q in query:

        prompt = f'<img>{image_path}</img>{q} Answer:'
        prompts.append(prompt)
        
    inputs = tokenizer(prompts, return_tensors='pt', padding='longest')
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")
    

    img_start_id = tokenizer.convert_tokens_to_ids("<img>")
    img_end_id = tokenizer.convert_tokens_to_ids("</img>")
    
    input_ids_list = input_ids[0].tolist()
    img_start_idx = None
    img_end_idx = None
    
    for i, token_id in enumerate(input_ids_list):
        if token_id == img_start_id: 
            img_start_idx = i + 1 
        elif token_id == img_end_id and img_start_idx is not None: 
            img_end_idx = i
            break
            
    if img_start_idx is None:
        img_start_idx = 0
        img_end_idx = 0

    kwargs = {
        "images": image_tensor,
        "attention_mask": attention_mask
    }
    
    return prompts, input_ids, img_start_idx, img_end_idx, kwargs

def prepare_shikra_inputs(model_args, query, image, preprocessor):
    from mllm.dataset.builder import prepare_interactive
    from mllm.dataset.utils.transform import expand2square
    
    ds = prepare_interactive(model_args, preprocessor)
    
    if isinstance(image, torch.Tensor):
        import torchvision.transforms as T
        to_pil = T.ToPILImage()
        image = to_pil(image)
    
    image_expanded = expand2square(image)
    ds.set_image(image_expanded)
    
    ds.append_message(role=ds.roles[0], message=query[0], boxes=[], boxes_seq=[])
    
    model_inputs = ds.to_model_input()
    model_inputs['images'] = model_inputs['images'].to(device='cuda')
    model_inputs['input_ids'] = model_inputs['input_ids'].to(device='cuda')
    
    input_ids = model_inputs['input_ids']
    
    if input_ids.dim() == 1:
        input_ids = input_ids.unsqueeze(0)
        
    img_start_idx = 0
    img_end_idx = 0
    
    if (input_ids == SHIKRA_IMG_START_TOKEN).any():
        img_start_idx = torch.where(input_ids == SHIKRA_IMG_START_TOKEN)[1][0].item()
    
    if (input_ids == SHIKRA_IMG_END_TOKEN).any():
        img_end_idx = torch.where(input_ids == SHIKRA_IMG_END_TOKEN)[1][0].item()
        
    kwargs = {
        "images": model_inputs['images'],
    }

    if "attention_mask" in model_inputs:
        kwargs["attention_mask"] = model_inputs["attention_mask"].to(device='cuda')
    
    return [query[0]], input_ids, img_start_idx, img_end_idx, kwargs

# ==================== Model Manager ====================

class ModelManager: 
    SUPPORTED_MODELS = ["llava-1.5", "minigpt4", "shikra", "qwen-vl"]
    
    def __init__(self, model_name: str, model_path: Optional[str] = None):
        self.model_name = model_name.lower()
        if self.model_name not in self.SUPPORTED_MODELS:
            raise ValueError(f"Unknown model: {self.model_name}")
        
        self.model_path = model_path or MODEL_PATHS.get(self.model_name)
        
        self.model_args = None
        if self.model_name == "shikra":
            yaml_path = self.model_path if self.model_path.endswith(".yml") else "path/to/shikra/config.yml"
            self.model_args = self._create_shikra_model_args(yaml_path)

        self.tokenizer = None
        self.model = None
        self.image_processor = None
        self.llm_model = None
        self.full_preprocessor = None
        self.img_start_idx = None
        self.img_end_idx = None
        
        self._load_model()
    
    def _create_shikra_model_args(self, yaml_path):
        base_model_args, _ = load_model_args_from_yaml(yaml_path)
        model_args_dict = dict(
            type='shikra',
            version='v1',
            model_name_or_path=base_model_args.model_name_or_path,
            model_max_length=base_model_args.model_max_length,
            vision_tower=getattr(base_model_args, 'vision_tower', "openai/clip-vit-large-patch14-336"),
            mm_vision_select_layer=getattr(base_model_args, 'mm_vision_select_layer', -2),
            pretrain_mm_mlp_adapter=getattr(base_model_args, 'pretrain_mm_mlp_adapter', None),
            tune_mm_mlp_adapter=False,
            freeze_backbone=True,
            freeze_mm_mlp_adapter=False,
            sep_image_conv_front=getattr(base_model_args, 'sep_image_conv_front', False),
            image_token_len=getattr(base_model_args, 'image_token_len', 256),
            mm_use_im_start_end=getattr(base_model_args, 'mm_use_im_start_end', True),
            target_processor=dict(
                boxes=dict(type='PlainBoxFormatter'),
            ),
            process_func_args=dict(
                conv=dict(type='ShikraConvProcess'),
                target=dict(type='BoxFormatProcess'),
                text=dict(type='ShikraTextProcess'),
                image=dict(type='ShikraImageProcessor'),
            ),
            conv_args=dict(
                conv_template='vicuna_v1.1',
                transforms=dict(type='Expand2square'),
                tokenize_kwargs=dict(truncation_size=None),
            ),
        )
        return Config(model_args_dict)

    def _load_model(self):
        print(f"Loading {self.model_name} from {self.model_path}...")
        
        if self.model_name == "llava-1.5":
            components = load_llava_model(self.model_path)
        elif self.model_name == "minigpt4":
            components = load_minigpt4_model(self.model_path)
        elif self.model_name == "shikra":
            yaml_path = self.model_path if self.model_path.endswith(".yml") else "path/to/shikra/config.yml"
            components = load_shikra_model(yaml_path)
            self.full_preprocessor = components.get("full_preprocessor")
        elif self.model_name == "qwen-vl": 
            components = load_qwen_vl_model(self.model_path)
        else:
            raise ValueError(f"Unknown model: {self.model_name}")
        
        self.tokenizer = components["tokenizer"]
        self.model = components["model"]
        self.image_processor = components["image_processor"]
        self.llm_model = components["llm_model"]
        print(f"Model {self.model_name} loaded successfully!")
    
    def construct_template(self) -> str:
        system = SYSTEM_MESSAGE.get(self.model_name, "") if isinstance(SYSTEM_MESSAGE, dict) else SYSTEM_MESSAGE
        instruction_map = INSTRUCTION_TEMPLATE if isinstance(INSTRUCTION_TEMPLATE, dict) else {self.model_name: INSTRUCTION_TEMPLATE}
        instruction = instruction_map.get(self.model_name, "<question>")
        return f"{system} {instruction}" if system else instruction
    
    def preprocess_image(self, image):
        if self.model_name == "shikra": return image 
        if self.image_processor is None: return image
        if self.model_name == "llava-1.5":
            if hasattr(image, 'convert'): image = image.convert('RGB')
            pixel_values = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
            return pixel_values.to(DEVICE)
        elif self.model_name == "minigpt4":
            if hasattr(image, 'convert'): image = image.convert('RGB')
            return self.image_processor(image).unsqueeze(0)
        elif self.model_name == "qwen-vl":
            if hasattr(image, 'convert'): image = image.convert('RGB')

            image_tensor = self.image_processor(image).unsqueeze(0).to(self.model.device)

            if hasattr(self.model, 'dtype') and image_tensor.dtype != self.model.dtype:
                image_tensor = image_tensor.to(dtype=self.model.dtype)
            return image_tensor
        return image
    
    def prepare_inputs(self, query: List[str], image, image_path: Optional[str] = None):
        template = self.construct_template()
        
        if self.model_name == "llava-1.5":
            questions, input_ids, img_start, img_end, kwargs = prepare_llava_inputs(template, query, image, self.tokenizer)
        elif self.model_name == "minigpt4":
            questions, input_ids, img_start, img_end, kwargs = prepare_minigpt4_inputs(template, query, image, self.model)
        elif self.model_name == "shikra":
            questions, input_ids, img_start, img_end, kwargs = prepare_shikra_inputs(self.model_args, query, image, self.full_preprocessor)
        elif self.model_name == "qwen-vl":
            if image_path is None: raise ValueError("Qwen-VL requires image_path")
            questions, input_ids, img_start, img_end, kwargs = prepare_qwen_vl_inputs(template, query, image_path, self.tokenizer, image)
        else:
            raise ValueError(f"Unknown model: {self.model_name}")
        
        self.img_start_idx = img_start
        self.img_end_idx = img_end
        return questions, input_ids, kwargs
    
    def decode(self, output_ids: torch.Tensor, input_ids: Optional[torch.Tensor] = None) -> List[str]:
        if self.model_name == "llava-1.5":
            output_ids = output_ids.clone()
            output_ids[output_ids == IMAGE_TOKEN_INDEX] = torch.tensor(
                0, dtype=output_ids.dtype, device=output_ids.device
            )
            output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            output_text = [text.split("ASSISTANT:")[-1].strip() for text in output_text]
        elif self.model_name == "minigpt4":
            output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            output_text = [text.split("###")[0].split("Assistant:")[-1].strip() for text in output_text]
        elif self.model_name == "shikra":
            if input_ids is not None:
                input_token_len = input_ids.shape[-1]
                output_ids = output_ids[:, input_token_len:]
            output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        elif self.model_name == "qwen-vl":
            if input_ids is not None:
                input_token_len = input_ids.shape[-1]

                if output_ids.shape[1] > input_token_len:
                    output_ids = output_ids[:, input_token_len:]
                else:
                    return [""]
            output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            output_text = [text.strip() for text in output_text]
        else:
            output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        return output_text
    
    @torch.no_grad()
    def generate(self, input_ids: Optional[torch.Tensor], max_new_tokens: int = 512, **kwargs):

        generation_config = {"max_new_tokens": max_new_tokens}

        user_do_sample = kwargs.pop("do_sample", None)
        user_num_beams = kwargs.pop("num_beams", None)
        if user_do_sample is not None:
            generation_config["do_sample"] = user_do_sample
        if user_num_beams is not None:
            generation_config["num_beams"] = user_num_beams
        
        if self.model_name == "llava-1.5":
            if input_ids is None:
                raise ValueError("LLaVA generate requires input_ids (got None)")
            images = kwargs.pop("images", None)
            if images is None:
                raise ValueError("LLaVA generate requires images tensor (got None)")
            generation_config.setdefault("do_sample", False)
            generation_config.setdefault("num_beams", 1)
            # print("debug:", type(input_ids), input_ids.shape, type(images), images.shape)
            outputs = self.model.generate(
                inputs=input_ids,
                images=images,
                **generation_config,
            )
            
        elif self.model_name == "minigpt4":
            generation_config.setdefault("do_sample", False)
            generation_config.setdefault("num_beams", 1)
            outputs = self.llm_model.generate(
                inputs_embeds=kwargs.get("inputs_embeds"),
                attention_mask=kwargs.get("attention_mask"),
                **generation_config
            )
            
        elif self.model_name == "shikra":
            generation_config.setdefault("do_sample", False)
            generation_config.setdefault("num_beams", 1)
            outputs = self.model.generate(
                input_ids=input_ids,
                images=kwargs.get('images'),
                **generation_config
            )
            
        elif self.model_name == "qwen-vl":


            generation_config.setdefault("do_sample", True)
            generation_config.setdefault("temperature", 1.0)
            generation_config.setdefault("top_p", 1.0)
            generation_config.setdefault("num_beams", 5)
            generation_config.setdefault("min_new_tokens", 20)
            generation_config.setdefault("length_penalty", 1)
            outputs = self.model.generate(
                input_ids=input_ids, 
                images=kwargs.get('images'),
                attention_mask=kwargs.get('attention_mask'),
                **generation_config)
        else:
            raise ValueError(f"Unknown model: {self.model_name}")
        return outputs
