import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import entropy
from scipy.optimize import linear_sum_assignment as linear_assignment
from sklearn.metrics import confusion_matrix, hamming_loss

def credible_interval(gamma_samples, credibility=0.95, dim=0):
    """
    Compute credible interval for state probabilities.
    """
    lower_bound = torch.quantile(gamma_samples, (1 - credibility) / 2, dim=dim)
    upper_bound = torch.quantile(gamma_samples, 1 - (1 - credibility) / 2, dim=dim)
    interval_width = upper_bound - lower_bound  # Wider intervals mean more uncertainty
    return lower_bound, upper_bound, interval_width


def compute_uncertainty_stats(uncertainty_measures):
    """
    Computes basic statistics for uncertainty measures.
    """
    stats = {}
    for key, tensor in uncertainty_measures.items():
        mean_val = tensor.mean().item()
        std_val = tensor.std().item()
        stats[key] = {
            'mean': mean_val,
            'std': std_val,
            'z_score': round(mean_val+3*std_val, 2),
            'percentile_25': torch.quantile(tensor, 0.25).item(),
            'percentile_50': torch.quantile(tensor, 0.50).item(),  # Median
            'percentile_75': torch.quantile(tensor, 0.75).item(),
            'percentile_95': torch.quantile(tensor, 0.95).item()
        }
    return stats

def compute_thresholds(stats, method='percentile'):
    """
    Computes thresholds for uncertainty measures.
    """
    thresholds = {}
    for key, values in stats.items():
        if method == 'z_score':  # Mean + 3*STD (Assumes Gaussian distribution)
            thresholds[key] = values['mean'] + 2 * values['std']
        elif method == 'percentile':  # 95th percentile (More robust)
            thresholds[key] = values['percentile_95']
        else:
            raise ValueError("Method must be 'zscore' or 'percentile'")
    return thresholds

def plot_uncertainty_distributions(uncertainty_measures, thresholds):
    """
    Plots histograms of uncertainty measures with thresholds.
    """
    for key, tensor in uncertainty_measures.items():
        if key not in {"perturbation_variance", "feature_dropout_variance"}:
            plt.figure(figsize=(6, 4))
            plt.hist(tensor.cpu().numpy().flatten(), bins=50, alpha=0.7, label=key)
            plt.axvline(thresholds[key], color='r', linestyle='dashed', linewidth=2, label='Threshold')
            plt.xlabel("Uncertainty Value")
            plt.ylabel("Frequency")
            plt.title(f"Distribution of {key}")
            plt.legend()
            plt.show()
            
def plot_comp_uncertainty(uncertainty_measures_train, uncertainty_measures_test, thresholds):
    """
    Plots histograms of uncertainty measures with thresholds.
    """
    for key, tensor in uncertainty_measures_train.items():
        if key not in {"perturbation_variance", "feature_dropout_variance", "state_entropy"}:
            plt.figure(figsize=(6, 4))
        
            # Convert tensors to numpy arrays
            train_values = uncertainty_measures_train[key].cpu().numpy().flatten()
            test_values = uncertainty_measures_test[key].cpu().numpy().flatten()
            
            # Plot histograms for train and test
            plt.hist(train_values, bins=50, alpha=0.5, label=f"Train {key}", color='blue', density=True)
            plt.hist(test_values, bins=50, alpha=0.5, label=f"Test {key}", color='orange', density=True)
            
            # Add threshold line
            plt.axvline(thresholds[key], color='r', linestyle='dashed', linewidth=2, label='Threshold')

            # Labels and title
            plt.xlabel("Uncertainty Value")
            plt.ylabel("Density")
            plt.title(f"Train vs. Test Distribution of {key}")
            plt.legend()
            plt.show()
            
def compute_comp_stats(uncertainty_measures_train, uncertainty_measures_test, thresholds, key="kl_div"):
    """
    Computes the number and percentage of individuals with 'kl_div' uncertainty 
    above and below the threshold for both train and test datasets.
    Returns the indices of individuals above the threshold.
    """
    
    if key not in uncertainty_measures_train or key not in thresholds:
        print(f"'{key}' not found in data.")
        return None, None  # Return empty if key not found
    
    # Extract uncertainty values and threshold
    train_values = uncertainty_measures_train[key].cpu().numpy().flatten()
    test_values = uncertainty_measures_test[key].cpu().numpy().flatten()
    threshold = thresholds[key]

    # Compute indices of individuals above threshold
    train_above_indices = np.where(train_values >= threshold)[0]
    test_above_indices = np.where(test_values >= threshold)[0]

    # Compute counts and percentages for train
    train_below = len(train_values) - len(train_above_indices)
    train_above = len(train_above_indices)
    train_total = len(train_values)
    
    train_below_pct = (train_below / train_total) * 100
    train_above_pct = (train_above / train_total) * 100

    # Compute counts and percentages for test
    test_below = len(test_values) - len(test_above_indices)
    test_above = len(test_above_indices)
    test_total = len(test_values)
    
    test_below_pct = (test_below / test_total) * 100
    test_above_pct = (test_above / test_total) * 100

    # Print results
    print(f"{key} Uncertainty Stats (Threshold: {threshold}):")
    print(f"Train: Below ({train_below}/{train_total}, {train_below_pct:.2f}%), Above ({train_above}/{train_total}, {train_above_pct:.2f}%)")
    print(f"Test:  Below ({test_below}/{test_total}, {test_below_pct:.2f}%), Above ({test_above}/{test_total}, {test_above_pct:.2f}%)")

    return train_above_indices, test_above_indices
        
        
