#!/usr/bin/env python3
# build_assets.py

import os, csv, json, ast, re, warnings
from pathlib import Path
import sys

from graphviz import Digraph
import cairosvg
from tqdm import tqdm
import argparse

import torchaudio
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2         # CosyVoice-2
from cosyvoice.utils.file_utils import load_wav

p = argparse.ArgumentParser(description="Extract per-layer attention vectors for fact parts and save as NPZ")
p.add_argument("--type", choices=["independent", "alternative","complementary","contradictory", "equivalent","entailment"], default="independent")
args = p.parse_args()

TYPE = args.type
# ────────────────────────── paths ────────────────────────── #
BASE_DIR   = "/path/to/dataset"
SRC_CSV    = os.path.join(BASE_DIR, f"reasoning_meta/reasoning_{TYPE}_dataset.csv")

ASSET_DIR = os.path.join(BASE_DIR, f"assets/{TYPE}")
IMG_DIR    = os.path.join(ASSET_DIR, "images")
WAV_DIR    = os.path.join(ASSET_DIR, "speech")
OUT_CSV    = os.path.join(BASE_DIR, f"assets/multimodal_datasets_{TYPE}.csv")

Path(IMG_DIR).mkdir(parents=True, exist_ok=True)
Path(WAV_DIR).mkdir(parents=True, exist_ok=True)

# ────────────────────────── load TTS ─────────────────────── #
print("Loading CosyVoice-2…")
cosyvoice = CosyVoice2("/path/to/CosyVoice2-0.5B", load_jit=False, load_trt=False, load_vllm=False, fp16=False)
prompt_speech_16k = load_wav('/path/to/zero_shot_prompt.wav', 16000)

print("✓ TTS model loaded")

# ─────────────────── utility helpers ─────────────────────── #
def strip_punct(text: str) -> str:
    return re.sub(r'[.?!,;:]*$', '', text).strip()

def svg2png(svg_path, png_path):
    cairosvg.svg2png(url=svg_path, write_to=png_path)

def draw_theory_graph(idx: int, triplets, img_dir: str, modality_id: int) -> str:
    dot = Digraph(comment=f"Facts {idx}")
    for e in triplets:
        subj, obj = e["subject"], e["object"]
        dot.node(subj, subj)
        dot.node(obj,  obj)
        dot.edge(subj, obj, label=e["relation"])
    base = os.path.join(img_dir, f"fact_{idx}_modality_{modality_id}")
    dot.render(base, format="svg", view=False, cleanup=True)
    png = base + ".png"
    svg2png(base + ".svg", png)
    os.remove(base + ".svg")
    return png

def write_cosyvoice_wav(idx: int, text: str, wav_dir: str, modality_id: int) -> str:
    wav_path = os.path.join(wav_dir, f"fact_{idx}_modality_{modality_id}.wav")
    for chunk in cosyvoice.inference_zero_shot(text, '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False):
        torchaudio.save(wav_path, chunk['tts_speech'], cosyvoice.sample_rate)
        break
    return wav_path

# ─────────────────────────── main ────────────────────────── #
def build_multimodal():
    recs = []

    with open(SRC_CSV, newline="", encoding="utf-8") as f:
        rows = list(csv.DictReader(f))

    for row in tqdm(rows, desc="building theory assets"):
        sg_id         = row["id"]
        try:
            triplets_1 = json.loads(row["modality1_triplet"])
            triplets_2 = json.loads(row["modality2_triplet"])
            triplets_3 = json.loads(row["modality3_triplet"])
        except Exception as e:
            warnings.warn(f"Bad triplet JSON for {sg_id}: {e}")
            continue
        fact_text_1   = strip_punct(row["modality1_text"])
        fact_text_2   = strip_punct(row["modality2_text"])
        fact_text_3   = strip_punct(row["modality3_text"])

        # 1)  image
        img_path_1 = draw_theory_graph(sg_id, triplets_1, IMG_DIR, 1)
        img_path_2 = draw_theory_graph(sg_id, triplets_2, IMG_DIR, 2)
        img_path_3 = draw_theory_graph(sg_id, triplets_3, IMG_DIR, 3)

        # 2)  audio
        wav_path_1 = write_cosyvoice_wav(sg_id, fact_text_1, WAV_DIR, 1)
        wav_path_2 = write_cosyvoice_wav(sg_id, fact_text_2, WAV_DIR, 2)
        wav_path_3 = write_cosyvoice_wav(sg_id, fact_text_3, WAV_DIR, 3)

        recs.append({
            "subgraph_id": sg_id,
            "modality1_img":  img_path_1,
            "modality2_img":  img_path_2,
            "modality3_img":  img_path_3,
            "modality1_wav":  wav_path_1,
            "modality2_wav":  wav_path_2,
            "modality3_wav":  wav_path_3,
            "modality1_txt":  fact_text_1,
            "modality2_txt":  fact_text_2,
            "modality3_txt":  fact_text_3,
        })

    fieldnames = list(recs[0].keys())
    with open(OUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(recs)

    print(f"✓ multimodal THEORY dataset saved →  {OUT_CSV}")

# -------------------------------------------------------------------- #
if __name__ == "__main__":
    build_multimodal()
