import argparse
import json
import os
import time
import uuid
from datetime import datetime

import numpy as np
from jsonschema import validate, exceptions
import typing as t

from openai import OpenAI, OpenAIError
from tqdm import tqdm

from audio_utils import json_schemas, schemas as s
from prompt_generation import prompts


class TextGenerator:
    def __init__(
            self,
            model: str,
            temperature: float,
            batch_size: int,
            input_captions: list[s.InputCaption],
            output_path: str
    ):
        self.model = model
        self.temperature = temperature
        self.batch_size = batch_size
        self.input_captions = input_captions
        self.output_path = output_path

        self.system_prompt = prompts.system_prompt_audio
        self.few_shots = prompts.few_shots_audio

        # prepare inputs according to batch size
        n_few_shot_chunks = len(self.few_shots) / batch_size
        assert n_few_shot_chunks.is_integer()
        self.few_shots = [list(x) for x in np.array_split(self.few_shots, int(n_few_shot_chunks))]

        self.batched_few_shots = []
        for sublist in self.few_shots:
            self.batched_few_shots.extend([
                {"role": "user", "content": "\n".join([x["input"] for x in sublist])},
                {"role": "assistant", "content": json.dumps({"samples": sublist})},
            ])

        self.client = OpenAI()

    def generate(self):
        """
        Main entry to generate synthetic datasets which delegates to either prompt-based generation. Can be extended to enable OpenAI batching.
        """
        self._create_prompt_file()

    def _create_prompt_file(self):
        """Creates a synthetic prompt file based on the provided captions."""
        def prompt_processor(caption_batch: list[s.InputCaption], custom_id: str) -> str:
            samples = self._send_completion_request(input_captions=caption_batch)
            if samples is None:
                print(f"Error: Samples for captions {caption_batch} is empty. Skipping...")
                return ""
            return "\n".join([json.dumps(sample) for sample in samples])

        self._process_captions(caption_processor=prompt_processor)

    def _process_captions(self, caption_processor: t.Callable[[list[s.InputCaption], str], str]):
        """
        Helper to process input captions in chunks. Invokes a caption_processor on each chunk.
        """
        for i in tqdm(range(0, len(self.input_captions), self.batch_size)):
            input_caption_batch = self.input_captions[i:i+self.batch_size]
            custom_id = uuid.uuid4().hex
            output_str = caption_processor(input_caption_batch, custom_id)
            if output_str:
                with open(self.output_path, "a") as jsonl_file:
                    jsonl_file.write(output_str + "\n")

    def _send_completion_request(self, input_captions: list[s.InputCaption]) -> t.Optional[list[dict]]:
        """Sends a single chat completion request to the LLM using the prepared configuration."""
        llm_config = self._get_llm_config(input_captions=input_captions)
        completion = self.client.chat.completions.create(**llm_config)

        retries = 0
        max_retries = 10
        base_delay = 1
        max_delay = 30
        while True:
            try:
                return self._prepare_completion(completion=completion.model_dump(), metadata_list=[x["metadata"] for x in input_captions])
            except OpenAIError as e:
                retries += 1
                if retries > max_retries:
                    raise RuntimeError("Exceeded max retries for OpenAI API") from e

                wait_time = min(base_delay * (2 ** (retries - 1)), max_delay)
                print(f"Error from OpenAI encountered {e}. Retrying {retries}/{max_retries} in {wait_time:.1f}s...")
                time.sleep(wait_time)

    def _prepare_completion(self, completion: dict, metadata_list: list[dict]) -> t.Optional[list[dict]]:
        """Parses LLM response and validates it against a JSON schema."""
        content = completion["choices"][0]["message"]["content"]
        try:
            response = json.loads(content)
            validate(response, json_schemas.response_schema)
        except json.JSONDecodeError as e:
            print("Invalid JSON", e)
            return None
        except exceptions.ValidationError as e:
            print("Invalid OpenAI response for batch", e)
            print(content)
            return None

        # prepare metadata
        if len(response["samples"]) != len(metadata_list):
            print("Something went wrong...")
            print(f"{len(response['samples'])} != {len(metadata_list)}, {completion}")

        results = []
        for sample, meta in zip(response["samples"], metadata_list):
            results.append({
                "data": sample,
                "metadata": {
                    "uid": uuid.uuid4().hex,
                    "model": completion["model"],
                    "created": datetime.fromtimestamp(completion["created"]).isoformat(),
                    "input_caption": {
                        **meta
                    }
                }
            })

        return results

    def _get_llm_config(self, input_captions: list[s.InputCaption]) -> dict:
        """Builds the config dictionary for the LLM request, using system prompts and few-shots."""
        return {
            "model": self.model,
            "temperature": self.temperature,
            "messages": [
                {"role": "system", "content": self.system_prompt},
                *self.batched_few_shots,
                {"role": "user", "content": "\n".join([x["caption"] for x in input_captions])},
            ],
            "response_format": {
                "type": "json_schema",
                "json_schema": {
                    "name": "audio_response",
                    "strict": True,
                    "schema": json_schemas.response_schema
                }
            }
        }


def main():
    """Main entry point for prompt generation."""
    parser = argparse.ArgumentParser(description="Generate text dataset based on provided input captions")

    # General arguments
    parser.add_argument("--input-path", required=True, type=str, help="Path to input dataset")
    parser.add_argument("--output-dir", required=True, type=str, help="Path to output directory")
    parser.add_argument("--n-chunks", required=False, default=1, type=int, help="Total chunks")
    parser.add_argument("--curr-chunk", required=False, default=0, type=int, help="Current chunk")
    parser.add_argument("--start-index", required=False, default=0, type=int, help="Start index of input dataset")
    parser.add_argument("--end-index", required=False, default=-1, type=int, help="End index of input dataset")

    # Model Config
    parser.add_argument("--model", required=False, default="gpt-4o", type=str, help="OpenAI model name")
    parser.add_argument("--temperature", required=False, default=0.7, type=float, help="Temperature of LLM")
    parser.add_argument("--batch-size", required=False, default=1, type=int, help="How many prompts to generate per LLM call")

    args = parser.parse_args()

    with open(args.input_path, "r") as jsonl_file:
        input_captions: list[s.InputCaption] = [json.loads(x) for x in jsonl_file.readlines()]
    input_captions = input_captions[args.start_index:args.end_index]
    print(f"Generating {len(input_captions)} total prompts...")

    os.makedirs(args.output_dir, exist_ok=True)
    output_filename = f"audio_{args.model}_t{args.temperature}_bs{args.batch_size}"
    try:
        with open(os.path.join(args.output_dir, f"{output_filename}.jsonl"), "r") as jsonl_file:
            processed_uids = [json.loads(x)["metadata"]["input_caption"]["uid"] for x in jsonl_file.readlines()]
    except FileNotFoundError:
        processed_uids = []
        print("Starting from scratch...")

    # filter out processed ids and chunk prompts
    input_captions = [x for x in input_captions if x["metadata"]["uid"] not in processed_uids]
    print(f"Working on a subset of {len(input_captions)} prompts after filtering...")
    input_captions = [x for x in np.array_split(input_captions, args.n_chunks)[args.curr_chunk].tolist()]
    print(f"Working on a subset of {len(input_captions)} prompts after chunking...")

    text_generator = TextGenerator(
        model=args.model,
        temperature=args.temperature,
        batch_size=args.batch_size,
        input_captions=input_captions,
        output_path=os.path.join(args.output_dir, f"{output_filename}_chunk{args.curr_chunk}.jsonl"),
    )
    text_generator.generate()


if __name__ == '__main__':
    main()