from typing import Dict, Optional
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig, AutoTokenizer
from peft import LoraConfig, get_peft_model, PeftModel
import os
from src.arguments import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
    backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, LLAVA_ONEVISION, LLAVA_QWEN2

from src.arguments import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
    QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, COLPALI, INTERN_VL3, LLAVA_ONEVISION
from src.model.vlm_backbone.colpali import ColPali
from src.model.vlm_backbone.gme.gme_inference import GmeQwen2VL
from src.model.vlm_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.vlm_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.vlm_backbone.llava_next import LlavaNextForConditionalGeneration
from src.model.vlm_backbone.llava_onevision import LlavaOnevisionForConditionalGeneration
from src.model.llava.model import *
from src.model.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

from peft import PeftConfig
from unittest.mock import patch

class MMEBModel(nn.Module):
    TRANSFORMER_CLS = AutoModelForCausalLM

    def __init__(self,
                 encoder: PreTrainedModel,
                 pooling: str = 'last',
                 normalize: bool = False,
                 temperature: float = 0.02,
                 ):
        super().__init__()
        self.config = encoder.config
        self.encoder = encoder
        self.pooling = pooling
        self.normalize = normalize
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.is_ddp = dist.is_initialized()
        if self.is_ddp:
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def encode_input(self, input):
        INTERNVIDEO2 = "internvideo2"
        if getattr(self, "model_backbone", None) == INTERNVIDEO2:
            if "input_ids" in input.keys():
                # text side
                text_output = self.encoder.get_text_encoder()(
                    input["input_ids"],
                    attention_mask=input["attention_mask"],
                    return_dict=True,
                    mode="text",
                )
                text_embeds = text_output.last_hidden_state
                pooled_text_embeds = text_embeds[:, 0]
                pooled_output = self.encoder.text_proj(pooled_text_embeds)
                pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
                return pooled_output
            else:
                _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
                vfeat = self.encoder.vision_proj(vfeat)
                vfeat /= vfeat.norm(dim=-1, keepdim=True)
                return vfeat
        elif getattr(self, "model_backbone", None) in [GME, LamRA]:
            # pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
            texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video querys so this should not happen
            images = []
            for imgs in input['images']:
                # if multi images are given, select the middle frame only
                if isinstance(imgs, list):
                    imgs = imgs[len(imgs) // 2]
                    assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
                    images.append(imgs)
                else:
                    images.append(None)
            pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
            return pooled_output
        elif getattr(self, "model_backbone", None) == COLPALI:
            pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
            return pooled_output
        elif getattr(self, "model_backbone", None) == INTERN_VL3:
            if('text' in input):
                del input['text']
            if('global_dataset_name' in input):
                del input['global_dataset_name']
            
            # import ipdb; ipdb.set_trace()
            
            input['pixel_values'] = torch.cat(input['pixel_values'], dim=0).to(input['input_ids'].device)

            if(input['pixel_values'].size(0)):
                input['image_flags'] = torch.ones(input['pixel_values'].size(0))
                with patch("builtins.print"):
                    hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
            else:
                del input['pixel_values']
                hidden_states = self.encoder.language_model(**input, return_dict=True, output_hidden_states=True)
            
            hidden_states = hidden_states.hidden_states[-1]
            pooled_output = self._pooling(hidden_states, input['attention_mask'])
            return pooled_output
            """
                - num_tokens = num_image_tokens - 1 + text_tokens
                - hidden_states.hidden_states: (num_layers, batch, num_tokens, embed_dim)
                    -> -1 because the first token is IMAGE_TOKEN_INDEX
                - pooled_output: (batch, embed_dim), 
                - image_features: (batch, num_image_tokens, embed_dim), 
                - attention_matrix: list of (batch, num_heads, num_tokens, num_tokens)
            """
        elif getattr(self, "model_backbone", None) in [LLAVA_NEXT, LLAVA_ONEVISION]:
            print("Encoding input for LLAVA model backbone")
            if hasattr(input, 'pixel_values'):
                input['pixel_values'] = input['pixel_values'].squeeze(1)
                input['image_sizes'] = input['image_sizes'].squeeze(1)
            hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True, output_attentions=True)
            # add for image feature
            if hasattr(hidden_states, 'batch_image_embeds'):
                image_features = hidden_states.batch_image_embeds
            else: 
                image_features = None
            output_hidden_states = hidden_states.hidden_states
            last_hidden_state = hidden_states.hidden_states[-1]
            attention_matrix = hidden_states.attentions if hasattr(hidden_states, 'attentions') else None
            pooled_output = self._pooling(last_hidden_state, input['attention_mask'])
            print("len image features:", None if image_features is None else image_features.shape)
            return pooled_output, image_features, attention_matrix, output_hidden_states
        elif getattr(self, "model_backbone", None) in [LLAVA_QWEN2, QWEN2_VL]:
            # print("Encoding input for FastVLM model backbone")
            hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True, output_attentions=True)
            if hasattr(hidden_states, 'batch_image_embeds'):
                image_features = hidden_states.batch_image_embeds
            else: 
                image_features = None
            output_hidden_states = hidden_states.hidden_states
            last_hidden_state = hidden_states.hidden_states[-1]
            attention_matrix = hidden_states.attentions if hasattr(hidden_states, 'attentions') else None
            pooled_output = self._pooling(last_hidden_state, input['attention_mask'])

            return pooled_output, image_features, attention_matrix, output_hidden_states
        else:
            # import ipdb; ipdb.set_trace()
            hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True, output_attentions=True)
            if hasattr(hidden_states, 'batch_image_embeds'):
                image_features = hidden_states.batch_image_embeds
            else: 
                image_features = None
            output_hidden_states = hidden_states.hidden_states
            last_hidden_state = hidden_states.hidden_states[-1]
            attention_matrix = hidden_states.attentions if hasattr(hidden_states, 'attentions') else None
            pooled_output = self._pooling(last_hidden_state, input['attention_mask'])

            all_layers_embeds = torch.stack([self._pooling(hidden_state, input['attention_mask']) 
                                            for hidden_state in hidden_states.hidden_states]).permute(1, 0, 2)
            
            return pooled_output, image_features, attention_matrix, output_hidden_states
        """
            - num_tokens = num_image_tokens - 1 + text_tokens
            - hidden_states.hidden_states: (num_layers, batch, num_tokens, embed_dim)
                -> -1 because the first token is IMAGE_TOKEN_INDEX
            - pooled_output: (batch, embed_dim), 
            - image_features: (batch, num_image_tokens, embed_dim), 
            - attention_matrix: list of (batch, num_heads, num_tokens, num_tokens)
        """
    def _pooling(self, last_hidden_state, attention_mask):
        if self.pooling == 'last' or self.pooling == 'eos':
            left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
            batch_size = last_hidden_state.shape[0]
            if left_padding:
                # Get the vectors at the last position
                reps = last_hidden_state[torch.arange(batch_size), -1, :]
            else:
                # Calculate last 1 position in the original tensor
                max_length = last_hidden_state.size(1)
                invert_mask = (attention_mask == 0).long()
                num_padding_tokens = invert_mask.sum(dim=1)
                eos_indices_positive = max_length - num_padding_tokens - 1
                # Get the vectors at the last 1 position of each attention mask
                reps = last_hidden_state[
                    torch.arange(batch_size, device=last_hidden_state.device), eos_indices_positive]

        else:
            raise NotImplementedError
        if self.normalize:
            reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
        return reps

    @classmethod
    def build(cls, model_args: ModelArguments, **kwargs):
        INTERNVIDEO2 = "internvideo2"
        # if model_args.lora:
        #     peft_config = PeftConfig.from_pretrained(model_args.model_name)   
        #     config = AutoConfig.from_pretrained(peft_config.base_model_name_or_path, trust_remote_code=True)
        # else:
        #     config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
        config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)

        model_backbone = get_backbone_name(hf_config=config)
        setattr(model_args, 'model_backbone', model_backbone)
        print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
        # Loading the base model
        if model_backbone == PHI3V:
            config._attn_implementation = "eager"
            config.padding_side = "right"
            config.use_cache = False
            base_model = Phi3VForCausalLM.from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
            )
        elif model_backbone == LLAVA_NEXT:
            config.use_cache = False
            config.padding_side = "left"
            base_model = LlavaNextForConditionalGeneration.from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
            )
        elif model_backbone == LLAVA_ONEVISION:
            config._attn_implementation = "eager"
            config.use_cache = False
            config.padding_side = "left"
            base_model = LlavaOnevisionForConditionalGeneration.from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
            )
        elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
            config._attn_implementation = "eager"
            config.padding_side = "left"
            config.use_cache = False
            base_model = backbone2model[model_backbone].from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
            )
        elif model_backbone in [INTERN_VL3]:
            config._attn_implementation = "eager"
            config.padding_side = "left"
            config.use_cache = False
            base_model = backbone2model[model_backbone].from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                use_flash_attn=False,
                trust_remote_code=True,
            )
            #! hardcoded. Also check hardcoded values in processor
            base_model.img_context_token_id=151667

        elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
            config._attn_implementation = "eager"
            config.padding_side = "left"
            config.use_cache = False

            from .utils import parse_layer_type
            lm_qwen_layer = 28
            vis_qwen_layer = 32
            lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
            vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)

            base_model = backbone2model[model_backbone].from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                lm_skip_layer=lm_skip_layer,
                vis_skip_layer=vis_skip_layer,
            )
        
        elif model_backbone in [LLAVA_QWEN2]:
            config._attn_implementation = "eager"
            base_model = LlavaQwen2ForCausalLM.from_pretrained(
                model_args.model_name,
                low_cpu_mem_usage=True,
                torch_dtype=torch.bfloat16,
                config=config,
                # **kwargs
            )
        else:
            config.use_cache = False
            base_model = cls.TRANSFORMER_CLS.from_pretrained(
                model_args.model_name, **kwargs, config=config,
                attn_implementation="eager",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True)


        if model_args.load_pretrained_lora:
            model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
            print_master(f'Loading Pre-trained LoRA model from {model_name_or_path}')
            lora_config = LoraConfig.from_pretrained(model_name_or_path)

            lora_model = PeftModel.from_pretrained(base_model, 
                                                   model_name_or_path, 
                                                   config=lora_config,
                                                   is_trainable=True)
            # lora_model.config.modules_to_save = ['mm_projector']
            # lora_model.peft_config['default'].modules_to_save = ['mm_projector']
            projector_path = os.path.join(model_name_or_path, "mm_projector.pth")

            if os.path.exists(projector_path):
                if model_args.model_backbone in ["llava_onevision", "llava_next"]:
                    lora_model.base_model.model.multi_modal_projector.load_state_dict(
                        torch.load(projector_path)
                    )
                else:
                    lora_model.base_model.model.model.mm_projector.load_state_dict(
                        torch.load(projector_path)
                    )
                print("Successfully loading the projector's weight")

            model = cls(
                encoder=lora_model,
                pooling=model_args.pooling,
                normalize=model_args.normalize,
                temperature=model_args.temperature
            )

            return model
        elif model_args.lora:
            print_master(f'Initializing LoRA adapter from {base_model}')
            if model_args.model_backbone in ["llava_onevision", "llava_next"]:
                base_targets = [t.strip() for t in model_args.lora_target_modules.split(',')]
    
                # Liệt kê TƯỜNG MINH tất cả modules trong language_model
                explicit_target_modules = []
                
                for name, module in base_model.named_modules():
                    # Chỉ lấy modules có "language_model" trong tên
                    if 'language_model' in name:
                        module_name = name.split('.')[-1]
                        if module_name in base_targets:
                            # Thêm FULL PATH của module này
                            explicit_target_modules.append(name)
                
                print(f"Found {len(explicit_target_modules)} explicit target modules")
                print(f"First 5 modules: {explicit_target_modules[:5]}")
                
                if not explicit_target_modules:
                    raise ValueError(f"No modules found with targets {base_targets} in language_model!")
                
                lora_config = LoraConfig(
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    target_modules=explicit_target_modules,  # Dùng full paths
                    lora_dropout=model_args.lora_dropout,
                    init_lora_weights="gaussian",
                    use_dora=True,
                    inference_mode=False
                )
                print(f"Applying LoRA to vision_tower/layers: {model_args.lora_target_modules.split(',')}")
            else:
                    
                lora_config = LoraConfig(
                    r=model_args.lora_r,
                    lora_alpha=model_args.lora_alpha,
                    target_modules=model_args.lora_target_modules.split(','),
                    lora_dropout=model_args.lora_dropout,
                    init_lora_weights="gaussian",
                    use_dora=True,
                    inference_mode=False
                )
            lora_model = get_peft_model(base_model, lora_config)

            model = cls(
                encoder=lora_model,
                pooling=model_args.pooling,
                normalize=model_args.normalize,
                temperature=model_args.temperature
            )

            return model
        
        model = cls(
            encoder=base_model,
            pooling=model_args.pooling,
            normalize=model_args.normalize,
            temperature=model_args.temperature
        )
        setattr(model, 'model_backbone', model_backbone)
        return model


    @classmethod
    def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
        # Loading the base model
        INTERNVIDEO2 = "internvideo2"
        model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
        config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
        model_backbone = get_backbone_name(hf_config=config)
        setattr(model_args, 'model_backbone', model_backbone)
        print_master(f'Loading backbone [{model_backbone}] from {model_name_or_path}')
        
        if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
            model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
            setattr(model_args, 'model_backbone', model_backbone)
        print_master(f'Loading backbone [{model_args.model_backbone}] from {model_args.model_name}')
        if model_args.model_backbone in {LLAVA_ONEVISION, LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION}:
            config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
            config._attn_implementation = "eager"
            config.vision_config._attn_implementation = "eager"
            base_model = backbone2model[model_args.model_backbone].from_pretrained(
                model_args.model_name,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                config=config
            )
        elif model_args.model_backbone in [INTERN_VL3]:
            config._attn_implementation = "eager"
            config.padding_side = "left"
            config.use_cache = False
            base_model = backbone2model[model_backbone].from_pretrained(
                model_args.model_name,
                config=config,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                use_flash_attn=False,
                trust_remote_code=True,
            )
            # import ipdb; ipdb.set_trace()
            #! hardcoded. Also check hardcoded values in processor
            base_model.img_context_token_id=151667
        elif model_args.model_backbone == PHI3V:
            config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
            config.use_cache = False
            config.padding_side = "right"
            base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
                                                          torch_dtype=torch.bfloat16, trust_remote_code=True)
            base_model.padding_side = "right"
        elif model_args.model_backbone == INTERNVIDEO2:
            print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
            config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
                                                trust_remote_code=True)
            base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
                                                                                   trust_remote_code=True)
        elif model_args.model_backbone == GME:
            base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
            setattr(base_model, 'config', config)
        elif model_args.model_backbone == LamRA:
            base_model = LamRAQwen2VL(model_args.model_name)
            setattr(base_model, 'config', config)
        elif model_args.model_backbone == COLPALI:
            base_model = ColPali.from_pretrained(model_args.model_name)
            setattr(base_model, 'config', config)
        else:
            # Loading external base model from HF
            config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
            config.use_cache = False
            base_model = cls.TRANSFORMER_CLS.from_pretrained(
                model_name_or_path, **kwargs, config=config,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True)
            print(f"Loaded base model from HF: {model_name_or_path}")

        # Building the model on top of the base
        if model_args.lora:
            print_master(f'Loading LoRA from {model_name_or_path}')
            lora_config = LoraConfig.from_pretrained(model_name_or_path)
            lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
            lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)

            projector_path = os.path.join(model_name_or_path, "mm_projector.pth")

            if os.path.exists(projector_path):
                if model_args.model_backbone in ["llava_onevision", "llava_next"]:
                    lora_model.base_model.model.multi_modal_projector.load_state_dict(
                        torch.load(projector_path)
                    )
                else:   
                    lora_model.base_model.model.model.mm_projector.load_state_dict(
                        torch.load(projector_path)
                    )
                
                print("Successfully loading the projector's weight from local path")
            else:
                try: 
                    from huggingface_hub import hf_hub_download
                    projector_path = hf_hub_download(
                        repo_id=model_name_or_path,
                        filename="mm_projector.pth",
                    )
                    if model_args.model_backbone in ["llava_onevision", "llava_next"]:
                        lora_model.base_model.model.multi_modal_projector.load_state_dict(
                            torch.load(projector_path)
                        )
                    else:
                        lora_model.base_model.model.model.mm_projector.load_state_dict(
                            torch.load(projector_path)
                        )
                except:
                    print("No projector weight found in the hub.")
                    pass
                print("Successfully loading the projector's weight")
                
            # lora_model = lora_model.merge_and_unload()

            model = cls(
                encoder=lora_model,
                pooling=model_args.pooling,
                normalize=model_args.normalize,
                temperature=model_args.temperature
            )

        else:
            model = cls(
                encoder=base_model,
                pooling=model_args.pooling,
                normalize=model_args.normalize,
                temperature=model_args.temperature
            )

        model.model_backbone = model_args.model_backbone
        return model

    def save(self, output_dir: str):
        self.encoder.save_pretrained(output_dir)

    def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs):
        # print(f"qry keys: {qry.keys() if qry else None}, tgt keys: {tgt.keys() if tgt else None}")
        qry_reps = self.encode_input(qry)[0] if qry else None  # (bsz_per_device, dim)
        tgt_reps = self.encode_input(tgt)[0] if tgt else None # (bsz_per_device, dim)

        if qry_reps is None or tgt_reps is None:
            return {"qry_reps": qry_reps, "tgt_reps": tgt_reps}

        if self.is_ddp:
            all_qry_reps = self._dist_gather_tensor(qry_reps)
            all_tgt_reps = self._dist_gather_tensor(tgt_reps)
        else:
            all_qry_reps = qry_reps
            all_tgt_reps = tgt_reps

        scores = self.compute_similarity(all_qry_reps, all_tgt_reps)
        scores = scores.view(all_qry_reps.size(0), -1)
        target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
        target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
        loss = self.cross_entropy(scores / self.temperature, target)
        if self.is_ddp:
            loss = loss * self.world_size

        return loss

    def _dist_gather_tensor(self, t: Tensor):
        t = t.contiguous()
        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)
        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)
        return all_tensors

    def compute_similarity(self, q_reps, p_reps):
        return torch.matmul(q_reps, p_reps.transpose(0, 1))