"""Batch AudioBox music aesthetics evaluation script"""
import argparse
import sys
import time
import subprocess
import json
import tempfile
import os
from typing import Optional, List
from pathlib import Path
from tqdm import tqdm
from loguru import logger

sys.path.insert(0, str(Path(__file__).parent))
from utils import (collect_audio_files, create_batch_results_template, 
                   finalize_batch_results, print_batch_summary,
                   save_batch_results, get_video_id, load_batch_results)


def prepare_jsonl_input(audio_files: list) -> tuple[str, list[str]]:
    """Prepare JSONL format input file"""
    jsonl_content = []
    filenames = []
    for audio_file in audio_files:
        audio_path = Path(audio_file)
        filenames.append(audio_path.name)
        entry = {
            "path": str(audio_path.absolute()),
            "metadata": {
                "filename": audio_path.name
            }
        }
        jsonl_content.append(json.dumps(entry, ensure_ascii=False))
    
    temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False, encoding='utf-8')
    for line in jsonl_content:
        temp_file.write(line + '\n')
    temp_file.close()
    
    return temp_file.name, filenames


def run_audiobox_evaluation(jsonl_file: str, file_order: List[str], checkpoint_path: Optional[str] = None, 
                           model_batch_size: int = 10) -> dict:
    """Call audiobox for music aesthetics evaluation"""
    try:
        cmd = ["audio-aes", str(jsonl_file), "--batch-size", str(model_batch_size)]
        if checkpoint_path:
            ckpt = Path(checkpoint_path).expanduser().resolve()
            if ckpt.exists():
                cmd.extend(["--ckpt", str(ckpt)])
            else:
                logger.warning(f"Checkpoint not found: {ckpt}")
        
        env = os.environ.copy()
        env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
        
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
        
        if result.returncode != 0:
            logger.warning(f"audio-aes failed with return code {result.returncode}")
            logger.debug(f"stderr: {result.stderr}")
            return {}
        
        # 解析JSONL输出
        parsed_scores = []
        lines = result.stdout.strip().split('\n')
        for line in lines:
            if line.strip():
                try:
                    entry = json.loads(line)
                    parsed_scores.append(entry)
                except (json.JSONDecodeError, TypeError) as e:
                    logger.debug(f"Error parsing line: {e}")

        if not parsed_scores:
            logger.warning("audio-aes stdout is empty or parsing failed")
            return {}

        if len(parsed_scores) != len(file_order):
            logger.warning(f"Output entries ({len(parsed_scores)}) != input audio count ({len(file_order)}), results will be aligned by order")

        scores_dict = {}
        for audio_name, score_entry in zip(file_order, parsed_scores):
            scores_dict[audio_name] = score_entry
        
        return scores_dict
    
    except FileNotFoundError:
        logger.error("audio-aes command not found. Make sure audiobox-aesthetics is installed.")
        return {}
    except subprocess.TimeoutExpired:
        logger.error("audiobox evaluation timeout")
        return {}
    except Exception as e:
        logger.error(f"Error running audiobox: {e}")
        return {}


def process_audio_file(audio_path: str, scores_dict: dict) -> dict:
    """Process single audio file"""
    audio_path = Path(audio_path)
    video_id = get_video_id(str(audio_path))
    
    audio_name = audio_path.name
    score_entry = scores_dict.get(audio_name, None)
    
    if score_entry is None:
        return {
            "audio_id": video_id,
            "audio_path": str(audio_path),
            "success": False,
            "error": "Score not found in audiobox output",
            "scores": {},
            "processing_time": 0
        }

    primary_score = (
        score_entry.get("PQ")
        or score_entry.get("primary")
        or score_entry.get("aesthetics_score")
        or score_entry.get("score")
    )

    formatted_scores = {}
    for key, value in score_entry.items():
        if isinstance(value, (int, float)):
            formatted_scores[key] = float(value)

    if primary_score is not None:
        formatted_scores.setdefault("primary", float(primary_score))
        formatted_scores.setdefault("aesthetics_score", float(primary_score))
    
    return {
        "audio_id": video_id,
        "audio_path": str(audio_path),
        "success": True,
        "error": None,
        "scores": formatted_scores,
        "processing_time": 0
    }


