import torch
if not hasattr(torch, "compiler"):
    torch.compiler = types.SimpleNamespace()
if not hasattr(torch.compiler, "is_compiling"):
    torch.compiler.is_compiling = lambda: False
from transformers import AutoProcessor, Glm4vForConditionalGeneration
import cv2
import os

MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking"
model, processor = None, None


from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

# The default range for the number of visual tokens per image in the model is 4-16384.
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-72B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

from cosmos_reason1_utils.script import init_script

init_script()

import argparse
import collections
import pathlib
import textwrap

import qwen_vl_utils
import transformers
import vllm
import yaml
from rich import print
from rich.pretty import pprint

from cosmos_reason1_utils.text import (
    PromptConfig,
    create_conversation,
    extract_tagged_text,
)
from cosmos_reason1_utils.vision import (
    VisionConfig,
    overlay_text_on_tensor,
    save_tensor,
)

ROOT = 'cosmos-reason1'
SEPARATOR = "-" * 20

def pprint_dict(d: dict, name: str):
    """Pretty print a dictionary."""
    pprint(collections.namedtuple(name, d.keys())(**d), expand_all=True)

cosmos_r = None
def init_glm():
    global cosmos_r

    prompt = "cosmos-reason1/prompts/question.yaml"

    vision_config = f"{ROOT}/configs/vision_config.yaml"
    sampling_params = f"{ROOT}/configs/sampling_params.yaml"
    verbose = True

    # Load configs
    prompt_kwargs = yaml.safe_load(open(prompt, "rb"))
    prompt_config = PromptConfig.model_validate(prompt_kwargs)
    vision_kwargs = yaml.safe_load(open(vision_config, "rb"))
    _vision_config = VisionConfig.model_validate(vision_kwargs)
    sampling_kwargs = yaml.safe_load(open(sampling_params, "rb"))
    sampling_params = vllm.SamplingParams(**sampling_kwargs)
    if verbose:
        pprint_dict(vision_kwargs, "VisionConfig")
        pprint_dict(sampling_kwargs, "SamplingParams")

    # Create conversation
    system_prompts = [open(f"{ROOT}/prompts/addons/english.txt").read()]
    if prompt_config.system_prompt:
        system_prompts.append(prompt_config.system_prompt)
    if True and "<think>" not in prompt_config.system_prompt:
        if extract_tagged_text(prompt_config.system_prompt)[0]:
            raise ValueError(
                "Prompt already contains output format. Cannot add reasoning."
            )
        system_prompts.append(open(f"{ROOT}/prompts/addons/reasoning.txt").read())
    system_prompt = "\n\n".join(map(str.rstrip, system_prompts))

    print(SEPARATOR)
    print("System:")
    print(textwrap.indent(system_prompt.rstrip(), "  "))
    print(SEPARATOR)

    # Create model
    llm = vllm.LLM(
        model="nvidia/Cosmos-Reason1-7B",
        # revision=args.revision,
        limit_mm_per_prompt={"image": 120, "video": 1},
        enforce_eager=True,
    )

    # Process inputs
    processor: transformers.Qwen2_5_VLProcessor = (
        transformers.AutoProcessor.from_pretrained("nvidia/Cosmos-Reason1-7B")
    )

    cosmos_r = llm, processor, system_prompt, vision_kwargs, sampling_params


import re
def extract_answer(text: str):
    m = re.search(r"<answer>(.*?)</answer>", text, flags=re.S|re.I)
    if not m:
        return None
    inner = re.sub(r"<[^>]+>", "", m.group(1))
    m2 = re.search(r"\b(natural|unnatural)\b", inner, flags=re.I)
    return m2.group(1).lower() if m2 else None


