import argparse
import os
import numpy as np
import json
import argparse

import os
import glob
from hallucination_detection import hallucination_detection, revise_obj_ls, get_all_obj_ls, get_gt_obj_ls_ls
# from hallucination_loss import ObjectLoss

def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="./checkpoints/llava-llama-2-13b-chat-lightning-preview")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="../POPE/data/minival2014/minival2014")
    
    parser.add_argument("--save_dir", type=str, default="../POPE/llava_eval_results/pope")
    
    parser.add_argument("--question_dir", type=str, default="../POPE/llava_qa/question")
    parser.add_argument("--question_file", type=str, default="I1_sub240_control.json")
    parser.add_argument("--answer_dir", type=str, default="../POPE/llava_qa/answer")
    parser.add_argument("--answer_file", type=str, default="I1_sub240_control_cfg1.0.jsonl")
    parser.add_argument("--label_dir", type=str, default="../POPE/llava_qa/label")
    parser.add_argument("--label_file", type=str, default='')
    parser.add_argument("--eval_dir", type=str, default=None)
    parser.add_argument("--model", type=str, default="llava")
    # parser.add_argument("--cfg_ls", nargs='+', type=float, default=[1.0])

    args = parser.parse_args()
    return args

 
def pope(args):
    # try:
    if 1:
        label_file = args.question_file.replace('.json', '_label.json')
        labels_ls = [json.loads(q) for q in open(os.path.join(args.label_dir, label_file), 'r')][0]
        labels = [item['label'] for item in labels_ls]
        
        answer_file = os.path.join(args.answer_dir, args.answer_file)
        if args.model in ["llava", "mPLUG-Owl2","llava2"]:
            answers_ls = [json.loads(q) for q in open(answer_file, 'r')]
            answers = [item['text'] for item in answers_ls]
        else:
            answers_ls = json.load(open(answer_file, 'r'))
            print(len(answers_ls))
            # for item in answers_ls:
            #     try:
            #         a = item['answer']
            #     except:
            #         print(item)
                # print(item['answer'], item['img_path'])
            answers = [item['answer'] for item in answers_ls]

        try:
            assert len(labels) == len(answers)
        except:
            print(f"== Error answers are not enough: {answer_file} ==")
        pred_list = []
        for answer in answers:
            if len(answer) == 0:
                pred_list.append(-1)
            else:
                answer = answer.split('.')[0] if '.' in answer else answer
                # answer = answer.replace(',', '')
                # words = answer.split(' ')
                # split sentences to words
                words = answer.split()
                
                # if 'No' in words or 'not' in words or 'no' in words:
                if any([neg_word in words for neg_word in ['No', 'not', 'no']]):
                    pred_list.append(0)
                else:
                    pred_list.append(1)

        label_list = [0 if label == 'no' else 1 for label in labels]

        # delete the -1 in pred_list and label_list
        label_list = [label for pred, label in zip(pred_list, label_list) if pred != -1]
        pred_list = [pred for pred in pred_list if pred != -1]
        
        yes_ratio = pred_list.count(1) / len(pred_list) if len(pred_list) > 0 else 0

        from sklearn.metrics import confusion_matrix
        TN, FP, FN, TP = confusion_matrix(label_list, pred_list).ravel()
    
        avg_len = round(sum([len(answer) for answer in answers]) / len(answers) if len(answers) > 0 else 0, 2)

        eval_results = {
            "answer_file": answer_file,
            "questions_num": len(answers),
            "length_response": avg_len,
            "overall_metrics": {
                "Accuracy": round((TP + TN )/ (TP + TN + FP + FN) if TP + TN + FP + FN > 0 else 0, 4),
                "Yes_ratio": round(yes_ratio, 4),
                "TP": int(TP),
                "FP": int(FP),
                "FN": int(FN),
                "TN": int(TN),
                "Precision": round(TP / (TP + FP) if TP + FP > 0 else 0, 4),
                "Recall": round(TP / (TP + FN) if TP + FN > 0 else 0, 4),
                "Specificity": round(TN / (TN + FP) if TN + FP > 0 else 0, 4),
                "F1": round(2 * TP / (2 * TP + FP + FN) if 2 * TP + FP + FN > 0 else 0, 4),
            }
        }

        print(f"{'Accuracy':<10}{'Yes_ratio':<10}")
        print(f"{eval_results['overall_metrics']['Accuracy']:<10}{eval_results['overall_metrics']['Yes_ratio']:<10}")

        with open(args.save_file, 'a+') as f:
            json.dump(eval_results, f)
            f.write('\n')
            
        print(f"== Save eval results at {args.save_file} ==\n")
    # except:
    #     print(f"== Error: {answer_file} ==")
    #     return


