import os
import json
import argparse
import torch
import laion_clap
import numpy as np
import multiprocessing
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(
        description="Labelling clap score for crpo dataset"
    )
    parser.add_argument(
        "--num_samples", type=int, default=5,
        help="Number of audio samples per prompt"
    )
    parser.add_argument(
        "--json_path", type=str, required=True,
        help="Path to input JSON file"
    )
    parser.add_argument(
        "--output_dir", type=str, required=True,
        help="Directory to save the final JSON with CLAP scores"
    )
    return parser.parse_args()

@torch.no_grad()
def compute_clap(model, audio_files, text_data):
    # Compute audio and text embeddings, then compute the dot product (CLAP score)
    audio_embed = model.get_audio_embedding_from_filelist(x=audio_files, use_tensor=True)
    text_embed = model.get_text_embedding(text_data, use_tensor=True)
    return audio_embed @ text_embed.T

def process_chunk(args, chunk, gpu_id, return_dict, process_id):
    """
    Process a chunk of the data on a specific GPU.
    Loads the CLAP model on the designated device, then for each item in the chunk,
    computes the CLAP scores and attaches them to the data.
    """
    try:
        device = f"cuda:{gpu_id}"
        torch.cuda.set_device(device)
        print(f"Process {process_id}: Using device {device}")

        # Initialize the CLAP model on this GPU
        model = laion_clap.CLAP_Module(enable_fusion=False)
        model.to(device)
        model.load_ckpt()
        model.eval()

        for j, item in enumerate(tqdm(chunk, desc=f"GPU {gpu_id}")):
            # Each item is assumed to be a list of samples.
            # Skip if already computed.
            if 'clap_score' in item[0]:
                continue

            # Collect audio file paths and text data (using the first caption)
            audio_files = [item[i]['path'] for i in range(args.num_samples)]
            text_data = [item[0]['captions']]

            try:
                clap_scores = compute_clap(model, audio_files, text_data)
            except Exception as e:
                print(f"Error processing item index {j} on GPU {gpu_id}: {e}")
                continue

            # Attach the computed score to each sample in the item
            for k in range(args.num_samples):
                item[k]['clap_score'] = np.round(clap_scores[k].item(), 3)

        return_dict[process_id] = chunk
        print(f"Process {process_id}: Completed processing on GPU {gpu_id}")
    except Exception as e:
        print(f"Process {process_id}: Error on GPU {gpu_id}: {e}")
        return_dict[process_id] = []

def split_into_chunks(data, num_chunks):
    """
    Splits data into num_chunks approximately equal parts.
    """
    avg = len(data) // num_chunks
    chunks = []
    for i in range(num_chunks):
        start = i * avg
        # Ensure the last chunk takes the remainder of the data
        end = (i + 1) * avg if i != num_chunks - 1 else len(data)
        chunks.append(data[start:end])
    return chunks

def main():
    args = parse_args()

    # Load data from JSON and slice by start/end if provided
    with open(args.json_path, 'r') as f:
        data = json.load(f)

    # Check GPU availability and split data accordingly
    num_gpus = torch.cuda.device_count()

    print(f"Found {num_gpus} GPUs. Splitting data into {num_gpus} chunks.")
    chunks = split_into_chunks(data, num_gpus)

    # Prepare output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Create a manager dict to collect results from all processes
    manager = multiprocessing.Manager()
    return_dict = manager.dict()
    processes = []

    for i in range(num_gpus):
        p = multiprocessing.Process(
            target=process_chunk,
            args=(args, chunks[i], i, return_dict, i)
        )
        processes.append(p)
        p.start()
        print(f"Started process {i} on GPU {i}")

    for p in processes:
        p.join()
        print(f"Process {p.pid} has finished.")

    # Aggregate all chunks back into a single list
    combined_data = []
    for i in range(num_gpus):
        combined_data.extend(return_dict[i])

    # Save the combined results to a single JSON file
    output_file =  f"{args.output_dir}/clap_scores.json"
    with open(output_file, 'w') as f:
        json.dump(combined_data, f)
    print(f"All CLAP scores have been computed and saved to {output_file}")

    max_item = [max(x, key=lambda item: item['clap_score']) for x in combined_data]
    min_item = [min(x, key=lambda item: item['clap_score']) for x in combined_data]

    crpo_dataset = []
    for chosen,reject in zip(max_item,min_item):
        crpo_dataset.append({"captions": chosen['captions'], 
        "duration": chosen['duration'], 
        "chosen": chosen['path'], 
        "reject": reject['path']})
        
    with open(f"{args.output_dir}/train.json",'w') as f:
        json.dump(crpo_dataset,f)


if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')
    main()
