import torch, logging
import re


def preview(batch_samples, decoded_answers, postprocess_answer_func):
    print("\n" + "=" * 40 + " Batch Preview " + "=" * 40)
    for i, (sample, ans) in enumerate(zip(batch_samples[:2], decoded_answers[:2])):
        print(f"Sample {i + 1}:")
        print(f"🔍 Question: {sample['question']}\n")
        print(f"💭 Raw Answer: {ans}\n")
        print(f"💭 Processed Answer: {postprocess_answer_func(ans)}")
        print("-" * 95)

def run_inference_qwen25vl_instruct(model, processor, prompt_prefix, prompt_suffix, batch_samples, benchmark_adapter, preview_mode):
    try:
        print("\n🚀 Preparing inputs")
        images_batch = [
            sample["images"] if "images" in sample else [sample["image"]]
            for sample in batch_samples
        ]
        prompts = []
        for sample in batch_samples:
            opts_list = []
            if sample.get("options") and sample["options"].strip() not in ["", "[]"]:
                opts_str = sample["options"]
                matches = re.findall(r"'(.*?)'|\"(.*?)\"", opts_str)
                opts_list = [t[0] if t[0] else t[1] for t in matches]
            if opts_list:
                opts_str = " / ".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(opts_list)])
                q = f"{sample['question']} Options: {opts_str}"
            else:
                q = sample['question']
            prompts.append(f"{prompt_prefix}{q}{prompt_suffix}")
        print("✅ Preparing inputs complete")
    except Exception as e:
        logging.error(f"⛔ Error preparing inputs: {e}")

    try:
        print("\n🚀 Generating")
        conversations = []

        for image, prompt in zip(images_batch, prompts):
            conversations.append([
                {
                    "role": "user",
                    "content":  
                                [{"type": "image", "image": img} for img in image] +
                                [{"type": "text", "text": prompt}]
                }
            ])

        inputs = processor.apply_chat_template(
            conversations,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            padding=True,
            return_tensors="pt"
        ).to(model.device, torch.float16)

        outputs = model.generate(
            **inputs,
            min_length=1,
            max_new_tokens=16,
            do_sample=False,
            temperature=None,
            top_p=None,
            num_beams=1,
            eos_token_id=processor.tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            remove_invalid_values=True,
            num_return_sequences=1
        )
        
        print("✅ Generating complete")
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            raise 
        else:
            logging.error(f"⛔ Error generating: {e}")

    try:
        print("\n🚀 Decoding")
        decoded = processor.batch_decode(outputs, skip_special_tokens=True)
        print("✅ Decoding complete")
    except Exception as e:
        logging.error(f"⛔ Error decoding: {e}")

    if preview_mode:
        try:
            print("\n🚀 Previewing")
            preview(batch_samples, decoded, benchmark_adapter.get_postprocess_answer_func())
            print("✅ Previewing complete")
        except Exception as e:
            logging.error(f"⛔ Error previewing: {e}")

    return [benchmark_adapter.get_postprocess_answer_func()(ans) for ans in decoded]

