from typing import List, Dict, Tuple
from .base import BaseVLM
from PIL import Image
import torch
from rich.console import Console

# 从utils导入_normalize_to_list函数
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path

console = Console()


class QwenVLVLM(BaseVLM):
    """
    Qwen-VL系列：参考phi.py逻辑改进的批处理实现
    采用"真·batch inference"方法，确保长度统计与forward_batch完全对齐
    """

    def __init__(self, model, tokenizer, processor=None, device="cuda", model_type="qwen"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        self.model_type=model_type
        # 强制统一 tokenizer 来源：如果 processor 存在且有 tokenizer，优先使用 processor.tokenizer
        if processor is not None and hasattr(processor, "tokenizer") and processor.tokenizer is not None:
            self.tokenizer = processor.tokenizer

        # 强制左 padding，避免警告 + 对齐生成逻辑
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # 如果 processor 内部也有 tokenizer，一并改掉
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        assert self.processor is not None, "QwenVLVLM 需要 AutoProcessor"

        # 从原始items中构造messages格式（按照官方示例）
        messages_list = []
        for item in items:
            prompt = parse_input(item)
            image_path = get_image_path(item)

            # 构造conversation格式
            if image_path:
                # 多模态情况
                content = [
                    {"type": "image"},
                    {"type": "text", "text": prompt},
                ]
            else:
                raise ValueError("QwenVLVLM 目前仅支持多模态输入")

            conversation = [
                {
                    "role": "user",
                    "content": content,
                }
            ]
            messages_list.append(conversation)

        # 应用chat_template（按照官方示例）
        texts = [
            self.processor.apply_chat_template(
                msg,
                tokenize=False,
                add_generation_prompt=True
            )
            for msg in messages_list
        ]

        # 处理图像输入
        image_inputs = []
        for item in items:
            image_path = get_image_path(item)
            if image_path:
                image = Image.open(image_path).convert("RGB")
                image_inputs.append(image)
            else:
                image_inputs.append(None)

        # 如果有图像，需要为texts调整对应关系
        # processor要求images列表与texts中的图像位置对应
        processed_images = []
        img_idx = 0
        for i, item in enumerate(items):
            if image_inputs[i] is not None:
                processed_images.append(image_inputs[i])
                img_idx += 1

        # 批量构造输入（按照官方示例）
        # console.print(f"[cyan]执行Qwen-VL批处理，batch_size={len(items)}[/cyan]")

        # 构建processor输入
        if processed_images:
            if self.model_type=="llama":
                processed_images = [[img] for img in processed_images]
            # 有图像的情况 - 尝试指定左padding
            try:
                inputs = self.processor(
                    text=texts,
                    images=processed_images,
                    padding=True,
                    padding_side="left",  # 尝试指定左padding
                    return_tensors="pt",
                )
            except TypeError:
                # 如果不支持 padding_side 参数，回退到不带该参数的调用
                inputs = self.processor(
                    text=texts,
                    images=processed_images,
                    padding=True,
                    return_tensors="pt",
                )
        else:
            # 纯文本的情况
            raise ValueError("QwenVLVLM 目前仅支持多模态输入")

        # 移动到设备
        inputs = inputs.to(self.device)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 构建生成配置
        gen_conf = dict(
            max_new_tokens=max_new_tokens,
            do_sample=False,
            **gen_cfg,
        )

        # 获取EOS token处理
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))
        if eos_ids and all(e is not None for e in eos_ids):
            gen_conf["eos_token_id"] = eos_ids[0] if len(eos_ids) == 1 else eos_ids

        # 获取pad token
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        # 设置pad token
        if pad_id is not None:
            gen_conf["pad_token_id"] = pad_id
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        with torch.no_grad():
            # 生成
            output_ids = self.model.generate(**inputs, **gen_conf)

            # 计算prompt长度和输出长度（按照官方示例方式）
            input_ids = inputs["input_ids"]
            batch_size = input_ids.size(0)

            outputs = []
            right_pad_lens = []
            hit_limit_flags = []

            for i in range(batch_size):
                # 计算长度
                input_len = input_ids.size(1)
                total_len = output_ids[i].size(0)

                cut = self.calculate_right_padding_length(output_ids[i])
                right_pad_lens.append(cut)
                output_len = total_len - input_len - cut

                # 提取生成的部分（去除输入部分）
                generated_ids = output_ids[i][input_len:-cut] if cut > 0 else output_ids[i][input_len:]

                # 检查是否hit limit
                ended_with_eos = False
                if output_len > 0:
                    last_token_id = int(generated_ids[-1].item())
                    if last_token_id in eos_ids:
                        ended_with_eos = True

                hit_limit = (output_len >= max_new_tokens) and (not ended_with_eos)
                hit_limit_flags.append(hit_limit)

                # 解码输出
                output_text = self.tokenizer.decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True
                )
                outputs.append(output_text)

        for o in outputs:
            console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
