from __future__ import annotations

import os
import sys
import traceback
from typing import Dict, List, Tuple, Any

import torch
from PIL import Image
from rich.console import Console
from transformers import GenerationConfig

from .base import BaseVLM

# 从 utils 导入通用工具
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class MolmoVLM(BaseVLM):
    """
    Molmo 系列（allenai/Molmo-7B-*-0924）适配类

    - 官方推荐：processor.process(...) + model.generate_from_batch(...)
    - generate_batch 采用“真·batch inference”：逐样本 process，再手动 left-pad + stack，一次 generate
    - 仅使用 processor.tokenizer（避免 tokenizer 混用）
    - right_pad_len 计算：原样复用你现有 calculate_right_padding_length
    """

    def __init__(self, model, processor, device: str = "cuda"):
        tokenizer = getattr(processor, "tokenizer", None)
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 强制统一 tokenizer 来源：如果 processor 存在且有 tokenizer，优先使用 processor.tokenizer
        if processor is not None and hasattr(processor, "tokenizer") and processor.tokenizer is not None:
            self.tokenizer = processor.tokenizer

        # 强制左 padding，避免警告 + 对齐生成逻辑
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 如果 processor 内部也有 tokenizer，一并改掉
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    @staticmethod
    def _pad_left_1d(seqs: List[torch.Tensor], pad_token_id: int) -> torch.Tensor:
        """左侧 pad 到同一长度（[L_i] -> [B, L_max]）"""
        max_len = max(int(s.numel()) for s in seqs)
        device = seqs[0].device
        dtype = seqs[0].dtype
        out = torch.full((len(seqs), max_len), pad_token_id, device=device, dtype=dtype)
        for i, s in enumerate(seqs):
            out[i, -s.numel():] = s
        return out

    @staticmethod
    def _infer_input_device(model, fallback: str) -> torch.device:
        """
        device_map='auto' 时，inputs 通常需要放到 model.device（一般是第一块卡）。
        """
        dev = getattr(model, "device", None)
        if dev is not None:
            return dev
        try:
            return next(model.parameters()).device
        except Exception:
            return torch.device(fallback)

    def _process_one(self, item: dict) -> Tuple[Dict[str, Any], int]:
        """
        单样本 processor.process，返回:
          - features: dict（各 key 对应 torch.Tensor 或其它）
          - input_len: input_ids 的长度（不含 pad）
        """
        try:
            # console.print("[cyan]MolmoVLM: 开始处理单个样本...[/cyan]")

            prompt = parse_input(item)
            image_path = get_image_path(item)
            image_paths = _normalize_to_list(image_path)

            # console.print(f"[blue]MolmoVLM: prompt='{prompt[:50]}...', image_path={image_path}[/blue]")

            if not image_paths:
                raise ValueError("MolmoVLM 目前仅支持多模态输入（需要至少一张图片）")
            if len(image_paths) != 1:
                raise ValueError("MolmoVLM 当前实现仅支持单图输入（你的数据集保证单图）")

            # console.print(f"[green]MolmoVLM: 正在加载图片 {image_paths[0]}[/green]")
            img = Image.open(image_paths[0]).convert("RGB")
            # console.print(f"[green]MolmoVLM: 图片加载成功，尺寸={img.size}[/green]")

            # Molmo 官方：processor.process(images=[PIL], text=...)
            # console.print("[green]MolmoVLM: 开始 processor.process...[/green]")
            try:
                feats = self.processor.process(images=[img], text=prompt, return_tensors="pt")
                # console.print("[green]MolmoVLM: processor.process (with return_tensors) 成功[/green]")
            except TypeError as e:
                # console.print(f"[yellow]MolmoVLM: processor.process 不支持 return_tensors 参数，回退: {e}[/yellow]")
                feats = self.processor.process(images=[img], text=prompt)
                # console.print("[green]MolmoVLM: processor.process (fallback) 成功[/green]")
            except Exception as e:
                console.print(f"[red]MolmoVLM: processor.process 失败: {e}[/red]")
                console.print(f"[red]详细错误: {traceback.format_exc()}[/red]")
                raise

            # console.print(f"[blue]MolmoVLM: processor 输出 keys={list(feats.keys())}[/blue]")

            if "input_ids" not in feats:
                raise ValueError(f"MolmoVLM: processor.process 输出缺少 input_ids，keys={list(feats.keys())}")

            # 把每个字段从 [1, ...] 统一 squeeze 成 [...]
            for k, v in list(feats.items()):
                if not isinstance(v, torch.Tensor):
                    v = torch.as_tensor(v)
                    # console.print(f"[yellow]MolmoVLM: {k} 转换为 tensor，原形状={tuple(v.shape)}[/yellow]")
                if v.ndim >= 1 and v.shape[0] == 1:
                    original_shape = v.shape
                    v = v.squeeze(0)
                    # console.print(f"[yellow]MolmoVLM: {k} squeeze 从 {original_shape} 到 {v.shape}[/yellow]")
                feats[k] = v

            ids = feats["input_ids"]
            # console.print(f"[blue]MolmoVLM: input_ids 最终形状={tuple(ids.shape)}[/blue]")

            if ids.ndim != 1:
                raise ValueError(f"MolmoVLM: input_ids 维度异常: {tuple(ids.shape)}")
            input_len = int(ids.numel())
            # console.print(f"[green]MolmoVLM: 单样本处理成功，input_len={input_len}[/green]")

            return feats, input_len

        except Exception as e:
            console.print(f"[red]MolmoVLM: _process_one 失败: {e}[/red]")
            console.print(f"[red]详细错误: {traceback.format_exc()}[/red]")
            raise

    @staticmethod
    def _pad_left_last_dim(v: torch.Tensor, pad_left: int, pad_value: int = 0) -> torch.Tensor:
        """仅在最后一维做 left-pad"""
        if pad_left <= 0:
            return v
        # torch.nn.functional.pad 的参数顺序从最后一维开始
        pad = [pad_left, 0] + [0, 0] * (v.ndim - 1)
        return torch.nn.functional.pad(v, pad=pad, value=pad_value)

    def _stack_and_pad(
        self,
        processed: List[Tuple[Dict[str, Any], int]],
        device: torch.device,
    ) -> Dict[str, torch.Tensor]:
        """
        合成 batch inputs（Molmo 特殊约定）：
        - input_ids 必须用 -1 做 left-pad（Molmo 在 generate_from_batch 内部用 input_ids!=-1 造 mask/pos）
        - 不要传 attention_mask（否则 Molmo 要求 mask_len=seq_len+max_new_tokens，会 assert）
        - image_input_idx 若存在：pad 用 -1，且对 >=0 的位置做 index shift(+pad_left)
        - 其它字段：若最后一维==原始 L_i，则同步 left-pad
        """
        try:
            # console.print(f"[green]MolmoVLM: 开始 stack_and_pad，pad_token_id={pad_token_id}[/green]")

            feats_list = [p[0] for p in processed]
            lens = [p[1] for p in processed]
            max_len = max(lens)
            # console.print(f"[blue]MolmoVLM: 各样本长度={lens}, max_len={max_len}[/blue]")

            # 1) input_ids：left-pad，pad 值固定为 -1（Molmo 约定）
            seqs = [f["input_ids"].to(device) for f in feats_list]  # [L_i]
            # console.print(f"[blue]MolmoVLM: 移动到设备完成，准备 left-pad[/blue]")
            input_ids = self._pad_left_1d(seqs, pad_token_id=-1)
            batch: Dict[str, torch.Tensor] = {"input_ids": input_ids}

            # console.print(f"[green]MolmoVLM: input_ids left-pad 完成，形状={input_ids.shape}[/green]")

            # 2) 其它 key：聚合（跳过 attention_mask）
            keys = [k for k in feats_list[0].keys() if k not in ("input_ids", "attention_mask")]
            # console.print(f"[blue]MolmoVLM: 需要处理的其他 keys={keys}[/blue]")

            for k in keys:
                # console.print(f"[blue]MolmoVLM: 处理 key={k}[/blue]")
                tensors: List[torch.Tensor] = []
                for i, (feat, Li) in enumerate(processed):
                    v = feat.get(k, None)
                    if v is None:
                        # console.print(f"[yellow]MolmoVLM: 样本 {i+1} key={k} 为 None，跳过[/yellow]")
                        tensors = []
                        break
                    if not isinstance(v, torch.Tensor):
                        v = torch.as_tensor(v)
                        # console.print(f"[yellow]MolmoVLM: 样本 {i+1} key={k} 转换为 tensor，形状={v.shape}[/yellow]")

                    pad_left = max_len - Li
                    # console.print(f"[blue]MolmoVLM: 样本 {i+1} key={k}, pad_left={pad_left}, v形状={v.shape}[/blue]")

                    # image_input_idx: index shift + pad
                    if k == "image_input_idx":
                        if v.ndim >= 1 and v.shape[-1] == Li and pad_left > 0:
                            v = self._pad_left_last_dim(v, pad_left, pad_value=-1)
                        if pad_left > 0:
                            v = torch.where(v >= 0, v + pad_left, v)
                    elif v.ndim >= 1 and v.shape[-1] == Li and pad_left > 0:
                        v = self._pad_left_last_dim(v, pad_left, pad_value=0)

                    tensors.append(v.to(device))

                if not tensors:
                    # console.print(f"[yellow]MolmoVLM: key={k} 没有有效 tensor，跳过[/yellow]")
                    continue

                # 2.3) 合并：优先 stack；若已带 batch_dim(=1) 则 cat
                try:
                    batch[k] = torch.stack(tensors, dim=0)
                    # console.print(f"[green]MolmoVLM: key={k} stack 成功，形状={batch[k].shape}[/green]")
                except Exception as e:
                    # console.print(f"[yellow]MolmoVLM: key={k} stack 失败: {e}，尝试 cat[/yellow]")
                    if all(t.ndim >= 1 and t.shape[0] == 1 for t in tensors):
                        batch[k] = torch.cat(tensors, dim=0)
                        # console.print(f"[green]MolmoVLM: key={k} cat 成功，形状={batch[k].shape}[/green]")
                    else:
                        # console.print(f"[red]MolmoVLM: key={k} stack 和 cat 都失败[/red]")
                        raise

            # console.print(f"[green]MolmoVLM: stack_and_pad 完成，batch keys={list(batch.keys())}[/green]")
            return batch

        except Exception as e:
            console.print(f"[red]MolmoVLM: _stack_and_pad 失败: {e}[/red]")
            console.print(f"[red]详细错误: {traceback.format_exc()}[/red]")
            raise

    def _prepare_generation_config(self, max_new_tokens: int, gen_cfg: Dict) -> GenerationConfig:
        """
        生成 GenerationConfig：
        - 默认 stop_strings="<|endoftext|>"（与 Molmo README 一致）
        - 填充 eos_token_id / pad_token_id，便于 right_pad 计算与生成对齐
        """
        gen_cfg = dict(gen_cfg or {})
        gen_cfg.setdefault("max_new_tokens", max_new_tokens)
        gen_cfg.setdefault("do_sample", False)
        gen_cfg.setdefault("stop_strings", "<|endoftext|>")

        # eos_token_id
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list({e for e in (gc_eos + tok_eos) if e is not None})
        if eos_ids:
            gen_cfg.setdefault("eos_token_id", eos_ids[0] if len(eos_ids) == 1 else eos_ids)

        # pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)
        if pad_id is not None:
            gen_cfg.setdefault("pad_token_id", pad_id)
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        return GenerationConfig(**gen_cfg)

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        """
        返回：
          - outputs: List[str]
          - right_pad_lens: List[int]（pad + 多余 eos）
          - hit_limits: List[bool]
        """
        try:
            # console.print(f"[cyan]MolmoVLM: 开始批处理，batch_size={len(items)}, max_new_tokens={max_new_tokens}[/cyan]")

            assert len(items) > 0, "generate_batch: items 不能为空"
            batch_size = len(items)

            # 1) 逐样本 process
            # console.print("[green]MolmoVLM: 开始逐样本处理...[/green]")
            processed: List[Tuple[Dict[str, Any], int]] = []
            for i, item in enumerate(items):
                # console.print(f"[blue]MolmoVLM: 处理第 {i+1}/{len(items)} 个样本[/blue]")
                feats, Li = self._process_one(item)
                processed.append((feats, Li))
                # console.print(f"[green]MolmoVLM: 第 {i+1} 个样本处理成功，input_len={Li}[/green]")

            # 2) generation_config & pad_token_id
            # console.print("[green]MolmoVLM: 准备 generation_config...[/green]")
            generation_config = self._prepare_generation_config(max_new_tokens=max_new_tokens, gen_cfg=gen_cfg)
            pad_token_id = getattr(generation_config, "pad_token_id", None)
            if pad_token_id is None:
                pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_token_id is None:
                raise ValueError("MolmoVLM: 无法确定 pad_token_id")
            # console.print(f"[green]MolmoVLM: pad_token_id={pad_token_id}[/green]")

            # 3) stack + left-pad（同时 shift image_input_idx）
            # console.print("[green]MolmoVLM: 开始 stack 和 pad...[/green]")
            device = self._infer_input_device(self.model, self.device)
            # console.print(f"[blue]MolmoVLM: 使用设备 {device}[/blue]")
            inputs = self._stack_and_pad(processed, device=device)

            # Molmo 的 vision backbone 里 patch_embedding 是 bf16 权重，要求 images 也是 bf16
            try:
                target_dtype = self.model.model.vision_backbone.image_vit.patch_embedding.weight.dtype
            except Exception:
                # fallback：取任意一个浮点参数的 dtype
                target_dtype = next(
                    (p.dtype for p in self.model.parameters() if p.is_floating_point()),
                    None
                )

            if target_dtype is not None:
                for k, v in list(inputs.items()):
                    if isinstance(v, torch.Tensor) and v.is_floating_point() and v.dtype != target_dtype:
                        inputs[k] = v.to(dtype=target_dtype)

            # ================================ 显存估计 ==================================
            if oom_estimate:
                from utils import apply_prefill_extra_tokens
                bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
                prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
                prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
                inputs = apply_prefill_extra_tokens(
                    batch_size=len(items),
                    inputs=inputs,
                    prefill_extra_tokens=prefill_extra_tokens,
                    tokenizer=self.tokenizer,
                    prefill_token_id=prefill_token_id,
                )
            # ===========================================================================

            # 4) generate_from_batch（Molmo 官方接口）
            # console.print("[green]MolmoVLM: 开始生成...[/green]")
            with torch.no_grad():
                out_ids = self.model.generate_from_batch(
                    inputs,
                    generation_config,
                    tokenizer=self.tokenizer,
                )
            # console.print(f"[green]MolmoVLM: 生成完成，输出形状={out_ids.shape}[/green]")

            # 5) 截掉前缀（padded input_ids 长度）
            prefix_len_padded = inputs["input_ids"].shape[1]
            # console.print(f"[blue]MolmoVLM: prefix_len_padded={prefix_len_padded}[/blue]")
            new_tokens_all = out_ids[:, prefix_len_padded:]  # [B, T_out_max]
            # console.print(f"[blue]MolmoVLM: new_tokens_all 形状={new_tokens_all.shape}[/blue]")

            # 6) per-sample: right_pad + hit_limit + decode
            # console.print("[green]MolmoVLM: 开始后处理...[/green]")
            right_pad_lens: List[int] = []
            hit_limits: List[bool] = []
            decoded_token_ids: List[List[int]] = []

            # eos_ids 用于 ended_with_eos 判断
            gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
            tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
            eos_ids = list(set(gc_eos + tok_eos))
            # console.print(f"[blue]MolmoVLM: eos_ids={eos_ids}[/blue]")

            for i in range(batch_size):
                seq = new_tokens_all[i]
                # console.print(f"[blue]MolmoVLM: 处理样本 {i+1}, seq 长度={seq.shape[0]}[/blue]")

                cut = self.calculate_right_padding_length(seq)  # 直接复用你的实现
                right_pad_lens.append(cut)
                # console.print(f"[blue]MolmoVLM: 样本 {i+1} right_pad_len={cut}[/blue]")

                seq = seq[:-cut] if cut > 0 else seq

                out_len = int(seq.shape[0])
                # console.print(f"[blue]MolmoVLM: 样本 {i+1} 输出长度={out_len}[/blue]")

                ended_with_eos = False
                if out_len > 0 and eos_ids and int(seq[-1].item()) in eos_ids:
                    ended_with_eos = True
                    # console.print(f"[blue]MolmoVLM: 样本 {i+1} 以 EOS 结束[/blue]")

                hit_limit = out_len >= max_new_tokens and not ended_with_eos
                hit_limits.append(hit_limit)
                # console.print(f"[blue]MolmoVLM: 样本 {i+1} hit_limit={hit_limit}[/blue]")

                decoded_token_ids.append(seq.detach().cpu().tolist())

            # console.print("[green]MolmoVLM: 开始解码...[/green]")
            outputs = self.tokenizer.batch_decode(
                decoded_token_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
            outputs = [t.strip() for t in outputs]

            for i, out in enumerate(outputs):
                console.print("\n[yellow]output without prompt: ", out)

            # console.print("[green]MolmoVLM: 批处理完成[/green]")
            return outputs, right_pad_lens, hit_limits

        except Exception as e:
            console.print(f"[red]MolmoVLM: generate_batch 失败: {e}[/red]")
            console.print(f"[red]详细错误: {traceback.format_exc()}[/red]")
            raise

    def generate_one(
        self,
        item: dict,
        max_new_tokens: int,
        gen_cfg: Dict,
    ) -> Tuple[str, int, bool]:
        outputs, right_pad_lens, hit_limits = self.generate_batch(
            [item],
            max_new_tokens=max_new_tokens,
            gen_cfg=gen_cfg,
            oom_estimate=False,
            bs_estimate_gen_cfg={}
        )
        return outputs[0], right_pad_lens[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
