import openai
import json
import os
import logging
from google import genai
import pandas as pd
import time
import requests
from google.genai.errors import ClientError
from openai import OpenAI
from typing import Literal
import httpcore
import httpx
import jsonschema

from functools import wraps

# 设置日志
from prompt_generate.util import logger
from util.interface import CausalModel


class Retry():
    def __init__(self, retries=100, delay=10):
        self.retries = retries
        self.delay = delay

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if not args:
                raise ValueError(f"Function {func.__name__} must have at least one argument (the 'self' instance).")
            instance = args[0]
            if not isinstance(instance, object) or not hasattr(instance, "switch_api_key"):
                raise ValueError(f"Function {func.__name__} must be a method of a class with 'switch_api_key' method.")

            consecutive_client_errors = 0  # 连续 ClientError 错误次数
            for attempt in range(self.retries):
                try:
                    return func(*args, **kwargs)
                except ClientError as e:
                    logger.warning(f"ClientError occurred during function {func.__name__}: {e}")
                    consecutive_client_errors += 1
                    if consecutive_client_errors >= 3:
                        logger.warning(
                            f"Encountered {consecutive_client_errors} consecutive ClientErrors. Switching API key...")
                        instance.switch_api_key()
                        consecutive_client_errors = 0  # Reset counter after switching API key
                    time.sleep(self.delay)

                except requests.exceptions.ConnectionError as e:
                    # It could be temporary network issue, so retry after some seconds. Especially design for too frequent requests.
                    logger.warning(
                        f"Error occurred at {func.__name__} running: {e}. Retrying in {self.delay} seconds...")
                    time.sleep(self.delay)  # Wait 30 seconds before retrying
                except httpcore.ConnectError as e:
                    # Handle httpcore connection error, retry after delay
                    logger.warning(
                        f"ConnectError occurred at {func.__name__} running: {e}. Retrying in {self.delay} seconds...")
                    time.sleep(self.delay)  # Wait for retry after 10 seconds
                except httpx.ConnectError as e:
                    # Handle httpx connection error (which includes httpcore.ConnectError)
                    logger.warning(
                        f"ConnectError occurred at {func.__name__} running: {e}. Retrying in {self.delay} seconds...")
                    time.sleep(self.delay)  # Wait for retry after 10 seconds
                except openai.APIConnectionError as e:
                    logger.error(f"API Connection Error: {e}")
                    time.sleep(self.delay)  # Wait for retry after 10 seconds
                except Exception as e:
                    # Handle any other unexpected errors
                    logger.error(f"Unexpected error occurred at {func.__name__} running: {e}")
                    raise e
            raise Exception(f"Failed to run {func.__name__} after {self.retries} attempts.")

        return wrapper


