import os
import json
import torch
from tqdm import tqdm
import sys
import re
import base64
import cv2
from PIL import Image
from io import BytesIO
import random
import numpy as np
import pandas as pd
import datetime
import warnings
import logging
from transformers import MllamaForConditionalGeneration, AutoProcessor
import argparse

# Set up logging and warnings
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ---------------- CLI ARGUMENTS & GLOBALS ----------------
parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=0, help="Starting sample index (inclusive)")
parser.add_argument("--end", type=int, default=None, help="Ending sample index (inclusive). If not provided, process till the last sample")
parser.add_argument("--output_dir", type=str, default=".", help="Directory to save result JSONs")
parser.add_argument("--num_runs", type=int, default=1, help="Number of stochastic inference runs per persona (set to 1 for single-run deterministic behaviour)")
args = parser.parse_args()

# Load similarity mapping (for few-shot selection)
SIMILARITY_JSON_PATH = "similaritywebaes.json"
with open(SIMILARITY_JSON_PATH, "r") as f:
    SIMILARITY_DATA = json.load(f)

def get_model_and_processor(model_dir="meta-llama/Llama-3.2-90B-Vision"):
    """
    Load the VLM model and processor (standard HuggingFace load, no unsloth).
    """
    logger.info(f"Loading model and processor from {model_dir}")
    processor = AutoProcessor.from_pretrained(model_dir)
    model = MllamaForConditionalGeneration.from_pretrained(
        model_dir,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True
    )
    return model, processor

def _data_url_to_pil(data_url):
    """Decode a base64 data-URL image into a PIL.Image"""
    try:
        header, encoded = data_url.split(",", 1)
        img_bytes = base64.b64decode(encoded)
        return Image.open(BytesIO(img_bytes)).convert("RGB")
    except Exception:
        return None

def inference_batch(image_urls, prompts, sys_prompts, model, processor, max_new_tokens=512):
    """Inference helper that uses chat templates for the Instruct model.

    Args:
        image_urls (List[str]): data-URL strings for each image.
        prompts (List[str]): user prompts.
        sys_prompts (List[str]): persona/system prompts.
    Returns:
        List[str]: generated responses per sample.
    """
    messages_batch = []
    for image_url, prompt, sys_prompt in zip(image_urls, prompts, sys_prompts):
        messages_batch.append([
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": [
                {"type": "image", "image": image_url},
                {"type": "text", "text": sys_prompt + prompt},
            ]}
        ])

    # Prepare the input for the processor
    texts = [
        processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        for messages in messages_batch
    ]

    # Convert image URLs to PIL images (one per sample)
    pil_images = []
    for image_url in image_urls:
        pil_img = _data_url_to_pil(image_url)
        if pil_img is None:
            pil_img = Image.new("RGB", (256, 256), "white")
        pil_images.append(pil_img)

    inputs = processor(
        text=texts,
        images=pil_images,
        padding=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.85,
            top_p=0.9,
            pad_token_id=processor.tokenizer.eos_token_id,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_texts

def frame_to_data_url(frame_bgr):
    try:
        # Check if the frame is valid
        if frame_bgr is None or frame_bgr.size == 0:
            return None
        
        # Convert the BGR frame (OpenCV format) to RGB
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

        # Convert the RGB frame to a PIL Image
        image = Image.fromarray(frame_rgb)
        image = image.resize((256, 256), Image.LANCZOS)
        # Create a BytesIO buffer to hold the image data
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        buffered.seek(0)

        # Encode the image data in base64
        base64_encoded_data = base64.b64encode(buffered.read()).decode('utf-8')

        # Construct the data URL
        return f"data:image/jpeg;base64,{base64_encoded_data}"
    except Exception as e:
        print(f"Error in frame_to_data_url: {e}")
        return None

# Load the model on the available device(s) with flash_attention_2
# print("Loading Qwen/Qwen2.5-VL-7B-Instruct model with Flash Attention 2...")
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,  # FA2 needs bfloat16 or float16
#     attn_implementation="eager",
#     device_map="auto",
# )
# print("Model loaded.")
#
# # default processor
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# print("Processor loaded.")



# ---------------------------------------------------------------------------
# Helper to extract vision inputs (images/videos) from chat-template messages
# ---------------------------------------------------------------------------
def process_vision_info(messages_batch):
    image_inputs = []
    video_inputs = []

    for sample in messages_batch:
        sample_images = []
        sample_videos = []

        for turn in sample:
            content = turn.get("content", [])
            if not isinstance(content, list):
                continue

            for item in content:
                if not isinstance(item, dict):
                    continue
                if item.get("type") == "image":
                    sample_images.append(item.get("image"))
                elif item.get("type") == "video":
                    sample_videos.append(item.get("video"))

        image_inputs.append(sample_images)
        video_inputs.append(sample_videos)

    return image_inputs, video_inputs

def batch_verbalize(batch_data, batch_size=16):
    """Mini-batch inference that avoids chat templates."""
    all_results = []

    for idx in tqdm(range(0, len(batch_data), batch_size), desc="Model Inference"):
        current_batch = batch_data[idx : idx + batch_size]

        messages_for_batch = []
        for prompt, sys_prompt, img_tuples in current_batch:
            # last image in the tuple list is the target
            data_url = img_tuples[-1][0] if img_tuples else None
            messages_for_batch.append([
                {"role": "system", "content": ""},
                {"role": "user", "content": [
                    {"type": "image", "image": data_url},
                    {"type": "text", "text": prompt},
                ]}
            ])

        try:
            texts = [
                processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                for messages in messages_for_batch
            ]

            # Convert each data URL to a PIL image (one per sample)
            pil_images_batch = []
            for msg in messages_for_batch:
                # The first "user" content is always the image dictionary we inserted
                img_url = None
                for item in msg[1]["content"]:
                    if item.get("type") == "image":
                        img_url = item["image"]
                        break
                pil = _data_url_to_pil(img_url) if img_url else None
                if pil is None:
                    pil = Image.new("RGB", (256, 256), "white")
                pil_images_batch.append(pil)

            inputs = processor(
                text=texts,
                images=pil_images_batch,
                padding=True,
                return_tensors="pt",
            ).to(model.device)

            with torch.no_grad():
                gen_ids = model.generate(
                    **inputs,
                    max_new_tokens=1200,
                    do_sample=True,
                    temperature=0.85,
                    top_p=0.9,
                    pad_token_id=processor.tokenizer.eos_token_id,
                )
            
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, gen_ids)
            ]
            outputs = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            all_results.extend(outputs)

        except Exception as e:
            print(f"Error during batch inference: {e}")
            all_results.extend([f"Error: {e}"] * len(current_batch))

    return all_results

