#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM, \
                         DPTImageProcessor, DPTForDepthEstimation

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class StableLayerNorm(nn.Module):
   

    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, input):
        input_dtype = input.dtype
        input_float32 = input.float()
        mean = input_float32.mean(dim=-1, keepdim=True)
        var = input_float32.var(dim=-1, keepdim=True, unbiased=False)
        var = torch.clamp(var, min=self.eps)
        std = torch.sqrt(var + self.eps)
        normalized = (input_float32 - mean) / std
        normalized = normalized.to(input_dtype)
        if self.elementwise_affine:
            normalized = normalized * self.weight + self.bias

        return normalized


class StableSoftmax(nn.Module):
   

    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, input):
        
        input_dtype = input.dtype
        input_float32 = input.float()

        
        input_shifted = input_float32 - input_float32.max(dim=self.dim, keepdim=True)[0]

        input_clipped = torch.clamp(input_shifted, min=-50.0, max=50.0)

        exp_values = torch.exp(input_clipped)
        sum_exp = exp_values.sum(dim=self.dim, keepdim=True)

        sum_exp = torch.clamp(sum_exp, min=1e-8)

        softmax_output = exp_values / sum_exp

        softmax_output = torch.clamp(softmax_output, min=1e-8, max=1.0)

        softmax_output = softmax_output / softmax_output.sum(dim=self.dim, keepdim=True)

        return softmax_output.to(input_dtype)
from ..relational_embedding import RelationalPositionalEncoding
from ..custom_attention import RelationalAttention



class SpatialStreamEncoderLayer(nn.Module):

    def __init__(self, hidden_size, num_heads, mlp_dim, attn_type='relational', dropout_prob=0.1):
        super().__init__()
        self.attn_type = attn_type
        if self.attn_type == 'relational':
            self.self_attn = RelationalAttention(hidden_size, num_heads)
        elif self.attn_type == 'vanilla':
            self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        else:
            raise ValueError(f"Unknown attn_type: {attn_type}")
            
        self.layer_norm1 = StableLayerNorm(hidden_size)
        self.layer_norm2 = StableLayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, hidden_size),
        )
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, relational_bias=None):
        
        residual = hidden_states
        input_dtype = hidden_states.dtype

        if self.attn_type == 'relational':
            assert relational_bias is not None, "Relational bias must be provided for relational attention"
            attn_output = self.self_attn(hidden_states, relational_bias=relational_bias)
        else: # 'vanilla'
            attn_output, _ = self.self_attn(query=hidden_states, key=hidden_states, value=hidden_states)
    
        
        hidden_states = self.layer_norm1(residual + self.dropout(attn_output))
    
        residual = hidden_states
      
        mlp_output = self.mlp(hidden_states)

        mlp_output = mlp_output.to(input_dtype)

        hidden_states = self.layer_norm2(residual + self.dropout(mlp_output))

        return hidden_states


class CrossAttentionFusion(nn.Module):
    def __init__(self, spatial_dim, semantic_dim, output_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        if semantic_dim % num_heads != 0:
            raise ValueError(f"semantic_dim ({semantic_dim}) must be divisible by num_heads ({num_heads})")
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=spatial_dim,
            kdim=semantic_dim,
            vdim=semantic_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.layer_norm1 = nn.LayerNorm(spatial_dim)
        self.layer_norm2 = nn.LayerNorm(spatial_dim)
        self.mlp = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim * 4),
            nn.GELU(),
            nn.Linear(spatial_dim * 4, output_dim),
        )

    def forward(self, spatial_tokens, semantic_tokens):
        attn_output, _ = self.cross_attention(
            query=spatial_tokens,
            key=semantic_tokens,
            value=semantic_tokens
        )
        hidden_states = self.layer_norm1(spatial_tokens + attn_output)
        fused_tokens = self.mlp(hidden_states)
        return fused_tokens


