import json
import random
from pathlib import Path
import pickle
import shutil

import yaml
import fire

from utils.general import read_jsonl_to_mapping, transform_gen_fn_to_id

task_name_mapping = {
    "tts": "text_to_speech",
    "svs": "singing_voice_synthesis",
    "tta": "text_to_audio",
    "se": "speech_enhancement",
    "ttm": "text_to_music",
    "sr": "audio_super_resolution",
    "v2a": "video_to_audio",
}

task_to_metrics = {
    "tts": ["MOS", "SMOS"],
    "svs": ["MOS", "SMOS"],
    "tta": ["OVL", "REL"],
    "ttm": ["OVL", "REL"],
    "sr": ["MOS"],
    "se": ["MOS"],
    "v2a": ["OVL", "SYNC"]
}

gt_data_dict = {
    "tts": [{
        "content":
            "data/libritts/test/phoneme.jsonl",
        "audio":
            "data/libritts/test/audio.jsonl",
        "base_content_path":
            "${env:DATA_BASE,\"/path/to/datasets\"}",
        "base_audio_path":
            "/cpfs02/shared/speechllm/data",
        "content_col":
            "audio",
        "audio_col":
            "audio",
    }],
    "svs": [{
        "content": "data/m4singer/test/midi.jsonl",
        "audio": "data/m4singer/test/audio.jsonl",
        "content_col": "midi",
        "audio_col": "audio",
    }],
    "tta": [{
        "content": "data/audiocaps_v2/test/caption.jsonl",
        "audio": "data/audiocaps_v2/test/audio.jsonl",
        "content_col": "caption",
        "audio_col": "audio",
    }],
    "se": [{
        "content":
            "${env:VOICEBANK_DEMAND_CAPTION,\"/path/to/voicebank_demand/test_metadata_caption.jsonl\"}",
        "audio":
            "${env:VOICEBANK_DEMAND_AUDIO,\"/path/to/voicebank_demand/test_metadata_audio.jsonl\"}",
        "base_content_path":
            "${env:VOICEBANK_DEMAND_BASE,\"/path/to/voicebank_demand\"}",
        "base_audio_path":
            "${env:VOICEBANK_DEMAND_BASE,\"/path/to/voicebank_demand\"}",
        "content_col":
            "InputPath",
        "audio_col":
            "WavPath",
    }],
    "ttm": [{
        "content": "data/music_caps/caption.jsonl",
        "audio": "data/music_caps/audio.jsonl",
        "content_col": "caption",
        "audio_col": "audio",
    }],
    "sr": [
        {
            "content":
                "${env:ESC_SR_CONTENT,\"/path/to/esc_sr/test/content.jsonl\"}",
            "base_content_path":
                "${env:ESC_SR_BASE,\"/path/to/esc_sr/test\"}",
            "audio":
                "${env:ESC_SR_AUDIO,\"/path/to/esc_sr/test/audio.jsonl\"}",
            "base_audio_path":
                "${env:ESC_SR_BASE,\"/path/to/esc_sr/test\"}",
            "content_col":
                "caption",
            "audio_col":
                "audio",
        },
        {
            "content":
                "${env:VCTK_SR_CONTENT,\"/path/to/vctk_sr/test/content.jsonl\"}",
            "base_content_path":
                "${env:VCTK_SR_BASE,\"/path/to/vctk_sr\"}",
            "audio":
                "${env:VCTK_SR_AUDIO,\"/path/to/vctk_sr/test/audio.jsonl\"}",
            "base_audio_path":
                "${env:VCTK_SR_BASE,\"/path/to/vctk_sr\"}",
            "content_col":
                "caption",
            "audio_col":
                "audio",
        },
        # musdb
        {
            "content":
                "${env:MUSDB_SR_CONTENT,\"/path/to/musdb_sr/test/content.jsonl\"}",
            "base_content_path":
                "${env:MUSDB_SR_BASE,\"/path/to/musdb_sr\"}",
            "audio":
                "${env:MUSDB_SR_AUDIO,\"/path/to/musdb_sr/test/audio.jsonl\"}",
            "base_audio_path":
                "${env:MUSDB_SR_BASE,\"/path/to/musdb_sr\"}",
            "content_col":
                "caption",
            "audio_col":
                "audio",
        },
    ],
    "v2a": [{
        "content": "data/visual_sound/clip/test/audio.jsonl",
        "audio": "data/visual_sound/clip/test/audio.jsonl",
        "content_col": "audio",
        "audio_col": "audio",
    }]
}