import pandas as pd
import re

def safe_load_image(image_path):
    """Safely load an image with multiple fallback methods"""
    try:
        # First try with cv2
        image = cv2.imread(image_path)
        if image is not None and image.size > 0:
            return image
    except Exception as e:
        print(f"cv2.imread failed for {image_path}: {e}")
    
    try:
        # Fallback: try with PIL and convert to cv2 format
        pil_image = Image.open(image_path)
        pil_image = pil_image.convert('RGB')
        cv2_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
        return cv2_image
    except Exception as e:
        print(f"PIL fallback failed for {image_path}: {e}")
    
    return None

test_filename = "website-aesthetics-datasets/rating-based-dataset/preprocess/test_list.csv"
try:
    df = pd.read_csv(test_filename)
    print(f"Loaded dataset with {len(df)} samples")
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

def extract_score_from_response(resp):
    """Extract numerical score from response with multiple patterns"""
    answer = None
    
    # Pattern 1: More flexible – tolerate punctuation / markdown between word and number
    answer_pattern = re.search(r'(?i)answer[^0-9]{0,10}(\d+(?:\.\d+)?)', resp)
    if answer_pattern:
        try:
            answer = float(answer_pattern.group(1))
            answer = max(0.0, min(10.0, answer))
        except ValueError:
            pass
    
    # Pattern 2: Look for score at the end of response
    if answer is None:
        score_pattern = re.search(r'(?:score|rating):\s*(\d+(?:\.\d+)?)', resp, re.IGNORECASE)
        if score_pattern:
            try:
                answer = float(score_pattern.group(1))
                answer = max(0.0, min(10.0, answer))
            except ValueError:
                pass
    
    # Pattern 3: Fallback - any number at the end (last resort)
    if answer is None:
        number_matches = re.findall(r'\b(\d+(?:\.\d+)?)\b', resp)
        # Filter numbers that could be scores (0-10 range)
        valid_scores = [float(n) for n in number_matches if 0 <= float(n) <= 10]
        if valid_scores:
            answer = valid_scores[-1]  # Take the last valid score
    
    return answer

