import argparse
import torch
from utils import *
from llavavid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llavavid.conversation import conv_templates, SeparatorStyle
from llavavid.model.builder import load_pretrained_model
from llavavid.utils import disable_torch_init
from llavavid.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
import random
from constants import *
import json
import os
import math
from tqdm import tqdm
from decord import VideoReader, cpu
from transformers import AutoConfig
import time

import numpy as np

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


def parse_args():
    """
    Parse command-line arguments.
    """
    parser = argparse.ArgumentParser()

    # Define the command-line arguments
    parser.add_argument("--video_path", help="Path to the video files.", required=True)
    parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True)
    parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True)
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
    parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
    parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
    parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
    parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
    parser.add_argument("--image_grid_pinpoints", type=str, default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]")
    parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
    parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == 'true'), default=True)
    parser.add_argument("--for_get_frames_num", type=int, default=4)
    parser.add_argument("--load_8bit",  type=lambda x: (str(x).lower() == 'true'), default=False)
    return parser.parse_args()


def load_video(video_path, args):
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frame_num = len(vr)
    # fps = round(vr.get_avg_fps())
    # frame_idx = [i for i in range(0, len(vr), fps)]
    uniform_sampled_frames = np.linspace(0, total_frame_num - 1, args.for_get_frames_num, dtype=int)
    frame_idx = uniform_sampled_frames.tolist()
    spare_frames = vr.get_batch(frame_idx).asnumpy()
    return spare_frames


def run_inference(args):
    """
    Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.

    Args:
        args: Command-line arguments.
    """
    # Initialize the model
    model_name = get_model_name_from_path(args.model_path)
    # Set model configuration parameters if they exist
    if args.overwrite == True:
        overwrite_config = {}
        overwrite_config["mm_resampler_type"] = args.mm_resampler_type
        overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
        overwrite_config["mm_spatial_pool_out_channels"] = args.mm_spatial_pool_out_channels
        overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
        overwrite_config["patchify_video_feature"] = False

        cfg_pretrained = AutoConfig.from_pretrained(args.model_path)

        
        if "224" in cfg_pretrained.mm_vision_tower:
            # suppose the length of text tokens is around 1000, from bo's report
            least_token_number = args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000
        else:
            least_token_number = args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2 + 1000

        scaling_factor = math.ceil(least_token_number/4096)
        # import pdb;pdb.set_trace()

        if scaling_factor >= 2:
            if "mistral" not in cfg_pretrained._name_or_path.lower() and "7b" in cfg_pretrained._name_or_path.lower():
                print(float(scaling_factor))
                overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
            overwrite_config["max_sequence_length"] = 4096 * scaling_factor
            overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor

        tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, load_8bit=args.load_8bit, overwrite_config=overwrite_config)
    else:
        tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name)

    # Create the output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    output_name = args.output_name
    answers_file = os.path.join(args.output_dir, f"{output_name}.json")
    ans_file = open(answers_file, "w")

    video_path = args.video_path
    train_file = f"{video_path}/train.json"

    with open(train_file, "r") as f:
        lines = f.readlines()
        test_files = [json.loads(line) for line in lines]
        random.shuffle(test_files)
        false = 0

        for n, test_file in enumerate(tqdm(test_files[:-2])):

            instruction = "You are a robot assistant that can help summarize the host's preference."
            
            video_path_1 = test_files[n+1]["camera"]
            preference_1 = test_files[n+1]["preference"]
            
            video_path_2 = test_files[n+2]["camera"]
            preference_2 = test_files[n+2]["preference"]
            
            video_path_3, preference_3 = get_same_demo(test_file["preference"], test_files)
            
            # shuffle the video logs and corresponding preferences
            demos = [[video_path_1, preference_1], [video_path_2, preference_2], [video_path_3, preference_3]]
            random.shuffle(demos)
            video_path_1, preference_1 = demos[0]
            video_path_2, preference_2 = demos[1]
            video_path_3, preference_3 = demos[2]

            # Check if the video exists
            videos = []
            for video in [video_path_1, video_path_2, video_path_3, test_file["camera"]]:
                if os.path.exists(test_file["camera"]):
                    video = load_video(test_file["camera"], args)
                    video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda()
                    videos.append(video)
                else:
                    assert False, f"Video file {video} does not exist."

            # possible_preferences = Rearrangement[0]['Level0'] + Rearrangement[2]['Level2']
            possible_preferences = Sequence_Preferences['name']
            instruction += f"possible preferences are: \n{parse_concat(possible_preferences, replace=', ')}."
            instruction += f"The preference of the first video is: {preference_1}."
            instruction += f"The preference of the second video is: {preference_2}."
            instruction += f"The preference of the third video is: {preference_3}."
            instruction += f"Please summarize the preference from the last video."
            instruction += "Quesiton: What's the user's preference? Choose from the preference listed before:"

            sample_set = {}
            # question = "Please provide a detailed description of the video, focusing on the main subjects, their actions, and the background scenes"
            sample_set["Q"] = instruction
            sample_set["video_name"] = test_file["camera"]


            # try:
            # Run inference on the video and add the output to the list

            qs = instruction
            if model.config.mm_use_im_start_end:
                qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

            conv = conv_templates[args.conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
            attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()

            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

            cur_prompt = instruction
            with torch.inference_mode():
                model.update_prompt([[cur_prompt]])
                # import pdb;pdb.set_trace()
                start_time = time.time()
                output_ids = model.generate(inputs=input_ids, images=videos, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.05, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
                end_time = time.time()
                print(f"Time taken for inference: {end_time - start_time} seconds")
                # import pdb;pdb.set_trace()
                # output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, use_cache=True, stopping_criteria=[stopping_criteria])

            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            # print(f"Question: {prompt}\n")
            # print(f"Response: {outputs}\n")
            # import pdb;pdb.set_trace()
            if outputs.endswith(stop_str):
                outputs = outputs[: -len(stop_str)]
            outputs = outputs.strip()

            # print()
            answer = outputs.split(".")[0]           
            gt = test_file["preference"]
            answer = answer.lower()
            for keyword in gt.split(" ")[:]:
                if keyword.lower() not in answer:
                    false += 1
                    print(f"False: {answer} vs {gt}")
                    break
            print(f"True: {n+1-false}/{n+1}")

            sample_set["pred"] = outputs
            sample_set["gt"] = test_file["preference"]
            ans_file.write(json.dumps(sample_set) + "\n")
            ans_file.flush()

        print(f"True: {n+1-false}/{len(test_files)}")

    ans_file.close()


if __name__ == "__main__":
    args = parse_args()
    run_inference(args)
