from flask import Flask, request
import pickle

import os
import torch.multiprocessing as mp
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

import torch
os.environ.setdefault('TOKENIZERS_PARALLELISM', 'false')

from PIL import Image
import base64
import io   
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import inspect
import re
import os
import numpy as np
import math
from vllm import LLM, SamplingParams
from threading import Lock
import argparse

llm = None
_llm_lock = Lock()

_visible = os.environ.get('CUDA_VISIBLE_DEVICES')
if _visible:
    try:
        _tp_default = max(1, len([x for x in _visible.split(',') if x.strip() != '']))
    except Exception:
        _tp_default = max(1, torch.cuda.device_count())
else:
    _tp_default = max(1, torch.cuda.device_count())

_hf_model = None

def _ensure_hf_model_loaded(ckpt):
    global _hf_model
    if _hf_model is not None:
        return _hf_model
    try:
        _hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            ckpt,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
        )
    except Exception:
        _hf_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            ckpt,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="cuda",
        )
    if torch.cuda.is_available() and torch.cuda.device_count() == 1:
        _hf_model = _hf_model.to('cuda:0')
    return _hf_model

processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", max_pixels=224 * 28 * 28, min_pixels=3136)

def _to_pil(entry):
    if isinstance(entry, Image.Image):
        return entry
    if isinstance(entry, (bytes, bytearray)):
        b = bytes(entry)
        try:
            return Image.open(io.BytesIO(b)).convert("RGB")
        except Exception:
            try:
                b2 = base64.b64decode(b, validate=False)
                return Image.open(io.BytesIO(b2)).convert("RGB")
            except Exception:
                raise ValueError("Invalid image bytes provided (not PNG/JPEG or base64-encoded).")
    if isinstance(entry, str):
        s = entry.strip()
        if s.startswith("data:image") and ";base64," in s:
            b64 = s.split(",", 1)[1]
            b = base64.b64decode(b64)
            return Image.open(io.BytesIO(b)).convert("RGB")
        if os.path.isfile(s):
            return Image.open(s).convert("RGB")
        try:
            b = base64.b64decode(s, validate=False)
            return Image.open(io.BytesIO(b)).convert("RGB")
        except Exception:
            raise ValueError("String image must be a data URI, base64 string, or existing file path.")
    try:
        import numpy as _np
        if isinstance(entry, _np.ndarray):
            if entry.dtype != _np.uint8:
                raise ValueError("NumPy image array must be uint8.")
            if entry.ndim == 2:
                mode = "L"
            elif entry.ndim == 3 and entry.shape[2] in (1, 3, 4):
                mode = None
            else:
                raise ValueError("NumPy image array must be HxW or HxWx[1|3|4].")
            img = Image.fromarray(entry)
            return img.convert("RGB")
    except Exception:
        pass
    raise ValueError("Unsupported image type. Provide raw PNG/JPEG bytes, base64 string, file path, PIL Image, or uint8 numpy array.")


#EXAMPLE OF PROMPT FOR VLLM
PAIRED_THINKING_TEMPLATE = "User prompt: {prompt} Which image is better given the prompt? Analyze aesthetics, composition, prompt alignment and other factors. Provide your reasoning in <think>...</think> tags and the final JSON answer in <answer>{"preferred":"second"}</answer> or {"preferred":"first"}."

