import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from src.models import get_model
from src.utils.parser_utils import get_parser
from src.prompt import qa_prompt, qa_context_prompt, qa_image_prompt, qa_blend_prompt

PROMPTS = [
    "As a specialist in answering queries, please provide the direct answer to the question without any additional explanation or follow-up questions.",
    "Please use your expertise in question answering to give the answer directly. Avoid offering any explanations or further inquiries.",
    "Relying on your knowledge of answering questions, output the answer succinctly without including any extra information or asking further questions.",
    "With your proficiency in question answering, provide just the answer to the given question. Do not include any explanations or follow-up questions.",
    "As a question-answering expert, your task is to give the answer only. Refrain from offering any explanations or asking subsequent questions.",
    "Given your expertise in providing answers, please offer the answer directly to the question, and do not include further explanations or additional questions.",
    "Use your proficiency in answering questions to supply only the answer. Avoid explaining or posing any further questions.",
    "Utilizing your skill in question answering, please deliver the answer right away. No explanations or further questions should be included.",
    "With your experience in giving direct answers, provide the answer to the question without elaboration or posing additional questions.",
    "As an adept in question answering, please produce only the answer to the question and omit any explanations or further",
]

def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for module in model.modules():
        if module.__class__.__name__.startswith('Dropout'):
            module.train()

def monte_carlo_predictions(model, inputs, n_samples):
    model.train()  # Set the model to training mode to enable dropout
    predictions = []
    for _ in range(n_samples):
        with torch.no_grad():
            outputs = model(**inputs)
            predictions.append(outputs.logits.unsqueeze(0))
    return torch.cat(predictions)

def main():
    parser = get_parser()
    parser.add_argument("--uncertainty_method", choices=["prompt", "dropout"])
    parser.add_argument("--is_scored", action="store_true")
    args = parser.parse_args()
    
    if args.greedy:
        args.temperature = 0.0
        
    # load dataset
    if "mc" in args.dataset:
        if "cleaned" in args.dataset:
            with open("data/viquae/cleaned_dataset_mc.json", "r") as fin:
                dataset = json.load(fin)
        else:
            with open("data/viquae/multiple_choice_data.json", "r") as fin:
                dataset = json.load(fin)
    else:
        if "full" in args.dataset:
            dataset = []
            datasets = load_dataset("PaulLerner/viquae_dataset")
            for ds_name in ["train", "validation", "test"]:
                ds = datasets[ds_name]
                for d in ds:
                    dataset.append(d)
        elif "clean" in args.dataset:
            with open("data/viquae/cleaned_dataset.json", "r") as fin:
                dataset = json.load(fin)
        else:
            dataset = load_dataset("PaulLerner/viquae_dataset")["train"]

    if "textual" in args.dataset:
        prompt = qa_prompt
        mode = "textual"
        if "caption" in args.dataset:
            if "prompted" in args.dataset:
                captions = {}
                with open("outputs/caption_prompted_viquae_mc_llava_T0.0.txt", "r") as fin:
                    for line in fin.readlines():
                        captions.update(json.loads(line))
            elif "reorganized" in args.dataset:
                mode = "textual_reorganized"
                captions = {}
                with open("outputs/analysis/viquae/viquae_recognize_llava_T0.0.txt", "r") as fin:
                    for line in fin.readlines():
                        captions.update(json.loads(line))
                for k, v in captions.items():
                    captions[k] = f"This is an image of {v}."
            else:
                captions = {}
                with open("outputs/caption_viquae_mc_llava_T0.0.txt", "r") as fin:
                    for line in fin.readlines():
                        captions.update(json.loads(line))
    elif "visual" in args.dataset:
        prompt = qa_prompt
        mode = "visual"
    elif "recognize" in args.dataset:
        mode = "recognize"
        prompt = ""
    elif "blend" in args.dataset:
        prompt = qa_prompt
        mode = "blend"
    elif "blank" in args.dataset:
        prompt = qa_prompt
        mode = "blank"
    elif "pad" in args.dataset:
        prompt = qa_prompt
        mode = "pad"
    output_path = os.path.join(args.output_dir, f"{args.dataset}_{args.model_name}_{args.uncertainty_method}_T{args.temperature}.txt.all")
    
    if args.is_scored:
        output_path += ".score"
    
    model = get_model(args)(args, prompt=prompt)
    if "textual" in mode:
        model.remode("text")

    pb = tqdm(range(len(dataset)))
    for data in dataset:
        data_id = data["id"]
        if mode in ["visual", "textual_reorganized"]:
            question = data["input"]
        else:
            question = data["original_question"]
        if "mc" in args.dataset:
            choices = data["multiple_choices"]
            choices_text = ""
            for c_name, c_content in choices.items():
                choices_text += f"{c_name}: {c_content}\n"
            text = f"Question:\n{question}\nOption:\n{choices_text}"
        else:
            text = f"Question:\n{question}"
        if "caption" in args.dataset:
            caption = captions.get(data_id)
            if caption is None:
                caption = ""
            text = caption + "\n" + text
        
        # print(data["original_question"])    
        # print(text)
        # input()
        
        if args.uncertainty_method == "prompt":
            answers = []
            for prompt in PROMPTS:
                model.prompt = prompt
                if mode == "visual":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": text, "image": image}
                elif "textual" in mode:
                    context = {"text": text}
                elif mode == "blend":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": text, "image": image}
                elif mode == "recognize":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": "What/Who is in the image? Do not describe details. Just give a named entity, e.g. Jackie Chan.", "image": image}
                elif mode == "blank":
                    context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
                elif mode == "pad":
                    model.mode = "zero_padding"
                    context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
                context.update({"is_scored": args.is_scored})
                answer = model.chat(**context)
                answers.append(answer)
                
            with open(output_path, "a+") as fout:
                fout.write(f"{json.dumps({data_id: answers})}\n")
            pb.update(1)
            
        elif args.uncertainty_method == "dropout":
            # enable_dropout(model.model)
            model.model.train()
            model.model.training = True
            for layer in model.model.language_model.model.layers:
                layer.self_attn.attention_dropout = 0.1
            answers = []
            for _ in range(10):
                if mode == "visual":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": text, "image": image}
                elif "textual" in mode:
                    context = {"text": text}
                elif mode == "blend":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": text, "image": image}
                elif mode == "recognize":
                    image = Image.open(os.path.join("data/viquae/images", data["image"]))
                    context = {"text": "What/Who is in the image? Do not describe details. Just give a named entity, e.g. Jackie Chan.", "image": image}
                elif mode == "blank":
                    context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
                elif mode == "pad":
                    model.mode = "zero_padding"
                    context = {"text": text, "image": Image.new('RGB', (336, 336), color = (255,255,255))}
                context.update({"is_scored": args.is_scored})
                answer = model.chat(**context)
                answers.append(answer)

            with open(output_path, "a+") as fout:
                fout.write(f"{json.dumps({data_id: answers})}\n")
            pb.update(1)

if __name__ == "__main__":
    main()            