from __future__ import annotations

from typing import List, Dict, Tuple, Any
import os
import sys

import torch
from PIL import Image
from rich.console import Console

from .base import BaseVLM

# project helpers
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()


class PixtralVLM(BaseVLM):
    """
    mistral-community/pixtral-12b backend.

    目标：
    - 真 batch：一次 encode + 一次 generate
    - 右侧裁剪：calculate_right_padding_length 的“位置/用法”与 qwen.py 完全一致
    """

    def __init__(self, model, tokenizer=None, processor=None, device: str = "cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

    def _build_conversation(self, item: dict) -> Tuple[List[Dict[str, Any]], int]:
        prompt = parse_input(item)
        image_paths = _normalize_to_list(get_image_path(item))

        if not image_paths:
            raise ValueError("PixtralVLM 目前仅支持多模态输入（需要至少一张图片）")

        images: List[Image.Image] = []
        for p in image_paths:
            if p is None:
                continue
            images.append(Image.open(p).convert("RGB"))

        if len(images) == 0:
            raise ValueError("PixtralVLM: 未成功加载任何图片，请检查 image_path 是否正确")

        content: List[Dict[str, Any]] = [{"type": "image", "image": img} for img in images]
        # 为了兼容不同 template 版本，这里同时放 text/content 两个 key
        content.append({"type": "text", "text": prompt, "content": prompt})

        conversation = [{"role": "user", "content": content}]
        return conversation, len(images)

    @staticmethod
    def _move_to_device(inputs: Dict[str, Any], device: torch.device, model_dtype: torch.dtype) -> Dict[str, Any]:
        out: Dict[str, Any] = {}
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                if v.dtype in (torch.float16, torch.bfloat16, torch.float32):
                    out[k] = v.to(device=device, dtype=model_dtype)
                else:
                    out[k] = v.to(device=device)
            else:
                out[k] = v
        return out

    @staticmethod
    def _stack_and_left_pad(encoded_list: List[Dict[str, Any]], pad_token_id: int) -> Dict[str, Any]:
        """
        兜底：单样本 encode 后再手工 left-pad 到同一长度，然后 cat 成 batch（确保仍是一趟 generate）。
        """
        keys = set()
        for e in encoded_list:
            keys |= set(e.keys())

        max_len = 0
        for e in encoded_list:
            if "input_ids" in e:
                max_len = max(max_len, int(e["input_ids"].shape[1]))

        batch: Dict[str, Any] = {}
        for k in keys:
            vals = [e.get(k, None) for e in encoded_list]
            if all(v is None for v in vals):
                continue

            first = next(v for v in vals if v is not None)
            if not isinstance(first, torch.Tensor):
                batch[k] = vals
                continue

            if k in ("input_ids", "attention_mask", "token_type_ids", "position_ids"):
                padded = []
                for v in vals:
                    if v is None:
                        continue
                    cur_len = int(v.shape[1])
                    pad_len = max_len - cur_len
                    if pad_len <= 0:
                        padded.append(v)
                        continue

                    pad_val = pad_token_id if k == "input_ids" else 0
                    pad_tensor = torch.full((v.shape[0], pad_len), pad_val, dtype=v.dtype, device=v.device)
                    padded.append(torch.cat([pad_tensor, v], dim=1))
                batch[k] = torch.cat(padded, dim=0)
            else:
                to_cat = [v for v in vals if v is not None]
                batch[k] = torch.cat(to_cat, dim=0)

        return batch

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        # 1) 构造 conversations
        conversations: List[List[Dict[str, Any]]] = []
        nimgs_list: List[int] = []
        for it in items:
            conv, nimgs = self._build_conversation(it)
            conversations.append(conv)
            nimgs_list.append(nimgs)

        # 与你们其他 VLM backend 一致：同 batch 图片数量必须一致
        if len(set(nimgs_list)) != 1:
            raise ValueError(f"PixtralVLM 同一 batch 需要相同图片数量，否则很可能无法正确 batch 推理: {nimgs_list}")

        # 2) Encode（优先走 tokenize=True fast-path）
        inputs = None
        try:
            inputs = self.processor.apply_chat_template(
                conversations,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                padding=True,
                return_tensors="pt",
            )
        except Exception:
            inputs = None

        if inputs is None:
            # Fallback：prompt string + processor(text=..., images=...)
            prompts: List[str] = []
            images_batch: List[Any] = []
            for conv in conversations:
                try:
                    p = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
                except Exception:
                    p = self.processor.apply_chat_template(conv)
                prompts.append(p)

                imgs = []
                for part in conv[0]["content"]:
                    if isinstance(part, dict) and part.get("type") == "image" and "image" in part:
                        imgs.append(part["image"])
                images_batch.append(imgs if len(imgs) > 1 else imgs[0])

            try:
                inputs = self.processor(text=prompts, images=images_batch, padding=True, return_tensors="pt")
            except Exception:
                # last resort：单样本 encode，再 left-pad+cat（仍然是一趟 generate）
                encoded_list = [self.processor(text=p, images=img, return_tensors="pt") for p, img in zip(prompts, images_batch)]
                pad_id = getattr(self.tokenizer, "pad_token_id", 0)
                inputs = self._stack_and_left_pad(encoded_list, pad_token_id=int(pad_id))

        # 3) 移到 device/dtype
        try:
            device = getattr(self.model, "device", None) or next(self.model.parameters()).device
        except Exception:
            device = torch.device(self.device)

        try:
            model_dtype = getattr(self.model, "dtype", None) or next(self.model.parameters()).dtype
        except Exception:
            model_dtype = torch.float16

        inputs = self._move_to_device(dict(inputs), device=device, model_dtype=model_dtype)

        # 4) 生成配置（尽量对齐 qwen.py 的思路）
        gen_conf = dict(gen_cfg or {})
        gen_conf.setdefault("do_sample", False)
        gen_conf.setdefault("temperature", 0)
        gen_conf["max_new_tokens"] = max_new_tokens

        # pad_token_id：如果模型里没设，补上，并同步到 model.config / generation_config（同 qwen.py）
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)

        if pad_id is not None:
            self.model.config.pad_token_id = pad_id
            self.model.generation_config.pad_token_id = pad_id

        # eos_ids（用于 hit_limit 判断；qwen.py 也是这么算的）
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
        # ===========================================================================

        # 5) 一次性 batch generate
        with torch.no_grad():
            output_ids = self.model.generate(**inputs, **gen_conf)

            input_ids = inputs["input_ids"]
            batch_size = input_ids.size(0)

            outputs = []
            right_pad_lens = []
            hit_limit_flags = []

            for i in range(batch_size):
                # （位置/用法）与 qwen.py 完全一致
                input_len = input_ids.size(1)
                total_len = output_ids[i].size(0)

                cut = self.calculate_right_padding_length(output_ids[i])
                right_pad_lens.append(cut)
                output_len = total_len - input_len - cut

                generated_ids = output_ids[i][input_len:-cut] if cut > 0 else output_ids[i][input_len:]

                # hit limit（同 qwen.py）
                ended_with_eos = False
                if output_len > 0:
                    last_token_id = int(generated_ids[-1].item())
                    if last_token_id in eos_ids:
                        ended_with_eos = True

                hit_limit = (output_len >= max_new_tokens) and (not ended_with_eos)
                hit_limit_flags.append(hit_limit)

                output_text = self.tokenizer.decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                )
                outputs.append(output_text)
            for o in outputs:
                console.print("\n[yellow]output without prompt: ", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outs, pads, hits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outs[0], pads[0], hits[0]