def main():
    parser = argparse.ArgumentParser(description="Batch AudioBox music aesthetics evaluation")
    parser.add_argument("--input", type=str, required=True,
                       help="Input audio file or folder path")
    parser.add_argument("--output", type=str, default="audiobox_results.json",
                       help="Output JSON file path")
    parser.add_argument("--chunk-size", type=int, default=100,
                       help="Number of audio files per batch (default 100, can be lowered to avoid OOM)")
    parser.add_argument("--model-batch-size", type=int, default=10,
                       help="Model internal batch size (default 10, can be lowered to avoid OOM)")
    default_ckpt = Path("models/facebook_audiobox_aesthetics/checkpoint.pt")
    parser.add_argument("--checkpoint", type=str, default=str(default_ckpt),
                       help="Audiobox model checkpoint path")
    
    args = parser.parse_args()
    
    logger.info(f"Collecting audio files from: {args.input}")
    audio_files = collect_audio_files(args.input)
    
    if not audio_files:
        logger.error("No audio files found!")
        return
    
    logger.info(f"Found {len(audio_files)} audio file(s)")
    logger.info(f"Processing in batches, {args.chunk_size} files per batch, model batch size: {args.model_batch_size}")
    
    output_path = Path(args.output)
    total_audio_files = len(audio_files)
    if output_path.exists():
        logger.info(f"Detected existing result file {args.output}, will load and continue processing")
        batch_results = load_batch_results(args.output)
        if not batch_results:
            batch_results = create_batch_results_template(
                metric_name="AudioBox",
                metric_type="audio_quality",
                total_videos=total_audio_files,
                device="cpu",
                parameters={}
            )
        processed_paths = {r["audio_path"] for r in batch_results.get("results", [])}
        audio_files = [f for f in audio_files if f not in processed_paths]
        batch_results["config"]["batch_size"] = len(processed_paths) + len(audio_files)
        batch_results["summary"]["total"] = len(processed_paths) + len(audio_files)
        logger.info(f"Processed {len(processed_paths)} files, {len(audio_files)} files remaining")
    else:
        batch_results = create_batch_results_template(
            metric_name="AudioBox",
            metric_type="audio_quality",
            total_videos=total_audio_files,
            device="cpu",
            parameters={}
        )
    
    if not audio_files:
        logger.info("All files have been processed")
        finalize_batch_results(batch_results)
        save_batch_results(batch_results, args.output)
        print_batch_summary(batch_results)
        return
    
    total_chunks = (len(audio_files) + args.chunk_size - 1) // args.chunk_size
    logger.info(f"Total {total_chunks} batches to process")
    
    for chunk_idx in range(total_chunks):
        start_idx = chunk_idx * args.chunk_size
        end_idx = min(start_idx + args.chunk_size, len(audio_files))
        chunk_files = audio_files[start_idx:end_idx]
        
        logger.info(f"\nProcessing batch {chunk_idx + 1}/{total_chunks} (files {start_idx + 1}-{end_idx}/{len(audio_files)})")
        
        jsonl_file, audio_filenames = prepare_jsonl_input(chunk_files)
        
        try:
            scores_dict = run_audiobox_evaluation(
                jsonl_file, 
                audio_filenames, 
                args.checkpoint,
                model_batch_size=args.model_batch_size
            )
            
            if not scores_dict:
                logger.warning(f"Batch {chunk_idx + 1} AudioBox returned no scores")
            
            for audio_path in chunk_files:
                result = process_audio_file(audio_path, scores_dict)
                batch_results["results"].append(result)
            
            save_batch_results(batch_results, args.output)
            logger.info(f"Batch {chunk_idx + 1} completed, results saved")
            
        except Exception as e:
            logger.error(f"Batch {chunk_idx + 1} processing failed: {e}")
            for audio_path in chunk_files:
                result = {
                    "audio_id": get_video_id(audio_path),
                    "audio_path": audio_path,
                    "success": False,
                    "error": str(e),
                    "scores": {},
                    "processing_time": 0
                }
                batch_results["results"].append(result)
            save_batch_results(batch_results, args.output)
        finally:
            try:
                Path(jsonl_file).unlink()
            except:
                pass
            
            import time
            time.sleep(1)
    
    finalize_batch_results(batch_results)
    
    save_batch_results(batch_results, args.output)
    
    print_batch_summary(batch_results)


if __name__ == "__main__":
    main()



