from __future__ import annotations

from typing import List, Dict, Tuple, Optional
import math

import torch
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

from .base import BaseVLM
from rich.console import Console

# 复用你项目里的 parse_input / get_image_path / _normalize_to_list
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import parse_input, get_image_path, _normalize_to_list  # noqa: E402

console = Console()

def _build_transform(input_size: int):
    # Ristretto 官方示例/InternVL 系列常用：mean/std = 0.5
    return T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])


def _find_closest_aspect_ratio(aspect_ratio: float, target_ratios, width: int, height: int, image_size: int):
    best_ratio = (1, 1)
    best_diff = float("inf")
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        diff = abs(aspect_ratio - target_aspect_ratio)
        if diff < best_diff:
            best_diff = diff
            best_ratio = ratio
        elif diff == best_diff:
            # tie-break：更偏向大图（这套逻辑是 InternVL/类似实现里常见的）
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def _dynamic_preprocess(
    image: Image.Image,
    min_num: int,
    max_num: int,
    image_size: int,
    use_thumbnail: bool,
) -> List[Image.Image]:
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if (i * j) <= max_num and (i * j) >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    best_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    target_width = image_size * best_ratio[0]
    target_height = image_size * best_ratio[1]
    resized = image.resize((target_width, target_height), resample=Image.BICUBIC)

    blocks = best_ratio[0] * best_ratio[1]
    processed = []
    for idx in range(blocks):
        x = (idx % best_ratio[0]) * image_size
        y = (idx // best_ratio[0]) * image_size
        crop = resized.crop((x, y, x + image_size, y + image_size))
        processed.append(crop)

    if use_thumbnail and blocks != 1:
        processed.append(image.resize((image_size, image_size), resample=Image.BICUBIC))

    return processed


def _load_image_as_pixel_values(
    image: Image.Image,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
    device: str,
    dtype: torch.dtype,
) -> Tuple[torch.Tensor, int]:
    transform = _build_transform(input_size)
    imgs = _dynamic_preprocess(
        image=image,
        min_num=min_num,
        max_num=max_num,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )
    pixel_values = torch.stack([transform(im) for im in imgs], dim=0)  # [num_patches, 3, H, W]
    pixel_values = pixel_values.to(device=device, dtype=dtype)
    return pixel_values, pixel_values.shape[0]


class RistrettoVLM(BaseVLM):
    """
    LiAutoAD/Ristretto-3B (trust_remote_code)
    - remote code 提供 model.generate(pixel_values=..., input_ids=..., attention_mask=...)
    - 以及 model.batch_chat(...)（但我们这里直接走 generate，保证能拿到 token ids 做 cut / right_pad_len）
    """

    def __init__(self, model, tokenizer, processor=None, device: str = "cuda", **kwargs):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # 1) 选主：processor.tokenizer 存在就用它（因为 processor 负责 tokenize）
        if self.processor is not None and getattr(self.processor, "tokenizer", None) is not None:
            self.tokenizer = self.processor.tokenizer
        else:
            self.tokenizer = tokenizer  # 外部传入

        # 2) 强制绑定：让 processor.tokenizer 与 self.tokenizer 指向同一对象
        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer = self.tokenizer

        # R-4B 是 decoder-only 生成，左 padding 更常见；但不要强行把 pad_token 设成 eos（有些模型 pad_token_id != eos）
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)
            if pad_id is not None:
                try:
                    tok = self.tokenizer.convert_ids_to_tokens(int(pad_id))
                    # 验证 round-trip，避免设了一个“看起来像 token 但不对应 pad_id”的字符串
                    if tok is not None and self.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                        self.tokenizer.pad_token = tok
                except Exception:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.processor is not None and hasattr(self.processor, "tokenizer"):
            self.processor.tokenizer.padding_side = "left"
            if self.processor.tokenizer.pad_token is None:
                pad_id = getattr(self.processor.tokenizer, "pad_token_id", None)
                if pad_id is not None:
                    try:
                        tok = self.processor.tokenizer.convert_ids_to_tokens(
                            self.processor.tokenizer.pad_token_id
                        )
                        if tok is not None and self.processor.tokenizer.convert_tokens_to_ids(tok) == int(pad_id):
                            self.processor.tokenizer.pad_token = tok
                    except Exception:
                        self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
                else:
                    self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token

        # 取模型 dtype（一般是 bf16）
        try:
            self._dtype = next(self.model.parameters()).dtype
        except Exception:
            self._dtype = torch.bfloat16

        # 从 config 拿动态切块参数（Ristretto config.json 里有这些字段）
        cfg = getattr(self.model, "config", None)
        self._input_size = int(getattr(cfg, "force_image_size", 384) or 384)
        self._min_patch = int(getattr(cfg, "min_dynamic_patch", 1) or 1)
        self._max_patch = int(getattr(cfg, "max_dynamic_patch", 12) or 12)
        self._use_thumbnail = bool(getattr(cfg, "use_thumbnail", True))

        # IMG_CONTEXT token id（remote code generate 需要先设置 self.model.img_context_token_id）
        self._img_context_token = "<IMG_CONTEXT>"
        self._img_start_token = "<img>"
        self._img_end_token = "</img>"

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        # pad_id = self.tokenizer.pad_token_id
        # 优先使用 generate 实际在用的 pad_token_id
        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        # 从末尾开始，计算连续 pad token 的数量，全部截断
        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            # 确实存在 pad 并且 pad 跟 eos 相同导致多截了一个 token
            return right_pad_len - 1

        # 继续从末尾计算连续 eos token 的数量
        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        # 保留 1 个 EOS，其余视作右侧填充
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len

    @torch.no_grad()
    def generate_batch(
        self,
        items: List[Dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:

        # 1) 文本 prompt
        questions = [parse_input(it) for it in items]

        # 2) 图像 -> pixel_values 拼接；并构造 num_patches_list
        pixel_values_list: List[torch.Tensor] = []
        num_patches_list: List[List[int]] = []

        for it in items:
            img: Optional[Image.Image] = None
            if isinstance(it.get("image", None), Image.Image):
                img = it["image"]
            else:
                p = get_image_path(it)
                if p:
                    img = Image.open(p)

            if img is None:
                # 允许纯文本：这个样本就不塞 image token
                num_patches_list.append([])
                continue

            pv, n_patches = _load_image_as_pixel_values(
                image=img,
                input_size=self._input_size,
                min_num=self._min_patch,
                max_num=self._max_patch,
                use_thumbnail=self._use_thumbnail,
                device=self.device,
                dtype=self._dtype,
            )
            pixel_values_list.append(pv)
            num_patches_list.append([n_patches])

        # 如果 batch 里全是纯文本
        pixel_values = None
        if len(pixel_values_list) > 0:
            pixel_values = torch.cat(pixel_values_list, dim=0)  # [sum_patches, 3, H, W]

        # 3) 拼 prompt（对齐 remote code 的 batch_chat 逻辑）
        img_context_token_id = self.tokenizer.convert_tokens_to_ids(self._img_context_token)
        self.model.img_context_token_id = img_context_token_id

        # 用 remote code 自带的 conv_template（config.template= "qwen"）
        queries: List[str] = []
        for idx, patches in enumerate(num_patches_list):
            q = questions[idx]
            if (pixel_values is not None) and ("<image>" not in q) and (len(patches) > 0):
                q = "<image>\n" + q

            # 复制一份模板，避免污染
            if hasattr(self.model, "conv_template") and self.model.conv_template is not None:
                template = self.model.conv_template.copy()
                template.system_message = getattr(self.model, "system_message", template.system_message)
            else:
                # 退化：不用模板
                template = None

            if template is not None:
                template.append_message(template.roles[0], q)
                template.append_message(template.roles[1], None)
                query = template.get_prompt()
            else:
                query = q

            # 把 <image> 替换成 <img> + <IMG_CONTEXT>*... + </img>
            if len(patches) > 0:
                num_image_token = int(getattr(self.model, "num_image_token", 256))
                for n_p in patches:
                    image_tokens = self._img_start_token + (self._img_context_token * (num_image_token * n_p)) + self._img_end_token
                    query = query.replace("<image>", image_tokens, 1)

            queries.append(query)

        # 4) tokenize（左 padding）
        self.tokenizer.padding_side = "left"
        model_inputs = self.tokenizer(queries, return_tensors="pt", padding=True)
        input_ids = model_inputs["input_ids"].to(self.device)
        attention_mask = model_inputs["attention_mask"].to(self.device)
        input_len = input_ids.size(1)

        # 5) generation 参数：max_new_tokens 以外，其它走 gen_cfg
        gcfg = dict(gen_cfg or {})
        gcfg.setdefault("max_new_tokens", max_new_tokens)
        gcfg.setdefault("do_sample", False)
        gcfg.setdefault("temperature", None)
        # pad/eos 尽量补齐
        if "pad_token_id" not in gcfg:
            gcfg["pad_token_id"] = getattr(self.tokenizer, "pad_token_id", None)
        if "eos_token_id" not in gcfg:
            gcfg["eos_token_id"] = getattr(self.tokenizer, "eos_token_id", None)

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        # 6) 真·batch generate
        output_ids = self.model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gcfg
        )

        # 7) 计算 right_pad_len / hit_limit，并 decode
        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        eos_ids = set(_normalize_to_list(getattr(self.tokenizer, "eos_token_id", None)))

        for b in range(output_ids.size(0)):
            seq = output_ids[b].tolist()
            cut = self.calculate_right_padding_length(seq)
            right_pad_lens.append(cut)

            total_len = len(seq)

            gen_ids = seq[:-cut] if cut > 0 else seq
            gen_len = len(gen_ids)

            # hit_limit：生成长度 >= max_new_tokens 且不是 eos 正常停
            ended_with_eos = (len(seq) > 0 and (seq[-1] in eos_ids))
            hit_limit = (gen_len >= max_new_tokens) and (not ended_with_eos)
            hit_limit_flags.append(hit_limit)

            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
            outputs.append(text)
        for out in outputs:
            console.print("\n[yellow]output without prompt: ", out)

        return outputs, right_pad_lens, hit_limit_flags
