"""Main compute_metrics function for trainer evaluation."""

from deepspeed.runtime.lr_schedules import WARMUP_LOG_RATE
import torch
import numpy as np
import os
import hashlib
from scipy.special import expit, softmax, logit
from typing import List, Dict, Optional, Tuple

from .basic_metrics import compute_roc_auc, compute_brier_score, compute_ece_score, compute_argmax_accuracy, compute_nll, compute_roc_auc_per_prompt
from .tradeoff import compute_tradeoff_metrics

def apply_weighted_voting(probs, labels, eval_ds, weighted_voting_group_size):
    """Apply weighted voting using bin_idx from eval_ds for wm_type=omni."""
    if weighted_voting_group_size is None:
        raise ValueError("weighted_voting_group_size must be specified when using weighted voting")

    B, T = labels.shape
    G = weighted_voting_group_size
    N = B // G

    # Extract bin_idx - get the last value from each instance
    bin_idx = [idx_list[-1] for idx_list in eval_ds["bin_idx"]]
    bin_idx = np.array(bin_idx)

    # Apply weighted voting if enabled
    valid_pos = labels != -100
    probs_valid = probs[valid_pos] # B
    assert probs_valid.shape[0] == B

    # Reshape to group structure
    probs_grouped = probs_valid.reshape(N, G)
    bin_idx_grouped = bin_idx.reshape(N, G)

    aggregated_probs = np.zeros_like(probs_grouped)

    # Apply weighted voting within each group of G completions
    for n in range(N):
        group_probs = probs_grouped[n]  # [G]
        group_bin_idx = bin_idx_grouped[n]  # [G]
        same_bins_edge_matrix = (group_bin_idx[:, None] == group_bin_idx[None, :]).astype(float)  # [G, G]

        # Each position gets weighted average based on bin membership (like original)
        probs_aggregated = np.einsum("ab,b->a", same_bins_edge_matrix, group_probs)  # [G]
        probs_aggregated = probs_aggregated / group_probs.sum() # [G]

        # Store the aggregated probabilities
        aggregated_probs[n] = probs_aggregated

    aggregated_probs = aggregated_probs.reshape(B)
    # Clip to ensure probabilities stay within [0, 1] due to numerical precision
    aggregated_probs = np.clip(aggregated_probs, 0.0, 1.0)
    probs_clone = np.zeros_like(probs)
    probs_clone[valid_pos] = aggregated_probs

    # Reshape back to original shape
    return probs_clone

