import os
import json
import torch
from tqdm import tqdm
import sys
import re
import base64
import cv2
from PIL import Image
from io import BytesIO
import random
import numpy as np
import pandas as pd
import datetime
import warnings
import logging
# Use generic auto classes for model loading (no unsloth)
from transformers import AutoProcessor, AutoModelForVision2Seq
from qwen_vl_utils import process_vision_info
import argparse

# Set up logging and warnings
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=0, help="Starting sample index (inclusive)")
parser.add_argument("--end", type=int, default=None, help="Ending sample index (inclusive). If not provided, process till the last sample")
parser.add_argument("--output_dir", type=str, default=".", help="Directory to save result JSONs")
parser.add_argument("--num_runs", type=int, default=1, help="Number of stochastic inference runs per persona (set to 1 for deterministic single-run behaviour)")
parser.add_argument("--similarity_json", type=str, default=None, help="Path to similarity JSON file (optional)")
args = parser.parse_args()

if args.similarity_json is not None:
    SIM_JSON_PATH = args.similarity_json
else:
    SIM_JSON_PATH = os.path.join(os.path.dirname(__file__), "similaritywebaes.json")

if os.path.exists(SIM_JSON_PATH):
    try:
        with open(SIM_JSON_PATH, "r") as f:
            SIMILARITY_DATA = json.load(f)
    except Exception:
        SIMILARITY_DATA = {}
else:
    SIMILARITY_DATA = {}

def get_model_and_processor(model_dir="Qwen/Qwen2.5-VL-72B-Instruct"):
    """
    Load the Qwen model and processor using generic Auto classes (no unsloth).
    Mirrors the approach in multillm_webaes.py.
    """
    logger.info(f"Loading model and processor from {model_dir}")

    # Processor handles both vision and text modalities
    processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)

    # Load Vision-Language model using AutoModelForVision2Seq (official class)
    model = AutoModelForVision2Seq.from_pretrained(
        model_dir,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
        trust_remote_code=True,
    )

    return model, processor

def inference_batch(image_urls, prompts, sys_prompts, model, processor, max_new_tokens=512):
    # Prepare the messages in the required format for batch inference
    messages_batch = [
        [
            {"role": "system", "content": sys_prompt},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_url},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        for image_url, prompt, sys_prompt in zip(image_urls, prompts, sys_prompts)
    ]

    # Prepare the input for the processor
    texts = [
        processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,  # Disable Qwen "thinking" tokens (consistent with multillm_webaes.py)
        )
        for messages in messages_batch
    ]

    image_inputs, video_inputs = process_vision_info(messages_batch)
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    inputs = inputs.to(model.device)

    # Perform batch inference
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.85,
            top_p=0.9,
            pad_token_id=processor.tokenizer.eos_token_id
        )
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_texts

def frame_to_data_url(frame_bgr):
    try:
        # Check if the frame is valid
        if frame_bgr is None or frame_bgr.size == 0:
            return None
        
        # Convert the BGR frame (OpenCV format) to RGB
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

        # Convert the RGB frame to a PIL Image
        image = Image.fromarray(frame_rgb)
        image = image.resize((256, 256), Image.LANCZOS)
        # Create a BytesIO buffer to hold the image data
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        buffered.seek(0)

        # Encode the image data in base64
        base64_encoded_data = base64.b64encode(buffered.read()).decode('utf-8')

        # Construct the data URL
        return f"data:image/jpeg;base64,{base64_encoded_data}"
    except Exception as e:
        print(f"Error in frame_to_data_url: {e}")
        return None