def pope_evaluation(args):
    question_file = os.path.join(args.question_dir, args.question_file)    
    questions = [json.loads(q) for q in open(question_file, 'r')][0]
    
    answer_file = os.path.join(args.answer_dir, args.answer_file)
    if args.model in ["llava", "mPLUG-Owl2"]:
        answers = [json.loads(q) for q in open(answer_file, 'r')]
    else:
        answers = json.load(open(answer_file, 'r'))

    gt_obj_ls_ls = get_gt_obj_ls_ls()
    all_obj_ls, all_obj_prob = get_all_obj_ls()
    # save the results in a json labels_ls
    results = []
    for answer,question in zip(answers,questions):
        if args.model in ["llava", "mPLUG-Owl2"]:
            response = answer['text']
        else:
            response = answer['answer']
        assert answer['question_id'] == question['id']
        for i in range(len(gt_obj_ls_ls)):
            if gt_obj_ls_ls[i]["image"] == question["image"]:
                gt_obj_ls = revise_obj_ls(gt_obj_ls_ls[i]["objects"])
                break

        TP_obj_ls, FP_obj_ls, FN_obj_ls, TN_obj_ls = hallucination_detection(response, gt_obj_ls, all_obj_ls)

        loss_ls = ObjectLoss(gt_obj_ls, TP_obj_ls, FP_obj_ls, FN_obj_ls, TN_obj_ls)

        if len(FP_obj_ls) > 0 or len(FN_obj_ls) > 0:
            mistake = True
        else:
            mistake = False
        result = {"id": answer['question_id'], "image": question["image"], "mistake": mistake, "total_loss": sum(loss_ls), "loss_ls": loss_ls, "response": response, "FP_obj_ls": FP_obj_ls, "FN_obj_ls": FN_obj_ls, "TP_obj_ls": TP_obj_ls, "TN_obj_ls": TN_obj_ls}
        results.append(result)
    
    pope(args, results)
    


    # args = get_parser()
    # if args.cfg_ls is not None:
    #     for cfg in args.cfg_ls:
    #         args.answer_file = args.answer_file[:-13] + f"-cfg{cfg}.jsonl"
    #         run(args)
    # else:
    #     run(args)

def save_file_check(args):
    file_name = "pope_eval.json"
    save_file = os.path.join(args.save_dir, file_name)
    
    os.makedirs(args.save_dir, exist_ok=True)
    if os.path.exists(save_file):
        # move the context to old_eval_file file
        old_eval_file = file_name.replace('.json', '_old.json')
        # copy and append the context to old_eval_file file
        with open(save_file, 'r') as f:
            old_eval_results = [json.loads(q) for q in f]
        with open(os.path.join(args.save_dir, old_eval_file), 'a+') as f:
            json.dump(old_eval_results, f)
            f.write('\n')
        print(f"== Save old eval results at {old_eval_file} ==")
        # remove the old_eval_file file
        os.remove(save_file)
    args.save_file = save_file
                    

if __name__ == '__main__':
    args = get_parser()

    if args.eval_dir is None:
        print(f"== Evaluating POPE of {args.answer_file} ==")
        pope(args)
        
    else:
        print(f"== Evaluating all files in {args.eval_dir} ==")
        save_file_check(args)
        
        eval_files = glob.glob(args.eval_dir + '/*.jsonl') + glob.glob(args.eval_dir + '/*.json')
        for eval_file in eval_files:
            print(f"== Evaluating POPE of {eval_file} ==")
            args.answer_file = os.path.basename(eval_file)
            args.answer_dir = args.eval_dir
            pope(args)
