import os
import re
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from vllm import LLM, SamplingParams

from src.models import get_model
from src.utils.parser_utils import get_parser
from src.prompt import qa_prompt, qa_context_prompt, qa_image_prompt, qa_blend_prompt

def clean_answer(options, answer):
    if type(answer) is list:
        answer = answer[0]
    pattern = r"\b[A-D]\b|[A-D](?=\s|:)"
    match = re.search(pattern, answer) 
    if match is None:   
        for option, content in options.items():
            if content in answer:
                return option
        return None
    else:
        return match.group()
    
def get_answer_and_prob(options, pred):
    answer = clean_answer(options, pred[0])
    if answer is None:
        prob = max(pred[1])
        answer = chr(ord("A") + np.argmax(pred[1]))
    else:
        prob = pred[1][ord(answer) - ord("A")]
    return answer, prob

def main():
    parser = get_parser()
    parser.add_argument("--method", choices=["post_hoc", "prob"])
    parser.add_argument("--conflict_prompt", choices=["fixed", "answer", "prob"])
    parser.add_argument("--prob_method", choices=["max", "compare"])
    args = parser.parse_args()
    if args.greedy:
        args.temperature = 0.0
        
    # load dataset
    if "mc" in args.dataset:
        if "cleaned" in args.dataset:
            with open("data/viquae/cleaned_dataset_mc.json", "r") as fin:
                dataset = json.load(fin)
        else:
            with open("data/viquae/multiple_choice_data.json", "r") as fin:
                dataset = json.load(fin)
    else:
        if "full" in args.dataset:
            dataset = []
            datasets = load_dataset("PaulLerner/viquae_dataset")
            for ds_name in ["train", "validation", "test"]:
                ds = datasets[ds_name]
                for d in ds:
                    dataset.append(d)
        elif "clean" in args.dataset:
            with open("data/viquae/cleaned_dataset.json", "r") as fin:
                dataset = json.load(fin)
        else:
            dataset = load_dataset("PaulLerner/viquae_dataset")["train"]
    
    if args.method == "post_hoc":
        # conflicts = {}
        # with open("outputs/inference_time/post_hoc_preprocess.txt", "r") as fin:
        #     for line in fin.readlines():
        #         conflicts.update(json.loads(line))
        model_nickname = args.model_name.split("/")[-1]
        text_preds = {}
        with open("outputs/analysis/viquae_mc_textual_caption_reorganized_llava_T0.0.txt.score", "r") as fin:
            for line in fin.readlines():
                text_preds.update(json.loads(line))
        visual_preds = {}
        with open("outputs/analysis/viquae_mc_visual_llava_T0.0.txt.score", "r") as fin:
            for line in fin.readlines():
                visual_preds.update(json.loads(line))
    elif args.method == "prob":
        text_preds = {}
        with open("outputs/analysis/viquae/viquae_mc_textual_caption_reorganized_llava_T0.0.txt.score", "r") as fin:
            for line in fin.readlines():
                text_preds.update(json.loads(line))
        visual_preds = {}
        with open("outputs/analysis/elicit_viquae_mc_llava_visual_KL-1_alpha0.0_beta0.txt.score", "r") as fin:
            for line in fin.readlines():
                visual_preds.update(json.loads(line))

    output_path = os.path.join(args.output_dir, f"{args.dataset}_{args.model_name}_CD_{args.method}_{args.conflict_prompt}_{args.prob_method}_T{args.temperature}.txt")
    
    if args.method == "post_hoc":
        # model = get_model(args)(args, prompt=qa_prompt)
        model = LLM(
            model=args.model_name,
            gpu_memory_utilization=0.8,
            tensor_parallel_size=torch.cuda.device_count(),
            # multi_modal_input_type=VisionLanguageConfig.IMAGE_INPUT_TYPE.IMAGE
        )
        pb = tqdm(range(len(dataset)))
        for data in dataset:
            data_id = data["id"]
            question = data["input"]
            if "mc" in args.dataset:
                choices = data["multiple_choices"]
                choices_text = ""
                for c_name, c_content in choices.items():
                    choices_text += f"{c_name}: {c_content}\n"
                text = f"Question:\n{question}\nOption:\n{choices_text}"
            else:
                text = f"Question:\n{question}"
            if args.method == "post_hoc":
                # extract answer and prob
                text_pred = text_preds.get(data_id)
                text_answer, text_prob = get_answer_and_prob(choices, text_pred)
                visual_pred = visual_preds.get(data_id)
                visual_answer, visual_prob = get_answer_and_prob(choices, visual_pred)
                # print(text_pred)
                # print(visual_pred)
                # print(text_answer)
                # print(text_prob)
                # print(visual_answer)
                # print(visual_prob)
                # input()
                if text_answer != visual_answer:
                    if args.conflict_prompt == "fixed":
                        text = "Mind that you may have knowledge conflict in this question\n" + text
                    elif args.conflict_prompt == "answer":
                        conflict_prompt = f"Your textual parametric knowledge is: {text_answer} and your visual parametric knowledge is: {visual_answer}."
                        text = text + conflict_prompt
                    elif args.conflict_prompt == "prob":
                        conflict_prompt = f"\nCurrently, you are having some conflicts in your memory. Your textual memory is: {text_answer} with a probability of {text_prob} and your visual memory is: {visual_answer} with a probability of {visual_prob}. You should decide which memory you are more inclined to believe is true or you can choose not to believe in any of them. You can follow the following rules: \nIf the question is about date or location, textual memory is more reliable.\nIf the question is about name or color, visual memory is more reliable. Mind that you have to give an answer. "
                        text = text + conflict_prompt
            image = Image.open(os.path.join("data/viquae/images", data["image"]))
            context = {"text": text, "image": image}
            answer = model.chat(**context)
            with open(output_path, "a+") as fout:
                fout.write(f"{json.dumps({data_id: answer})}\n")
            pb.update(1)
    elif args.method == "prob":
        pb = tqdm(range(len(dataset)))
        for data in dataset:
            data_id = data["id"]
            question = data["input"]
            text_pred = text_preds.get(data_id)
            visual_pred = visual_preds.get(data_id)
            if text_pred is None or visual_pred is None:
                continue
            text_prob = text_pred[1]
            visual_prob = visual_pred[1]
            text_index = np.argmax(text_prob)
            visual_index = np.argmax(visual_prob)
            if text_index == visual_index:
                with open(output_path, "a+") as fout:
                    fout.write(f"{json.dumps({data_id: chr(ord("A") + text_index)})}\n")
            else:
                if args.prob_method == "compare":
                    if text_prob[text_index] - text_prob[visual_index] > visual_prob[visual_index] - visual_prob[text_index]:
                        with open(output_path, "a+") as fout:
                            fout.write(f"{json.dumps({data_id: chr(ord("A") + text_index)})}\n")
                    else:
                        with open(output_path, "a+") as fout:
                            fout.write(f"{json.dumps({data_id: chr(ord("A") + visual_index)})}\n")
                elif args.prob_method == "max":
                    if text_prob[text_index] > visual_prob[visual_index]:
                        with open(output_path, "a+") as fout:
                            fout.write(f"{json.dumps({data_id: chr(ord("A") + text_index)})}\n")
                    else:
                        with open(output_path, "a+") as fout:
                            fout.write(f"{json.dumps({data_id: chr(ord("A") + visual_index)})}\n")
            pb.update(1)

if __name__ == "__main__":
    main() 