import math
import os
root_dir = os.path.join(os.getcwd(), "..")
import sys
import random
sys.path.append(root_dir)
from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from vtimellm.conversation import conv_templates, SeparatorStyle
from vtimellm.model.builder import load_pretrained_model, load_lora
from vtimellm.utils import disable_torch_init
from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor
from PIL import Image
import requests
from io import BytesIO
from transformers import TextStreamer, BartForConditionalGeneration, BartTokenizer
from easydict import EasyDict as edict
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    from PIL import Image
    BICUBIC = Image.BICUBIC
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize
import numpy as np
import clip
import torch
import json, os, re
from tqdm import tqdm
from cons_utils import load_jsonl, save_jsonl, save_json, load_json

PROMPT = {
    'grounding': 'During which frames can we see {}?',
    'captioning': 'Could you please describe the events in the video in detail? Be specific about the activities of individuals, their surroundings, and interactions with others.',
    'occurrence_detection': "Can we see the event '{target}' occurring during frames {st} to {ed} in the video?",
    'co_occurrence_detection': "Can we see the events '{target1}' and '{target2}' occurring at the same time in the video?"
}

def inference(model, tokenizer, context_length, image, args, inp, conv=None, first=False, return_conv=False):
    if conv is None:
        conv = conv_templates['v1'].copy()
        first = True

    roles = conv.roles
    inp = f"{roles[0]}: {inp}"

    if first:
        # first message
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
        conv.append_message(conv.roles[0], inp)
    else:
        # later messages
        conv.append_message(conv.roles[0], inp)

    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # plain:sep(###) v1:sep2(None)
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image[None,].cuda(),
            do_sample=True,
            temperature=args.temperature,
            max_new_tokens=256,
            repetition_penalty=1.0,
            length_penalty=1,
            use_cache=True,
            stopping_criteria=[stopping_criteria]
        )

    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip()
    conv.messages[-1][-1] = outputs

    if return_conv:
        return outputs, conv

    return outputs


def get_iou(gt_frames, outputs, gt_process=False):
    matches = re.search(r"(\d{2}) (to|and) (\d{2})", outputs)
    if not matches:
        return 0, 0, 0
    from_number = float(matches.group(1)) / 100
    to_number = float(matches.group(3)) / 100
    s, e = gt_frames[0] / 100, gt_frames[1] / 100

    try:
        intersection = max(0, min(to_number, e) - max(from_number, s))
        union = max(to_number, e) - min(from_number, s)
        iou = intersection / union

    except ZeroDivisionError as e:
        print(e)
        return 0, from_number, to_number

    return round(iou, 2), from_number * 100, to_number * 100


def frame_to_second(moment, duration):
    s = round(moment[0] / 100 * duration, 3)
    e = round(moment[1] / 100 * duration, 3)
    return [s, e]


