import json
import argparse
import sys
from generate import call_to_open_api
from collections import defaultdict
from functools import reduce
import logging
from typing import List
import os
import numpy as np

from prompts_for_llava_evaluator_batch import sys_msg, rule_msg


BATCH = 10

def safe_mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)

def eval_one_video(episode, prompt_version, model):
    if 'gpt-4' in model:
        print("Evaluation is a O(N x M1 x M2) number of API calls, "
                "where N is the number of action annotations, "
                "M1 is the max number of state-predicates predicted per action, "
                "and M2 is the max number of predicates in ground truth per action. "
                "GPT-4 is an expensive model - HIGHLY SUGGEST to terminate this script here "
                "and use --model gpt-3.5-turbo.", file=sys.stderr)
        input("You have been warned. Press ENTER to continue")
    
    safe_mkdir(os.path.join("eval_data/", prompt_version))

    prediction_file = os.path.join("raw_data", prompt_version, f"{prompt_version}_{episode}_output.json")
    ground_truth_file = os.path.join("eval_data/annotation", f"{episode}_ann.json")

    with open(prediction_file, 'r') as fin:
        pd_json = json.load(fin)

    with open(ground_truth_file, 'r') as fin:
        gt_json = json.load(fin)

    if "v1b" in prompt_version or "v0" in prompt_version:
        # it's in window format
        window_id_list = list(pd_json.keys())
        window_sort_index = np.argsort([int(window_id.split("_")[0]) for window_id in window_id_list])

        predictions = {}

        for i in window_sort_index:
            window_id = window_id_list[i]

            for pred in pd_json[window_id]:
                predictions[pred['id']] = {
                    "prediction": {
                        "start_states": pred['prediction_data']['states'],
                        "action": pred['prediction_data']['action'],
                    },
                    "meta_data": pred['meta_data']
                }
    else:
        predictions = pd_json

    ground_truths = {}
    for gt in gt_json:
        ground_truths[gt['id']] = gt
    # action_data = {}
    # for frame_data in json.load(open(epic_file, 'r')):
    #     action_data[frame_data['id']] = frame_data['action']

    prompts_json = {}
    prompts_json['header'] = sys_msg

    prompts_json['example'] = {'conversations': 
        {"rules": [
            {
                "from": "human",
                "value": rule_msg
            }
        ]}
    }
    prompts_json['prompts'] = []
    prompts = prompts_json['prompts']
    num_retrieved = defaultdict(lambda: 0) # how many predicates in prediction per timestep
    num_relevant = defaultdict(lambda: 0) # how many predicates in groundtruth per timestep
    num_retrieved_relevant = defaultdict(lambda: set()) # how many [[yes]] for each timestep
    for id, pred_data in predictions.items():
        pred = pred_data['prediction']
        p_act = pred['action']
        pred_data['ground_truth'] = ground_truths[id]
        gt_act = pred_data['meta_data']['action']
        prompts.append({"id": f"{id}-act", "conversations": [{'from': 'human', 'value': f"PREDICTION: {p_act} ; GROUND TRUTH: {gt_act}"}]})
        p_st = pred['start_states']
        gt_st: List[str] = pred_data['ground_truth']['states']
        if "" in gt_st:
            gt_st.remove("")
        num_retrieved[id] = len(p_st)
        num_relevant[id] = len(gt_st)
        for i, gt in enumerate(gt_st):
            pn = len(p_st)
            if gt.strip() in [p.strip() for p in p_st]:

                # print(f"Perfect match: {gt} in {p_st}")

                num_retrieved_relevant[id].add(i)
                continue
            for jdx in range(0, pn, BATCH):
                conv = ""
                for j in range(jdx, min(pn, jdx+BATCH)):
                    conv += f"PREDICTION: {p_st[j]} ; GROUND TRUTH: {gt}\n"
                prompts.append({"id": f"{id}-{i}-{jdx}", "conversations": [{'from': 'human', 'value': conv}]})
    
    # for k in prompts_json:
    #     print(f"======{k}")
    #     print(prompts_json[k])

    # input("check overall prompt")

    for k in prompts_json["prompts"]:
        print(k)
    # input("check specific query")

    temp_output_fp = os.path.join("eval_data/", prompt_version, f"{prompt_version}_{episode}_eval_temp.json")
    outputs = call_to_open_api(prompts_json, temp_output_fp, model, logging.WARN)
    num_correct_actions = 0
    num_actions = 0
    i = 0
    for out in outputs:
        pred_id = out['id']
        id, *_ = pred_id.split('-')
        if 'eval' not in predictions[id]:
            predictions[id]['eval'] = {}
        _, gpt = out['conversations']
        if 'act' in pred_id:
            # print(pred_id)
            id, _ = pred_id.split('-')
            predictions[id]['eval'].update({'action': gpt['value']})
            num_yes = '[[yes]]' in gpt['value']
            num_actions += 1
            if num_yes:
                num_correct_actions += 1
            continue
        id, i, j = pred_id.split('-')
        i = int(i)
        j = int(j)
        predictions[id]['eval'].update({f"state-{i}-{j}:{j+BATCH}": gpt['value']})
        num_yes = '[[yes]]' in gpt['value']
        if num_yes:
            num_retrieved_relevant[id].add(i)
    for id, data in predictions.items():
        if 'eval' not in data.keys():
            continue
        evl = data['eval']
        evl['act_score'] = int('[[yes]]' in evl['action'])
        num_states_correct = num_retrieved_relevant[id]

        num_states_correct = len(num_retrieved_relevant[id])
        
        if num_retrieved[id] == 0:
            evl['state_precision'] = f"100%"
        else:
            evl['state_precision'] = f"{100 * num_states_correct / num_retrieved[id]}%"
        
        if num_relevant[id] == 0:
            evl['state_recall'] = f"100%"
        else:
            evl['state_recall'] = f"{100 * num_states_correct / num_relevant[id]}%"
    def reduce_fn(x, y):
        return x + len(y)
    precision = reduce(reduce_fn, num_retrieved_relevant.values(), 0) / sum(num_retrieved.values())
    recall = reduce(reduce_fn, num_retrieved_relevant.values(), 0) / sum(num_relevant.values())

    if precision == 0 or recall == 0:
        f1 = 0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    eval_result_dict = {
        "num_correct_actions": num_correct_actions,
        "num_gt_actions": num_actions,
        "action_accuracy": 100 * num_correct_actions / num_actions,
        "num_correct_states": reduce(reduce_fn, num_retrieved_relevant.values(), 0),
        "num_predicted_states": sum(num_retrieved.values()),
        "num_gt_states": sum(num_relevant.values()),
        "states_precision": precision,
        "states_recall": recall,
        "states_f1": f1
    }

    # make printing easier
    ed = eval_result_dict

    print(f"correct - {ed['num_correct_actions']} / {ed['num_gt_actions']}, accuracy in num actions = {ed['action_accuracy']:.2f}%")
    print(f"correct states - {ed['num_correct_states']}")
    print(f"number of predicted states - {ed['num_predicted_states']}")
    print(f"number of ground truth states - {ed['num_gt_states']}")
    print(f"precision in predicates = {ed['states_precision']:.2f}%")
    print(f"recall in predicates = {ed['states_recall']:.2f}%")
    print(f"F1 in predicates = {ed['states_f1']:.2f}%")

    output_fp = os.path.join("eval_data/", prompt_version, f"{prompt_version}_{episode}_eval.json")

    with open(output_fp, 'w') as fout:
        json.dump(predictions, fout, indent=4)

    return eval_result_dict

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--episode', type=str, required=True)
    parser.add_argument('-v', '--prompt_version', type=str, required=True)
    parser.add_argument('--model', type=str, default='gpt-3.5-turbo')

    args = parser.parse_args()

    eval_one_video(args)