def glm_reasoning(gen_video_path, first_fids, first_box, last_fids, last_box, next_fids, glm_dir):
    os.makedirs(glm_dir, exist_ok=True)

    
    first_fid = first_fids[0] 
    
    first_img = cv2.imread(gen_video_path[first_fid])
    suffix = gen_video_path[first_fid][-4:]
    x1, y1, x2, y2 = first_box[0]
    cv2.rectangle(first_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.imwrite(os.path.join(glm_dir, f'{first_fid:05d}{suffix}'), first_img)

    first_fid_2 = first_fids[1] 
    
    first_img_2 = cv2.imread(gen_video_path[first_fid_2])
    suffix = gen_video_path[first_fid_2][-4:]
    x1, y1, x2, y2 = first_box[1]
    cv2.rectangle(first_img_2, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.imwrite(os.path.join(glm_dir, f'{first_fid_2:05d}{suffix}'), first_img_2)
    

    last_fid_2 = last_fids[0]

    last_img_2 = cv2.imread(gen_video_path[last_fid_2])
    suffix = gen_video_path[last_fid_2][-4:]
    x1, y1, x2, y2 = last_box[0]
    cv2.rectangle(last_img_2, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.imwrite(os.path.join(glm_dir, f'{last_fid_2:05d}{suffix}'), last_img_2)

    last_fid = last_fids[-1]

    last_img = cv2.imread(gen_video_path[last_fid])
    suffix = gen_video_path[last_fid][-4:]
    x1, y1, x2, y2 = last_box[1]
    cv2.rectangle(last_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.imwrite(os.path.join(glm_dir, f'{last_fid:05d}{suffix}'), last_img)

    next_fid = next_fids[0]

    next_img = cv2.imread(gen_video_path[next_fid])
    suffix = gen_video_path[next_fid][-4:]
    cv2.imwrite(os.path.join(glm_dir, f'{next_fid:05d}{suffix}'), next_img)

    next_fid_2 = next_fids[1]

    next_img_2 = cv2.imread(gen_video_path[next_fid_2])
    suffix = gen_video_path[next_fid_2][-4:]
    cv2.imwrite(os.path.join(glm_dir, f'{next_fid_2:05d}{suffix}'), next_img_2)




    global cosmos_r
    if cosmos_r is None:
        init_glm()
    llm, processor, system_prompt, vision_kwargs, sampling_params = cosmos_r

    question = '''Given frames around the moment the same green-boxed object disappears, classify the disappearance as Natural (e.g., occlusion, leaving the field of view) or Unnatural (e.g., abrupt/non-physical disappearance). \
        Base your decision on visual continuity, motion continuity, and the object’s interactions with surrounding vehicles and the environment.'''

    images = [
        os.path.join(glm_dir, f'{first_fid_2:05d}{suffix}'),
        os.path.join(glm_dir, f'{last_fid:05d}{suffix}'),
        os.path.join(glm_dir, f'{next_fid:05d}{suffix}'),
    ]

    videos = []

    user_prompt = question
    if not user_prompt:
        raise ValueError("No user prompt provided.")
    user_prompt = user_prompt.rstrip()
    conversation = create_conversation(
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        images=images,
        videos=videos,
        vision_kwargs=vision_kwargs,
    )


    prompt = processor.apply_chat_template(
        conversation, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs, video_kwargs = qwen_vl_utils.process_vision_info(
        conversation, return_video_kwargs=True
    )

    # Run inference
    mm_data = {}
    if image_inputs is not None:
        mm_data["image"] = image_inputs
    if video_inputs is not None:
        mm_data["video"] = video_inputs
    llm_inputs = {
        "prompt": prompt,
        "multi_modal_data": mm_data,
        "mm_processor_kwargs": video_kwargs,
    }
    outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
    print(SEPARATOR)
    for output in outputs[0].outputs:
        output_text = output.text
        print("Assistant:")
        print(textwrap.indent(output_text.rstrip(), "  "))
    print(SEPARATOR)

    result, _ = extract_tagged_text(output_text)
    if result:
        pprint_dict(result, "Result")

    if 'Unnatural' in result['answer'][0]:
        return False
    return True