# 自定义一个类来处理视频提问任务
class VideoQuestionGenerator:
    def __init__(self, api: str, vllm: str, video_model: str, database_path: str, dataset_path: str, api_keys: list,
                 feedback_times: int = 3, start_api_key_index: int = 0, one_by_one: bool = False,
                 force_regen: bool = False, suffix: str = "", index: list[int] | None = None, frame_interval: int = 10,
                 debug: bool = False):
        self.api = api
        self.vllm = vllm
        self.video_model = video_model
        self.api_keys = api_keys  # a list of api keys, only used by google api
        self.api_key_index = start_api_key_index  # the default index of the api key in the list
        self.set_client(self.api)
        self.database_path = database_path
        self.dataset_path = dataset_path
        self.feedback_times = feedback_times
        self.one_by_one = one_by_one
        self.force_regen = force_regen
        if suffix and not suffix.startswith("_"):
            self.suffix = "_" + suffix
        else:
            self.suffix = suffix
        self.index = index
        self.frame_interval = frame_interval
        self.debug = debug

    def set_client(self, api_name: str):
        if api_name == "openai":
            logger.warning(
                "Using OpenAI API. The api_key parameter will be ignored. Please provide the API key as environment variable OPENAI_API_KEY.")
            self.client = OpenAI()
        elif api_name == "google":
            self.client = genai.Client(api_key=self.api_keys[self.api_key_index])
        else:
            raise ValueError(f"API name {api_name} not recognized.")

    def save_debug_response(self, video_path, response_dict, idx: int | None = None):
        save_file = self._get_save_path(video_path)
        # add '_debug.json' to the file name
        if idx is None:
            save_file = save_file.replace('.json', '_debug.json')
        else:
            save_file = save_file.replace('.json', f'_debug_{idx}.json')
        with open(save_file, 'w') as f:
            json.dump(response_dict, f, indent=4)
        logger.info(f"Debug response saved to {save_file}")

    def load_scenario_csv(self):
        # 读取 CSV 文件
        df = pd.read_csv(os.path.join(self.database_path, "scenario2.csv"))
        return df

    def generate_questions(self, scenario) -> tuple[dict[str, str], list[str]]:
        # 根据 scenario 读取问题 JSON 文件
        sample_file_path = os.path.join(self.dataset_path, "samples", f"{scenario.split('%')[0]}.json")
        if not os.path.exists(sample_file_path):
            raise ValueError(f"Sample file for scenario {scenario} not found at {sample_file_path}")
        with open(sample_file_path, "r") as f:
            samples = json.load(f)
        probes = samples["probes"]["factor_question_pairs"]
        causal_model = CausalModel(roots=samples["roots"], non_roots=samples["non_roots"], rules=samples["rules"],
                                   scenario=samples["scenario"])
        topo_sorted_factors: list[str] = samples["roots"] + causal_model.topo_sorted_non_roots()
        topo_sorted_factors.reverse()
        # probe_file_path = os.path.join(self.dataset_path, "probes", f"{scenario.split('%')[0]}.json")
        # if not os.path.exists(probe_file_path):
        #     raise ValueError(f"Probe file for scenario {scenario} not found at {probe_file_path}")
        # with open(probe_file_path, "r") as f:
        #     probes = json.load(f)
        return probes, topo_sorted_factors

    def switch_api_key(self):
        """
        only used by google api, switch to the next api key in the list to avoid rate limit
        """
        assert self.api == "google", "Only Google API requires switching API keys."
        self.api_key_index = (self.api_key_index + 1) % len(self.api_keys)
        self.client = genai.Client(api_key=self.api_keys[self.api_key_index])
        logger.info(f"Switched to API key {self.api_key_index + 1}.")

    @Retry(retries=100, delay=10)
    def upload_video(self, video_path: str, wait_times: int = 100, delay: int = 10):
        """
        上传视频，并等待其状态变为 "ACTIVE"。
        如果上传失败（状态为 "FAILED"），则抛出异常。
        如果发生 ConnectionError，则重试最多 n 次。
        """

        def get_video(video_file, retries, delay):
            retry = 0
            while True:
                video_file = self.client.files.get(name=video_file.name)
                if video_file.state.name == "ACTIVE":
                    logger.info(f"Video {video_path} is ready for use (State: ACTIVE)")
                    return video_file
                elif video_file.state.name == "FAILED":
                    raise ValueError(f"Video upload failed for {video_path}. State: {video_file.state.name}")
                if retry > 0:
                    logger.info(
                        f"Video {video_path} is still processing (State: {video_file.state.name}). Total wait time: {retry * delay} seconds.")
                time.sleep(delay)
                if retry >= retries:
                    raise Exception(f"Failed to get video after {retry * delay} seconds: {video_path}")

        # 上传视频文件
        logger.info(f"Uploading video from {video_path}...")

        # Different version of api use different parameter name, check here.
        from inspect import signature
        upload_params = signature(self.client.files.upload).parameters
        if 'file' in upload_params:
            video_file = self.client.files.upload(file=video_path)
        elif 'path' in upload_params:
            video_file = self.client.files.upload(path=video_path)
        else:
            raise ValueError("The upload method does not accept 'file' or 'path' as parameters.")

        logger.info(f"Uploading video {video_path} completed. Waiting for video to be ready...")
        video = get_video(video_file, wait_times, 10)
        return video

    @Retry(retries=100, delay=10)
    def ask_question_one_by_one(self, video_file, video_file_path, prompt, response_schema) -> dict:
        if self.api == 'google' and video_file.state.name != "ACTIVE":
            # if the video is not active, (could be because the api key is switched), re-upload the video
            logger.warning(f"The video {video_file_path} is not active. Re-uploading the video...")
            video_file = self.upload_video(video_file_path)

        if self.api == "google":
            response = self.client.models.generate_content(
                model=self.vllm,
                contents=[video_file, prompt],
                config=genai.types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=response_schema
                )
            )
            return json.loads(response.text)
        elif self.api == "openai":
            messages = [{
                "role": "system",
                "content": prompt,
            }, {
                "role": "user",
                "content": ["The video frames:",
                            *map(lambda x: {"image": x, "resize": 768}, video_file[0::self.frame_interval])],
            }]
            completion = self.client.chat.completions.create(
                model=self.vllm,
                messages=messages,
                response_format={"type": "response_schema",
                                 "json_schema": {"name": "response_schema", "schema": response_schema}}
            )
            return completion.choices[0].message.parsed

    @Retry(retries=100, delay=10)
    def ask_question(self, video_file, video_file_path, questions: dict[str, str], inverse_topo_list: list[str]) -> \
    dict[str, Literal[True, False, 'nan']]:
        """
        Ask questions to the model and get the responses. If the request fails, retry up to n times.
        ----------
        video_file: genai.types.File or list[str]
            The video file object returned by the google API or a list of encoded images as base64 (utf-8) which will be encoded in openai prompt.
        """
        responses: dict[str, Literal[True, False, 'nan']] = {}
        from answer_retrieve.prompt import get_prompt
        from answer_retrieve.schema import SchemaWriter
        questions_qa_dict = {q['factor']: q['question'] for q in questions}
        question_list: list[str] = [questions_qa_dict[factor] for factor in inverse_topo_list]

        schema_writer = SchemaWriter(question_list, api=self.api)

        if not self.one_by_one:
            prompt = get_prompt(question_list)
            response_schema = schema_writer.together_schema()
            if self.api == "google":
                response = self.client.models.generate_content(
                    model=self.vllm,
                    contents=[video_file, prompt],
                    config=genai.types.GenerateContentConfig(
                        response_mime_type="application/json",
                        response_schema=response_schema
                    )
                )
                response_dict = json.loads(response.text)
            elif self.api == "openai":
                messages = [{
                    "role": "system",
                    "content": prompt,
                }, {
                    "role": "user",
                    "content": ["The video frames:",
                                *map(lambda x: {"image": x, "resize": 768}, video_file[0::self.frame_interval])],
                }]
                completion = self.client.beta.chat.completions.parse(
                    model=self.vllm,
                    messages=messages,
                    response_format={"type": "json_schema",
                                     "json_schema": {"name": "qas_and_explanations", "strict": True,
                                                     "schema": response_schema}}
                )
                response_dict = completion.choices[0].message
                print("debug", response_dict)
                if response_dict.refusal:
                    raise ValueError(f"OpenAI refused to generate completions: {response_dict.refusal}")
                try:
                    response_dict = response_dict.parsed["qas"]
                except TypeError:
                    response_dict = json.loads(response_dict.content)["qas"]

            else:  # Should not reach here
                raise ValueError(f"API {self.api} not recognized.")

            if self.debug:
                self.save_debug_response(video_file_path, response_dict)
            # The response_dict is a list of dictionaries, each containing 'question', 'explanation', and 'answer'
            for item in response_dict:
                responses[item['question']] = item['answer']
            for q in questions:
                if q['question'] not in responses:
                    responses[q['question']] = 'nan'
                    logger.error("No response found for question: {q['question']}. Setting answer to 'nan'.")
            return responses

        else:
            for qid, question in enumerate(questions):
                # Create prompt for the question
                prompt = get_prompt(question['question'])
                # Create response schema to allow 'true', 'false', or 'nan'
                response_schema = schema_writer.separate_schema(qid)
                # Ask Gemini the question
                response_dict = self.ask_question_one_by_one(video_file, video_file_path, prompt, response_schema)
                if self.debug:
                    self.save_debug_response(video_file_path, response_dict, idx=qid)
                # Store the answer in the responses dictionary
                responses[question['question']] = response_dict['answer']
            return responses

            # except ClientError as e:
            #     logger.warning(f"ClientError occurred during question asking: {e}.")
            #     consecutive_client_errors += 1
            #     if consecutive_client_errors >= 3:
            #         self.delete_video(video_file_obj)
            #         logger.warning(f"Encountered {consecutive_client_errors} consecutive ClientErrors. Switching API key...")
            #         self.switch_api_key()
            #         consecutive_client_errors = 0  # Reset counter after switching API key
            #         video_file = self.upload_video(video_file_path)
            #         continue
            #     time.sleep(30)
            # except (requests.exceptions.ConnectionError, Exception) as e:
            #     # If either of these exceptions occurs, wait for 30 seconds before retrying
            #     logger.warning(f"Error occurred: {e}. Retrying in 30 seconds...")
            #     time.sleep(30)  # Wait 30 seconds before retrying
            # except Exception as e:
            #     # Handle any other unexpected errors
            #     logger.error(f"Unexpected error occurred: {e}")
            #     raise

    def _get_save_path(self, video_file_path: str) -> str:
        # 获取视频所在的目录
        video_folder = os.path.dirname(video_file_path)
        # 获取视频文件的名称（无扩展名），用于生成 JSON 文件名
        video_name = os.path.splitext(os.path.basename(video_file_path))[0]
        answer_file = os.path.join(video_folder, f"{video_name}{self.suffix}.json")
        return answer_file

    def save_answer(self, video_file_path, response_dict, questions):
        # 创建一个字典：将问题的 factor 作为 key，把答案（bool）作为 value
        factor_answers = {q['factor']: response_dict[q['question']] for q in questions}
        answer_file: str = self._get_save_path(video_file_path)

        # 保存结果
        with open(answer_file, 'w') as f:
            json.dump(factor_answers, f, indent=4)
        logger.info(f"Answers saved to {answer_file}")

    def _process_video_google(self, video_file_path: str, questions: dict[str, str],
                              inverse_topo_factors: list[str]) -> None:
        video_file_obj = self.upload_video(video_file_path)

        # 向 Gemini 提问
        logger.info(f"Asking questions for video {video_file_path}...")
        response_dict = self.ask_question(video_file_obj, video_file_path, questions,
                                          inverse_topo_factors)  # 传入 video_file 和 video_file_path

        # **修正：确保答案文件存储在视频所在的目录**
        self.save_answer(video_file_path, response_dict, questions)
        logger.info(f"Saved answers for {video_file_path}")

        # 删除视频文件
        self.delete_video(video_file_obj)  # 调用修改后的 delete_video 方法
        logger.info(f"Deleted video file {video_file_path}")

    def _process_video_openai(self, video_file_path: str, questions: dict[str, str],
                              inverse_topo_factors: list[str]) -> None:
        from answer_retrieve.video_utils import video2frames
        frames = video2frames(video_file_path)

        logger.info(f"Asking questions for video {video_file_path}...")
        response_dict = self.ask_question(frames, video_file_path, questions,
                                          inverse_topo_factors)  # 传入 video_file 和 video_file_path

    def process_video(self, scenario_id: int, scenario: str) -> None:
        # 找到对应的视频文件夹
        video_folder = os.path.join(self.database_path, self.video_model, str(scenario_id))
        if not os.path.exists(video_folder):
            raise ValueError(f"Video folder for scenario_id {scenario_id} not found.")
        # 找到问题
        questions, inverse_topo_factors = self.generate_questions(scenario)

        # 遍历视频文件夹及子文件夹中的所有视频
        for root, _, files in os.walk(video_folder):  # 遍历子文件夹中的视频
            for file in files:

                if not file.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):  # 只处理视频文件
                    continue

                video_file_path = os.path.join(root, file)

                save_path = self._get_save_path(video_file_path)

                if os.path.exists(save_path) and not self.force_regen:
                    logger.warning("Answers already exist for video {file}. Skipping...")
                    continue


                if self.api == "google":
                    self._process_video_google(video_file_path, questions, inverse_topo_factors)
                elif self.api == "openai":
                    self._process_video_openai(video_file_path, questions, inverse_topo_factors)
                else:
                    raise ValueError(f"API {self.api} not recognized.")

    @Retry(retries=100, delay=10)
    def delete_video(self, video_file_obj):
        """
        尝试删除视频文件，最多重试 'retries' 次，每次重试之间等待 `delay` 秒。
        """
        self.client.files.delete(name=video_file_obj.name)
        logger.info(f"Successfully deleted video {video_file_obj.name}.")
        return  # 删除成功，返回

    def run(self, num_workers: int = 1):
        import multiprocessing
        scenario_df = self.load_scenario_csv()
        if self.index is not None:
            candidates = [(idx, sce) for idx, sce in zip(scenario_df["scenario_id"], scenario_df['scenario']) if
                          (idx in self.index)]
        else:
            candidates = [(idx, sce) for idx, sce in zip(scenario_df["scenario_id"], scenario_df['scenario'])]
        if num_workers == 1:
            for scenario_id, scenario in candidates:
                self.process_video(scenario_id, scenario)
        else:
            with multiprocessing.Pool(processes=num_workers) as pool:
                results = [pool.apply_async(self.process_video, (scenario_id, scenario)) for scenario_id, scenario in
                           candidates]
                for result in results:
                    result.get()  # Wait for all processes to complete
        logger.info("All videos processed successfully.")


