import logging
import logging.handlers
import os
import sys
import argparse
import re
import requests
import torch

from cons_utils import BaseOptions, prompt, generate_question

from video_chatgpt.constants import LOGDIR
from video_chatgpt.video_conversation import conv_templates, SeparatorStyle
from video_chatgpt.model.utils import KeywordsStoppingCriteria
from video_chatgpt.eval.model_utils import initialize_model, load_video

# prompt["grounding"] = "Please answer when the given event happens in the video. The format should be: 'start - end seconds'. For example, The event 'person turn a light on' happens in the 24.3 - 30.4 seonds. Now I will give you the textual sentence: {event}. Please return its start time and end time."

# prompt = {
#     "grounding": "Please answer when the event '{event}' in the video. The format should be: 'start - end seconds'.",
#     "occurrence": "Could you answer whether the event {event} occurs from {st} to {ed} seconds in the video?",
#     "add_detail": "You should only answer with 'Yes' or 'No'. Do not provide any additional information or explanation.",
#     "co_occurrence": "Could you answer whether the events '{target1}' and '{target2}' occur at the same time in the video?.",
#     "sequential_after": "Does the event '{target1}' occur after the event '{target2}' in the video?.",
#     "sequential_before": "Does the event '{target1}' occur before the event '{target2}' in the video?.",
#     "compositional": "{question} from {st} to {ed} seconds in the video?",
# }