def prepare_sample_data(df, sample_indices):
    """Prepare batch data for all samples"""
    batch_samples = []
    skipped_samples = []
    
    print("Preparing sample data...")
    for i in tqdm(sample_indices, desc="Loading samples"):
        try:
            d = df.iloc[i]
            value = d.to_dict()
            image_path = 'website-aesthetics-datasets/rating-based-dataset/images/'+d['image'].replace('_resized','').lstrip('/')
            
            # Check if image file exists
            if not os.path.exists(image_path):
                skipped_samples.append((i, f"Image file does not exist: {image_path}"))
                continue
            
            # Use safe image loading
            image = safe_load_image(image_path)
            
            # Check if main image loaded successfully
            if image is None:
                skipped_samples.append((i, f"Could not load image: {image_path}"))
                continue
                
            image_url = frame_to_data_url(image)
            if image_url is None:
                skipped_samples.append((i, f"Could not process image: {image_path}"))
                continue
            
            # Sample example images
            example_lines = []
            example_images = []
            valid_examples = 0

            # Try similarity-based retrieval first
            similar_list = SIMILARITY_DATA.get(str(i), {}).get("similar_images", [])
            for sim in similar_list:
                if valid_examples >= 5:
                    break
                try:
                    fname = sim["image"]
                    score = sim.get("mean_score", None)
                    img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                    if not os.path.exists(img_path):
                        continue
                    img = safe_load_image(img_path)
                    if img is None:
                        continue
                    img_url = frame_to_data_url(img)
                    if img_url is None:
                        continue
                    example_lines.append(f"Score: {score:.1f}" if score is not None else "Score: N/A")
                    example_images.append((img_url, score))
                    valid_examples += 1
                except Exception:
                    continue

            # Fallback – random sampling if fewer than 5 similar examples
            if valid_examples < 5:
                other_indices = list(range(df.shape[0]))
                other_indices.remove(i)
                random.shuffle(other_indices)
                for idx in other_indices:
                    if valid_examples >= 5:
                        break
                    try:
                        row = df.iloc[idx]
                        fname = row['image']
                        score = row['mean_score']
                        img_path = 'website-aesthetics-datasets/rating-based-dataset/images/' + fname.replace('_resized', '').lstrip('/')
                        if not os.path.exists(img_path):
                            continue
                        img = safe_load_image(img_path)
                        if img is None:
                            continue
                        img_url = frame_to_data_url(img)
                        if img_url is None:
                            continue
                        example_lines.append(f"Score: {score:.1f}")
                        example_images.append((img_url, score))
                        valid_examples += 1
                    except Exception:
                        continue
            
            # Add the current image as the last one
            example_images.append((image_url, None))
            examples_text = "\n".join(example_lines)
            
            # Create the user prompt based on whether we have examples
            if valid_examples > 0:
                prompt = f"""Given the images below, the first {valid_examples} are example website screenshots with their likeability scores (on a 0-10 scale, see the list below). The last image is the one you should score. 

Carefully analyze the last website screenshot and provide a score between 0 to 10 based on how much people would like the website's visual design, layout, colors, typography, and overall aesthetic appeal.

Here are {valid_examples} example likeability scores (in order):
{examples_text}

Please evaluate the final website screenshot and provide your assessment."""
            
            batch_samples.append({
                'index': i,
                'value': value,
                'prompt': prompt,
                'example_images': example_images,
                'valid_examples': valid_examples
            })
            
        except Exception as e:
            skipped_samples.append((i, f"Unexpected error: {str(e)}"))
            continue
    
    return batch_samples, skipped_samples

# -----------------------------------------------------------------------------
# Runtime configuration (no explicit multi-sample batching)
# -----------------------------------------------------------------------------
# We process one (image, prompt) pair at a time; GPU parallelism is handled by
# model/device_map rather than by stacking samples. If you wish to sample
# multiple completions per persona for robustness, set --num_runs accordingly.
NUM_RUNS = 10
GENERIC_SYSTEM_PROMPT = ""

# Load model and processor
logger.info("Loading model and processor...")
model, processor = get_model_and_processor("meta-llama/Llama-3.2-90B-Vision-Instruct")
logger.info("Model and processor loaded successfully!")

response_dict = []
processed_count = 0
skipped_count = 0
error_count = 0

print(f"Starting processing of {len(df)} samples...")

# ------------------------------------ MAIN PROCESSING ------------------------------------
# Prepare list of sample indices according to CLI slice
if args.end is not None:
    all_indices = list(range(args.start, min(args.end + 1, len(df))))
else:
    all_indices = list(range(args.start, len(df)))

batch_samples, skipped_samples = prepare_sample_data(df, all_indices)
print(f"Prepared {len(batch_samples)} valid samples, skipped {len(skipped_samples)}")

for sample_idx, sample in enumerate(tqdm(batch_samples, desc="Samples")):
    try:
        image_url = sample['example_images'][-1][0]
        prompt_text = sample['prompt']

        all_responses = []
        all_predictions = []

        for run_idx in range(NUM_RUNS):
            # Single-inference call
            resp = inference_batch(
                [image_url],
                [prompt_text],
                [GENERIC_SYSTEM_PROMPT],
                model,
                processor,
            )[0]

            all_responses.append(resp)

            if not resp.startswith("Error"):
                score_val = extract_score_from_response(resp)
                if score_val is not None:
                    all_predictions.append(score_val)

        mean_prediction = np.mean(all_predictions) if all_predictions else None
        
        # Attach results to sample value
        sample['value'].update({
            "responses": all_responses,
            "predictions": all_predictions,
            "mean_prediction": mean_prediction,
            "num_runs": NUM_RUNS,
            "num_valid_predictions": len(all_predictions)
        })
        response_dict.append(sample['value'])

        # Incremental save
        output_filename = os.path.join(
            args.output_dir,
            f'results_llama_nopersona_static_web_aes_ten_slice_{args.start}_{args.end if args.end is not None else "end"}.json'
        )
        with open(output_filename, 'w') as f_out:
            json.dump(response_dict, f_out, indent=4)
        print(f"💾 Saved progress after sample {sample_idx + 1}/{len(batch_samples)} -> {output_filename}")

    except Exception as e:
        print(f"[ERROR] Failed processing sample idx {sample_idx}: {e}")
        continue

print("✅ All requested samples processed.")
print(f"Total processed: {len(response_dict)} | Skipped: {len(skipped_samples)}")
