import json
import os
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch
from datasets import load_from_disk
from tqdm import tqdm
import argparse
from wer import compute_wer


def calculate_average_cos_sim(sample_pair, models, batch_size=16):
    """
    Calculate average cosine similarity between ASR and original text samples using three models.

    Args:
        sample_pair (list of tuples): List of (ASR text, original text) pairs.
        models (list of SentenceTransformer): List of preloaded embedding models.
        batch_size (int): Number of samples per batch.

    Returns:
        list: A list of dictionaries with average similarity score and best sample pair for each model.
    """
    results = []

    for i in tqdm(range(0, len(sample_pair), batch_size), desc="Calculating Cosine Similarities"):
        batch_pairs = sample_pair[i:i + batch_size]
        asr_texts, original_texts = zip(*batch_pairs)

        batch_results = []

        for model in models:
            # Encode batches
            asr_embeddings = model.encode(asr_texts, convert_to_tensor=True, device=device)
            original_embeddings = model.encode(original_texts, convert_to_tensor=True, device=device)

            # Calculate cosine similarities for the batch
            batch_similarities = cos_sim(asr_embeddings, original_embeddings).diagonal().tolist()

            batch_results.append(batch_similarities)

        # Average similarities across all models
        average_similarities = [sum(scores) / len(models) for scores in zip(*batch_results)]
        best_pairs = [
            avg_sim
            for idx, avg_sim in enumerate(average_similarities)
        ]

        results.extend(best_pairs)

    return results


def calculate_vote_sim_rate(cos_similarities, threshold=0.9):
    """
    Calculate the vote similarity rate based on cosine similarities with a given threshold.
    """
    vote_sim_rate = sum(1 for sim in cos_similarities if sim >= threshold) / len(cos_similarities)
    return vote_sim_rate


if __name__ == '__main__':
    data_name = ["drop", 'narrativeqa', 'quoref', 'ropes', 'squad1.1', 'squad2.0', 'tatqa'][0]
    asr_model = ["whisper", "parakeet",'canary'][2]
    rewrite_llm = [
                "Phi-3-small-8k-instruct",
                "Meta-Llama-3-8B-Instruct",
                "Qwen2-7B-Instruct",
                "train_set",
                "stage_2"
                ]
    embedding_model = ["gte-large-en-v1.5", "mxbai-embed-large-v1", "stella_en_400M_v5"]
    embedding_model_name=[m.split("-")[0].split("_")[0]  for m in embedding_model]
    print(embedding_model_name)
    original_dataset_path = f"../local_dataset/{data_name}/train_set"
    asr_result_file_folder = f"../data/original_check/{data_name}"
    save_path=f"../data/sim_result/{data_name}"
    print(rewrite_llm)
    print(f"<{data_name}>")
    print(embedding_model)
    target_file_path = [f"{data_name}_check_{m}_{asr_model}.jsonl" for m in rewrite_llm]
    target_file_path=[f for f in target_file_path if os.path.isfile(f"{asr_result_file_folder}/{f}")]
    device='cuda'
    models = [SentenceTransformer("../local_model/"+path, trust_remote_code=True).to(device).eval().requires_grad_(False) for path in embedding_model]
    original_dataset = load_from_disk(original_dataset_path)
    
    for idx,model in enumerate(models):
        for f in target_file_path:
            # open model
            asr_dataset=[]
            with open(os.path.join(asr_result_file_folder, f), 'r', encoding='utf-8') as file:
                for d in tqdm(file):
                    asr_dataset += [json.loads(d)]
            asr_sample_list = [d['text'] for d in asr_dataset]
            asr_tag_list = [d['file'].split("/")[-1] for d in asr_dataset]
            original_tag_list = [f"{d['id']}_{d['speaker_index']}.wav" for d in original_dataset]
            original_sample_list = [d['query'] for d in original_dataset]
            dataset = [{"text": asr_sample_list[asr_tag_list.index(tag)], "original_text": original_sample_list[idx],
            "file_name": tag, "cos_sim": None} for idx, tag in tqdm(enumerate(original_tag_list)) if tag in asr_tag_list]
            
            one_sample_pair = [[d['text'], d['original_text']] for d in dataset]
            results = calculate_average_cos_sim(one_sample_pair, [model])
            for j,d in enumerate(dataset):
                d['con_sim']=results[j]
            with open(f"{save_path}/{f.replace('.jsonl','_'+embedding_model_name[idx]+'.jsonl')}",'w') as file:
                json.dump(dataset,file)