from flask import Flask, request
import pickle
import os
from threading import Lock

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
import torch.nn as nn
from PIL import Image
import io
import numpy as np
import random
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import argparse


MODEL = None
PROCESSOR = None
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_LOCK = Lock()


class BinaryClassificationHead(torch.nn.Module):
    """
    Copied from qwen_pref_sft.py
    """
    def __init__(self, hidden_size: int):
        super().__init__()
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.GELU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(hidden_size, 2),
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        return self.classifier(features)

def _to_pil(entry):
    if isinstance(entry, Image.Image):
        return entry
    if isinstance(entry, (bytes, bytearray)):
        return Image.open(io.BytesIO(bytes(entry))).convert("RGB")
    raise ValueError("Unsupported image type. Provide raw PNG/JPEG bytes.")



def compute_binary_scores(images_bytes, prompts):
    """
    Computes reward scores for single images using a binary classification head.
    """
    if not hasattr(MODEL, "binary_head"):
        raise RuntimeError("Model does not have a binary_head, but 'binary' mode was requested.")

    images = [_to_pil(b) for b in images_bytes]
    scores = []
    
    batch_size = 1
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i+batch_size]
        batch_prompts = prompts[i:i+batch_size]

        chats = []
        for img, p in zip(batch_images, batch_prompts):
            chats.append(
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": p},
                        {"type": "image", "image": img},
                        {"type": "text", "text": "Is this image the winner? Reply with 'yes' or 'no'."},
                    ],
                }
            )
        
        texts = [PROCESSOR.apply_chat_template([chat], tokenize=False, add_generation_prompt=True) for chat in chats]
        
        inputs = PROCESSOR(
            text=texts,
            images=[[c["content"][1]["image"]] for c in chats],
            padding=True,
            return_tensors="pt",
        ).to(DEVICE)

        with torch.no_grad():
            outputs = MODEL(output_hidden_states=True, **inputs)
            last_hidden = outputs.hidden_states[-1]
            features = last_hidden[:, -1, :]
            cls_logits = MODEL.binary_head(features)
            
            batch_scores = cls_logits[:, 1].tolist()
            scores.extend(batch_scores)
            
    return scores


def compute_pairwise_scores(images_bytes, prompts, num_anchors):
    """
    Computes reward scores using pairwise comparison against anchors.
    """
    images = [_to_pil(b) for b in images_bytes]
    prompt = prompts[0] if prompts else "Which image is better?"
    n_images = len(images)

    if n_images < 2:
        return [0.5] * n_images

    answer_id_1 = 3896
    answer_id_2 = 5569

    image_scores = np.zeros(n_images)
    comparison_counts = np.zeros(n_images)

    actual_anchors = min(num_anchors, n_images)
    anchor_indices = list(range(actual_anchors))

    comparison_pairs = []
    for i in range(n_images):
        for anchor_idx in anchor_indices:
            if i == anchor_idx:
                continue

            comparison_pairs.append({'img1_idx': i, 'img2_idx': anchor_idx})
            comparison_pairs.append({'img1_idx': anchor_idx, 'img2_idx': i})

    batch_size = 1 
    
    for i in range(0, len(comparison_pairs), batch_size):
        batch_pairs = comparison_pairs[i:i+batch_size]
        
        chats = []
        batch_images_input = []
        for pair in batch_pairs:
            img_a = images[pair['img1_idx']]
            img_b = images[pair['img2_idx']]
            
            chat_content = [
                {"type": "text", "text": prompt},
                {"type": "image", "image": img_a},
                {"type": "image", "image": img_b},
                {"type": "text", "text": "Which image matches the prompt better? Reply with 'first' or 'second'."},
            ]
            chats.append([{"role": "user", "content": chat_content}])
            batch_images_input.append([img_a, img_b])

        texts = [PROCESSOR.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) for chat in chats]
        
        inputs = PROCESSOR(
            text=texts,
            images=batch_images_input,
            padding=True,
            return_tensors="pt",
        ).to(DEVICE)

        with torch.no_grad():
            outputs = MODEL(**inputs)
            logits = outputs.logits[:, -1, :]
            
            pair_logits = logits[:, [answer_id_1, answer_id_2]]
            pair_probs = torch.softmax(pair_logits, dim=-1)

            for j, pair in enumerate(batch_pairs):
                prob_first = pair_probs[j, 0].item()
                prob_second = pair_probs[j, 1].item()
                
                img1_idx = pair['img1_idx']
                img2_idx = pair['img2_idx']
                
                image_scores[img1_idx] += prob_first
                comparison_counts[img1_idx] += 1
                
                image_scores[img2_idx] += prob_second
                comparison_counts[img2_idx] += 1

    final_scores = [
        (image_scores[i] / comparison_counts[i]) if comparison_counts[i] > 0 else 0.5
        for i in range(n_images)
    ]
    
    return final_scores



def create_app():
    app = Flask(__name__)

    @app.route('/', methods=['POST'])
    def reward():
        data_bytes = request.get_data()
        data = pickle.loads(data_bytes)
        images_bytes = data["images"]
        prompts = data["prompts"]
        mode = data.get("mode", "binary")
        mode = 'pairwise'
        with MODEL_LOCK:
            if mode == "binary":
                scores = compute_binary_scores(images_bytes, prompts)
            elif mode == "pairwise":
                num_anchors = data.get("num_anchors", 10)
                scores = compute_pairwise_scores(images_bytes, prompts, num_anchors)
            else:
                return pickle.dumps({"error": f"Invalid mode: {mode}"}), 400

        return pickle.dumps({"outputs": {"scores": scores}})

    return app

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the finetuned Qwen-VL preference model.")
    parser.add_argument("--host", type=str, default='0.0.0.0')
    parser.add_argument("--port", type=int, default=5001)
    args = parser.parse_args()

    global MODEL, PROCESSOR
    
    print(f"Loading processor...")
    PROCESSOR = AutoProcessor.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct",
        trust_remote_code=True,
        max_pixels=224*28*28
    )
    
    print(f"Loading model from {args.model_path}...")
    MODEL = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    
    binary_head_path = os.path.join(args.model_path, "binary_head.pt")
    if os.path.exists(binary_head_path):
        print("Found binary_head.pt, loading it.")
        hidden_size = getattr(MODEL.config, "hidden_size", None)
        if hidden_size is None and hasattr(MODEL.config, "text_config"):
            hidden_size = getattr(MODEL.config.text_config, "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not determine hidden size for binary head.")
        
        MODEL.binary_head = BinaryClassificationHead(hidden_size)
        MODEL.binary_head.load_state_dict(torch.load(binary_head_path, map_location="cpu"))
        MODEL.binary_head.to(MODEL.dtype)
        print("Binary head loaded.")

    MODEL.to(DEVICE)
    MODEL.eval()

    print(f"Model loaded on device: {DEVICE}")

    app = create_app()
    app.run(host=args.host, port=args.port, debug=False, threaded=True)

if __name__ == '__main__':
    main()
