import sys
sys.path.append('..')

import torch
from transformers import TextStreamer

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image
from tqdm import tqdm
import argparse

from utils.data import get_data, show_data
from utils.external_retrieval import get_matching_urls, get_summary
from utils.compile_results import add_result
from utils.prompts import initial_prompt, round1_prompt, debate_prompt, web_access_prompt

def get_final_prediction(num_models, model_responses):
    num_true, num_false, num_unsure = 0,0,0
    final_pred = ""
    final_outputs = {}
    for i in range(num_models):
        if model_responses[i]["falsified"] == True:
            num_true += 1
        elif model_responses[i]["falsified"] == False:
            num_false += 1
        else:
            num_unsure += 1
        final_outputs["model_"+str(i)] = model_responses[i]["output"]

    if num_true > num_false and num_true > num_unsure:
        final_pred = True
    elif num_false > num_true and num_false > num_unsure:
        final_pred = False
    else:
        final_pred = "Unsure"
    return final_pred, final_outputs

def get_conv_and_roles(model_name, conv_mode):
    conv = []
    roles = []

    for i in range(args.num_models):
        conv.append(conv_templates[conv_mode].copy())
        if "mpt" in model_name.lower():
            roles.append(('user', 'assistant'))
        else:
            roles.append(conv[i].roles)
    return conv, roles


def generate_output(i, conv, models, image_tensor, temperature, image_size, max_new_tokens):
    prompt = conv[i].get_prompt()

    input_ids = tokenizer_image_token(prompt, models[i]['tokenizer'], IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(models[i]['model'].device)
    stop_str = conv[i].sep if conv[i].sep_style != SeparatorStyle.TWO else conv[i].sep2
    keywords = [stop_str]

    with torch.inference_mode():
        output_ids = models[i]['model'].generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image_size],
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            use_cache=True)

    outputs = models[i]['tokenizer'].decode(output_ids[0]).strip()
    return outputs

def main(args):
    models = []
    disable_torch_init()
    model_name = get_model_name_from_path(args.model_path)
    for i in range(args.num_models):
        tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, model_base=None, model_name=model_name, load_8bit=args.load_8bit, load_4bit=args.load_4bit, device_map="auto")
        models.append({"tokenizer":tokenizer, "model":model, "image_processor":image_processor, "context_len":context_len})
    

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    print("Running inference now!")
    for data_idx in tqdm(range(args.start_idx, args.end_idx)):
        search_result = ""
        search_done = False
        conv, roles = get_conv_and_roles(model_name, conv_mode)
        image, caption, img_path, annotation = get_data(data_idx)
        image_size = image.size
        # Similar operation in model_worker.py
        image_tensor = process_images([image], models[0]['image_processor'], models[0]['model'].config)
        if type(image_tensor) is list:
            image_tensor = [image.to(models[0]['model'].device, dtype=torch.float16) for image in image_tensor]
        else:
            image_tensor = image_tensor.to(models[0]['model'].device, dtype=torch.float16)
        
        temp = ""
        model_responses = {}
        for i in range(args.num_models):
            model_responses[i] = {"falsified":"", "output":""}
        for round in range(args.num_rounds+1):
            for i in range(args.num_models):
                if round == 0:
                    inp = initial_prompt(roles[i][0], caption)
                elif round == 1:
                    if i == 1:
                        inp = round1_prompt(roles[i][0], temp)
                    else:
                        inp = round1_prompt(roles[i][0], conv[(i+1)%args.num_models].messages[-1][-1])
                else:
                    inp = debate_prompt(roles[i][0], conv[(i+1)%args.num_models].messages[-1][-1])
                if image is not None:
                    # first message
                    if models[i]['model'].config.mm_use_im_start_end:
                        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                    else:
                        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                    if round == 0 and i == 1:
                        image = None

                conv[i].append_message(conv[i].roles[0], inp)
                conv[i].append_message(conv[i].roles[1], None)
                outputs = generate_output(i, conv, models, image_tensor, args.temperature, image_size, args.max_new_tokens)
                conv[i].messages[-1][-1] = outputs
                if "UNSURE" in outputs:
                    if not search_done:
                        try:
                            matching_urls = get_matching_urls(data_idx)
                            search_result = get_summary(matching_urls)
                            print("************ search results: ", search_result)
                            inp = web_access_prompt(roles[i][0], search_result)
                            conv[i].append_message(conv[i].roles[0], inp)
                            conv[i].append_message(conv[i].roles[1], None)
                            outputs = generate_output(i, conv, models, image_tensor, args.temperature, image_size, args.max_new_tokens)
                            conv[i].messages[-1][-1] = outputs
                            search_done = True
                        except:
                            pass
                    else:
                        inp = web_access_prompt(roles[i][0], search_result)
                        conv[i].append_message(conv[i].roles[0], inp)
                        conv[i].append_message(conv[i].roles[1], None)
                        outputs = generate_output(i, conv, models, image_tensor, args.temperature, image_size, args.max_new_tokens)
                        conv[i].messages[-1][-1] = outputs
                if i == 0 and round == 0:
                    temp = outputs
                if round == args.num_rounds-1:
                    #final answer from the model
                    if "YES" in outputs:
                        model_responses[i]["falsified"] = True
                        model_responses[i]["output"] = outputs
                    elif "NO" in outputs:
                        model_responses[i]["falsified"] = False
                        model_responses[i]["output"] = outputs
                    elif "UNSURE" in outputs:
                        model_responses[i]["falsified"] = "Unsure"
                        model_responses[i]["output"] = outputs

        annotation['falsified'], annotation["output"] = get_final_prediction(args.num_models, model_responses)
        add_result(args.result_file, annotation)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="liuhaotian/llava-v1.6-34b")
    parser.add_argument("--num_models", type=int, default=2)
    parser.add_argument("--num_rounds", type=int, default=3)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--result_file", type=str, default="results.json")
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--end_idx", type=int, default=1000)
    parser.add_argument("--load_8bit", type=bool, default=False)
    parser.add_argument("--load_4bit", type=bool, default=False)
    args = parser.parse_args()
    main(args)