if __name__ == "__main__":
    # 解析命令行参数
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-a", "--api", type=str, required=True, help="The name of video reader api provider.",
                        choices=['openai', 'google'])
    parser.add_argument('-v', '--vllm', type=str, required=True, help='The name of the VLLM model.')
    parser.add_argument("-m", "--video_model", type=str, required=True, help="The model name (e.g., videocrafter2).")
    parser.add_argument("-d", "--database", type=str, required=True, help="The path to the database folder.")
    parser.add_argument("-s", "--dataset", type=str, required=True, help="The path to the dataset folder.")
    parser.add_argument("-i", "--idx", type=int, nargs="*",
                        help="The scenario ID(s) to process. If not provided, process all scenarios.")
    parser.add_argument("-f", "--force_regen", action="store_true",
                        help="Force regenerate answers even if they already exist.")
    parser.add_argument("--suffix", type=str, default="", help='Suffix for the output file')
    parser.add_argument("--one_by_one", action="store_true", help="Ask questions one by one.")
    parser.add_argument("--frame_interval", type=int, default=10, help="The interval between frames to send to OpenAI.")
    parser.add_argument("--workers", type=int, default=1, help="Number of workers to use for multiprocessing.")
    parser.add_argument("--debug", action="store_true", help="Save debug responses to separate files.")

    args = parser.parse_args()

    # 运行视频提问任务
    generator = VideoQuestionGenerator(
        api=args.api,
        vllm=args.vllm,
        video_model=args.video_model,
        database_path=args.database,
        dataset_path=args.dataset,
        index=args.idx,
        frame_interval=args.frame_interval,
        api_keys= "",  # Replace with actual keys
        one_by_one=args.one_by_one,
        force_regen=args.force_regen,
        suffix=args.suffix,
        debug=args.debug,
    )
    generator.run(num_workers=args.workers)