import torch
from scipy.stats import norm
import matplotlib.pyplot as plt
import numpy as np
import os
from multiguide.helpers import PROJECT_ROOT
import wandb
def compute_picp(pred_value, pred_log_var, y, confidence=0.95):
    '''
        Compute the proportion of intervals that contain the true value
    '''
    # Calculate confidence interval based on Gaussian assumption
    z_score = norm.ppf((1 + confidence) / 2)  # For 95% confidence, z ≈ 1.96
    std_dev = torch.exp(0.5 * pred_log_var)
    lower_bound = pred_value - z_score * std_dev
    upper_bound = pred_value + z_score * std_dev
    # Count instances where true value is in interval
    in_interval = ((y >= lower_bound) & (y <= upper_bound))
    
    return in_interval

def compute_miw(pred_log_var, confidence=0.95):
    '''
        Compute the mean interval width
    '''
    z_score = norm.ppf((1 + confidence) / 2)
    std_dev = torch.exp(0.5 * pred_log_var)
    interval_width = 2 * z_score * std_dev

    return interval_width

def uncertainty_calibration_plot(config, predictions, variances, targets, errors, plot_name, wandb_panel, run=None, num_bins=10):
    '''
        Plot the uncertainty calibration plot
    '''
    predictions = torch.cat(predictions, dim=0).cpu().numpy().flatten()
    variances = torch.cat(variances, dim=0).cpu().numpy().flatten()
    targets = torch.cat(targets, dim=0).cpu().numpy().flatten()
    errors = torch.cat(errors, dim=0).cpu().numpy().flatten()

    # Sort by predicted variance
    sorted_idx = np.argsort(variances)
    sorted_errors = errors[sorted_idx]
    sorted_variances = variances[sorted_idx]
    
    # Bin the results
    bin_size = len(sorted_idx) // num_bins
    mean_errors = []
    mean_variances = []
    for i in range(num_bins):
        start_idx = i * bin_size
        end_idx = min((i + 1) * bin_size, len(sorted_idx))
        mean_errors.append(np.mean(sorted_errors[start_idx:end_idx]))
        mean_variances.append(np.mean(sorted_variances[start_idx:end_idx]))
    # Plot
    plt.figure(figsize=(10, 6))
    plt.scatter(mean_variances, mean_errors, marker='o')
    # Add the ideal calibration line
    min_val = min(mean_variances)
    max_val = max(mean_variances)
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='Ideal Calibration')
    plt.xlabel('Predicted Variance')
    plt.ylabel('Observed MSE')
    plt.title('Uncertainty Calibration Plot')
    plt.grid(True)
    plt.legend()
    output_path = os.path.join(PROJECT_ROOT,
                               'experiments', 
                                config.general.experiment_name,
                               'figures',
                               f'{plot_name}.png')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path)
    if run is not None:
        image = wandb.Image(output_path)
        run.log({f'{wandb_panel}/{plot_name}': image})

    # Calculate calibration error (distance from ideal line)
    calibration_error = np.mean(np.abs(np.array(mean_errors) - np.array(mean_variances)))

    return calibration_error

def analyze_uncertainty_by_length(config, errors, variances, lengths, plot_name, wandb_panel, run=None):
    '''
        Analyze the uncertainty by length
    '''
    # Convert to numpy arrays
    lengths = torch.cat(lengths, dim=0).cpu().numpy().flatten()
    uncertainties = torch.cat(variances, dim=0).cpu().numpy().flatten()
    errors = torch.cat(errors, dim=0).cpu().numpy().flatten()
    
    # Plot
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.scatter(lengths, uncertainties, alpha=0.5)
    plt.xlabel('Sequence Length')
    plt.ylabel('Predicted Variance')
    plt.title('Uncertainty vs. Sequence Length')
    
    plt.subplot(1, 2, 2)
    plt.scatter(lengths, errors, alpha=0.5)
    plt.xlabel('Sequence Length')
    plt.ylabel('Squared Error')
    plt.title('Error vs. Sequence Length')
    
    plt.tight_layout()
    output_path = os.path.join(PROJECT_ROOT,
                               'experiments', 
                               config.general.experiment_name,
                               'figures',
                               f'{plot_name}.png')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plt.savefig(output_path)
    if run is not None:
        image = wandb.Image(output_path)
        run.log({f'{wandb_panel}/{plot_name}': image})
    
    # Calculate correlation
    uncertainty_length_corr = np.corrcoef(lengths, uncertainties)[0, 1]
    error_length_corr = np.corrcoef(lengths, errors)[0, 1]
    
    return uncertainty_length_corr, error_length_corr

def ece_in_bin(pred_value, variance, y, num_bins=10):
    confidence_bins = np.linspace(0, 1, num_bins + 1)

    threshold = 0.1
    confidence = 1 / (1 + variance/threshold)
    
    # Calculate error
    error = (pred_value - y)**2
    
    # Bin the confidences
    for i in range(num_bins):
        lower = confidence_bins[i]
        upper = confidence_bins[i+1]
        
        # Mask for samples in this bin
        in_bin = (confidence >= lower) & (confidence < upper)
        
        if in_bin.sum() > 0:
            return error[in_bin].mean().item(), (lower + upper) / 2, in_bin.sum().item()

def expected_calibration_error(error, variance, num_bins=10, threshold=0.1):
    '''
        Calculate the expected calibration error
    '''
    confidence_bins = np.linspace(0, 1, num_bins + 1)
    accuracies = []
    confidences = []
    bin_counts = []

    confidence = 1 / (1 + variance/threshold)

    # Bin the confidences
    for i in range(num_bins):
        lower = confidence_bins[i]
        upper = confidence_bins[i+1]
        # Mask for samples in this bin
        in_bin = (confidence >= lower) & (confidence < upper)
        if in_bin.sum() > 0:
            accuracies.append(error)
            confidences.append(np.array((lower + upper) / 2))
            bin_counts.append(in_bin)

    # return values for a single pred_value
    return accuracies, confidences, bin_counts

def calculate_ece(confidences, accuracies, bin_counts):
    ''' 
        Calculate the expected calibration error
    '''
    # Calculate ECE
    bin_counts = np.array(bin_counts)
    total_samples = bin_counts.sum()
    
    ece = 0
    for i in range(len(bin_counts)):
        ece += (bin_counts[i] / total_samples) * abs(confidences[i] - (1 - accuracies[i]))
    
    return ece