import os
import argparse
import json
from pathlib import Path
from typing import List, Dict, Any, Optional

import torch
import torchaudio
from tqdm import tqdm
from peft import PeftModel
from voiceldm import VoiceLDMPipeline

def trim_audio_silence(audio_tensor: torch.Tensor, sample_rate: int = 16000, top_db: int = 35, frame_length: int = 2048, hop_length: int = 512, padding_ms: int = 250) -> torch.Tensor:
    if audio_tensor.numel() == 0: return audio_tensor
    if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0)
    max_abs_val = torch.max(torch.abs(audio_tensor))
    if max_abs_val == 0: return audio_tensor
    threshold = max_abs_val * (10 ** (-top_db / 20))
    mono_signal = torch.mean(audio_tensor, dim=0)
    frames = mono_signal.unfold(0, frame_length, hop_length)
    rms = torch.sqrt(torch.mean(frames**2, dim=1))
    active_frames = torch.where(rms > threshold)[0]
    if len(active_frames) == 0: return torch.zeros_like(audio_tensor[:, :1])
    start_sample = max(0, (active_frames[0].item() * hop_length) - int(sample_rate * padding_ms / 1000))
    end_sample = min(audio_tensor.shape[1], (active_frames[-1].item() * hop_length) + frame_length + int(sample_rate * padding_ms / 1000))
    return audio_tensor[:, start_sample:end_sample] if start_sample < end_sample else torch.zeros_like(audio_tensor[:, :1])

class AudioGenerator:
    def __init__(self, args: argparse.Namespace, device: torch.device):
        self.args = args
        self.device = device
        self.pipe = self._load_pipeline()

    def _load_pipeline(self) -> VoiceLDMPipeline:
        pipe = VoiceLDMPipeline(self.args.model_config, self.args.ckpt_path, self.device)
        if self.args.lora_path:
            base_unet = pipe.model.unet.to(self.device)
            lora_unet = PeftModel.from_pretrained(base_unet, self.args.lora_path)
            pipe.model.unet = lora_unet.merge_and_unload()
        return pipe

    def generate(self, item_data: Dict[str, Any], gen_index: int = 0) -> Optional[Dict[str, Any]]:
        desc_prompt, cont_prompt, file_id = item_data.get('description'), item_data.get('prompt'), item_data.get('id')
        if not all([desc_prompt, cont_prompt, file_id]): return None

        base_filename = f"gen_{file_id}"
        output_filename = f"{Path(base_filename).stem}_{gen_index}.wav" if self.args.num_generations > 1 else f"{Path(base_filename).stem}.wav"
        save_path = Path(self.args.output_dir) / output_filename
        
        manifest_entry = {**item_data, 'audio_path': str(save_path.absolute())}
        if self.args.add_lora_path_to_manifest and self.args.lora_path: manifest_entry['Lora_model_path'] = self.args.lora_path
        if self.args.add_base_model_path_to_manifest: manifest_entry['Base_model_path'] = self.args.ckpt_path

        if save_path.exists(): return manifest_entry

        try:
            seed = self.args.seed + gen_index if self.args.seed is not None else None
            gen_args = {'desc_prompt': desc_prompt, 'cont_prompt': cont_prompt, 'num_inference_steps': self.args.num_inference_steps, 'audio_length_in_s': self.args.audio_length_in_s, 'guidance_scale': self.args.guidance_scale, 'desc_guidance_scale': self.args.desc_guidance_scale, 'cont_guidance_scale': self.args.cont_guidance_scale, 'seed': seed}

            with torch.no_grad():
                audio_gpu = self.pipe(**gen_args)

            if self.args.trim_silence: audio_gpu = trim_audio_silence(audio_gpu)

            torchaudio.save(save_path, src=audio_gpu.cpu(), sample_rate=16000)
            if torch.cuda.is_available(): torch.cuda.empty_cache()
            return manifest_entry
        except Exception as e:
            print(f"\n[ERROR on {self.device}] Failed to generate {output_filename}: {e}")
            return None

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate audio using VoiceLDM.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--batch_file", type=str, help="Path to a JSON manifest for batch generation.")
    parser.add_argument("--output_dir", type=str, default="./outputs", help="Directory to save generated audio.")
    parser.add_argument('--output_json_manifest', type=str, help='Path to save the output JSON manifest with results.')
    parser.add_argument('--desc_prompt', '-d', type=str, help='Description prompt (single generation).')
    parser.add_argument('--cont_prompt', '-c', type=str, help='Content prompt (single generation).')
    parser.add_argument('--file_name', type=str, help="Filename for the generated audio (single generation).")
    parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the VoiceLDM checkpoint file.")
    parser.add_argument("--lora_path", type=str, help="Path to the trained LoRA adapter directory.")
    parser.add_argument("--model_config", type=str, default="m", choices=['m', 's'], help="Model config: 'm' or 's'.")
    parser.add_argument("--device", type=str, default="auto", help="Device to use ('cuda', 'cpu', 'auto').")
    parser.add_argument('--num_generations', type=int, default=1, help='Number of samples to generate per prompt.')
    parser.add_argument('--num_inference_steps', type=int, default=250, help='Number of DDIM inference steps.')
    parser.add_argument('--audio_length_in_s', type=float, default=10.0, help='Duration of audio to generate in seconds.')
    parser.add_argument('--guidance_scale', type=float, default=5.0, help='Guidance weight for single CFG.')
    parser.add_argument('--desc_guidance_scale', type=float, default=5.0, help='Description guidance weight for dual CFG.')
    parser.add_argument('--cont_guidance_scale', type=float, default=7.0, help='Content guidance weight for dual CFG.')
    parser.add_argument('--seed', type=int, help='Random seed for deterministic generation.')
    parser.add_argument('--trim_silence', action='store_true', help='Trim silence from start/end of audio.')
    parser.add_argument('--add_lora_path_to_manifest', action='store_true', help='Include LoRA path in manifest.')
    parser.add_argument('--add_base_model_path_to_manifest', action='store_true', help='Include base model path in manifest.')
    return parser.parse_args()

def main():
    args = parse_args()
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    all_results = []
    
    device = torch.device("cuda" if args.device != 'cpu' and torch.cuda.is_available() else "cpu")
    generator = AudioGenerator(args, device)
    print(f"Using device: {device}")

    if args.batch_file:
        try:
            with open(args.batch_file, 'r', encoding='utf-8') as f: tasks = json.load(f)
        except Exception as e:
            print(f"Error: Could not read batch file '{args.batch_file}'.\n{e}"); return
        
        for item in tqdm(tasks, desc=f"Generating on {device}"):
            for i in range(args.num_generations):
                if result := generator.generate(item, gen_index=i): all_results.append(result)
    else:
        if not args.desc_prompt or not args.cont_prompt:
            print("Error: --desc_prompt (-d) and --cont_prompt (-c) are required."); return
        
        base_name = args.file_name or f"{args.desc_prompt[:20]}-{args.cont_prompt[:20]}"
        task_item = {'description': args.desc_prompt, 'prompt': args.cont_prompt, 'id': Path(base_name).stem}
        
        for i in tqdm(range(args.num_generations), desc="Generating Sample(s)"):
            if result := generator.generate(task_item, gen_index=i): all_results.append(result)

    if args.output_json_manifest and all_results:
        print(f"Saving manifest to {args.output_json_manifest}...")
        all_results.sort(key=lambda x: x['audio_path'])
        with open(args.output_json_manifest, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, indent=4, ensure_ascii=False)
        print("Manifest saved.")

if __name__ == "__main__":
    main()