import os
import argparse
import json
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
from peft import PeftModel
from voiceldm import VoiceLDMPipeline

def trim_audio_silence(audio_tensor, sample_rate=16000, top_db=35, frame_length=2048, hop_length=512, padding_ms=250):
    if audio_tensor.numel() == 0:
        return audio_tensor
    if audio_tensor.dim() == 1:
        audio_tensor = audio_tensor.unsqueeze(0)
    
    max_abs_val = torch.max(torch.abs(audio_tensor))
    if max_abs_val == 0:
        return audio_tensor

    threshold = max_abs_val * (10 ** (-top_db / 20))
    mono_signal = torch.mean(audio_tensor, dim=0)
    frames = mono_signal.unfold(0, frame_length, hop_length)
    rms = torch.sqrt(torch.mean(frames**2, dim=1))
    active_frames = torch.where(rms > threshold)[0]

    if len(active_frames) == 0:
        return torch.zeros_like(audio_tensor[:, :1])

    padding_samples = int(sample_rate * padding_ms / 1000)
    start_sample = max(0, active_frames[0].item() * hop_length - padding_samples)
    end_sample = min(audio_tensor.shape[1], (active_frames[-1].item() * hop_length) + frame_length + padding_samples)

    if start_sample >= end_sample:
        return torch.zeros_like(audio_tensor[:, :1])
    return audio_tensor[:, start_sample:end_sample]

def load_and_merge_model(args, device):
    pipe = VoiceLDMPipeline(model_config=args.model_config, ckpt_path=args.ckpt_path, device=device)
    base_unet = pipe.model.unet
    
    peft_unet = PeftModel.from_pretrained(base_unet, args.lora_paths[0], adapter_name="lora_0")
    for i, lora_path in enumerate(args.lora_paths[1:], start=1):
        peft_unet.load_adapter(lora_path, adapter_name=f"lora_{i}")

    adapter_names = [f"lora_{i}" for i in range(len(args.lora_paths))]
    peft_kwargs = {"combination_type": args.combination_type, "density": args.density}
    
    peft_unet.add_weighted_adapter(adapters=adapter_names, weights=args.lora_weights, adapter_name="merged_lora", **peft_kwargs)
    peft_unet.set_adapter("merged_lora")
    
    pipe.model.unet = peft_unet.merge_and_unload()
    return pipe

def process_generation_task(pipe, task_info, args):
    desc_prompt = task_info.get('description')
    cont_prompt = task_info.get('prompt')
    audio_prompt = task_info.get('audio_prompt')
    
    if not (cont_prompt or audio_prompt):
        return

    base_filename = Path(task_info.get('file_name', f"gen_{task_info.get('id', 'unnamed')}"))
    
    for i in range(args.num_generations):
        suffix = f"_{i}" if args.num_generations > 1 else ""
        save_path = Path(args.output_dir) / f"{base_filename.stem}{suffix}.wav"
        
        if save_path.exists():
            continue

        generation_args = {
            'desc_prompt': desc_prompt,
            'cont_prompt': cont_prompt,
            'audio_prompt': audio_prompt,
            'num_inference_steps': args.num_inference_steps,
            'audio_length_in_s': args.audio_length_in_s,
            'guidance_scale': args.guidance_scale,
            'desc_guidance_scale': args.desc_guidance_scale,
            'cont_guidance_scale': args.cont_guidance_scale,
            'seed': args.seed + i if args.seed is not None else None
        }
        
        try:
            with torch.no_grad():
                audio_out = pipe(**generation_args)
            
            if args.trim_silence:
                audio_out = trim_audio_silence(audio_out)
            
            torchaudio.save(save_path, src=audio_out.cpu(), sample_rate=16000)
            
        except Exception as e:
            print(f"\nERROR generating file {save_path.name}: {e}")
        finally:
            if pipe.device.type == 'cuda':
                torch.cuda.empty_cache()

def parse_args():
    parser = argparse.ArgumentParser(description="Generate audio with merged LoRA adapters using VoiceLDM.")
    
    # --- I/O Arguments ---
    parser.add_argument('--batch_file', type=str, help="Path to a JSON file for batch processing.")
    parser.add_argument("--output_dir", type=str, default="./outputs_merged", help="Directory to save the generated audio.")
    parser.add_argument("--file_name", type=str, help="Filename for the generated audio in single mode.")

    # --- Model & LoRA Arguments ---
    parser.add_argument("--ckpt_path", type=str, required=True, help="Path to the base VoiceLDM model checkpoint.")
    parser.add_argument("--lora_paths", type=str, nargs='+', required=True, help="List of paths to LoRA adapter directories.")
    parser.add_argument("--lora_weights", type=float, nargs='+', required=True, help="List of weights for each lora_path.")
    parser.add_argument("--combination_type", type=str, default="linear", choices=["linear", "svd", "ties", "dare_ties", "dare_linear", "cat"], help="Method for merging LoRA adapters.")
    parser.add_argument("--density", type=float, default=0.5, help="Density parameter for 'ties' and 'dare' methods.")
    parser.add_argument("--model_config", type=str, default="m", choices=['m', 's'], help="VoiceLDM model configuration size ('m' or 's').")
    
    # --- Prompt Arguments ---
    parser.add_argument('--desc_prompt', '-d', type=str, help="Descriptive prompt for style.")
    parser.add_argument('--cont_prompt', '-c', type=str, help="Content prompt (the transcription to be spoken).")
    parser.add_argument('--audio_prompt', '-a', type=str, help="Path to an audio file for style prompt.")
    
    # --- Generation Control Arguments ---
    parser.add_argument('--num_generations', type=int, default=1, help="Number of samples to generate per prompt.")
    parser.add_argument('--num_inference_steps', type=int, default=250, help="Number of DDIM inference steps.")
    parser.add_argument('--audio_length_in_s', type=float, default=10.0, help="Target audio length in seconds.")
    parser.add_argument('--guidance_scale', type=float, help="Guidance for single (audio) prompt.")
    parser.add_argument('--desc_guidance_scale', type=float, default=5.0, help="Guidance for the description prompt.")
    parser.add_argument('--cont_guidance_scale', type=float, default=7.0, help="Guidance for the content prompt.")
    parser.add_argument('--seed', type=int, help="Random seed for reproducible generation.")
    
    # --- Post-processing Arguments ---
    parser.add_argument('--trim_silence', action='store_true', help="If set, trim silence from the audio.")

    args = parser.parse_args()
    if len(args.lora_paths) != len(args.lora_weights):
        parser.error("The number of --lora_paths must equal the number of --lora_weights.")
    return args

def main():
    args = parse_args()
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    try:
        pipe = load_and_merge_model(args, device)
    except Exception as e:
        print(f"Failed to load and merge model: {e}")
        return

    tasks = []
    if args.batch_file:
        try:
            with open(args.batch_file, 'r', encoding='utf-8') as f:
                tasks = json.load(f)
        except Exception as e:
            print(f"Error reading batch file '{args.batch_file}': {e}")
            return
    else:
        tasks.append(vars(args))

    if not tasks:
        print("No tasks to process.")
        return

    for task in tqdm(tasks, desc="Generating Audio"):
        process_generation_task(pipe, task, args)
        
    print(f"\nGeneration complete. Outputs are in '{args.output_dir}'.")

if __name__ == "__main__":
    main()