import os
import json
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from src.utils.parser_utils import get_parser

def main():
    parser = get_parser()
    parser.add_argument("--method", choices=["compare", "shift"])
    args = parser.parse_args()
    if args.greedy:
        args.temperature = 0.0
        
    model_nickname = args.model_name.split("/")[-1]
    
    # load dataset
    if "viquae" in args.dataset:
        dataset_nickname = "viquae"
        if "mc" in args.dataset:
            if "cleaned" in args.dataset:
                with open(f"data/viquae/cleaned_dataset_mc_{model_nickname}.json", "r") as fin:
                    dataset = json.load(fin)
            else:
                with open("data/viquae/multiple_choice_data.json", "r") as fin:
                    dataset = json.load(fin)
        else:
            if "full" in args.dataset:
                dataset = []
                datasets = load_dataset("PaulLerner/viquae_dataset")
                for ds_name in ["train", "validation", "test"]:
                    ds = datasets[ds_name]
                    for d in ds:
                        dataset.append(d)
            elif "clean" in args.dataset:
                with open("data/viquae/cleaned_dataset.json", "r") as fin:
                    dataset = json.load(fin)
            else:
                dataset = load_dataset("PaulLerner/viquae_dataset")["train"]
    elif "infoseek" in args.dataset:
        dataset_nickname = "infoseek"
        if "mc" in args.dataset:
            if "cleaned" in args.dataset:
                with open(f"data/infoseek/cleaned_dataset_mc_{model_nickname}.json", "r") as fin:
                    dataset = json.load(fin)
            else:
                with open("data/infoseek/sampled_val_mc.json", "r") as fin:
                    dataset = json.load(fin)
        else:
            with open("data/infoseek/infoseek_val_with_entity.json", "r") as fin:
                dataset = json.load(fin)
    
    text_preds = {}
    with open(f"outputs/analysis/{dataset_nickname}/{model_nickname}/{args.dataset}_textual_T0.0.txt.score", "r") as fin:
        for line in fin.readlines():
            text_preds.update(json.loads(line))
    visual_preds = {}
    with open(f"outputs/analysis/{dataset_nickname}/{model_nickname}/{args.dataset}_visual_T0.0.txt.score", "r") as fin:
        for line in fin.readlines():
            visual_preds.update(json.loads(line))

    output_dir = os.path.join(args.output_dir, dataset_nickname)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_dir = os.path.join(output_dir, model_nickname)
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_path = os.path.join(output_dir, f"{args.dataset}_prob_{args.method}.txt")

    pb = tqdm(range(len(dataset)))
    for data in dataset:
        if dataset_nickname == "viquae":
            data_id = data["id"]
        elif dataset_nickname == "infoseek":
            data_id = data["data_id"]
        text_pred = text_preds.get(data_id)
        visual_pred = visual_preds.get(data_id)
        
        if text_pred is None or visual_pred is None:
            continue
        
        text_prob = text_pred[1]
        visual_prob = visual_pred[1]
        text_prob = torch.nn.functional.softmax(torch.tensor(text_prob)).tolist()
        visual_prob = torch.nn.functional.softmax(torch.tensor(visual_prob)).tolist()
        
        # print(text_prob)
        # print(visual_prob)
        
        if args.method == "compare":
            if max(text_prob) > max(visual_prob):
                answer = chr(ord("A") + np.argmax(text_prob))
            else:
                answer = chr(ord("A") + np.argmax(visual_prob))
            # print(answer)
            # input()
        elif args.method == "shift":
            shift_prob = np.array(visual_prob) - np.array(text_prob)
            answer = chr(ord("A") + np.argmax(shift_prob))
        with open(output_path, "a+") as fout:
            fout.write(f"{json.dumps({data_id: answer})}\n")
        pb.update(1)

if __name__ == "__main__":
    main() 