import argparse
import dataclasses
import json
import numpy as np
from tqdm import tqdm
from typing import Union

from alignment.analysis import cosine_similarity
from alignment.analysis import word_diversity

def load_alignment_results(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return data

@dataclasses.dataclass
class AlignmentMetrics:
    average_safety: Union[float, None]
    average_cosine_similarity: float
    average_word_diversity: float

def compute_alignment_metrics(data, compute_baseline_metrics, batch_size: int = 32):
    cos_sim = cosine_similarity.CosineSimilarity()

    N = len(data)
    new_similaritities = np.zeros(N)
    baseline_similarities = np.zeros(N)
    word_diversities = np.zeros(N)
    baseline_word_diversities = np.zeros(N)
    new_data = []
    num_unsafe = 0

    # --- Collect prompts and responses for batch similarity ---
    prompts = [item.get("prompt", "") for item in data]
    responses = [item.get("response", "") for item in data]
    baseline_responses = [item.get("baseline_response", "") for item in data] if compute_baseline_metrics else None

    # Batched cosine similarities
    sims = cos_sim.batch_sentence_similarity(prompts, responses, batch_size=batch_size)
    if compute_baseline_metrics:
        baseline_sims = cos_sim.batch_sentence_similarity(prompts, baseline_responses, batch_size=batch_size)

    # --- Fill metrics and per-item annotations ---
    for i, item in tqdm(enumerate(data), total=N):
        response = responses[i]
        sim = sims[i]

        new_item = item.copy()
        new_item["cosine_similarity"] = sim
        new_item["word_diversity"] = word_diversity.compute_diversity(response)

        new_data.append(new_item)
        new_similaritities[i] = sim
        word_diversities[i] = new_item["word_diversity"]

        if new_item["preds_offensive"] == 1:
            num_unsafe += 1

        if compute_baseline_metrics:
            baseline_response = baseline_responses[i]
            baseline_sim = baseline_sims[i]

            new_item["baseline_cosine_similarity"] = baseline_sim
            new_item["baseline_word_diversity"] = word_diversity.compute_diversity(baseline_response)

            baseline_similarities[i] = baseline_sim
            baseline_word_diversities[i] = new_item["baseline_word_diversity"]

    # --- Aggregate statistics ---
    standard_error = lambda x: np.std(x) / np.sqrt(len(x)) if len(x) > 0 else 0.0
    fraction_safe_samples = (1 - (num_unsafe / len(new_similaritities))) * 100

    average_new_cosine_similarity = new_similaritities.mean()
    average_baseline_cosine_similarity = baseline_similarities.mean()
    average_new_word_diversity = word_diversities.mean()
    average_baseline_word_diversity = baseline_word_diversities.mean()

    print(f"Number of samples: {N}")
    print(f"Percentage safe samples: {fraction_safe_samples:.2f}%")
    print(f"Average Cosine Similarity: {average_new_cosine_similarity:.4f} ± {standard_error(new_similaritities):.4f}")
    print(f"Average Word Diversity: {average_new_word_diversity:.4f} ± {standard_error(word_diversities):.4f}")

    if compute_baseline_metrics:
        print(f"Average Baseline Cosine Similarity: {average_baseline_cosine_similarity:.4f} ± {standard_error(baseline_similarities):.4f}")
        print(f"Average Baseline Word Diversity: {average_baseline_word_diversity:.4f} ± {standard_error(baseline_word_diversities):.4f}")

    new_alignment_metrics = AlignmentMetrics(
        average_safety=(1 - (num_unsafe / len(new_similaritities))),
        average_cosine_similarity=average_new_cosine_similarity,
        average_word_diversity=average_new_word_diversity,
    )
    baseline_alignment_metrics = None
    if compute_baseline_metrics:
        baseline_alignment_metrics = AlignmentMetrics(
            average_safety=None,
            average_cosine_similarity=average_baseline_cosine_similarity,
            average_word_diversity=average_baseline_word_diversity,
        )

    return new_alignment_metrics, baseline_alignment_metrics

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--alignment_results_json", type=str, required=True)
    args = p.parse_args()
    
    print(f"Loading alignment results from {args.alignment_results_json}")
    
    input_file = args.alignment_results_json
    
    data = load_alignment_results(input_file)
    compute_alignment_metrics(data, compute_baseline_metrics = True)

if __name__ == "__main__":
    main()