class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        hidden_act="silu",
        max_position_embeddings=2048,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=0,
        bos_token_id=1,
        eos_token_id=2,
        tie_word_embeddings=False,

        spatial_stream_depth=4,
        num_relation_bins_per_axis=5,
        use_dual_stream_encoding=False,
        mm_semantic_select_layer=-1,
        spatial_stream_attn_type="relational",
        fusion_type="cross_attention",
        **kwargs,
    ):
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
        self.spatial_stream_depth = spatial_stream_depth
        self.num_relation_bins_per_axis = num_relation_bins_per_axis
        self.use_dual_stream_encoding = use_dual_stream_encoding
        self.mm_semantic_select_layer = mm_semantic_select_layer
        self.spatial_stream_attn_type = spatial_stream_attn_type
        self.fusion_type = fusion_type


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)
    def initialize_vision_modules(self, model_args, fsdp=None):
        super().initialize_vision_modules(model_args, fsdp)
        
        depth_tower_path = getattr(model_args, 'depth_tower', None)
        if depth_tower_path is not None:
           
            self.depth_tower = DPTForDepthEstimation.from_pretrained(depth_tower_path)
            self.depth_tower.to(device=self.device, dtype=self.dtype)
            self.depth_processor = DPTImageProcessor.from_pretrained(depth_tower_path)
            self.depth_tower.requires_grad_(False)
            
        

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()
    
    def get_model(self):
        return self.model

    def get_depth_tower(self):
        return getattr(self.get_model(), 'depth_tower', None)

    def get_depth_processor(self):
        return getattr(self.get_model(), 'depth_processor', None)

    def _init_dual_stream_modules(self, config):
        vision_hidden_size = self.get_model().config.mm_hidden_size
        llm_hidden_size = config.hidden_size
        num_attention_heads = getattr(config, 'num_attention_heads', 32)
        
        attn_type = getattr(config, 'spatial_stream_attn_type', 'relational')
        fusion_type = getattr(config, 'fusion_type', 'cross_attention')
        target_device = self.model.embed_tokens.weight.device
        target_dtype = self.model.embed_tokens.weight.dtype  
        self.depth_feature_projector = nn.Sequential(
            nn.Linear(1, vision_hidden_size),
            nn.GELU(),
            StableLayerNorm(vision_hidden_size),
            nn.Linear(vision_hidden_size, vision_hidden_size)
        ).to(device=target_device, dtype=target_dtype)


        for param in self.depth_feature_projector.parameters():
            param.requires_grad = True


        
        self.final_fusion_merger = nn.Sequential(
            StableLayerNorm(config.hidden_size * 2),
            nn.Linear(config.hidden_size * 2, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, config.hidden_size)
        ).to(device=target_device, dtype=target_dtype)

        for param in self.final_fusion_merger.parameters():
            param.requires_grad = True

        fused_input_dim = vision_hidden_size * 2

        self.depth_visual_fusion_projector = nn.Sequential(

            StableLayerNorm(fused_input_dim), 
            nn.Linear(fused_input_dim, vision_hidden_size),
            nn.GELU(),
    
            StableLayerNorm(vision_hidden_size),
            nn.Linear(vision_hidden_size, vision_hidden_size)
        ).to(device=target_device, dtype=target_dtype)
        
        for param in self.depth_visual_fusion_projector.parameters():
            param.requires_grad = True


        
        self.spatial_norm = StableLayerNorm(vision_hidden_size).to(device=target_device, dtype=target_dtype)
        
        if attn_type == 'relational':
            self.relational_encoder = RelationalPositionalEncoding(
                embedding_dim=num_attention_heads,
                num_bins_per_axis=config.num_relation_bins_per_axis,
                num_dims=3
            ).to(device=target_device, dtype=target_dtype)

        self.spatial_stream_processor = nn.ModuleList([
            SpatialStreamEncoderLayer(
                hidden_size=vision_hidden_size,
                num_heads=num_attention_heads,
                mlp_dim=vision_hidden_size * 4,
                attn_type=attn_type
            ) for _ in range(config.spatial_stream_depth)
        ]).to(device=target_device, dtype=target_dtype)


        
        if fusion_type == 'cross_attention':
            self.fusion_projector = CrossAttentionFusion(
                spatial_dim=vision_hidden_size,
                semantic_dim=llm_hidden_size,
                output_dim=llm_hidden_size,
                num_heads=num_attention_heads
            ).to(device=target_device, dtype=target_dtype)



        elif fusion_type == 'simple_add':
            self.spatial_stream_projector = nn.Linear(vision_hidden_size, llm_hidden_size).to(device=target_device, dtype=target_dtype)

        else:
            raise ValueError(f"Unknown fusion_type: {fusion_type}")

        vision_tower = self.get_model().get_vision_tower()
        num_patches_per_side = vision_tower.num_patches_per_side
        x = torch.linspace(-1.0, 1.0, num_patches_per_side, device=self.device, dtype=self.dtype)
        y = torch.linspace(-1.0, 1.0, num_patches_per_side, device=self.device, dtype=self.dtype)
        grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
        coords = torch.stack((grid_y, grid_x), dim=-1).view(-1, 2)
        self.register_buffer('patch_coordinates', coords)

    def get_patch_coordinates(self) -> torch.Tensor:
        return self.patch_coordinates

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        if not getattr(self.config, 'use_dual_stream_encoding', False):
            return super(LlavaLlamaForCausalLM, self).encode_images(images)
        
        vision_tower = self.get_model().get_vision_tower()
        all_hidden_states = vision_tower(images, output_all_hidden_states=True)
        mid_layer_features = all_hidden_states[self.model.config.mm_vision_select_layer]
        semantic_raw_features = all_hidden_states[getattr(self.model.config, 'mm_semantic_select_layer', -1)]

        depth_tower = self.get_depth_tower()
        relational_bias = None
        
        if depth_tower is not None:
            with torch.no_grad():
                depth_processor = self.get_depth_processor()
                image_processor = vision_tower.image_processor
                mean = torch.tensor(image_processor.image_mean, device=images.device, dtype=images.dtype).view(1, -1, 1, 1)
                std = torch.tensor(image_processor.image_std, device=images.device, dtype=images.dtype).view(1, -1, 1, 1)
                images_denormalized = torch.clamp(images * std + mean, 0, 1)

                depth_inputs = depth_processor(images=images_denormalized.to(torch.float32), return_tensors="pt", do_rescale=False)

                original_dtype = next(depth_tower.parameters()).dtype
                if original_dtype != torch.float32:
                    depth_tower.to(device=self.device, dtype=torch.float32)

                depth_inputs = {k: v.to(device=self.device, dtype=torch.float32) for k, v in depth_inputs.items()}

                predicted_depth = depth_tower(**depth_inputs).predicted_depth

                if original_dtype != torch.float32:
                    depth_tower.to(device=self.device, dtype=original_dtype)

                predicted_depth = predicted_depth.to(device=self.device, dtype=self.dtype)

            num_patches_per_side = vision_tower.num_patches_per_side
            resampled_depth = F.interpolate(
                predicted_depth.unsqueeze(1),
                size=(num_patches_per_side, num_patches_per_side),
                mode='bicubic',
                align_corners=False
            ).squeeze(1).flatten(start_dim=1)
            
            resampled_depth = resampled_depth.to(device=self.device, dtype=self.dtype)
            v_min = resampled_depth.min(dim=1, keepdim=True)[0]
            v_max = resampled_depth.max(dim=1, keepdim=True)[0]
            denominator = v_max - v_min
            resampled_depth_normalized = torch.where(
                denominator > 1e-6,
                (resampled_depth - v_min) / denominator,
                torch.zeros_like(resampled_depth)
            )

            resampled_depth_expanded = resampled_depth_normalized.unsqueeze(-1).to(device=self.device, dtype=self.dtype)
            projected_depth_features = self.depth_feature_projector(resampled_depth_expanded)

            mid_layer_features = mid_layer_features.to(self.dtype)
            concatenated_features = torch.cat([mid_layer_features, projected_depth_features], dim=-1)
            fused_mid_layer_features = self.depth_visual_fusion_projector(concatenated_features)
            spatial_input = self.spatial_norm(fused_mid_layer_features)

            attn_type = getattr(self.config, 'spatial_stream_attn_type', 'relational')
            if attn_type == 'relational':
                xy_coords = self.get_patch_coordinates().unsqueeze(0).expand(spatial_input.shape[0], -1, -1)
                depth_coords_neg1_to_1 = (resampled_depth_normalized.unsqueeze(-1) * 2.0) - 1.0
                xyz_coords = torch.cat([xy_coords, depth_coords_neg1_to_1], dim=-1)
                relational_bias = self.relational_encoder(xyz_coords).to(spatial_input.device)
        else:
            mid_layer_features = mid_layer_features.to(self.dtype)
            spatial_input = self.spatial_norm(mid_layer_features)

        spatial_tokens = spatial_input
        for layer in self.spatial_stream_processor:
            spatial_tokens = layer(spatial_tokens, relational_bias=relational_bias)

        semantic_raw_features = semantic_raw_features.to(self.dtype)
        semantic_tokens = self.get_model().mm_projector(semantic_raw_features)

        fusion_type = getattr(self.config, 'fusion_type', 'cross_attention')
        if fusion_type == 'cross_attention':
            context_aware_spatial_tokens = self.fusion_projector(
                spatial_tokens=spatial_tokens,
                semantic_tokens=semantic_tokens
            )
        elif fusion_type == 'simple_add':
            context_aware_spatial_tokens = self.spatial_stream_projector(spatial_tokens)
        else:
            raise ValueError(f"Unknown fusion_type: {fusion_type}")

        final_concatenated_tokens = torch.cat([context_aware_spatial_tokens, semantic_tokens], dim=-1)
        fused_tokens = self.final_fusion_merger(final_concatenated_tokens)

        return fused_tokens
     
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model = super(LlavaLlamaForCausalLM, cls).from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            **kwargs
        )
        if hasattr(model.config, 'mm_vision_tower'):
            print("Post-loading initialization of multimodal components...")
            class MockModelArgs:
                def __init__(self, config):
                    self.vision_tower = config.mm_vision_tower
                    self.depth_tower = getattr(config, 'depth_tower', None)
                    self.mm_vision_select_layer = getattr(config, 'mm_vision_select_layer', -2)
                    self.mm_vision_select_feature = getattr(config, 'mm_vision_select_feature', 'patch')
                    self.mm_projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu')
                    self.mm_patch_merge_type = getattr(config, 'mm_patch_merge_type', 'flat')
                    self.pretrain_mm_mlp_adapter = None
            
            model_args_mock = MockModelArgs(model.config)

            model.get_model().initialize_vision_modules(model_args=model_args_mock)
            
            vision_tower = model.get_vision_tower()
            if vision_tower:
                vision_tower.to(device=model.device, dtype=model.dtype)
            
            depth_tower = model.get_depth_tower()
            if depth_tower:
                depth_tower.to(device=model.device, dtype=model.dtype)
            
            if getattr(model.config, 'use_dual_stream_encoding', False):
                model._init_dual_stream_modules(model.config)

                
                dual_stream_modules = [
                    'depth_feature_projector',
                    'depth_visual_fusion_projector',
                    'spatial_norm',
                    'spatial_stream_processor',
                    'final_fusion_merger',
                    'fusion_projector',
                    'spatial_stream_projector'
                ]

                for module_name in dual_stream_modules:
                    if hasattr(model, module_name):
                        module = getattr(model, module_name)
                        for param in module.parameters():
                            param.requires_grad = True

            
            
        return model 

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs

AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