def eval_consistency(args):
    disable_torch_init()
    tokenizer, model, context_len = load_pretrained_model(args, args.stage2, args.stage3)
    model = model.cuda()
    model = model.to(torch.float16)

    clip_model, _ = clip.load(args.clip_path)
    clip_model.eval()
    clip_model = clip_model.cuda()
    video_loader = VideoExtractor(N=100)

    transform = Compose([
        Resize(224, interpolation=BICUBIC),
        CenterCrop(224),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

    annotations = load_json(args.anno_path)
    print(f"{len(os.listdir(args.video_root))} are available.")

    check_vid = []
    results = {}

    # Overwrite the results on existing output file.
    if os.path.exists(args.output_path) and args.overwrite:
        results = load_jsonl(args.output_path)
        check_vid = [result['meta']['vid'] for result in results]
        annotations = [anno for anno in annotations if anno['vid'] not in check_vid]
        print(f"Load {args.output_path} results and skip {len(check_vid)} videos since we have already the results.")
    else:
        print("Run the code from scratch.")

    for i, (vid, data) in tqdm(enumerate(annotations.items()), total=len(annotations), desc="Evaluating.."):
        video_path = os.path.join(args.video_root, f"{vid}.mp4")
        duration = data['duration']

        if not os.path.exists(video_path):
            video_path = os.path.join(args.video_root, f"{vid}.mkv")
            if not os.path.exists(video_path):
                print(f"pass {vid} since it does not exist.")
                continue

        try:
            _, images = video_loader.extract({'id': None, 'video': video_path})
        except:
            print(f"pass the video {video_path}")
            continue
        # print(images.shape) # <N, 3, H, W>
        images = transform(images / 255.0)
        images = images.to(torch.float16)

        with torch.no_grad():
            features = clip_model.encode_image(images.to('cuda'))

        pred_moments = []
        pred_consistency_answers = []
        pred_correctness_answers = []
        ious = []

        for query, timestamp in zip(data['sentences'], data['timestamps']):
            gt_frames = [int(timestamp[0] / duration * 100), int(timestamp[1] / duration * 100)]

            # Grounding
            moment_prediction = inference(model, tokenizer, context_len, features, args, inp=PROMPT['grounding'].format(query))
            iou_val, st_frame, ed_frame = get_iou(gt_frames, moment_prediction)
            pred_frames = [st_frame, ed_frame]

            # Consistency
            answer_consistency = inference(model, tokenizer, context_len, features, args, inp=PROMPT['occurrence_detection'].format(target=query, st=pred_frames[0], ed=pred_frames[1]))

            # Correctness
            answer_correctness = inference(model, tokenizer, context_len, features, args, inp=PROMPT['occurrence_detection'].format(target=query, st=gt_frames[0], ed=gt_frames[1]))

            pred_moments.append(frame_to_second(pred_frames, duration))
            ious.append(iou_val)
            pred_consistency_answers.append(answer_consistency)
            pred_correctness_answers.append(answer_correctness)

        n_sent = len(data['sentences'])
        neg_co_occurence_answers = []
        for idx in range(n_sent - 1):
            for idx2 in range(idx + 1, n_sent):
                answer_co_occurence = inference(model, tokenizer, context_len, features, args, inp=PROMPT['co_occurrence_detection'].format(target1=data['sentences'][idx], target2=data['sentences'][idx2]))
                neg_co_occurence_answers.append(edict(query_indices=[idx, idx2], answer=answer_co_occurence))

        # Save the results.
        result = edict(
            meta=edict(
                vid=vid,
                query=data['sentences'],
                timestamp=data['timestamps'],
                duration=duration
            ),
            prediction=edict(
                pred_moments=pred_moments,
                iou=ious,
                occurence_detection=edict(
                    consistency=pred_consistency_answers,
                    correctness=pred_correctness_answers,
                ),
                co_occurence_detection=edict(
                    neg_co_occurence_answers=neg_co_occurence_answers,
                )
            ),
        )
        results[vid] = result

        if args.debug:
            print(results)
            break

        if i % 100 == 0:
            save_json(results, args.output_path)
            print(f"{len(results)} results are saved.")

    save_json(results, args.output_path)
    print(f"{len(results)} results are saved.")

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--dset_name", type=str, choices=["charades", "activitynet"], default="activitynet")
    parser.add_argument("--data_root_path", type=str, required=True)
    parser.add_argument("--clip_path", type=str, default="vtimellm/ViT-L-14.pt")
    parser.add_argument("--model_base", type=str, default="vidllm/ckpt/vicuna-7b-v1.5")
    parser.add_argument("--pretrain_mm_mlp_adapter", type=str, default="vtimellm/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin")
    parser.add_argument("--stage2", type=str, default="vtimellm/vtimellm-vicuna-v1-5-7b-stage2")
    parser.add_argument("--stage3", type=str, default="vtimellm/vtimellm-vicuna-v1-5-7b-stage3")
    parser.add_argument("--video_root", type=str, required=True)
    parser.add_argument("--temperature", type=float, default=0.05)
    parser.add_argument("--anno_path", type=str, default="annotations/")
    parser.add_argument("--overwrite", action='store_true')
    parser.add_argument("--debug", action='store_true')
    args = parser.parse_args()

    return args

# Usage: python evaluate_consistency.py --data_root_path /mnt/ --anno_path filtered_activitynet_test.json --video_root /data/video_datasets/activitynet/ --debug

if __name__ == "__main__":
    args = parse_args()
    args.clip_path = os.path.join(args.data_root_path, args.clip_path)
    args.model_base = os.path.join(args.data_root_path, args.model_base)
    args.pretrain_mm_mlp_adapter = os.path.join(args.data_root_path, args.pretrain_mm_mlp_adapter)
    args.stage2 = os.path.join(args.data_root_path, args.stage2)
    args.stage3 = os.path.join(args.data_root_path, args.stage3)
    args.output_path = f"vtime_{args.dset_name}.json" if not args.debug else f"debug_vtime_{args.dset_name}.json"

    print(args)
    eval_consistency(args)
