import os
import torch
import random
import numpy as np
import pandas as pd

from collections import defaultdict
from types import SimpleNamespace
from torchvision import transforms
from PIL import Image
from .read_utils import transform_image_path
from modules.prompt import (qa_prompt, qa_context_prompt, qa_image_prompt, 
qa_blend_prompt, paraphrase, generate_unrelated_distraction_prompts, RandomImageIterator, blip_example_format, llava_example_format)


image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]

def text_aug_for_ImageHeavy(batch, dataset_nickname):
    aug_contexts = defaultdict(list)
    gcg_adv_data = pd.read_json(f"data/{dataset_nickname}/perturb_GCG_K5L16.jsonl", lines=True)
    id_gcg_map = dict(zip(gcg_adv_data['id'], gcg_adv_data['perturbed_GCG_suffix']))
    
    for sample in ['related_text', 'origin', 'unrelated_text']:
        for data in batch:
            question = data["question"]
            choices = data["multiple_choices"]
            choices_text = ""
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"

            image = Image.open(os.path.join(f"data/{dataset_nickname}/images", data["image"]))

            if sample == 'related_text':
                interfere_facts = list(data.get("facts").values())
                add_text = np.random.choice(interfere_facts)
                question = add_text + ' ' + question
            elif sample == 'origin':
                pass
            elif sample == 'unrelated_text':
                add_text = generate_unrelated_distraction_prompts()
                question = add_text + question
            elif sample == 'GCG':
                gcg_suffix = id_gcg_map[data["id"]]
                text = f"Question:\n{question}\nOption:\n{choices_text}" + gcg_suffix
                context = {"text": text, "image": image}
                aug_contexts[sample].append(context)
                continue
            elif sample == 'TextFooler':
                raise NotImplementedError
            text = f"Question:\n{question}\nOption:\n{choices_text}"
            context = {"text": text, "image": image}
            aug_contexts[sample].append(context)
    return aug_contexts

def image_aug_for_TextHeavy(batch, dataset_nickname, model_nickname=None):
    aug_contexts = defaultdict(list)
    image_loader = RandomImageIterator(f"image_data_loader...")
    for sample in ['random', 'switch', 'full_black', 'full_white']:
        for data in batch:
            question = data["question"]
            choices = data["multiple_choices"]
            choices_text = ""
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"
            text = f"Question:\n{question}\nOption:\n{choices_text}"
            if sample == 'random':
                image = Image.fromarray(np.random.randint(0, 256, (336, 336, 3), dtype=np.uint8))
            elif sample == 'switch':
                image_path = image_loader.get_random()
                image = Image.open(image_path)
            elif sample == 'full_black':
                image = Image.new('RGB', (336, 336), color = 'black')
            elif sample == 'full_white':
                image = Image.new('RGB', (336, 336), color = 'white')
            context = {"text": text, "image": image}
            aug_contexts[sample].append(context)
    return aug_contexts

def vqa_origin_format(batch, dataset_nickname):
    aug_contexts = defaultdict(list)
    for data in batch:
        question = data["question"]
        choices = data["multiple_choices"]
        choices_text = ""
        for c_name, c_content in choices.items():
            choices_text += f"{c_name}: {c_content}\n"

        image_path = transform_image_path(dataset_nickname, data["image"])
        image = Image.open(image_path)

        text = f"Question:\n{question}\nOption:\n{choices_text}"
        context = {"text": text, "image" : image}
        aug_contexts['origin'].append(context)
    return aug_contexts

def construct_icl_prompt_for_image_heavy_with_answer(batch, dataset_nickname, model_nickname):
    """
    Constructs ICL-style prompts for image-heavy tasks, including answers.
    Each example consists of the related question, multiple-choice options, and the correct answer.
    """
    aug_contexts = defaultdict(list) 
    if "llava" in model_nickname:
        example_format = "Example:\nUSER: {text}ASSISTANT: {answer}\n"
    elif "blip" in model_nickname:
        example_format = "Example:\n{user_input}Answer: {answer}\n"
    
    for sample in ['related_text', 'origin', 'unrelated_text']:
        for data in batch:
            icl_context = ""
            question = data["question"]
            choices = data["multiple_choices"]
            gold_answer = data["multiple_choices_answer"]
            choices_text = ""
            
            # Format choices into text
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"

            # Load image
            try:
                image = Image.open(os.path.join(f"data/{dataset_nickname}/images", data["image"]))
            except:
                image = Image.open(os.path.join(f"data/{dataset_nickname}/test", data["image"]))
            #origin setting
            origin_question = f"Question:\n{question}\nOption:\n{choices_text}\n"

            # related text
            interfere_facts = list(data.get("facts", {}).values())
            if interfere_facts:  # Ensure there are facts to choose from
                add_text = np.random.choice(interfere_facts)
                related_text_question = add_text + ' ' + question
                related_text_question = f"Question:\n{related_text_question}\nOption:\n{choices_text}\n"
            
            # unrelated text
            add_text = generate_unrelated_distraction_prompts()
            unrelated_text_question = add_text + question
            unrelated_text_question = f"Question:\n{unrelated_text_question}\nOption:\n{choices_text}\n"
            
            # Construct ICL example
            if sample == 'related_text':
                examples = [example_format.format(text=qes, answer=gold_answer) for qes in [origin_question, unrelated_text_question]]
                icl_context += ''.join(examples)
                text = related_text_question
    
            elif sample == 'origin': 
                examples = [example_format.format(text=qes, answer=gold_answer) for qes in [related_text_question, unrelated_text_question]]
                icl_context += ''.join(examples)
                text = origin_question

            elif sample == 'unrelated_text':
                examples = [example_format.format(text=qes, answer=gold_answer) for qes in [origin_question, related_text_question]]
                icl_context += ''.join(examples)
                text = unrelated_text_question

            context = {'text' : text, 'image' : image, 'icl_context' : icl_context}
            aug_contexts[sample].append(context)    

    return aug_contexts