def main(
    infer_dir: str,
    output_yaml: str,
    output_audio_dir: str,
    sample_num: int = 10
):
    task_to_samples = {}
    for task, data_list in gt_data_dict.items():
        task_data_list = []
        for data in data_list:
            base_content_path = None
            if "base_content_path" in data:
                base_content_path = data["base_content_path"]
            if "base_audio_path" in data:
                base_audio_path = data["base_audio_path"]
            if task == "se":
                aid_to_content = read_jsonl_to_mapping(
                    data["content"], "UUID", data["content_col"]
                )
                aid_to_audio = read_jsonl_to_mapping(
                    data["audio"], "UUID", data["audio_col"]
                )
            else:
                aid_to_content = read_jsonl_to_mapping(
                    data["content"], "audio_id", data["content_col"]
                )
                aid_to_audio = read_jsonl_to_mapping(
                    data["audio"], "audio_id", data["audio_col"]
                )
            for aid in aid_to_content.keys():
                if base_content_path:
                    aid_to_content[aid] = str(
                        Path(base_content_path) / aid_to_content[aid]
                    )
                if base_audio_path:
                    aid_to_audio[aid] = str(
                        Path(base_audio_path) / aid_to_audio[aid]
                    )

                task_data_list.append({
                    "audio_id": aid,
                    "content": aid_to_content[aid],
                    "audio": aid_to_audio[aid],
                })
        task_to_samples[task] = task_data_list

    with open("data/libritts/test/ref_transcription.json", "r") as f:
        ref_transcriptions = json.load(f)

    generated_aid_sample_mapping = {}
    for task in task_to_samples.keys():
        for file in (Path(infer_dir) / task_name_mapping[task]).glob("*.wav"):
            aid = transform_gen_fn_to_id(file, task)
            generated_aid_sample_mapping[aid] = file

    output_data = []
    (Path(output_audio_dir) /
     "uniflow_audio").mkdir(parents=True, exist_ok=True)
    (Path(output_audio_dir) / "gt").mkdir(parents=True, exist_ok=True)
    (Path(output_audio_dir) / "video").mkdir(parents=True, exist_ok=True)
    for task, samples in task_to_samples.items():
        random.shuffle(samples)
        samples = samples[:sample_num]

        for sample in samples:
            if task == "tts":
                content = ref_transcriptions[sample["audio_id"].lstrip("0")]
            elif task == "svs":
                with open(sample["content"], "rb") as f:
                    content = pickle.load(f)
                    content = content[sample["audio_id"]]["text"]
            else:
                content = sample["content"]

            if task in ["se", "sr"]:
                shutil.copy(
                    content,
                    Path(output_audio_dir) /
                    f"gt/{sample['audio_id']}_input.wav"
                )
                content = f"gt/{sample['audio_id']}_input.wav"
            elif task == "v2a":
                shutil.copy(
                    content,
                    Path(output_audio_dir) / f"video/{sample['audio_id']}"
                )
                content = f"video/{sample['audio_id']}"

            shutil.copy(
                generated_aid_sample_mapping[sample["audio_id"]],
                Path(output_audio_dir) /
                f"uniflow_audio/{sample['audio_id']}.wav"
            )
            systems = [{
                "system_id": "A",
                "system_name": "UniFlow-Audio",
                "audio_path": f"uniflow_audio/{sample['audio_id']}.wav",
            }]

            shutil.copy(
                sample["audio"],
                Path(output_audio_dir) / f"gt/{sample['audio_id']}.wav"
            )
            output_data.append({
                "sample_id": sample["audio_id"],
                "task_type": task,
                "metrics": task_to_metrics[task],
                "systems": systems,
                "gt_audio_path": f"gt/{sample['audio_id']}.wav",
                "prompt": content,
                "show_gt_audio": True,
                "show_prompt": True,
                "show_gt_audio_mel": True if task == "sr" else False,
                "show_prompt_mel": True if task == "sr" else False,
            })

    Path(output_yaml).parent.mkdir(parents=True, exist_ok=True)
    with open(output_yaml, "w") as f:
        yaml.dump({"samples": output_data}, f, allow_unicode=True)


if __name__ == "__main__":
    fire.Fire(main)
