#!/usr/bin/env python3
import argparse
import base64
import json
import os
import time
import openai
import concurrent.futures
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--debug',
        action='store_true',
        help='Run in debug mode with only first few entries'
    )
    parser.add_argument(
        '--ak',
        type=str,
        default="",
        help='API key for gemini'
    )
    parser.add_argument(
        '--input_file',
        type=str,
        default="/path/to/your/input.txt",
        help='Path to the file with video paths'
    )
    parser.add_argument(
        '--output_file',
        type=str,
        default="/path/to/your/output.csv",
        help='Base path to the output file (split files will be derived from this name)'
    )
    parser.add_argument(
        '--num_splits',
        type=int,
        default=1,
        help='Number of splits (and parallel API calls) to use'
    )
    return parser.parse_args()

def encode_image(image_path):
    """Reads an image (or video file) from disk and returns its base64 encoding."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def gemini_chat(contents, args):
    """Sends a chat request using the gemini API."""
    base_url = "/path/to/your/gemini/api"
    api_version = "/version"
    ak = args.ak
    model_name = "/model/name"
    max_tokens = 4096  # Valid range: [0, 4096]
    
    client = openai.AzureOpenAI(
        azure_endpoint=base_url,
        api_version=api_version,
        api_key=ak,
    )
    for i in range(5):
        try:
            completion = client.chat.completions.create(
                model=model_name,
                messages=[
                    {
                        "role": "user",
                        "content": str(contents)
                    }
                ],
                max_tokens=max_tokens,
                extra_headers={"X-TT-LOGID": "${your_logid}"},  # Ensure you pass the required x-tt-logid header.
            )
            return completion.choices[0].message.content
        except Exception as e:
            print(f"Attempt {i+1}: {e}")
            time.sleep(5)
            continue
    return ''

def gemini_TAL_PE(video_uri, args, time_padding=0.0):
    """
    For a given video file or URL, returns a caption generated by gemini.
    
    If the video_uri is a local path (i.e., does not contain 'http'),
    it will be read, encoded in base64, and prepended with a data URI.
    """
    if 'http' not in video_uri:
        video_url = f"data:video/webm;base64,{encode_image(video_uri)}"
    else:
        video_url = video_uri

    contents = [
        {
            "type": "image_url",
            "image_url": {
                "url": video_url
            }
        },
        {
            "type": "text",
            "text": (
                "I require you to offer detailed descriptions regarding the actions and expressions of the principal "
                "characters in a human video. In case there are multiple main individuals, kindly describe them one by one, "
                "in descending order of their character prominence. You merely need to respond with the content acquired "
                "from the visual screen and do not describe the audio content and simply return it with JSON. Begin now. "
                "Please ONLY return the content in JSON format:{ \"Environment\": \"xxx\", \"1. Appearance and 2. Posture of ID1\": \"xxx\", "
                "\"Detailed Motion including 1. Detailed Body and Detailed Hands Action and Detailed Gestures, 2. Head and Detailed Facial Action, "
                "and 3. Detailed Emotion of ID1\": \"xxx\", \"1. Appearance and 2. Posture of ID2\": \"xxx\", ...} "
                "Interactions with other IDs should not be ignored. "
                "Please note that although the content is separated in JSON, when assembled in order, it still forms natural sentences."
            )
        }
    ]
    for _ in range(3):
        try:
            res_str = gemini_chat(contents, args)
            return res_str
        except Exception as e:
            print(e)
            continue
    return ''

def process_split(split_id, video_list, args, output_file_base):
    """
    Process a list of videos belonging to one split.
    
    The function first checks the output file (in JSON Lines format) to determine which videos have already been processed.
    For each unprocessed video, it calls gemini_TAL_PE and appends the result as a JSON record (with keys "video" and "caption")
    using append mode ('a') so that progress is saved incrementally.
    """
    base, _ = os.path.splitext(output_file_base)
    output_filename = f"{base}_split{split_id}.json"
    processed_set = set()

    # If the output file exists, load already processed videos.
    if os.path.exists(output_filename):
        try:
            with open(output_filename, 'r', encoding='utf-8') as fin:
                for line in fin:
                    try:
                        record = json.loads(line.strip())
                        if "video" in record:
                            processed_set.add(record["video"])
                    except Exception as e:
                        print(f"Split {split_id}: Error reading a line: {e}")
        except Exception as e:
            print(f"Split {split_id}: Error loading {output_filename}: {e}")

    for video in tqdm(video_list, desc=f"Split {split_id}", unit="video", position=split_id-1, leave=True):
        if video in processed_set:
            print(f"Split {split_id}: Video '{video}' already processed, skipping.")
            continue
        # If the video file does not exist, try removing ".mp4" extension and checking again.
        if not os.path.exists(video):
            video = video.replace(".mp4", "")
        if not os.path.exists(video):
            continue

        caption = gemini_TAL_PE(video, args)
        caption = caption.strip() if caption else ""
        record = {"video": video, "caption": caption}
        try:
            with open(output_filename, 'a', encoding='utf-8') as fout:
                fout.write(json.dumps(record, ensure_ascii=False) + "\n")
        except Exception as e:
            print(f"Split {split_id}: Error writing to '{output_filename}': {e}")

# Optional unittests
import unittest

class TestOpenAI(unittest.TestCase):
    def test_chat(self):
        caption = gemini_chat("Hello, world!", args=parse_args())
        self.assertIsInstance(caption, str)

    def test_streaming_chat(self):
        result = gemini_TAL_PE("dummy_video_path", args=parse_args())
        self.assertIsInstance(result, str)

if __name__ == "__main__":
    args = parse_args()
    
    try:
        with open(args.input_file, 'r', encoding='utf-8') as fin:
            video_paths = [line.strip() for line in fin if line.strip()]
    except Exception as e:
        print(f"Error reading the input file '{args.input_file}': {e}")
        exit(1)
    
    if args.debug:
        video_paths = video_paths[:20]
    
    total_videos = len(video_paths)
    print(f"Total videos to process: {total_videos}")
    
    num_splits = args.num_splits
    if num_splits < 1:
        print("The number of splits (--num_splits) must be at least 1")
        exit(1)
        
    splits = [video_paths[i::num_splits] for i in range(num_splits)]
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_splits) as executor:
        futures = []
        for idx, video_list in enumerate(splits, start=1):
            if not video_list:
                print(f"Split {idx}: No videos to process.")
                continue
            future = executor.submit(process_split, idx, video_list, args, args.output_file)
            futures.append(future)
    
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()
            except Exception as exc:
                print(f"Generated an exception: {exc}")
    
    print("All splits have been processed.")