def create_compute_metrics(num_completions_per_prompt, eval_temps="1.0", compute_tradeoff=False, eval_ds=None, save_dir="", group_softmax=False, sum_group_softmax=False, group_size=None, wm=False, save_pred=False, eval_completion_path="", eval_balance_difficulty=False, eval_num_prompts=None, weighted_voting_group_sizes=None, last_rollout_only=False, eval_precomputed_path=None, eval_select_difficulty=None, dataset_seed=42, num_difficulty_bins=8, model_name=None, tokenizer_name=None):
    """Create a compute_metrics function that extracts indices from inputs.

    Args:
        num_completions_per_prompt: Number of completions per prompt
        eval_temps: Comma-separated string of temperature values for calibration evaluation
        weighted_voting_group_sizes: List of group sizes for weighted voting metrics (default: [4, 16, 64])
        last_rollout_only: Whether to only use last rollout data (required for weighted voting metrics)

    Returns:
        A compute_metrics function that can be passed to trainer initialization
    """
    # Parse temperature values
    temps = [float(t.strip()) for t in eval_temps.split(",")]

    # Set default weighted voting group sizes
    if weighted_voting_group_sizes is None:
        weighted_voting_group_sizes = [4, 16, 64]
    
    def compute_metrics(eval_pred):
        """Compute metrics for evaluation."""
        # Handle both old format (predictions, labels) and new format with inputs
        
        # Extract logits and labels
        logits = eval_pred.predictions
        labels = eval_pred.label_ids

        def compute_probs_for_temperature(temp, apply_weighted_voting_size=None):
            """Compute probabilities for a given temperature."""
            if not group_softmax or wm:
                # Single logit case - squeeze and apply sigmoid with temperature
                logits_squeezed = logits.squeeze(-1)  # Remove last dimension
                temp_logits = logits_squeezed / temp  # Apply temperature scaling
                probs = expit(temp_logits)  # Numerically stable sigmoid
            elif group_softmax:
                B, T = labels.shape
                G = group_size
                N = B // G
                # Do group aggregation with original logits first, then apply temperature scaling
                original_logits = logits.squeeze(-1)  # Remove last dimension but keep original scale
                assert original_logits.shape[0] == B+N
                assert original_logits.shape[1] == T
                original_logits = original_logits.reshape(N, G+1, T)
                logits_class = original_logits[:,:G] # [N, G, T]
                logits_class = logits_class.reshape(N, G*T) # [N, G*T]
                logits_abstain = original_logits[:,G] # [N, T]
                bin_idx = [idx_list[-1] for idx_list in eval_ds["bin_idx"]][:B] # B
                bin_idx = np.array(bin_idx) # [B]
                bin_idx = bin_idx.reshape(N, G) # [N, G]
                temp_labels = labels.reshape(N, G*T) # [N, G*T]
                probs = np.zeros_like(temp_labels, dtype=float) # [B, T]
                probs = probs.reshape(N, G*T) # [N, G*T]
                for n in range(N):
                    valid_pos = temp_labels[n] != -100 # [G]
                    logits_class_valid = logits_class[n][valid_pos] # [G]
                    logits_abstain_valid = logits_abstain[n][0] # [1]
                    assert len(logits_class_valid) == G
                    same_bins_edge_matrix = (bin_idx[n][:,None] == bin_idx[n][None,:]).astype(float) # [G, G]
                    if sum_group_softmax:
                        logits_class_valid_aggregated = np.einsum("ab,b->a", same_bins_edge_matrix, logits_class_valid) # [G]
                    else:
                        normalized_same_bins_edge_matrix = same_bins_edge_matrix / same_bins_edge_matrix.sum(axis=-1, keepdims=True) # [G, G]
                        logits_class_valid_aggregated = np.einsum("ab,b->a", normalized_same_bins_edge_matrix, logits_class_valid) # [G]
                    logits_class_valid = logits_class_valid_aggregated
                    unique_bins = set()
                    unique_bins_logits = dict()

                    for i, idx in enumerate(bin_idx[n]):
                        if idx.item() not in unique_bins:
                            unique_bins.add(idx.item())
                            unique_bins_logits[idx.item()] = logits_class_valid[i]

                    unique_bins_logits[-1] = logits_abstain_valid
                    unique_bins_logits_arr = np.stack(list(unique_bins_logits.values())) # [num_unique_bins+1]
                    # Apply temperature scaling AFTER aggregation to preserve ranking
                    temp_unique_bins_logits_arr = unique_bins_logits_arr / temp
                    unique_bins_probs_arr = softmax(temp_unique_bins_logits_arr) # [num_unique_bins+1] - with temperature scaling
                    unique_bins_probs = dict()
                    for i, k in enumerate(unique_bins_logits.keys()):
                        unique_bins_probs[k] = unique_bins_probs_arr[i]
                    probs_valid = []
                    for i, idx in enumerate(bin_idx[n]):
                        probs_valid.append(unique_bins_probs[idx.item()])
                    probs_valid = np.array(probs_valid)
                    probs[n][valid_pos] = probs_valid
                probs = probs.reshape(B, T)

            # Apply weighted voting if specified
            if apply_weighted_voting_size is not None and eval_ds is not None:
                probs = apply_weighted_voting(probs, labels, eval_ds, apply_weighted_voting_size)

            return probs

        mask = labels != -100
        if mask.sum().item() == 0:
            raise ValueError("No valid labels found")

        # Compute base probabilities (temperature 1.0) for best-of-n metrics
        probs = compute_probs_for_temperature(1.0)
        probs_per_completion = [probs[i][mask[i]] for i in range(len(probs))]
        labels_per_completion = [labels[i][mask[i]].astype(int) for i in range(len(labels))]

        assert len(probs_per_completion) % num_completions_per_prompt == 0
        probs_per_prompt = [np.concatenate(probs_per_completion[i:i+num_completions_per_prompt]) for i in range(0, len(probs_per_completion), num_completions_per_prompt)]
        labels_per_prompt = [np.concatenate(labels_per_completion[i:i+num_completions_per_prompt]) for i in range(0, len(labels_per_completion), num_completions_per_prompt)]

        probs_flat = np.concatenate(probs_per_completion)
        labels_flat = np.concatenate(labels_per_completion)

        # Compute metrics for each temperature
        result = {}
        
        for temp in temps:
            # Apply temperature scaling to logits before group softmax (or to logits for non-group softmax)
            temp_probs = compute_probs_for_temperature(temp)

            temp_probs_per_completion = [temp_probs[i][mask[i]] for i in range(len(temp_probs))]
            temp_probs_per_prompt = [np.concatenate(temp_probs_per_completion[i:i+num_completions_per_prompt]) for i in range(0, len(temp_probs_per_completion), num_completions_per_prompt)]
            temp_probs_flat = np.concatenate(temp_probs_per_completion)
            
            # Compute calibration metrics for this temperature
            temp_str = str(temp).replace('.', '_')
            
            if np.unique(labels_flat).size == 2:
                result[f"roc_auc_{temp_str}"] = compute_roc_auc(labels_flat, temp_probs_flat)
            else:
                result[f"roc_auc_{temp_str}"] = float("nan")
            
            result[f"brier_{temp_str}"] = compute_brier_score(labels_flat, temp_probs_flat)
            result[f"ece_{temp_str}"] = compute_ece_score(labels_flat, temp_probs_flat)
            result[f"nll_{temp_str}"] = compute_nll(labels_flat, temp_probs_flat)
            result[f"roc_auc_per_prompt_{temp_str}"] = compute_roc_auc_per_prompt(labels_per_prompt, temp_probs_per_prompt)

            # Compute weighted voting metrics for this temperature
            if eval_ds is not None and last_rollout_only:
                for wm_group_size in weighted_voting_group_sizes:
                    # Check if the group size is valid for this dataset
                    if num_completions_per_prompt % wm_group_size == 0:
                        # Compute probabilities with weighted voting
                        wm_probs = compute_probs_for_temperature(temp, apply_weighted_voting_size=wm_group_size)
                        wm_probs_per_completion = [wm_probs[i][mask[i]] for i in range(len(wm_probs))]
                        wm_probs_per_prompt = [np.concatenate(wm_probs_per_completion[i:i+num_completions_per_prompt]) for i in range(0, len(wm_probs_per_completion), num_completions_per_prompt)]
                        wm_probs_flat = np.concatenate(wm_probs_per_completion)

                        # Compute calibration metrics for weighted voting
                        if np.unique(labels_flat).size == 2:
                            result[f"wm_{wm_group_size}_roc_auc_{temp_str}"] = compute_roc_auc(labels_flat, wm_probs_flat)
                        else:
                            result[f"wm_{wm_group_size}_roc_auc_{temp_str}"] = float("nan")

                        result[f"wm_{wm_group_size}_brier_{temp_str}"] = compute_brier_score(labels_flat, wm_probs_flat)
                        result[f"wm_{wm_group_size}_ece_{temp_str}"] = compute_ece_score(labels_flat, wm_probs_flat)
                        result[f"wm_{wm_group_size}_nll_{temp_str}"] = compute_nll(labels_flat, wm_probs_flat)
        
        # Best-of-n metrics using original probabilities (temperature 1.0)
        n_values = [2**i for i in range(int(np.log2(num_completions_per_prompt))+1)]
        assert num_completions_per_prompt in n_values
        for n in n_values:
            probs_per_group_n = [np.concatenate(probs_per_completion[i:i+n]) for i in range(0, len(probs_per_completion), n)]
            labels_per_group_n = [np.concatenate(labels_per_completion[i:i+n]) for i in range(0, len(labels_per_completion), n)]
            result[f"best_of_{n}"] = compute_argmax_accuracy(labels_per_group_n, probs_per_group_n)

        # Weighted voting best-of-n metrics
        if eval_ds is not None and last_rollout_only:
            for wm_group_size in weighted_voting_group_sizes:
                if num_completions_per_prompt % wm_group_size == 0:
                    # Compute probabilities with weighted voting (temperature 1.0)
                    wm_probs = compute_probs_for_temperature(1.0, apply_weighted_voting_size=wm_group_size)
                    wm_probs_per_completion = [wm_probs[i][mask[i]] for i in range(len(wm_probs))]

                    # Compute weighted voting best-of-n metrics
                    for n in n_values:
                        if n <= wm_group_size:  # Only compute for valid n values
                            wm_probs_per_group_n = [np.concatenate(wm_probs_per_completion[i:i+n]) for i in range(0, len(wm_probs_per_completion), n)]
                            wm_labels_per_group_n = [np.concatenate(labels_per_completion[i:i+n]) for i in range(0, len(labels_per_completion), n)]
                            result[f"wm_{wm_group_size}_best_of_{n}"] = compute_argmax_accuracy(wm_labels_per_group_n, wm_probs_per_group_n)

        if compute_tradeoff:
            plot_save_dir = os.path.join(save_dir, "tradeoff") if save_dir else ""
            result.update(compute_tradeoff_metrics(probs, labels, num_completions_per_prompt, eval_ds, plot_save_dir))

        if save_pred and save_dir:
            import pickle

            # Create hash of eval path (either eval_completion_path or eval_precomputed_path)
            eval_path = eval_precomputed_path if eval_precomputed_path else eval_completion_path
            eval_hash = hashlib.md5(eval_path.encode()).hexdigest()[:8]
            folder_name = f"{eval_hash}_{eval_balance_difficulty}_{eval_num_prompts}"
            pred_save_dir = os.path.join(save_dir, "predictions", folder_name)
            os.makedirs(pred_save_dir, exist_ok=True)

            # Save predictions and dataset
            with open(os.path.join(pred_save_dir, "eval_pred.pkl"), "wb") as f:
                pickle.dump(eval_pred, f)

            # Save key parameters for dataset reconstruction and metric recomputation
            params = {
                "num_completions_per_prompt": num_completions_per_prompt,
                "wm": wm,
                "group_size": group_size,

                # Dataset reconstruction parameters
                "eval_completion_path": eval_completion_path if not eval_precomputed_path else None,
                "eval_precomputed_path": eval_precomputed_path if eval_precomputed_path else None,
                "eval_num_prompts": eval_num_prompts,
                "eval_balance_difficulty": eval_balance_difficulty,
                "eval_select_difficulty": eval_select_difficulty,
                "dataset_seed": dataset_seed,
                "num_difficulty_bins": num_difficulty_bins,

                # Only needed for non-precomputed reconstruction
                "model_name": model_name if not eval_precomputed_path else None,
                "tokenizer_name": tokenizer_name if not eval_precomputed_path else None,
            }
            with open(os.path.join(pred_save_dir, "params.pkl"), "wb") as f:
                pickle.dump(params, f)

            print(f"Predictions saved to {pred_save_dir}")

        return result
    
    return compute_metrics