def evaluate_model_robustness(uncertainty_measures):
    """
    Evaluates model robustness
    """
    robustness_scores = {
        'input_noise_sensitivity': uncertainty_measures['perturbation_variance'].mean().item(),
        'dropout_sensitivity': uncertainty_measures['feature_dropout_variance'].mean().item(),
    }
    print(robustness_scores)
    return robustness_scores

def change_shapes(uncertainty_measures_in):
    uncertainty_measures = uncertainty_measures_in.copy()
    gamma_var = uncertainty_measures['gamma_var']
    gamma_var_mean = gamma_var.mean(-1)  # Averaging over K
    uncertainty_measures['gamma_var'] = gamma_var_mean # Shape: (N, T)
    
    kl_div = uncertainty_measures['kl_div']
    first_value = kl_div[:, :1]
    #first_value = torch.zeros_like(kl_div[:, :1])
    uncertainty_measures['kl_div'] = torch.cat([first_value, kl_div], dim=1)  # Shape: (N, T) 
    
    uncertainty_measures['credible_interval_width'] = uncertainty_measures['credible_interval_width'].mean(-1)
    
    uncertainty_measures['perturbation_variance'] = uncertainty_measures['perturbation_variance'].mean(-1)
    
    uncertainty_measures['feature_dropout_variance']  =  uncertainty_measures['feature_dropout_variance'].mean(-1)
    
    return uncertainty_measures # All outputs' shape: (N, T)

def adaptive_threshold_update(thresholds, new_uncertainty_stats, alpha=0.1):
    """
    Adaptively updates uncertainty thresholds based on new data.
    
    Args:
    - thresholds (dict): Existing thresholds
    - new_uncertainty_stats (dict): Updated statistics from latest data
    - alpha (float): Update factor (0.1 = slow update, 0.5 = fast adaptation)
    """
    updated_thresholds = {}
    for key in thresholds:
        updated_thresholds[key] = (1 - alpha) * thresholds[key] + alpha * new_uncertainty_stats[key]['percentile_95']
    
    return updated_thresholds

def compute_ece(true_labels, pred_probs, z_lens, num_bins=10, mapping=None):
    """summery_
    Parameters:
    - pred_probs: (N, C) array of predicted probabilities for the most confident class.
    - true_labels: (N,) array of true class labels.
    - num_bins
    Returns:
    - ece: Expected Calibration Error.
    """
    if not isinstance(true_labels, np.ndarray):
        true_labels = np.array(true_labels)
    if not isinstance(pred_probs, np.ndarray):
        pred_probs = np.array(pred_probs)
        
    max_prob = np.max(pred_probs, axis=-1)
    max_indices = np.argmax(pred_probs, axis=-1)
    
    true_vec = np.concatenate([zz[:z_lens[z_i]] for z_i, zz in enumerate(true_labels)])#z_true.numpy().reshape(-1,)
    pred_vec = np.concatenate([zz[:z_lens[z_i]] for z_i, zz in enumerate(max_prob)])#z_pred.numpy().reshape(-1,)
    pred_ind = np.concatenate([zz[:z_lens[z_i]] for z_i, zz in enumerate(max_indices)])#z_pred.numpy().reshape(-1,)

    if mapping is None:
        mapping = {}
        for true_labels in np.unique(true_vec):
            mapping[int(true_labels)] = col_ind[int(true_labels)]
    cm = confusion_matrix(true_vec, pred_ind, labels=np.arange(20))  # the ij'th element is the number of class i predicted as class
    row_ind, col_ind = linear_assignment(cm, maximize=True)
    z_true_mapped = np.copy(true_vec)
    for (gt_z, pred_z) in mapping.items():
         z_true_mapped[z_true_mapped == gt_z] = pred_z
         
    bin_boundaries = np.linspace(0, 1, num_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0.0
    vec_bin_acc, vec_bin_conf, vec_bin_w = [],[],[]
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        bin_mask = (pred_vec > bin_lower) & (pred_vec <= bin_upper)
        if np.sum(bin_mask) == 0:
            continue
        #print(pred_probs.shape, bin_mask.shape)
        #print(pred_probs[bin_mask].shape)
        bin_accuracy = np.mean(z_true_mapped[bin_mask] == pred_ind[bin_mask]) #np.argmax(pred_probs[bin_mask], axis=1))
        bin_confidence = np.mean(pred_vec[bin_mask])
        bin_weight = np.sum(bin_mask) / len(pred_vec)
        vec_bin_acc.append(bin_accuracy)
        vec_bin_conf.append(bin_confidence)
        vec_bin_w.append(bin_weight)
        ece += bin_weight * abs(bin_confidence - bin_accuracy)
    
    print(f'ECE: {ece}')
    print(vec_bin_w)
    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], "k--", label="Perfectly Calibrated")  # Ideal line
    # plt.scatter(vec_bin_conf, vec_bin_acc, color="blue", 
    #             s=np.array(vec_bin_w) * 500, alpha=0.7, label="Observed Calibration")  # Bubble size proportional to bin weight
    # plt.plot(bin_confidence, bin_accuracy, color="blue", lw=2)
    plt.bar(vec_bin_conf, vec_bin_acc, width=0.1, color="royalblue", alpha=0.8, edgecolor="black", label="Observed Accuracy")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    #plt.title(f"Expected Calibration Error (ECE): {ece:.4f}")
    plt.title(f"Reliability Diagram (ECE: {ece:.4f})")
    plt.legend()
    plt.grid(True)
    plt.show()
    

