from typing import List, Optional, Tuple

import torch
from transformers import GenerationConfig

from ..actions import (
    Qwen2ForActionOFT,
    Qwen2ForV3RActionOFT,
    Qwen3ForActionOFT,
    Qwen3ForV3RActionOFT,
)
from .configuration_internvl_chat import (
    InternVL3RActionOFTConfig,
    InternVLActionOFTConfig,
)
from .modeling_internvl3r_chat import InternVL3RChatModel
from .modeling_internvl_chat import InternVLChatModel


class InternVLActionOFT(InternVLChatModel):
    config_class = InternVLActionOFTConfig

    def __init__(
        self,
        config: InternVLActionOFTConfig,
        vision_model=None,
        use_flash_attn=True,
    ):
        architecture: str = config.llm_config.architectures[0]
        if architecture in ("Qwen2ForActionOFT", "Qwen2ForCausalLM"):
            language_model = Qwen2ForActionOFT(config.llm_config)
        elif architecture in ("Qwen3ForActionOFT", "Qwen3ForCausalLM"):
            language_model = Qwen3ForActionOFT(config.llm_config)
        else:
            raise ValueError(f"Unsupported architecture: {architecture}")

        super().__init__(config, vision_model, language_model, use_flash_attn)

        self.action_token_id = self.language_model.action_token_id

    @torch.no_grad()
    def generate(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        input_ids: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        visual_features: Optional[torch.FloatTensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        output_hidden_states: Optional[bool] = None,
        **generate_kwargs,
    ) -> torch.LongTensor:
        action_mask = None

        assert self.img_context_token_id is not None
        if pixel_values is not None:
            if visual_features is not None:
                vit_embeds = visual_features
            else:
                vit_embeds = self.extract_feature(pixel_values)
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            B, N, C = input_embeds.shape
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = input_ids == self.img_context_token_id

            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

            input_embeds = input_embeds.reshape(B, N, C)

            action_mask = (input_ids == self.action_token_id).view(B, N)
        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            use_cache=True,
            action_mask=action_mask,
            **generate_kwargs,
        )

        return outputs


class InternVL3RActionOFT(InternVL3RChatModel):
    config_class = InternVL3RActionOFTConfig

    def __init__(
        self,
        config: InternVL3RActionOFTConfig,
        vision_model=None,
        use_flash_attn=True,
    ):
        architecture: str = config.llm_config.architectures[0]
        if architecture in ("Qwen2ForV3RActionOFT", "Qwen2ForV3RCausalLM"):
            language_model = Qwen2ForV3RActionOFT(config.llm_config)
        elif architecture in ("Qwen3ForV3RActionOFT", "Qwen3ForV3RCausalLM"):
            language_model = Qwen3ForV3RActionOFT(config.llm_config)
        else:
            raise ValueError(f"Unsupported architecture: {architecture}")

        super(InternVL3RChatModel, self).__init__(
            config, vision_model, language_model, use_flash_attn
        )

        self._build_condition_model(config)

        if config.vision_dpt_config is not None:
            architecture = config.vision_dpt_config.dpt_config.architectures[0]
            if architecture == "DA3Model":
                from ..da3 import DA3DepthHead

                dpt_head = DA3DepthHead(config.vision_dpt_config.dpt_config)
            else:
                raise ValueError(f"Unsupported dpt architecture: {architecture}")

            # Select the feature maps to be used
            dpt_head.register_forward_pre_hook(
                lambda _, inputs: (
                    [[inputs[0][i]] for i in config.vision_dpt_config.layer_ids],
                    int(0.5 * inputs[1]),  # Image width
                    int(0.5 * inputs[2]),  # Image height
                    *inputs[3:],
                )
            )
            setattr(self.language_model, "dpt_head", dpt_head)

    @torch.no_grad()
    def generate(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        input_ids: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        visual_features: Optional[torch.FloatTensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        output_hidden_states: Optional[bool] = None,
        **generate_kwargs,
    ) -> torch.LongTensor:
        assert self.img_context_token_id is not None

        # Prepare vision 3R task inputs
        frame_num = None
        vision_embeds = None
        vision_pos_masks = None
        vision_height = None
        vision_width = None
        action_mask = None

        if pixel_values is not None:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)
            B, N, C = input_embeds.shape
            vision_height, vision_width = pixel_values.shape[-2:]

            if visual_features is not None:
                vision_embeds = vit_embeds = visual_features
            else:
                vision_embeds, vit_embeds = self.extract_feature(pixel_values, B)
            input_embeds = input_embeds.reshape(B * N, C)

            input_ids = input_ids.reshape(B * N)
            selected = input_ids == self.img_context_token_id
            assert selected.sum() != 0
            input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

            input_embeds = input_embeds.reshape(B, N, C)

            frame_num = vit_embeds.shape[0] // B
            vision_pos_masks = selected.view(B, N)

            action_mask = (input_ids == self.action_token_id).view(B, N)

        else:
            input_embeds = self.language_model.get_input_embeddings()(input_ids)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            frame_num=frame_num,
            vision_embeds=vision_embeds,
            visual_pos_masks=vision_pos_masks,
            vision_height=vision_height,
            vision_width=vision_width,
            action_mask=action_mask,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            use_cache=True,
            **generate_kwargs,
        )

        return outputs
