import torch
from tqdm import tqdm
from easydict import EasyDict as edict
import argparse
import time
import os
import random
import numpy as np
from cons_utils import load_jsonl, save_jsonl, save_json, load_json, get_iou, BaseOptions, load_logger

logger = load_logger()


def set_seed(seed, use_cuda=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed_all(seed)

    print("Set seed ", seed)


def main(args, model):
    topn = args.topn
    # temporal_reasoning_tasks = ["co_occurrence", "sequential_after", "sequential_before"]
    results = []

    test_data = load_json(args.test_path)
    target_vid_list = [file.split(".")[0] for file in os.listdir(args.video_root)]
    target_vid_list = [vid for vid in target_vid_list if vid in list(test_data.keys())]
    print(f"Total {len(target_vid_list)} videos in {args.video_root}")

    for n_data, (vid, data) in tqdm(enumerate(test_data.items()), total=len(target_vid_list), desc="Evaluating.."):
        duration = data['duration']
        video_path = os.path.join(args.video_root, f"{vid}.mp4")

        # Load video frame features
        if os.path.exists(video_path):
            video_features, msg = model.load_video_features(video_path)
        else:
            print(f"Video {vid} not found")
            continue

        # Set lists to save the predictions as below.
        pred_moment_list = []
        iou_list = []

        pred_moments_aligned_list = []
        aligned_iou_list = []

        occur_ans_list = []
        aligned_occur_ans_list = []
        misaligned_occur_ans_list = []

        comp_ans_list = []

        for i, (query, gt_moment) in enumerate(zip(data['sentences'], data['timestamps'])):
            gt_moment = [min(gt_moment[0], duration), min(gt_moment[1], duration)]
            aligned_sentences = data['consistency_annotations'][i]["A"][:topn]
            misaligned_sentences = data['consistency_annotations'][i]["M"][:topn]
            compositional_info = data['consistency_annotations'][i]["C"]

            """ Grounding """
            pred_moment = model.run(task="grounding", video_features=video_features, query=query, duration=duration,
                                    msg=msg)
            standard_moment = gt_moment if args.correctness else pred_moment

            pred_moment_list.append(pred_moment)
            iou_list.append(get_iou(gt_moment, pred_moment))  # The format of pred_moment should be the "seconds".

            """ Consistent Moment Detection """
            pred_moments_aligned = [
                model.run(task="grounding", video_features=video_features, query=aligned_sentence, duration=duration,
                          msg=msg)
                for aligned_sentence in aligned_sentences
            ]

            pred_moments_aligned_list.append(pred_moments_aligned)
            aligned_iou_list.append(
                [get_iou(standard_moment, pred_moment_aligned) for pred_moment_aligned in pred_moments_aligned]
            )

            """ Event Occurrence Detection """
            occur_ans = model.run(task="occurrence", video_features=video_features, query=query,
                                  duration=duration, st=standard_moment[0], ed=standard_moment[1], msg=msg
                                  )
            occur_ans_aligned = [
                model.run(task="occurrence", video_features=video_features, query=aligned_sent, duration=duration,
                          st=standard_moment[0], ed=standard_moment[1], msg=msg) for aligned_sent in aligned_sentences
            ]
            occur_ans_misaligned = [
                model.run(task="occurrence", video_features=video_features, query=misaligned_sent, duration=duration,
                          st=standard_moment[0], ed=standard_moment[1], msg=msg) for misaligned_sent in
                misaligned_sentences
            ]

            occur_ans_list.append(occur_ans)
            aligned_occur_ans_list.append(occur_ans_aligned)
            misaligned_occur_ans_list.append(occur_ans_misaligned)

            """ Compositional Understanding """
            y_comp_q, n_comp_q = compositional_info["Y"][:args.topn], compositional_info["N"][:args.topn]

            y_comp_ans = [
                model.run(task="compositional", video_features=video_features, query=query, duration=duration,
                          st=standard_moment[0], ed=standard_moment[1], msg=msg) for query in y_comp_q]
            n_comp_ans = [
                model.run(task="compositional", video_features=video_features, query=query, duration=duration,
                          st=standard_moment[0], ed=standard_moment[1], msg=msg) for query in n_comp_q]
            comp_ans_list.append(["pos", y_comp_ans])
            comp_ans_list.append(["neg", n_comp_ans])

        assert len(data['sentences']) == len(data['timestamps']) == len(pred_moment_list), logger.error("The number of predictions are different.")

        """ Temporal Reasoning """
        # n_sent = len(data['sentences'])
        # query_indices = []
        # for idx in range(n_sent - 1):
        #     for idx2 in range(idx + 1, n_sent):
        #         query_indices.append([idx, idx2])
        #
        # query_indices = random.choice(query_indices)
        # query_pair = [data['sentences'][query_indices[0]], data['sentences'][query_indices[1]]]
        # temporal_reasoning_ans_list = [model.run(task=task, video_features=video_features, query=query_pair, duration=duration, msg=msg) for task in temporal_reasoning_tasks]

        # Save the results.
        result = edict(
            meta=edict(
                vid=vid,
                query=data['sentences'],
                timestamp=data['timestamps'],
                duration=data['duration']
            ),
            prediction=edict(
                pred_moments=pred_moment_list,
                ious=iou_list
            ),
            con_moment_prediction=edict(
                pred_moments=pred_moments_aligned_list,
                ious=aligned_iou_list,
            ),
            occurrence=edict(
                answer=occur_ans_list,
                aligned=aligned_occur_ans_list,
                misaligned=misaligned_occur_ans_list,
            ),
            compositional=edict(
                answer=comp_ans_list
            ),
            # temporal=edict(
            #     answer=temporal_reasoning_ans_list,
            #     indices=query_indices,
            # )
        )

        results.append(result)

        if args.debug and n_data == 1:
            break

        if n_data % 50 == 0:
            logger.info(f"{len(results)} results are saved")
            save_jsonl(results, args.output_path)

    logger.info(f"{len(results)} results are saved")
    save_jsonl(results, args.output_path)
    logger.info("Done.")


if __name__ == "__main__":
    """ 
    How to run this file: python evaluate_consistency.py --model_type Video-ChatGPT --dset_name charades | activitynet --debug
    """
    base_options = BaseOptions().parse()
    set_seed(base_options.seed)

    if base_options.model_type == "Video-ChatGPT":
        from video_chatgpt.utils import VideoChatGPT, VideoChatGPT_Options
        args = VideoChatGPT_Options().parse(visualize=True)
        model = VideoChatGPT(args)

    elif base_options.model_type == "VTimeLLM":
        from vtimellm.utils import VTimeLLM, VTimeLLM_Options
        args = VTimeLLM_Options().parse(visualize=True)
        model = VTimeLLM(args)

    elif base_options.model_type == "Video-LLaMA":
        from video_llama.utils import VideoLLaMA, VideoLLaMA_Options
        args = VideoLLaMA_Options().parse(visualize=True)
        model = VideoLLaMA(args)

    elif base_options.model_type == "TimeChat":
        from timechat.utils import TimeChat, TimeChat_Options
        args = TimeChat_Options().parse(visualize=True)
        model = TimeChat(args)
    else:
        raise NotImplementedError

    if args.correctness:
        logger.info("Measuring Correctness")
        output_dir = f"correctness_results/{args.model_type}"
    else:
        logger.info("Measuring Consistency")
        output_dir = f"consistency_results/{args.model_type}"

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    cur_time = time.strftime("%Y_%m_%d_%H_%M_%S")
    if args.exp_id is None:
        args.exp_id = args.model_type

    args.output_path = f"{output_dir}/{args.exp_id}_{args.dset_name}_{cur_time}.jsonl"
    if args.fine_tuned:
        args.output_path = f"{output_dir}/fine_tuned_{args.exp_id}_{args.dset_name}_{cur_time}.jsonl"

    if args.debug and not args.overwrite:
        logger.debug("Debug Mode")
        args.output_path = f"{output_dir}/debug_{args.dset_name}_{args.exp_id}_{cur_time}.jsonl"

    main(args, model)