def run_inference_llava15(model, processor, prompt_prefix, prompt_suffix, batch_samples, benchmark_adapter, preview_mode):
    try:
        print("\n🚀 Preparing inputs")
        images = [
            sample["images"] if "images" in sample else [sample["image"]]
            for sample in batch_samples
        ]
        prompts = []
        for sample in batch_samples:
            opts_list = []
            if sample.get("options") and sample["options"].strip() not in ["", "[]"]:
                opts_str = sample["options"]
                matches = re.findall(r"'(.*?)'|\"(.*?)\"", opts_str)
                opts_list = [t[0] if t[0] else t[1] for t in matches]
            if opts_list:
                opts_str = " / ".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(opts_list)])
                q = f"{sample['question']} Options: {opts_str}"
            else:
                q = sample['question']
            prompts.append(f"{prompt_prefix}{q}{prompt_suffix}")
        print("✅ Preparing inputs complete")
    except Exception as e:
        logging.error(f"⛔ Error preparing inputs: {e}")

    try:
        print("\n🚀 Generating")
        conversations = []
        for image, prompt in zip(images, prompts):
            if len(image) > 1:
                logging.warning(f"⚠️ More than one image found for sample, which is not supported by normal llava-1.5 series model, using the last one")
            conversations.append([
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image[-1]},
                        {"type": "text", "text": prompt}
                    ]
                }
            ])

        inputs = processor.apply_chat_template(
        conversations,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        padding=True,
        return_tensors="pt"
        ).to(model.device, torch.float16)

        outputs = model.generate(
            **inputs,
            min_length=1,
            max_new_tokens=16,
            do_sample=False,
            temperature=None,
            top_p=None,
            num_beams=1,
            eos_token_id=processor.tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            remove_invalid_values=True,
            num_return_sequences=1
        )
        print("✅ Generating complete")
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            raise
        else:
            logging.error(f"⛔ Error generating: {e}")

    try:
        print("\n🚀 Decoding")
        decoded = processor.batch_decode(outputs, skip_special_tokens=True)
        decoded = [d.split("ASSISTANT: ")[-1].strip() if len(d.split("ASSISTANT: ")) > 1 else "" for d in decoded]
        print("✅ Decoding complete")
    except Exception as e:
        logging.error(f"⛔ Error decoding: {e}")

    if preview_mode:
        try:
            print("\n🚀 Previewing")
            preview(batch_samples, decoded, benchmark_adapter.get_postprocess_answer_func())
            print("✅ Previewing complete")
        except Exception as e:
            logging.error(f"⛔ Error previewing: {e}")

    return [benchmark_adapter.get_postprocess_answer_func()(ans) for ans in decoded]

def run_inference_llama32vision_instruct(model, processor, prompt_prefix, prompt_suffix, batch_samples, benchmark_adapter, preview_mode):
    try:
        print("\n🚀 Preparing inputs")
        images = [sample["images"] if "images" in sample else [sample["image"]] for sample in batch_samples]
        prompts = []
        for sample in batch_samples:
            opts_list = []
            if sample.get("options") and sample["options"].strip() not in ["", "[]"]:
                opts_str = sample["options"]
                matches = re.findall(r"'(.*?)'|\"(.*?)\"", opts_str)
                opts_list = [t[0] if t[0] else t[1] for t in matches]
            if opts_list:
                opts_str = " / ".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(opts_list)])
                q = f"{sample['question']} Options: {opts_str}"
            else:
                q = sample['question']
            prompts.append(f"{prompt_prefix}{q}{prompt_suffix}")
        print("✅ Preparing inputs complete")
    except Exception as e:
        logging.error(f"⛔ Error preparing inputs: {e}")

    try:
        print("\n🚀 Generating")

        inputs = processor(
            text=prompts,
            images=images,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
            return_token_type_ids=False
        ).to(model.device)
        print("processor outputs:", {k: v.shape for k, v in inputs.items()})

        outputs = model.generate(
            **inputs,
            min_length=1,
            max_new_tokens=16,
            do_sample=False,
            temperature=None,
            top_p=None,
            num_beams=1,
            eos_token_id=processor.tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            remove_invalid_values=True,
            num_return_sequences=1
        )
        print("✅ Generating complete")
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            raise
        else:
            logging.error(f"⛔ Error generating: {e}")

    try:
        print("\n🚀 Decoding")
        decoded = processor.batch_decode(outputs, skip_special_tokens=True)
        print("✅ Decoding complete")
    except Exception as e:
        logging.error(f"⛔ Error decoding: {e}")

    if preview_mode:
        try:
            print("\n🚀 Previewing")
            preview(batch_samples, decoded, benchmark_adapter.get_postprocess_answer_func())
            print("✅ Previewing complete")
        except Exception as e:
            logging.error(f"⛔ Error previewing: {e}")

    return [benchmark_adapter.get_postprocess_answer_func()(ans) for ans in decoded]