import argparse
import glob
import json
import os
import time
import uuid
from pathlib import Path

import numpy as np
import torchaudio as ta
import multiprocessing as mp

from dataset import TestDataset

def main():
    p = argparse.ArgumentParser("Generate manual audio edits")
    p.add_argument("--input-path", required=True)
    p.add_argument("--output-path", required=True)
    p.add_argument("--csv-path", required=True)
    p.add_argument("--prompt-path", required=True)
    p.add_argument("--num-workers", type=int, default=mp.cpu_count())
    p.add_argument("--max-gens", type=int, default=30_000)
    args = p.parse_args()

    os.makedirs(args.output_path, exist_ok=True)

    base, remainder = divmod(args.max_gens, args.num_workers)
    chunks = [base] * args.num_workers
    chunks[0] += remainder

    procs = [
        mp.Process(
            target=generate_edits,
            args=(
                args.input_path,
                args.csv_path,
                args.prompt_path,
                args.output_path,
                wid,
                num_gens
            ),
        )
        for wid, num_gens in enumerate(chunks)
    ]
    for pr in procs:
        pr.start()
    for pr in procs:
        pr.join()

    merge_json_files(args.output_path)


def generate_edits(input_path: str, csv_path: str, prompt_path: str, output_path: str, worker_id: int, num_gens: int):
    print(f"Worker {worker_id}: generating {num_gens} edits...")
    dataset = TestDataset(data_dir=Path(input_path), csv_path=Path(csv_path), prompt_path=Path(prompt_path), channels=2, use_o3=True)

    metadata_path = os.path.join(output_path, f"samples_worker_{worker_id}.json")
    metadata = {}
    i = 0

    times = []

    while i < num_gens:
        start_time = time.time()
        try:
            input_audio, output_audio, command, audio_indices, desc_list = dataset[i]
        except RuntimeError as e:
            print(f"Error at index {i}")
            continue

        print(f"Audio Indices, Command, Shapes: {audio_indices}, {command}, ({input_audio.shape} -> {output_audio.shape})")

        audio_id = uuid.uuid4().hex + "_ids_" + "_".join([str(x) for x in audio_indices])
        audio_output_path = os.path.join(output_path, audio_id)

        metadata[audio_id] = {
            "instruction": command,
            "audio_indices": [int(x) for x in audio_indices],
            "audio_descriptions": desc_list
        }

        with open(metadata_path, "w") as json_file:
            json.dump(metadata, json_file, indent=4)

        os.makedirs(audio_output_path, exist_ok=True)
        ta.save(uri=os.path.join(audio_output_path, "input.wav"), src=input_audio, sample_rate=44100)
        ta.save(uri=os.path.join(audio_output_path, "output.wav"), src=output_audio, sample_rate=44100)

        print(f"Metadata entry for {audio_id} saved with lock.")
        i += 1
        times.append(time.time() - start_time)
        print("Average time for one sample", np.mean(times))


def merge_json_files(output_path: str):
    merged = {}
    for fp in glob.glob(os.path.join(output_path, "samples_worker_*.json")):
        with open(fp, "r") as f:
            merged.update(json.load(f))
        os.remove(fp)
    with open(os.path.join(output_path, "samples.json"), "w") as f:
        json.dump(merged, f, indent=4)
    print(f"Merged {len(merged)} entries into samples.json")


if __name__ == "__main__":
    main()