import os
import time
import json
import torch
import pickle
import argparse
from tqdm import tqdm
from datetime import datetime
from model_list import get_LVLM_class
from match_tools import select_mc_option
from bench_utils import get_analysis_prompt, get_choice_prompt, get_white_bytes


def main(args):
    ### Loading datasets
    dataset_path = os.path.join(args.bench_dir, "data_test.pkl")
    with open(dataset_path, 'rb') as file:
        dataset = pickle.load(file)
        idx_list = list(range(len(dataset)))
    if args.mode == "chunk":
        idx_list = idx_list[args.chunk_idx * args.chunk_size:(args.chunk_idx + 1) * args.chunk_size]
    else:
        idx_list = eval(f"idx_list[{args.slice}]") if ":" in args.slice else [idx for idx in eval(args.slice)]

    ### Creating output directories
    output_dir = os.path.join(os.getcwd(), "../output/benchmark")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    if args.mode == "chunk":
        chunk_dir = os.path.join(output_dir, f"{args.lvlm_path.split('/')[-1]}/{args.modality}/chunks")
        if not os.path.exists(chunk_dir):
            os.makedirs(chunk_dir)
        output_file = os.path.join(chunk_dir, f"chunk_{args.chunk_idx}.json")
    else:
        output_file = os.path.join(output_dir, f"{args.lvlm_path.split('/')[-1]}_{args.modality}_{args.slice.replace(':', '_')}_{time.strftime('%m%d_%H%M%S')}.json")

    ### Loading existing output files
    solutions = []
    last_idx = -1
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            solutions = json.load(f)
            if len(solutions) > 0:
                last_idx = solutions[-1]['idx']

    ### Loading models
    load_start_time = datetime.now()
    LVLM = get_LVLM_class(args)
    lvlm_model = LVLM(args)
    load_end_time = datetime.now()
    print(f"Loading duration: {str(load_end_time - load_start_time)[:-7]}")

    ### Running models
    tasks = ["Recognition", "Understanding", "Grounding", "Reasoning"] if args.task == "all" else [args.task]
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Running models"):
        if cnt % 10 == 0:
            with open(output_file, "w") as f:
                json.dump(solutions, f, indent=4)
        if idx <= last_idx:
            continue

        item = dataset[idx]
        QAs = item["QAs"]
        if args.modality == "real":
            modality_data = item["real_bytes"]
        elif args.modality == "synthetic":
            modality_data = item["syn_bytes"]
        elif args.modality == "triple":
            modality_data = item["triples"]
        elif args.modality == "empty":
            modality_data = get_white_bytes(item["real_bytes"])
        else:
            raise NotImplementedError

        sol = {}
        for task in tasks:
            # Step 1: Getting analysis
            analysis_prompt = get_analysis_prompt(modality_data, QAs[task])
            if args.modality in ["real", "synthetic", "empty"]:
                analysis = lvlm_model.run(modality_data, analysis_prompt)
            else:
                analysis = lvlm_model.run(None, analysis_prompt)

            # Step 2: Getting choice
            choice_prompt = get_choice_prompt(analysis_prompt, analysis)
            if args.modality in ["real", "synthetic", "empty"]:
                conclusion = lvlm_model.run(modality_data, choice_prompt)
            else:
                conclusion = lvlm_model.run(None, choice_prompt)
            choice_idx = select_mc_option(conclusion, list(QAs[task]["Options"].values()))
            choice = chr(ord('A') + choice_idx)

            sol[task] = {'analysis': analysis, 'conclusion': conclusion, 'choice': choice}
        solutions.append({'idx': idx, 'img_url': item["img_url"], 'solution': sol})

    run_end_time = datetime.now()
    print(f"Running duration: {str(run_end_time - load_end_time)[:-7]}")

    ### Saving results
    with open(output_file, "w") as f:
        json.dump(solutions, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='benchmark')
    parser.add_argument('--bench_dir', type=str, default="../../Datasets/Benchmark")
    parser.add_argument('--lvlm_path', type=str, default="debug")
    parser.add_argument('--API_KEY', type=str, default="")
    parser.add_argument('--max_new_tokens', type=int, default=512)
    parser.add_argument('--modality', type=str, default="real", help="[real, synthetic, triple, empty]")
    parser.add_argument('--task', type=str, default="all", help="[Recognition, Understanding, Grounding, Reasoning]")
    parser.add_argument('--mode', type=str, default="debug", help="[debug, chunk]")
    parser.add_argument('--slice', type=str, default="")
    parser.add_argument('--chunk_size', type=int, default=750)
    parser.add_argument('--chunk_idx', type=int, default=0)
    args = parser.parse_args()

    print("########## Information ##########")
    print(f"Model: {args.lvlm_path}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Modality: {args.modality}")
    print(f"Task: {args.task}")
    print(f"Mode: {args.mode}")
    print(f"Slice: {args.slice}")
    print(f"Chunk size: {args.chunk_size}")
    print(f"Chunk idx: {args.chunk_idx}")
    print(f"Starting time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    main(args)
    print(f"Ending time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
