from typing import List, Dict, Tuple
import os
import sys
import copy

import torch
import torchvision.transforms as T
from PIL import Image
from rich.console import Console
from torchvision.transforms.functional import InterpolationMode

from .base import BaseVLM

# project root for utils
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from utils import _normalize_to_list, parse_input, get_image_path  # noqa: E402

console = Console()

IMG_START_TOKEN = "<img>"
IMG_END_TOKEN = "</img>"
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size: int):
    return T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        processed_images.append(resized_img.crop(box))

    if use_thumbnail and len(processed_images) != 1:
        processed_images.append(image.resize((image_size, image_size)))

    return processed_images


def load_image(image_file: str, input_size=448, min_num=1, max_num=12, use_thumbnail=True) -> torch.Tensor:
    image = Image.open(image_file).convert("RGB")
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(
        image,
        min_num=min_num,
        max_num=max_num,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )
    pixel_values = torch.stack([transform(im) for im in images])  # [num_tiles, 3, H, W]
    return pixel_values


class QTuneInternVLVLM(BaseVLM):
    """
    QTuneVL1.5-2B（InternVLChatModel）走“直呼 generate”真 batch：
    - 复刻 batch_chat 的 prompt 逻辑（IMG_CONTEXT / template.sep 做 eos）
    - pixel_values 走 dynamic tiling： [sum_tiles, 3, 448, 448] + num_patches_list
    - right_pad_len / hit_limit：对 new_tokens 计算（切掉 prefix），并复用你已有的 calculate_right_padding_length
    """

    def __init__(self, model, tokenizer, processor=None, device="cuda"):
        super().__init__(model=model, tokenizer=tokenizer, processor=processor, device=device)

        # decoder-only：强制 left padding
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if getattr(self.tokenizer, "pad_token_id", None) is None and getattr(self.tokenizer, "eos_token_id", None) is not None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        cfg = getattr(self.model, "config", None)
        self.image_size = int(getattr(cfg, "force_image_size", 448) or 448)
        self.max_num_tiles = int(getattr(cfg, "max_dynamic_patch", 12) or 12)
        self.min_num_tiles = int(getattr(cfg, "min_dynamic_patch", 1) or 1)
        self.use_thumbnail = bool(getattr(cfg, "use_thumbnail", True))

    def _model_dtype(self) -> torch.dtype:
        try:
            return next(self.model.parameters()).dtype
        except StopIteration:
            return torch.bfloat16

    def _ensure_image_prefix(self, q: str) -> str:
        q = q or ""
        return q if "<image>" in q else "<image>\n" + q

    def _build_pixels_and_patches(self, items: List[dict]) -> Tuple[torch.Tensor, List[int]]:
        all_pixels: List[torch.Tensor] = []
        num_patches_list: List[int] = []

        for item in items:
            image_paths = _normalize_to_list(get_image_path(item))
            if not image_paths:
                raise ValueError("QTuneVLVLM 需要每条样本至少一张图片（image_path）")
            img_path = image_paths[0]  # 你保证单图

            pixels = load_image(
                img_path,
                input_size=self.image_size,
                min_num=self.min_num_tiles,
                max_num=self.max_num_tiles,
                use_thumbnail=self.use_thumbnail,
            )  # [num_tiles, 3, H, W]
            all_pixels.append(pixels)
            num_patches_list.append(int(pixels.shape[0]))

        pixel_values = torch.cat(all_pixels, dim=0).to(self.device, dtype=self._model_dtype())
        return pixel_values, num_patches_list

    def _build_queries_and_eos(self, questions: List[str], num_patches_list: List[int]) -> Tuple[List[str], int, str]:
        queries: List[str] = []

        # 优先用模型自带 conv_template（InternVLChatModel.__init__ 里就建好了）
        base_template = getattr(self.model, "conv_template", None)
        if base_template is None:
            # 兜底：走模型同包的 get_conv_template
            base_pkg = self.model.__module__.rsplit(".", 1)[0]
            conv_mod = __import__(f"{base_pkg}.conversation", fromlist=["get_conv_template"])
            get_conv_template = getattr(conv_mod, "get_conv_template")
            base_template = get_conv_template(self.model.template)

        template_sep = None
        eos_token_id = None

        for q, num_patches in zip(questions, num_patches_list):
            q = self._ensure_image_prefix(q)

            template = copy.deepcopy(base_template)
            template.system_message = getattr(self.model, "system_message", template.system_message)
            template.messages = []
            template.append_message(template.roles[0], q)
            template.append_message(template.roles[1], None)

            query = template.get_prompt()
            image_tokens = IMG_START_TOKEN + (IMG_CONTEXT_TOKEN * self.model.num_image_token * int(num_patches)) + IMG_END_TOKEN
            query = query.replace("<image>", image_tokens, 1)

            queries.append(query)

            template_sep = template.sep.strip()
            eos_token_id = self.tokenizer.convert_tokens_to_ids(template_sep)

        if eos_token_id is None or template_sep is None:
            raise RuntimeError("无法从 conv_template 推导 eos_token_id/template_sep")

        return queries, eos_token_id, template_sep

    def generate_batch(
        self,
        items: List[dict],
        max_new_tokens: int,
        gen_cfg: Dict,
        oom_estimate: bool,
        bs_estimate_gen_cfg: Dict
    ) -> Tuple[List[str], List[int], List[bool]]:
        batch_size = len(items)
        if batch_size == 0:
            return [], [], []

        # 1) questions
        questions = [self._ensure_image_prefix(parse_input(it)) for it in items]

        # 2) pixel_values + num_patches_list（关键：dynamic tiling）
        pixel_values, num_patches_list = self._build_pixels_and_patches(items)

        # 3) build queries + eos (sep)
        queries, eos_token_id, _template_sep = self._build_queries_and_eos(questions, num_patches_list)

        # 4) tokenize（left pad）
        self.tokenizer.padding_side = "left"
        model_inputs = self.tokenizer(queries, return_tensors="pt", padding=True)
        input_ids = model_inputs["input_ids"].to(self.device)
        attention_mask = model_inputs["attention_mask"].to(self.device)
        prefix_len_padded = input_ids.shape[1]

        # 5) 必须设置 img_context_token_id，否则 model.generate 会 assert
        self.model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)

        # 6) generation config（kwargs 方式传入，和 batch_chat 一致）
        generation_config = dict(gen_cfg) if gen_cfg is not None else {}
        generation_config["max_new_tokens"] = max_new_tokens
        generation_config.setdefault("do_sample", False)

        # pad/eos 对齐（batch_chat 用 template.sep 当 eos）
        pad_id = getattr(self.tokenizer, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "eos_token_id", None)
        generation_config["pad_token_id"] = pad_id
        generation_config["eos_token_id"] = eos_token_id

        # 同步到 model.generation_config，方便你 calculate_right_padding_length 读到一致的 eos/pad
        try:
            self.model.generation_config.eos_token_id = eos_token_id
            self.model.config.eos_token_id = eos_token_id
            self.model.generation_config.pad_token_id = pad_id
            self.model.config.pad_token_id = pad_id
        except Exception:
            pass

        # eos 集合（hit_limit 判断用）
        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set([x for x in (gc_eos + tok_eos + [eos_token_id]) if x is not None]))

        # console.print(f"[cyan][QTuneVL] generate batch_size={batch_size}, max_new_tokens={max_new_tokens}[/cyan]")

        # ================================ 显存估计 ==================================
        if oom_estimate:
            from utils import apply_prefill_extra_tokens
            bs_estimate_gen_cfg = dict(bs_estimate_gen_cfg or {})
            prefill_extra_tokens = int(bs_estimate_gen_cfg.pop("_prefill_extra_tokens", 0) or 0)
            prefill_token_id = bs_estimate_gen_cfg.pop("_prefill_token_id", None)
            batch_model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
            batch_model_inputs = apply_prefill_extra_tokens(
                batch_size=len(items),
                inputs=batch_model_inputs,
                prefill_extra_tokens=prefill_extra_tokens,
                tokenizer=self.tokenizer,
                prefill_token_id=prefill_token_id,
            )
            input_ids, attention_mask = batch_model_inputs["input_ids"], batch_model_inputs["attention_mask"]
        # ===========================================================================

        with torch.no_grad():
            out = self.model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                **generation_config,
            )

        # 兼容 return_dict_in_generate 的情况
        sequences = out.sequences if hasattr(out, "sequences") else out  # [B, prefix+new]

        # 7) 切掉 prefix，只在 new tokens 上算 right_pad/hit_limit，并 decode
        new_tokens_all = sequences

        outputs: List[str] = []
        right_pad_lens: List[int] = []
        hit_limit_flags: List[bool] = []

        for i in range(batch_size):
            seq_new = new_tokens_all[i]

            cut = self.calculate_right_padding_length(seq_new)  # 复用你原逻辑
            right_pad_lens.append(cut)

            seq_trim = seq_new[:-cut] if cut > 0 else seq_new  # 这里保留 1 个 eos（由你的函数决定）
            out_len = int(seq_trim.shape[0])

            ended_with_eos = False
            if out_len > 0 and eos_ids:
                ended_with_eos = int(seq_trim[-1].item()) in eos_ids

            hit_limit_flags.append(out_len >= max_new_tokens and (not ended_with_eos))

            # 为了“输出不带 stop token”，如果最后一个 token 就是 template.sep 的 eos，把它去掉
            if out_len > 0 and eos_token_id is not None and int(seq_trim[-1].item()) == int(eos_token_id):
                seq_decode = seq_trim[:-1]
            else:
                seq_decode = seq_trim

            outputs.append(self.tokenizer.decode(
                seq_decode,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            ))

        for o in outputs:
            console.print("\n[yellow]output without prompt:", o)

        return outputs, right_pad_lens, hit_limit_flags

    def generate_one(self, item, max_new_tokens, gen_cfg):
        outputs, right_pads, hit_limits = self.generate_batch([item], max_new_tokens, gen_cfg, oom_estimate=False, bs_estimate_gen_cfg={})
        return outputs[0], right_pads[0], hit_limits[0]

    def calculate_right_padding_length(self, total_sequence) -> int:
        if isinstance(total_sequence, torch.Tensor):
            total_sequence = total_sequence.tolist()
        right_pad_len = 0

        pad_id = getattr(self.model.generation_config, "pad_token_id", None)
        if pad_id is None:
            pad_id = getattr(self.tokenizer, "pad_token_id", None)

        gc_eos = _normalize_to_list(getattr(self.model.generation_config, "eos_token_id", None))
        tok_eos = _normalize_to_list(getattr(self.tokenizer, "eos_token_id", None))
        eos_ids = list(set(gc_eos + tok_eos))

        n = len(total_sequence)
        j = n - 1
        while j >= 0 and total_sequence[j] == pad_id:
            right_pad_len += 1
            j -= 1

        if j >= 0 and total_sequence[j] not in eos_ids and right_pad_len > 0:
            return right_pad_len - 1

        eos_count = 0
        i = j
        while i >= 0 and total_sequence[i] in eos_ids:
            eos_count += 1
            i -= 1
        right_pad_len += max(0, eos_count - 1)

        return right_pad_len
