"""Tradeoff metrics computation and logging."""

import os
import numpy as np
import torch
import time
from typing import List, Dict, Any, Optional
from scipy.special import expit

THRESHOLDS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.98, 0.99, 1.0]

def compute_tradeoff_group(probs_group, labels_group, times_group):
    """
    Compute the time and reward for a single group.
    """
    last_argmax = np.argmax(np.array([probs[-1] for probs in probs_group]))
    last_argmax_label = labels_group[last_argmax][-1]

    flattened_probs = np.concatenate(probs_group).flatten()
    flattened_labels = np.concatenate(labels_group).flatten()
    flattened_times = np.concatenate(times_group).flatten()
    sorted_indices = np.argsort(flattened_times)
    probs, labels, times = flattened_probs[sorted_indices], flattened_labels[sorted_indices], flattened_times[sorted_indices]

    threshold_times = []
    threshold_rewards = []
    
    for tau in THRESHOLDS:
        if (probs >= tau).sum() == 0:
            exit_time = times[-1]
            exit_reward = last_argmax_label
        else:
            for prob, label, time in zip(probs, labels, times):
                if prob >= tau:
                    exit_time = time
                    exit_reward = label
                    break
        
        threshold_times.append(exit_time)
        threshold_rewards.append(exit_reward)

    # Store results for this group size
    tradeoff_results = {
        "rewards": np.array(threshold_rewards),
        "times": np.array(threshold_times),
    }
    
    return tradeoff_results

def compute_tradeoff_n(probs_per_group, labels_per_group, position_ids_per_group):
    """
    Compute tradeoff_{n} metrics for different group sizes n.
    """
    threshold_rewards_all = []
    threshold_times_all = []
    for probs_group, labels_group, times_group in zip(probs_per_group, labels_per_group, position_ids_per_group):
        tradeoff_results = compute_tradeoff_group(probs_group, labels_group, times_group)
        threshold_rewards_all.append(tradeoff_results["rewards"])
        threshold_times_all.append(tradeoff_results["times"])
    
    threshold_rewards_mean = np.stack(threshold_rewards_all,axis=0).mean(axis=0)
    threshold_times_mean = np.stack(threshold_times_all,axis=0).mean(axis=0)

    # integrate threshold_rewards over threshold_times. note that threshold_times is not uniformly spaced.
    tradeoff_auc = np.trapezoid(threshold_rewards_mean, threshold_times_mean)

    return {
        "thresholds": THRESHOLDS,
        "rewards": threshold_rewards_mean,
        "times": threshold_times_mean,
        "auc": tradeoff_auc
    }
    

def compute_tradeoff_metrics(
    probs, labels, num_completions_per_prompt, eval_ds, plot_save_dir
) -> Dict[str, Any]:
    """
    Compute tradeoff_{g} metrics and save plots for different group sizes g.
    """
    print("Starting tradeoff metrics computation...")
    
    mask = labels != -100
    probs_per_completion = [probs[i][mask[i]] for i in range(len(probs))]
    labels_per_completion = [labels[i][mask[i]].astype(int) for i in range(len(labels))]
    
    position_ids = [[p for p, x in enumerate(c) if x != -100] for c in eval_ds["labels"]]
    assert len(position_ids[0]) == len(probs_per_completion[0]) == len(labels_per_completion[0])

    results = {}
    metrics = {}
    n_values = [2**i for i in range(int(np.log2(num_completions_per_prompt))+1)]
    assert num_completions_per_prompt in n_values
    for n in n_values:
        probs_per_group = [probs_per_completion[i:i+n] for i in range(0, len(probs_per_completion), n)]
        labels_per_group = [labels_per_completion[i:i+n] for i in range(0, len(labels_per_completion), n)]
        position_ids_per_group = [position_ids[i:i+n] for i in range(0, len(position_ids), n)]
        results[f"tradeoff_{n}"] = compute_tradeoff_n(probs_per_group, labels_per_group, position_ids_per_group)
        metrics[f"tradeoff_{n}_auc"] = results[f"tradeoff_{n}"]["auc"]
        plot_tradeoff(results[f"tradeoff_{n}"], f"tradeoff_{n}", plot_save_dir)
    
    return metrics

def plot_tradeoff(results, table_name, plot_save_dir):
    """Create and log a line plot showing the time-accuracy tradeoff curve."""
    import matplotlib.pyplot as plt
    
    # Extract data from results
    times = results["times"]
    accuracies = results["rewards"]  # rewards are the accuracy values
    thresholds = results["thresholds"]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot time vs accuracy with threshold annotations
    scatter = ax.scatter(times, accuracies, c=thresholds, 
                        cmap='viridis', s=50, alpha=0.7)
    ax.plot(times, accuracies, 'b-', alpha=0.5, linewidth=1)
    
    # Add colorbar for thresholds
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Threshold', rotation=270, labelpad=15)
    
    # Labels and title
    ax.set_xlabel('Time (tokens)')
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Time-Accuracy Tradeoff: {table_name}')
    ax.grid(True, alpha=0.3)
    
    # Annotate a few key points
    for i in range(0, len(thresholds)):
        ax.annotate(f'τ={thresholds[i]:.1f}', 
                    (times[i], accuracies[i]),
                    xytext=(5, 5), textcoords='offset points',
                    fontsize=8, alpha=0.8)
    
    plt.tight_layout()
    os.makedirs(plot_save_dir, exist_ok=True)
    plt.savefig(os.path.join(plot_save_dir, f"{table_name}.png"))
    print(f"Saved plot to {os.path.join(plot_save_dir, f'{table_name}.png')}")
    plt.close(fig)
    