persona_prompts = {
    "18-24_female": """You are a woman aged 18–24. You're fluent in digital aesthetics, raised on platforms like TikTok and Instagram. You notice instantly if something has a vibe—bold colors, expressive fonts, emotional tone, or modern, fun design. Websites that are cluttered, generic, or try-hard are less likely to appeal to you.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its visual design, layout, color scheme, and content.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "18-24_male": """You are a man aged 18–24. You're used to fast-scroll content and visual punch—memes, Twitch, TikTok, YouTube. You like websites that grab attention fast: bold layouts, smart design, or a bit of edge. If a website feels outdated, cluttered, or boring, it loses your interest quickly.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on visuals, usability, and vibe.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "25-34_female": """You are a woman aged 25–34. You appreciate modern, polished websites that feel aligned with your lifestyle—whether it's wellness, creativity, relationships, or career. You like clean layouts, elegant color palettes, and visuals that are both pretty and purposeful.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on design, clarity, aesthetics, and content.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "25-34_male": """You are a man aged 25–34. You value strong, clear, and modern visuals. You're likely to appreciate websites that are bold but not messy—clean grids, high contrast, sharp fonts, and relevant content (fitness, tech, ambition, money).

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its layout, visual punch, and message.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "35-44_female": """You are a woman aged 35–44. You're drawn to websites that are intentional, emotionally intelligent, and visually clean. Family, meaning, and beauty in simplicity appeal to you more than trend-driven clutter.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on its design, clarity, and emotional tone.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "35-44_male": """You are a man aged 35–44. You like websites that are grounded, practical, and cleanly designed. Strong layouts, good use of space, and purpose-driven content grab your attention more than visual noise.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on structure, relevance, and visual balance.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "45-54_female": """You are a woman aged 45–54. You like websites that are calm, clear, and visually composed. Design that feels warm, thoughtful, and emotionally grounded appeals more than flashy visuals or trendy noise.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on clarity, emotional tone, and visual presentation.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "45-54_male": """You are a man aged 45–54. You prefer websites that are easy to navigate, focused, and visually grounded. You're drawn to sites that reflect purpose and clarity over trend or flash.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on usability, structure, and message.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "55+_female": """You are a woman aged 55 or older. You appreciate websites that feel meaningful, visually calm, and easy to understand. Gentle color palettes, clear fonts, and emotionally warm content make a big difference.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on design simplicity and emotional tone.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]""",

    "55+_male": """You are a man aged 55 or older. You value websites that are straightforward, honest, and easy to engage with. Flashy or cluttered pages can feel frustrating, while clear structure and meaningful content feel worthwhile.

You are given 5 example website screenshots and how much everyone liked them (on a 0–10 scale). You're now shown a new website screenshot. Your task is to judge how much you **like** this website based on clarity, usefulness, and visual comfort.

Return your response in this exact format:
Answer: [0–10] ← You must include this numerical score.
Reason: [Why this website does or doesn't appeal to you visually and emotionally in minimal words]"""
}