class VideoChatGPT_Options(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument("--model_path", type=str, default="/mnt/Video-ChatGPT/LLaVA-7B-Lightening-v1-1/")
        self.parser.add_argument("--vision_tower_name", type=str, default="openai/clip-vit-large-patch14")
        self.parser.add_argument("--projection_path", type=str, default="/mnt/Video-ChatGPT/Video-ChatGPT-7B/video_chatgpt-7B.bin")
        self.parser.add_argument("--conv_mode", type=str, default='video-chatgpt_v1')

class VideoChatGPT:
    def __init__(self, args):
        self.model, self.vision_tower, self.tokenizer, self.image_processor, self.video_token_len = self.load_video_chatgpt_model(args)
        self.conv_mode = "video-chatgpt_v1"
        self.DEFAULT_VIDEO_TOKEN = "<video>"
        self.DEFAULT_VIDEO_PATCH_TOKEN = "<vid_patch>"
        self.DEFAULT_VID_START_TOKEN = "<vid_start>"
        self.DEFAULT_VID_END_TOKEN = "<vid_end>"
        self.debug = args.debug

    def load_video_chatgpt_model(self, args):
        model, vision_tower, tokenizer, image_processor, video_token_len = \
            initialize_model(args.model_path, args.projection_path)

        return model, vision_tower, tokenizer, image_processor, video_token_len

    def get_spatio_temporal_features_torch(self, features):
        """
        Computes spatio-temporal features from given features.

        Parameters:
        features (torch.Tensor): Input features to process.

        Returns:
        torch.Tensor: Spatio-temporal features.
        """

        # Extract the dimensions of the features
        t, s, c = features.shape

        # Compute temporal tokens as the mean along the time axis
        temporal_tokens = torch.mean(features, dim=1)

        # Padding size calculation
        padding_size = 100 - t

        # Pad temporal tokens if necessary
        if padding_size > 0:
            padding = torch.zeros(padding_size, c, device=features.device)
            temporal_tokens = torch.cat((temporal_tokens, padding), dim=0)

        # Compute spatial tokens as the mean along the spatial axis
        spatial_tokens = torch.mean(features, dim=0)

        # Concatenate temporal and spatial tokens and cast to half precision
        concat_tokens = torch.cat([temporal_tokens, spatial_tokens], dim=0).half()

        return concat_tokens


    def video_chatgpt_infer(self, video_spatio_temporal_features, question, conv=None, add_detail=None):
        """
        Run inference using the Video-ChatGPT model.

        Parameters:
        video_spatio_temporal_features (torch.Tensor): Video frame features.
        question (str): The question string.
        Returns:
        dict: Dictionary containing the model's output.
        """

        # Prepare question string for the model
        if self.model.get_model().vision_config.use_vid_start_end:
            qs = question + '\n' + self.DEFAULT_VID_START_TOKEN + self.DEFAULT_VIDEO_PATCH_TOKEN * self.video_token_len + self.DEFAULT_VID_END_TOKEN
        else:
            qs = question + '\n' + self.DEFAULT_VIDEO_PATCH_TOKEN * self.video_token_len

        # Prepare conversation prompt
        conv = conv_templates[self.conv_mode].copy()
        if add_detail:
            conv.system += add_detail
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        # Tokenize the prompt
        inputs = self.tokenizer([prompt])

        # Move inputs to GPU
        input_ids = torch.as_tensor(inputs.input_ids).cuda()

        # Define stopping criteria for generation
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids)

        # Run model inference
        self.model.eval()
        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                video_spatio_temporal_features=video_spatio_temporal_features.unsqueeze(0),
                do_sample=False,
                temperature=0.2,
                max_new_tokens=1024,
                stopping_criteria=[stopping_criteria])

        # Check if output is the same as input
        n_diff_input_output = (input_ids != output_ids[:, :input_ids.shape[1]]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')

        # Decode output tokens
        outputs = self.tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0]

        # Clean output string
        outputs = outputs.strip().rstrip(stop_str).strip()

        return outputs

    def load_video_features(self, video_path):
        video_frames = load_video(video_path)
        image_tensor = self.image_processor.preprocess(video_frames, return_tensors='pt')['pixel_values']

        # Move image tensor to GPU and reduce precision to half
        image_tensor = image_tensor.half().cuda()

        # Generate video spatio-temporal features
        with torch.no_grad():
            image_forward_outs = self.vision_tower(image_tensor, output_hidden_states=True)
            frame_features = image_forward_outs.hidden_states[-2][:, 1:]  # Use second to last layer as in LLaVA
        video_spatio_temporal_features = self.get_spatio_temporal_features_torch(frame_features)

        return video_spatio_temporal_features, None

    def extract_time_video_chatgpt(self, sentence):
        times = re.findall(r'\b(\d+):(\d+)\b', sentence)

        # Convert the extracted time strings to seconds
        results = []
        for minutes, seconds in times:
            total_seconds = int(minutes) * 60 + int(seconds)
            results.append(total_seconds)

        if len(results) != 2:
            return [0, 0]
        else:
            return results[:2]

    def run(self, task, video_features, query, duration, st=None, ed=None, msg=None):
        # if st and ed:
        #     st, ed = min(st, duration), min(ed, duration)
        #
        # add_details = None
        # if task in ["grounding"]:
        #     question = prompt[task].format(event=query)
        #
        # elif task in ["occurrence"]:
        #     # question = prompt[task].format(event=query, st=st, ed=ed)
        #     tasks = ["pos_occurrence", "neg_occurrence"]
        #     task_choice = random.randint(0, 1)
        #     if task_choice == 0: # pos
        #         question = random.choice(prompt[tasks[0]]).format(event=query, st=st, ed=ed)
        #     else: # neg
        #         question = random.choice(prompt[tasks[1]]).format(event=query, st=st, ed=ed)
        #     add_details = prompt["add_detail"]
        #
        # elif task in ["co_occurrence", "sequential_after", "sequential_before"]:
        #     if not isinstance(query, list):
        #         raise ValueError(f"Invalid style of query: {query}")
        #
        #     question = prompt[task].format(target1=query[0], target2=query[1])
        #     add_details = prompt["add_detail"]
        #
        # elif task in ["compositional"]:
        #     query = query.replace("?", "")
        #     question = prompt[task].format(question=query, st=st, ed=ed)
        #
        # else:
        #     raise NotImplementedError(f"Not implemented task: {task}")
        question, add_detail, choice = generate_question(task, prompt, query, duration, st, ed)

        answer = self.video_chatgpt_infer(video_features, question, add_detail=add_detail)

        if self.debug:
            print("Question:" + question)
            print("Answer:" + answer)
            print("")

        if task in ["grounding"]:
            return self.extract_time_video_chatgpt(answer)

        if task in ["occurrence"]:
            return [choice, answer]

        return answer
