import json
import os
from pathlib import Path

import numpy as np
import gc
import polars as pl
import torch
import torchaudio
from dotenv import load_dotenv
from openai import OpenAI
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import Dataset

from effects import trim
from torch_edit import (
    add,
    drop,
    high_pass,
    inpaint,
    loop,
    low_pass,
    noise,
    pitch,
    replace,
    speed,
    superres,
    swap,
)

load_dotenv()


class TestDataset(Dataset):
    def __init__(
        self,
        data_dir: Path,
        csv_path: Path,
        prompt_path: Path,
        channels: int,
        strip_silence: bool = True,
        use_o3: bool = True,
    ) -> None:
        super().__init__()
        data_dir = data_dir.expanduser()
        self.data_dir = data_dir

        df = pl.read_csv(csv_path)
        df = df.filter(pl.col("n_elements") > 0)
        min_elements = df.get_column("n_elements").min()
        max_elements = df.get_column("n_elements").max()
        assert isinstance(min_elements, int)
        assert isinstance(max_elements, int)
        assert not isinstance(min_elements, bool)
        assert not isinstance(max_elements, bool)
        self.element_dict: dict[int, list[int]] = dict()
        for i in range(min_elements, max_elements + 1):
            filter_df = df.filter(pl.col("n_elements") == i)
            self.element_dict[i] = filter_df.get_column("audiocap_id").to_list()
        self.audio_descs: dict[int, str] = dict()
        for id, desc in df.select("audiocap_id", "input").iter_rows():
            self.audio_descs[id] = desc
        self.sampling_rate: int = 44_100
        self.channels = channels
        if channels not in (1, 2):
            raise NotImplementedError("Only mono & stereo channels supported")

        self.strip_silence = strip_silence
        self.max_len_sec = 40

        self.task_comp = {
            "ADD": [max_elements, 1],
            "DROP": [max_elements, 1],
            "REPLACE": [max_elements, 1, 1],
            "INPAINT": [max_elements],
            "SUPER_RES": [max_elements],
            "DENOISE": [max_elements],
            "PITCH": [max_elements],
            "SPEED": [max_elements],
            "HIGH_PASS": [max_elements],
            "LOW_PASS": [max_elements],
            "SWAP": [1, 1],  # Perhaps more complex?
            "LOOP": [max_elements],
        }

        self.tasks = [
            "ADD",
            "DROP",
            "REPLACE",
            "INPAINT",
            "SUPER_RES",
            "DENOISE",
            "PITCH",
            "SPEED",
            "HIGH_PASS",
            "LOW_PASS",
            "SWAP",
            "LOOP",
        ]
        self.rng = np.random.default_rng()
        self.openai_client: None | OpenAI = None
        self.use_o3 = use_o3

        with open(prompt_path, "r") as f:
            self.prompt = json.load(f)

    def __len__(self) -> int:
        return len(self.audio_descs)

    def gpt_api_call(
        self,
        system_prompt: str,
        user_prompt: str,
        examples: list[dict[str, str]],
        gpt_model: str = "gpt-4.1-mini",
    ) -> str:
        if self.openai_client is None:
            openai_key = os.environ["OPENAI_API_KEY"]
            self.openai_client = OpenAI(api_key=openai_key)
        try:
            completion = self.openai_client.chat.completions.create(
                model=gpt_model,
                messages=[
                    {"role": "developer", "content": system_prompt},
                    *examples,
                    {
                        "role": "user",
                        "content": user_prompt,
                    },
                ],
            )
            command = completion.choices[0].message.content
            assert command is not None, "GPT API returned None"
            return command
        except Exception as e:
            print(f"Error in GPT API call: {e}")
            raise e

    def process_command(self, command: str, process: str) -> str:
        post_system_prompt = self.prompt[process]["system_prompt"]
        examples = self.prompt[process]["examples"]
        example_messages: list[dict[str, str]] = []
        for example in examples:
            example_keys = sorted(list(example["inputs"]))
            user_msg = {
                "role": "user",
                "content": "\n".join(
                    f"{key}: {example['inputs'][key]}" for key in example_keys
                ),
            }
            assistant_msg = {"role": "assistant", "content": example["instruction"]}
            example_messages.extend([user_msg, assistant_msg])

        command = self.gpt_api_call(
            system_prompt=post_system_prompt,
            user_prompt=command,
            examples=example_messages,
            gpt_model="o3-mini" if self.use_o3 else "gpt-4.1-mini",
        )
        if command is None:
            raise ValueError("Failed to get a response from the GPT API.")
        return command

    def command(
        self,
        task: str,
        desc_1: str,
        desc_2: str | None = None,
        desc_3: str | None = None,
        parameter: str | int | float | None = None,
    ) -> str:
        instruction = self.prompt[task]
        system_prompt = instruction["system_prompt"]
        examples = instruction["examples"]
        example_messages: list[dict[str, str]] = []
        for example in examples:
            example_keys = sorted(list(example["inputs"]))
            user_msg = {
                "role": "user",
                "content": "\n".join(
                    f"{key}: {example['inputs'][key]}" for key in example_keys
                ),
            }
            assistant_msg = {"role": "assistant", "content": example["instruction"]}
            example_messages.extend([user_msg, assistant_msg])

        input_captions = f"caption1: {desc_1}"
        if desc_2 is not None:
            input_captions += f"\ncaption2: {desc_2}"
        if desc_3 is not None:
            input_captions += f"\ncaption3: {desc_3}"
        if parameter is not None:
            input_captions += f"\nparameter: {parameter}"

        command = self.gpt_api_call(
            system_prompt=system_prompt,
            user_prompt=input_captions,
            examples=example_messages,
        )

        for process in [
            "POST_PROCESS_vari",
            "POST_PROCESS_minimal",
            # "POST_PROCESS_natural",
        ]:
            if self.rng.random() < 0.5:
                command = self.process_command(command, process)

        return command

    def __getitem__(self, _index: int) -> tuple[Tensor, Tensor, str]:
        while True:
            task = self.rng.choice(self.tasks)
            indices = []
            audio_list = []
            desc_list = []
            print(task)
            for max_comp in self.task_comp[task]:
                candidates = [
                    indice
                    for n_element in self.element_dict
                    if n_element <= max_comp
                    for indice in self.element_dict[n_element]
                ]
                while (indice := self.rng.choice(candidates)) in indices:
                    continue
                indices.append(indice)

                audio, sr = torchaudio.load(self.data_dir / f"{indice}.wav")
                if audio.shape[0] != self.channels:
                    audio = audio.mean(dim=0, keepdim=True)
                    audio = audio.repeat(self.channels, 1)
                if sr != self.sampling_rate:
                    audio = torchaudio.functional.resample(
                        audio, orig_freq=sr, new_freq=self.sampling_rate
                    )
                if self.strip_silence:
                    audio = audio.numpy()
                    audio = trim(audio)
                    audio = torch.from_numpy(audio)
                audio_list.append(audio)
                desc_list.append(self.audio_descs[indice])

            audio_lens: list[int] = [audio.shape[1] for audio in audio_list]
            max_audio_len = max(audio_lens)
            audio_list = [
                F.pad(audio, (0, max_audio_len - audio.shape[1]))
                for audio in audio_list
            ]
            parameter = None

            match task:
                case "ADD":
                    audio_list = [
                        audio[:audio_len]
                        for (audio, audio_len) in zip(audio_list, audio_lens)
                    ]
                    if audio_lens[0] < audio_lens[1]:
                        continue
                    if self.rng.normal() > 0:
                        loc: str | float = self.rng.choice(["start", "middle", "end"])
                    else:
                        max_sec = (audio_lens[0] - audio_lens[1]) / self.sampling_rate
                        loc = round(self.rng.random() * max_sec, 1)
                    parameter = loc
                    input_audio, output_audio = add(
                        *audio_list, loc=loc, base_sr=self.sampling_rate
                    )
                case "DROP":
                    input_audio, output_audio = drop(*audio_list)
                case "REPLACE":
                    input_audio, output_audio = replace(*audio_list)
                case "INPAINT":
                    input_audio, output_audio = inpaint(*audio_list)
                case "SUPER_RES":
                    input_audio, output_audio = superres(
                        *audio_list,
                        base_sr=self.sampling_rate,
                        target_sr=self.sampling_rate // 4,
                    )
                case "DENOISE":
                    input_audio, output_audio = noise(*audio_list, scale=0.01)
                case "PITCH":
                    parameter = 1
                    while parameter == 1:
                        parameter = self.rng.integers(-11, 12)
                    input_audio, output_audio = pitch(
                        *audio_list, base_sr=self.sampling_rate, steps=parameter
                    )
                case "SPEED":
                    parameter = np.exp(self.rng.uniform(-1, 1) * np.log(3))
                    input_audio, output_audio = speed(*audio_list, factor=parameter)
                case "HIGH_PASS":
                    input_audio, output_audio = high_pass(
                        *audio_list, base_sr=self.sampling_rate
                    )
                case "LOW_PASS":
                    input_audio, output_audio = low_pass(
                        *audio_list, base_sr=self.sampling_rate
                    )
                case "SWAP":
                    audio_list = [
                        audio[:audio_len]
                        for (audio, audio_len) in zip(audio_list, audio_lens)
                    ]
                    input_audio, output_audio = swap(*audio_list)
                case "LOOP":
                    max_loop = np.floor(
                        self.max_len_sec / (max_audio_len / self.sampling_rate)
                    )
                    if max_loop > 1:
                        parameter = self.rng.integers(1, max_loop)
                    else:
                        parameter = 1
                    input_audio, output_audio = loop(*audio_list, num_loop=parameter)
                case _:
                    raise NotImplementedError()

            if (
                len(input_audio) / self.sampling_rate <= self.max_len_sec
                and len(output_audio) / self.sampling_rate <= self.max_len_sec
            ):
                break
            else:
                print("Exceeded maximum length!")
                print(
                    f"Task {task} yielded {len(output_audio) / self.sampling_rate} second audio"
                )
                print("Retrying")
                del input_audio, output_audio, audio_list, indices, desc_list
                gc.collect()

        command = self.command(task, *desc_list, parameter=parameter)

        return input_audio, output_audio, command, indices, desc_list