def batch_verbalize(batch_data, batch_size=16):
    try:
        all_results = []
        for i in tqdm(range(0, len(batch_data), batch_size), desc="Model Inference"):
            current_batch_data = batch_data[i:i+batch_size]
            
            messages_for_batch = []
            for prompt, sys_prompt, images in current_batch_data:
                content = []
                for img_url, _ in images:
                    if img_url:
                        content.append({"type": "image", "image": img_url})
                content.append({"type": "text", "text": prompt})
                
                message = [
                    {"role": "system", "content": sys_prompt},
                    {"role": "user", "content": content}
                ]
                messages_for_batch.append(message)

            try:
                # Preparation for batch inference
                texts = [
                    processor.apply_chat_template(
                        msg,
                        tokenize=False,
                        add_generation_prompt=True,
                        enable_thinking=False,  # Disable Qwen "thinking" tokens (consistent with multillm_webaes.py)
                    )
                    for msg in messages_for_batch
                ]
                
                image_inputs, video_inputs = process_vision_info(messages_for_batch)
                
                inputs = processor(
                    text=texts,
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to(model.device)

                # Batch Inference
                with torch.no_grad():
                    generated_ids = model.generate(
                        **inputs,
                        max_new_tokens=2000,
                        do_sample=True,
                        temperature=0.85,
                        top_p=0.9,
                        pad_token_id=processor.tokenizer.eos_token_id
                    )
                
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                output_texts = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
                all_results.extend(output_texts)

            except Exception as e:
                print(f"Error during a batch inference pass: {e}")
                # Add error placeholders for the failed batch
                all_results.extend([f"Error during model inference: {str(e)}"] * len(current_batch_data))
        
        return all_results
        
    except Exception as e:
        print(f"Error in batch_verbalize function: {e}")
        return [f"Error during batch processing: {str(e)}"] * len(batch_data)

import pandas as pd
import re

def safe_load_image(image_path):
    try:
        # First try with cv2
        image = cv2.imread(image_path)
        if image is not None and image.size > 0:
            return image
    except Exception as e:
        print(f"cv2.imread failed for {image_path}: {e}")
    
    try:
        # Fallback: try with PIL and convert to cv2 format
        pil_image = Image.open(image_path)
        pil_image = pil_image.convert('RGB')
        cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
        return cv2_image
    except Exception as e:
        print(f"PIL fallback failed for {image_path}: {e}")
    
    return None

test_filename = "website-aesthetics-datasets/rating-based-dataset/preprocess/test_list.csv"
try:
    df = pd.read_csv(test_filename)
    print(f"Loaded dataset with {len(df)} samples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

def extract_score_from_response(resp):
    answer = None
    
    # Pattern 1: Look for "Answer: X" format (prioritized since it's our requested format)
    answer_pattern = re.search(r'Answer:\s*(\d+(?:\.\d+)?)', resp, re.IGNORECASE)
    if answer_pattern:
        try:
            answer = float(answer_pattern.group(1))
            answer = max(0.0, min(10.0, answer))
        except ValueError:
            pass
    
    # Pattern 2: Look for score at the end of response
    if answer is None:
        score_pattern = re.search(r'(?:score|rating):\s*(\d+(?:\.\d+)?)', resp, re.IGNORECASE)
        if score_pattern:
            try:
                answer = float(score_pattern.group(1))
                answer = max(0.0, min(10.0, answer))
            except ValueError:
                pass
    
    # Pattern 3: Fallback - any number at the end (last resort)
    if answer is None:
        number_matches = re.findall(r'\b(\d+(?:\.\d+)?)\b', resp)
        # Filter numbers that could be scores (0-10 range)
        valid_scores = [float(n) for n in number_matches if 0 <= float(n) <= 10]
        if valid_scores:
            answer = valid_scores[-1]  # Take the last valid score
    
    return answer

def prepare_sample_data(df, sample_indices):
    batch_samples = []
    skipped_samples = []
    
    print("Preparing sample data...")
    for i in tqdm(sample_indices, desc="Loading samples"):
        try:
            d = df.iloc[i]
            value = d.to_dict()
            image_path = 'website-aesthetics-datasets/rating-based-dataset/images/'+d['image'].replace('_resized','').lstrip('/')
            
            # Check if image file exists
            if not os.path.exists(image_path):
                skipped_samples.append((i, f"Image file does not exist: {image_path}"))
                continue
            
            # Use safe image loading
            image = safe_load_image(image_path)
            
            # Check if main image loaded successfully
            if image is None:
                skipped_samples.append((i, f"Could not load image: {image_path}"))
                continue
                
            image_url = frame_to_data_url(image)
            if image_url is None:
                skipped_samples.append((i, f"Could not process image: {image_path}"))
                continue
            
            # Sample example images
            example_lines = []
            example_images = []
            valid_examples = 0

            # First attempt similarity-based retrieval
            similar_list = SIMILARITY_DATA.get(str(i), {}).get("similar_images", [])
            for sim in similar_list:
                if valid_examples >= 5:
                    break
                try:
                    fname = sim["image"]
                    score = sim.get("mean_score", None)
                    img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                    if not os.path.exists(img_path):
                        continue
                    img = safe_load_image(img_path)
                    if img is None:
                        continue
                    img_url = frame_to_data_url(img)
                    if img_url is None:
                        continue
                    example_lines.append(f"Score: {score:.1f}" if score is not None else "Score: N/A")
                    example_images.append((img_url, score))
                    valid_examples += 1
                except Exception:
                    continue

            # Random sampling fallback if fewer than 5 examples found
            if valid_examples < 5:
                other_indices = list(range(df.shape[0]))
                other_indices.remove(i)
                random.shuffle(other_indices)
                for idx in other_indices:
                    if valid_examples >= 5:
                        break
                    try:
                        row = df.iloc[idx]
                        fname = row['image']
                        score = row['mean_score']
                        img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                        if not os.path.exists(img_path):
                            continue
                        img = safe_load_image(img_path)
                        if img is None:
                            continue
                        img_url = frame_to_data_url(img)
                        if img_url is None:
                            continue
                        example_lines.append(f"Score: {score:.1f}")
                        example_images.append((img_url, score))
                        valid_examples += 1
                    except Exception:
                        continue
            
            # Add the current image as the last one
            example_images.append((image_url, None))
            examples_text = "\n".join(example_lines)
            
            # Create the user prompt based on whether we have examples
            if valid_examples > 0:
                prompt = f"""Given the images below, the first {valid_examples} are example website screenshots with their likeability scores (on a 0-10 scale, see the list below). The last image is the one you should score. 

Carefully analyze the last website screenshot and provide a score between 0 to 10 based on how much people would like the website's visual design, layout, colors, typography, and overall aesthetic appeal.

Here are {valid_examples} example likeability scores (in order):
{examples_text}

Please evaluate the final website screenshot and provide your assessment."""
            
            batch_samples.append({
                'index': i,
                'value': value,
                'prompt': prompt,
                'example_images': example_images,
                'valid_examples': valid_examples
            })
            
        except Exception as e:
            skipped_samples.append((i, f"Unexpected error: {str(e)}"))
            continue
    
    return batch_samples, skipped_samples

NUM_RUNS = args.num_runs

# Load model and processor
logger.info("Loading model and processor...")
model, processor = get_model_and_processor()
logger.info("Model and processor loaded successfully!")

response_dict = []
processed_count = 0
skipped_count = 0
error_count = 0

print(f"Starting processing of {len(df)} samples with batch processing and {len(persona_prompts)} personas...")


# Prepare list of sample indices according to CLI slice
if args.end is not None:
    all_indices = list(range(args.start, min(args.end + 1, len(df))))
else:
    all_indices = list(range(args.start, len(df)))

batch_samples, skipped_samples = prepare_sample_data(df, all_indices)
print(f"Prepared {len(batch_samples)} valid samples, skipped {len(skipped_samples)}")

for sample_idx, sample in enumerate(tqdm(batch_samples, desc="Samples")):
    try:
        persona_results = {}

        image_url = sample['example_images'][-1][0]
        prompt_text = sample['prompt']

        for persona_name, persona_prompt in persona_prompts.items():
            persona_results[persona_name] = {"all_responses": [], "all_predictions": []}

            # Potentially run multiple stochastic generations per persona.
            for _ in range(NUM_RUNS):
                resp = inference_batch(
                    [image_url],
                    [prompt_text],
                    [persona_prompt],
                    model,
                    processor,
                )[0]

                persona_results[persona_name]["all_responses"].append(resp)

                if not resp.startswith("Error"):
                    score_val = extract_score_from_response(resp)
                    if score_val is not None:
                        persona_results[persona_name]["all_predictions"].append(score_val)

        # Compute persona means
        all_persona_means = []
        for pdata in persona_results.values():
            if pdata['all_predictions']:
                mean_pred = np.mean(pdata['all_predictions'])
            else:
                mean_pred = None
            pdata['mean_prediction'] = mean_pred
            pdata['num_valid_predictions'] = len(pdata['all_predictions'])
            if mean_pred is not None:
                all_persona_means.append(mean_pred)

        overall_mean = np.mean(all_persona_means) if all_persona_means else None

        # Attach results to sample value
        sample['value'].update({
            "persona_responses": persona_results,
            "overall_mean_prediction": overall_mean,
            "num_personas": len(persona_prompts),
            "valid_persona_predictions": len(all_persona_means)
        })
        response_dict.append(sample['value'])

        # Incremental save
        output_filename = os.path.join(
            args.output_dir,
            f'results_qwen_persona_static_web_aes_ten_slice_{args.start}_{args.end if args.end is not None else "end"}.json'
        )
        with open(output_filename, 'w') as f_out:
            json.dump(response_dict, f_out, indent=4)
        print(f"💾 Saved progress after sample {sample_idx + 1}/{len(batch_samples)} -> {output_filename}")

    except Exception as e:
        print(f"[ERROR] Failed processing sample idx {sample_idx}: {e}")
        continue

print("All requested samples processed.")
print(f"Total processed: {len(response_dict)} | Skipped: {len(skipped_samples)}")