def create_app(ckpt):
    app = Flask(__name__)

    @app.route('/', methods=['POST'])
    def reward():
        data_bytes = request.get_data()
        data = pickle.loads(data_bytes)
        images_bytes = data["images"]
        prompts = data["prompts"]
        mode = data.get("mode", "naive")
        num_anchors = 10
        
        if mode == "paired":
            raise NotImplementedError("Paired mode is not implemented for vLLM")
        elif mode == "paired_thinking":
            print("NUM ANCHORS IS ", num_anchors)
            scores = compute_paired_thinking_scores_vllm(images_bytes, prompts, num_anchors, ckpt)
            return pickle.dumps({"outputs": {"scores": scores}})
        elif mode == "paired_fake_thinking":
            raise NotImplementedError("Paired fake thinking mode is not implemented for vLLM")
        else:
            raise NotImplementedError("Naive mode is not implemented for vLLM")


    def compute_paired_thinking_scores_vllm(images_bytes, prompts, num_anchors=12, ckpt=None):
        """
        vLLM version of thinking scores - generates full answers then uses log-probs for final decision
        """
        images = [_to_pil(b).resize((512, 512)) for b in images_bytes]
        
        SYSTEM_PROMPT = (
            "The user has two images and a textual prompt. You need to reason carefully and produce an answer with reasoning in \{<think>\}...\{</think>\} where you should choose the best image."
        )

        prompt = prompts[0] if prompts else "Which image is better?"
        
        n_images = len(images)
        if n_images < 2:
            raise ValueError("compute_paired_thinking_scores_vllm requires at least two images to compare.")
        
        image_scores = np.zeros(n_images)
        comparison_counts = np.zeros(n_images)
        
        actual_anchors = min(num_anchors, n_images)
        
        comparison_data = []
        for anchor_idx in range(actual_anchors):
            for compare_idx in range(n_images):
                if compare_idx == anchor_idx:
                    continue
                comparison_data.append({
                    'anchor_idx': anchor_idx,
                    'compare_idx': compare_idx,
                    'anchor_img': images[anchor_idx],
                    'compare_img': images[compare_idx]
                })
        
        if not comparison_data:
            raise RuntimeError("No comparison pairs could be constructed from inputs.")
            
        import time
        start_time = time.time()
        print(f"Processing {len(comparison_data)} comparisons with vLLM")
        
        global llm
        if llm is None:
            _visible = os.environ.get('CUDA_VISIBLE_DEVICES')
            if _visible:
                try:
                    _tp_default = max(1, len([x for x in _visible.split(',') if x.strip() != '']))
                except Exception:
                    _tp_default = max(1, torch.cuda.device_count())
            else:
                _tp_default = max(1, torch.cuda.device_count())
            _tp = int(os.environ.get('TP_SIZE', _tp_default))
            llm = LLM(
                model=ckpt,
                dtype="bfloat16",
                tensor_parallel_size=_tp,
                trust_remote_code=True,
                gpu_memory_utilization=float(os.environ.get('VLLM_GPU_UTIL', '0.95')),
            )

        first_ids_head = processor.tokenizer.encode("first", add_special_tokens=False) or [processor.tokenizer.convert_tokens_to_ids("first")]
        second_ids_head = processor.tokenizer.encode("second", add_special_tokens=False) or [processor.tokenizer.convert_tokens_to_ids("second")]
        first_head_id = first_ids_head[0]
        second_head_id = second_ids_head[0]

        def _extract_p_second_from_lp_dict(lp_dict):
            p_first = 0.0
            p_second = 0.0
            for k, v in lp_dict.items():
                if hasattr(v, 'id') and hasattr(v, 'logprob'):
                    tok_id = v.id
                    lp = v.logprob
                elif isinstance(k, int):
                    tok_id = k
                    lp = v if isinstance(v, (int, float)) else getattr(v, 'logprob', None)
                else:
                    tok_id = getattr(v, 'id', None)
                    lp = getattr(v, 'logprob', None)
                if tok_id is None or lp is None:
                    continue
                if tok_id == first_head_id:
                    p_first = math.exp(lp)
                elif tok_id == second_head_id:
                    p_second = math.exp(lp)
            if p_first <= 0 and p_second <= 0:
                raise RuntimeError("Logprobs did not contain tokens for 'first' or 'second'.")
            eps = 1e-12
            p_first = p_first if p_first > 0 else eps
            p_second = p_second if p_second > 0 else eps
            z = p_first + p_second
            return p_second / z

        gen_prompts_in = []
        meta_pairs = []
        for comp in comparison_data:
            user_content = [
                    {"type": "image"},
                    {"type": "image"},
                    {
                        "type": "text",
                        "text": (
                            f"User prompt: {prompt}\n\n"
                            "Which image is better (first or second) given the prompt? Analyze aesthetics, composition, prompt alignment, and other relevant factors. "
                            "Provide your reasoning in <think>...</think> tags, "
                            f'and the final JSON answer in <answer>{{"preferred":"first"}}</answer> or <answer>{{"preferred":"second"}}</answer> format.'
                        ),
                    },
                ]

            conv = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ]
            full_text = processor.apply_chat_template(conv, add_generation_prompt=True, tokenize=False)
            gen_prompts_in.append({
                "prompt": full_text,
                "multi_modal_data": {"image": [comp['anchor_img'], comp['compare_img']]},
            })
            meta_pairs.append((comp['anchor_idx'], comp['compare_idx'], comp['anchor_img'], comp['compare_img'], user_content))

        def _locked_generate(prompts, params):
            try:
                with _llm_lock:
                    return llm.generate(prompts, params)
            except Exception as e:
                raise e

        think_tokens = int(os.environ.get('THINK_TOKENS', '512'))
        max_batch = max(1, int(os.environ.get('VLLM_MAX_BATCH', '16')))
        params_gen = SamplingParams(max_tokens=think_tokens, temperature=0.0, seed=42)
        gen_outputs = []
        try:
            for i in range(0, len(gen_prompts_in), max_batch):
                chunk = gen_prompts_in[i:i+max_batch]
                out = _locked_generate(chunk, params_gen)
                gen_outputs.extend(out)
        except Exception as e:
            raise RuntimeError(f"Batched generation failed: {e}")

        scoring_prompts_in = []
        valid_indices = []

        template_prefix_1 = '</think>\n<answer>{"preferred": "'
        template_prefix_2 = '</think>\n\n<answer>{"preferred": "'
        fallback_suffix = '<answer>{"preferred": "'
        fallback_anchor = "<answer>"

        for idx, meta in enumerate(meta_pairs):
            anchor_idx, compare_idx, anchor_img, compare_img, user_content = meta
            try:
                generated_text = gen_outputs[idx].outputs[0].text if idx < len(gen_outputs) else ""
            except Exception as e:
                raise RuntimeError(f"Missing generated text for comparison index {idx}: {e}")
            if not generated_text:
                raise RuntimeError(f"Empty generated text for comparison index {idx}.")
            conversation_with_thinking = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
                {"role": "assistant", "content": generated_text},
            ]
            full_conversation_text = processor.apply_chat_template(
                conversation_with_thinking, add_generation_prompt=False, tokenize=False
            )

            text_for_scoring = None
            split_parts_1 = full_conversation_text.split(template_prefix_1)
            if len(split_parts_1) > 1:
                text_for_scoring = split_parts_1[0] + template_prefix_1
            else:
                split_parts_2 = full_conversation_text.split(template_prefix_2)
                if len(split_parts_2) > 1:
                    text_for_scoring = split_parts_2[0] + template_prefix_2
                else:
                    last_answer_index = full_conversation_text.rfind(fallback_anchor)
                    if last_answer_index != -1:
                        truncated_text = full_conversation_text[:last_answer_index]
                        text_for_scoring = truncated_text + fallback_suffix
            if not text_for_scoring:
                raise RuntimeError("Could not construct scoring prompt anchored at <answer> for vLLM path.")
            scoring_prompts_in.append({
                "prompt": text_for_scoring,
                "multi_modal_data": {"image": [anchor_img, compare_img]},
            })
            valid_indices.append(idx)

        if not scoring_prompts_in:
            raise RuntimeError("No scoring prompts could be constructed; cannot compute scores.")
        probs_list = [None] * len(meta_pairs)
        params_lp = SamplingParams(max_tokens=1, temperature=0.0, logprobs=20, top_p=1.0, seed=123)
        try:
            out_cursor = 0
            for i in range(0, len(scoring_prompts_in), max_batch):
                chunk = scoring_prompts_in[i:i+max_batch]
                lp_chunk = _locked_generate(chunk, params_lp)
                for j, item in enumerate(lp_chunk):
                    lp_dict = item.outputs[0].logprobs[0]
                    probs_list[valid_indices[out_cursor + j]] = _extract_p_second_from_lp_dict(lp_dict)
                out_cursor += len(lp_chunk)
        except Exception as e:
            raise RuntimeError(f"Batched logprob failed: {e}")
        
        for i, comp in enumerate(comparison_data):
            if i >= len(probs_list) or probs_list[i] is None:
                raise RuntimeError(f"Missing probability for comparison index {i}.")
            prob_second_better = probs_list[i]
            anchor_idx = comp['anchor_idx']
            compare_idx = comp['compare_idx']
            image_scores[anchor_idx] += (1.0 - prob_second_better)
            comparison_counts[anchor_idx] += 1
            image_scores[compare_idx] += prob_second_better
            comparison_counts[compare_idx] += 1
        
        final_scores = []
        for i in range(n_images):
            if comparison_counts[i] <= 0:
                raise RuntimeError(f"No valid comparisons computed for image index {i} (vLLM path).")
            final_scores.append(image_scores[i] / comparison_counts[i])
        
        end_time = time.time()
        print(f"vLLM processing completed in {end_time - start_time:.2f}s for {len(comparison_data)} comparisons")
        
        return final_scores

    return app

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run the vLLM reward server.")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run on.")
    parser.add_argument("--port", type=int, default=5000, help="Port to run on.")
    parser.add_argument("--ckpt", type=str, required=True, help="Path to model checkpoint.")
    args = parser.parse_args()
    
    global llm
    llm = LLM(
        model=args.ckpt,
        dtype="bfloat16",
        tensor_parallel_size=1,
        trust_remote_code=True,
        gpu_memory_utilization=float(os.environ.get('VLLM_GPU_UTIL', '0.95')),
    )
    
    app = create_app(args.ckpt)
    app.run(host=args.host, port=args.port, debug=False, threaded=True, use_reloader=False)