import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
import copy
from sklearn.metrics import r2_score
np.random.seed(42)

# Constants
mu = 0.001
alpha = 0.004
beta = 0.5
partial_hawkes = True
model_auc = True
partial_hawkes_no_kalman = True

# Reshape data for easier computation
def reshape_for_auc(data):
    return data.reshape(-1, data.shape[-1])


def mc_compute_ovr_auc(y_true, y_pred, n_classes, n_samples):
    """
    Compute the AUC using Monte Carlo Simulation with OvR strategy.
    
    Iterate over the classes, and then for each class, 
    draw n_samples from the true and predicted labels, then sample randomly from
    a sample not in the class, then compute the AUC by comparing the probabilities.
    """
    auc_values = []
    # convert y_true from one-hot to scalar
    y_true = np.argmax(y_true, axis=1)

    for i in range(n_classes):
        # Select samples where the true class is i
        true_class_indices = np.where(y_true == i)[0]
        other_class_indices = np.where(y_true != i)[0]
        # Draw random samples from the true and predicted labels
        true_samples = np.random.choice(true_class_indices, n_samples, replace=True)
        other_samples = np.random.choice(other_class_indices, n_samples, replace=True)
        # Combine the samples
        # Now compute the probabilities for class i in both the true and other samples
        true_probs = y_pred[true_samples, i]
        other_probs = y_pred[other_samples, i]
        
        # Compute the AUC
        auc = (1/n_samples)*np.sum([1 if true_probs[i] > other_probs[i] else 0 for i in range(n_samples)])
        auc_values.append(auc)
        
    return auc_values

def compute_binary_auc(labels: np.ndarray, scores: np.ndarray) -> float:
    """
    Compute the AUC for a binary classification problem without using any libraries.

    Parameters
    ----------
    labels : np.ndarray, shape (n_samples,)
        Binary ground truth labels: 1 for positive, 0 for negative.
    scores : np.ndarray, shape (n_samples,)
        Predicted scores or probabilities for the positive class.

    Returns
    -------
    float
        The AUC (Area Under the ROC Curve) value between 0.0 and 1.0.
    """
    # Ensure inputs are 1-D
    labels = labels.ravel()
    scores = scores.ravel()
    
    # Sort samples by score
    sorted_idx = np.argsort(scores)[::-1]
    sorted_labels = labels[sorted_idx]

    # Count total positives/negatives
    pos_count = np.sum(sorted_labels)
    neg_count = len(labels) - pos_count

    # Edge case: if there's no positive or no negative, define AUC = 0.5
    if pos_count == 0 or neg_count == 0:
        return 0.5

    # Accumulate how many positives we have encountered as we scan in ascending order
    # For each negative, add how many positives came before it.
    auc_partial = 0
    cum_positives = 0
    for lbl in sorted_labels:
        if lbl == 1:
            cum_positives += 1
        else:  # lbl == 0 (negative)
            auc_partial += cum_positives

    # Divide by (pos_count * neg_count) to normalize into [0, 1]
    auc_value = auc_partial / (pos_count * neg_count)
    return auc_value

def compute_multiclass_auc_ovo_macro(true_labels: np.ndarray,
                                     predicted_probs: np.ndarray) -> float:
    """
    Compute multi-class AUC (OvO, macro-average) without using libraries.

    Parameters
    ----------
    true_labels : np.ndarray, shape (n_samples, n_classes)
        One-hot-encoded true labels. For example, if there are 3 classes,
        each row is something like [1, 0, 0] meaning class 0 is the true class.
    predicted_probs : np.ndarray, shape (n_samples, n_classes)
        Predicted probabilities for each class. Each row corresponds to one sample.

    Returns
    -------
    float
        The macro-average AUC across all one-vs-one pairs.
    """
    n_samples, n_classes = true_labels.shape
    
    # Basic shape checks
    if predicted_probs.shape != (n_samples, n_classes):
        raise ValueError("Shapes of true_labels and predicted_probs must match.")

    # For each pair (i, j), i < j, compute the binary AUC
    auc_values = []
    for i in range(n_classes):
        for j in range(i + 1, n_classes):
            # Select samples where the true class is either i or j
            # Mask: sample belongs to class i or j
            mask = (true_labels[:, i] == 1) | (true_labels[:, j] == 1)
            
            # Extract the sub-problem
            # y_sub = 1 if class i is true, 0 if class j is true
            y_sub = (true_labels[mask, i] == 1).astype(int)
            scores_sub = predicted_probs[mask, i]
            # Compute binary AUC for classes (i vs j)
            sub_auc = compute_binary_auc(y_sub, scores_sub)
            #sub_auc1 = roc_auc_score(y_sub, scores_sub)
            
            auc_values.append(sub_auc)

    # Macro-average of all OvO pairs
    if len(auc_values) == 0:
        return 0.5  # edge case if something is empty
    
    return np.mean(auc_values)

# Compute AUC for each start state
def compute_auc_by_start_state(y_true_scalar, y_probs, y_true_reshaped):
    num_classes = y_probs.shape[-1]
    auc_scores = []
    joint_ovr_true_labels = []
    joint_ovr_pred_probs = []
    for start_state in range(num_classes):
        # Filter samples starting in the start_state
        start_state_indices = np.where(y_true_scalar == start_state)[0]
        end_state_indices = start_state_indices + 1  # Shift to the next step
        valid_indices = end_state_indices[end_state_indices < len(y_true_scalar)]  # Ensure valid indices

        if len(valid_indices) == 0:
            auc_scores.append([np.nan] * num_classes)  # No samples for this state
            continue

        # Select true labels (next step) and probabilities for these samples
        true_labels_next_step = y_true_reshaped[valid_indices, :]
        predicted_probs_next_step = y_probs[valid_indices, :]
        
        # Compute AUC for multiclass transitions from this start state
        try:
            assert start_state != 2
            auc = roc_auc_score(
                true_labels_next_step[:,:],
                predicted_probs_next_step[:,:] + np.random.normal(0, 1e-6, predicted_probs_next_step[:,:].shape),
                multi_class='ovr',
                average=None
            )
            #
            #auc = mc_compute_ovr_auc(true_labels_next_step[:,:], predicted_probs_next_step[:,:] + np.random.normal(0, 1e-6, predicted_probs_next_step[:,:].shape), n_classes=3, n_samples=1000)
            
            auc_scores.append(auc)
            joint_ovr_pred_probs.append(predicted_probs_next_step)
            joint_ovr_true_labels.append(true_labels_next_step)
        except:
            # Handle case with insufficient samples for AUC calculation
            
            #auc_scores.append(np.nan)
            auc_scores.append([np.nan] * num_classes)
    
    joint_ovr_pred_probs = np.concatenate(joint_ovr_pred_probs, axis=0)
    joint_ovr_true_labels = np.concatenate(joint_ovr_true_labels, axis=0)
    # Compute the AUC for the joint OvR problem
    auc_joint_ovr = roc_auc_score(
                joint_ovr_true_labels[:,:],
                joint_ovr_pred_probs[:,:] + np.random.normal(0, 1e-6, joint_ovr_pred_probs[:,:].shape),
                multi_class='ovr',
                average=None
            )
    #auc_joint_ovr = mc_compute_ovr_auc(joint_ovr_true_labels[:,:], joint_ovr_pred_probs[:,:] + np.random.normal(0, 1e-6, joint_ovr_pred_probs[:,:].shape), n_classes=3, n_samples=5000)
    print(f"Joint OvR AUC: {auc_joint_ovr}")

    return np.array(auc_scores), auc_joint_ovr

# Save AUC matrix with values
def save_auc_matrix_with_values(auc_matrix, method_name, output_path):
    plt.figure(figsize=(10, 8))
    plt.imshow(auc_matrix, interpolation='nearest', cmap='viridis')
    plt.colorbar(label='AUC')
    plt.title(f"AUC Matrix - {method_name}, Contagion Hawkes Process \n Parameters: mu = {mu}, alpha = {alpha}, beta = {beta}")
    plt.xlabel("End State")
    plt.ylabel("Start State")
    plt.xticks(range(auc_matrix.shape[1]), labels=["Class 0", "Class 1", "Default"])
    plt.yticks(range(auc_matrix.shape[0]-1), labels=["Class 0", "Class 1"])

    # Annotate matrix with numerical values
    for i in range(auc_matrix.shape[0]-1):
        for j in range(auc_matrix.shape[1]):
            value = auc_matrix[i, j]
            if not np.isnan(value):
                plt.text(j, i, f"{value:.3f}", ha='center', va='center', color='white' if value < 0.5 else 'black')

    plt.savefig(output_path)
    plt.close()

def save_joint_ovr(auc_matrix, output_path, N):
    auc_matrix= np.array(auc_matrix)
    plt.figure(figsize=(10, 8))
    plt.imshow(auc_matrix, interpolation='nearest', cmap='viridis')
    plt.colorbar(label='AUC')
    plt.title(f"AUC One Vs Rest with Contagion Hawkes Process \n Parameters  mu = {mu}, alpha = {alpha}, beta = {beta}")
    plt.xticks(range(auc_matrix.shape[1]), labels=["Class 0 vs (Class 1 and Defualt)", "Class 1 vs (Class 0 and Default)", "Default vs (Class 0 and Class 1)"], rotation=45)
    if model_auc and partial_hawkes:
        plt.yticks(range(auc_matrix.shape[0]), labels=["Non Contagion Probabilities", "True Transition Probabilities", "Set Sequence Model", "Partial True Transition Probabilities N ={}".format(N)])
    elif partial_hawkes:
        plt.yticks(range(auc_matrix.shape[0]), labels=["Non Contagion Probabilities", "True Transition Probabilities", "Partial True Transition Probabilities N ={}".format(N)])
    elif model_auc:
        plt.yticks(range(auc_matrix.shape[0]), labels=["Non Contagion Probabilities", "True Transition Probabilities", "Set Sequence Model"])
    else:
        plt.yticks(range(auc_matrix.shape[0]), labels=["Non Contagion Probabilities", "True Transition Probabilities"])

    # Annotate matrix with numerical values
    for i in range(auc_matrix.shape[0]):
        for j in range(auc_matrix.shape[1]):
            value = auc_matrix[i, j]
            if not np.isnan(value):
                plt.text(j, i, f"{value:.3f}", ha='center', va='center', color='white' if value < 0.5 else 'black')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()



import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt

def plot_kalman_model_comparison(y_true, ground_truth_probs, data, kalman_indices, save_path):
    """
    Generate and save a figure with three horizontal subplots (left to right):
      1. AUC One vs Rest for Rare Event Class.
      2. Correlation between ground truth and predictions.
      3. R² (Coefficient of Determination) for Rare Event Probabilities.
    
    The figure is saved as a PDF file.
    
    Parameters:
      y_true: array-like
          Ground truth labels.
      ground_truth_probs: array-like
          Ground truth probabilities.
      data: dict
          Dictionary of model results. Keys (e.g. 'Kalman Filter') map to prediction data.
      kalman_indices: array-like
          Values for the x-axis (# Units for Inference / Non-hidden units).
      save_path: str
          File path (with or without .pdf extension) to save the resulting high-resolution figure.
    """
    # Ensure the figure is saved as a PDF file.
    if not save_path.lower().endswith('.pdf'):
        save_path += '.pdf'
    
    # Update global font settings and line width.
    plt.rcParams.update({'font.size': 20, 'lines.linewidth': 2.3})
    xlabel = "# Units for Inference / Non-hidden units"
    
    # Create a figure with 3 subplots in one row that share the x-axis.
    fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharex=True)
    
    # Iterate over the data for each model.
    for key, value in data.items():
        metric_dict = get_metrics(y_true, ground_truth_probs, value)
        # Extract metrics.
        aucs  = metric_dict["auc"]
        corrs = metric_dict["corr"]
        r2s   = metric_dict["r2"]
        
        # Use a square marker for the Kalman Filter and circles for other models.
        marker = 's' if key == 'Kalman Filter' else 'o'
        # Left subplot: AUC.
        axes[0].plot(kalman_indices, aucs, label=key, marker=marker, markersize=8)
        # Middle subplot: Correlation.
        axes[1].plot(kalman_indices, corrs, label=key, marker=marker, markersize=8)
        # Right subplot: R².
        axes[2].plot(kalman_indices, r2s, label=key, marker=marker, markersize=8)
    
    # Customize the left subplot: AUC.
    axes[0].set_title("AUC One vs Rest for Rare Event Class", fontsize=24)
    axes[0].set_ylabel("AUC", fontsize=22)
    axes[0].grid(True, linestyle='-')
    axes[0].tick_params(axis='both', labelsize=18)
    
    # Customize the middle subplot: Correlation.
    axes[1].set_title("Correlation With Ground Truth Probabilities", fontsize=24)
    axes[1].set_ylabel("Correlation", fontsize=22)
    axes[1].grid(True, linestyle='-')
    axes[1].tick_params(axis='both', labelsize=18)
    
    # Customize the right subplot: R².
    axes[2].set_title("R² of Rare Event Probabilities", fontsize=24)
    axes[2].set_ylabel("R² Score", fontsize=22)
    axes[2].grid(True, linestyle='-')
    axes[2].tick_params(axis='both', labelsize=18)
    
    # Set a common x-axis label for all subplots.
    for ax in axes:
        ax.set_xlabel(xlabel, fontsize=22)
    
    # Create a common legend below all subplots.
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=len(labels), fontsize=20,
               bbox_to_anchor=(0.5, -0.02))
    
    # Adjust layout to ensure the legend and subplots don't overlap.
    fig.tight_layout(rect=[0, 0.05, 1, 1])
    
    # Save the figure as a PDF with high resolution.
    fig.savefig(save_path, format='pdf', dpi=300, bbox_inches='tight')
    plt.close(fig)


def plot_kalman_model_comparison_old(y_true, ground_truth_probs,data, kalman_indices, save_path):
    """
    Generate and save two plots:
    1. Correlation between ground truth and predictions (Kalman filter and model)
    2. Relative absolute error for Kalman filter and model
    """
    # Create the plots
    plt.figure(figsize=(16, 12))
    #plt.rcParams.update({'font.size': 16, 'font.family': 'Times New Roman'})  # Set font size and style
    plt.rcParams.update({'font.size': 16}) 
    xlabel = "# Units for Inference / Non-hidden units"

    # Iterate over all data and labels and get metrics to plot
    for key, value in data.items():
        metric_dict = get_metrics(y_true, ground_truth_probs, value)
        corrs = metric_dict["corr"]
        rel_abs_errors = metric_dict["rel_abs_err"]
        aucs = metric_dict["auc"]
        r2s = metric_dict["r2"]

        marker = 's' if key == 'Kalman Filter' else 'o'
        # Plot 1: Correlation
        plt.subplot(2, 2, 1)
        plt.plot(kalman_indices, corrs, label=key, marker=marker)
        plt.subplot(2, 2, 2)
        plt.plot(kalman_indices, rel_abs_errors, label=key, marker=marker)
        plt.subplot(2, 2, 3)
        plt.plot(kalman_indices, aucs, label=key, marker=marker)
        plt.subplot(2, 2, 4)
        plt.plot(kalman_indices, r2s, label=key, marker=marker)
    
    plt.subplot(2, 2, 1)
    plt.xlabel(xlabel)
    plt.ylabel('Correlation')
    plt.title('Correlation With Ground Truth Probabilities')
    plt.legend(loc='best', fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.5)

    # Plot 2: Relative Absolute Error
    plt.subplot(2, 2, 2)
    plt.xlabel(xlabel)
    plt.ylabel('Normalized Mean Absolute Error')
    plt.title('Normalized Mean Absolute Error')
    plt.legend(loc='best', fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.5)

    # Plot 3: AUC One vs Rest
    plt.subplot(2, 2, 3)
    plt.xlabel(xlabel)
    plt.ylabel('AUC Rare Event class vs Rest')
    plt.title('AUC One vs Rest for Rare Event Class')
    plt.legend(loc='best', fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.5)

    # Plot 4: R² (Coefficient of Determination)
    plt.subplot(2, 2, 4)
    plt.xlabel(xlabel)
    plt.ylabel(r'$R^2$ Score')
    plt.title(r'$R^2$ of Rare Event Probabilities')
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.legend(loc='best')
    
        
    # Add tight layout and save the figure in high resolution
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()


def get_metrics(y_true,y_true_probs, y_pred_probs):
    """y_pred is a list, each element is an array of probabilities. 
        y_true is the true labels. 
    """
    idx = 2
    r2_vec = []
    corr_vec = []
    rel_abs_err_vec = []
    auc_vec = []
    for i in range(len(y_pred_probs)):
        r2 = r2_score(y_true_probs[:,idx], y_pred_probs[i][:,idx])
        corr = np.corrcoef(y_true_probs[:,idx], y_pred_probs[i][:,idx])[0,1]
        rel_abs_err = np.mean(np.abs(y_true_probs[:,idx] - y_pred_probs[i][:,idx])) / np.mean(y_true_probs[:,idx])
        auc = roc_auc_score(
            y_true,
            y_pred_probs[i] + np.random.normal(0, 1e-6, y_pred_probs[i].shape),
            multi_class='ovr',
            average=None
        )[idx]
        r2_vec.append(r2)
        corr_vec.append(corr)
        rel_abs_err_vec.append(rel_abs_err)
        auc_vec.append(auc)
    return {
        "r2": r2_vec,
        "corr": corr_vec,
        "rel_abs_err": rel_abs_err_vec,
        "auc": auc_vec
    }

def get_model_predictions(name,params, nr_units_for_inference, pretext):
    L = []
    mu, alpha, beta = params["mu"], params["alpha"], params["beta"]
    for n in nr_units_for_inference:
        y_pred_probs = np.load(pretext+"data/all_set_seq_model_preds_mu_{}_alpha_{}_beta_{}_n_{}_{}.npy".format(mu,alpha,beta,n,name))
        L.append(reshape_for_auc(y_pred_probs[:,:,:-1,:]))
    return L

def get_y_true(params, pretext):
    mu, alpha, beta = params["mu"], params["alpha"], params["beta"]
    y_true = np.load(pretext+"data/y_label_mu_{}_alpha_{}_beta_{}.npy".format(mu,alpha,beta))  # True transitions
    y_true = y_true[:,:,1:,:]
    return reshape_for_auc(y_true)

def get_true_probs(params, pretext):
    mu, alpha, beta = params["mu"], params["alpha"], params["beta"]
    y_true_probs = np.load(pretext+"data/y_true_probs_mu_{}_alpha_{}_beta_{}.npy".format(mu,alpha,beta))  # Ground Truth probabilities
    y_true_probs = y_true_probs[:,:,1:,:]
    return reshape_for_auc(y_true_probs)

def get_kalman_predictions( params, nr_obs_partial_hawkes, pretext, kalman=True):
    mu, alpha, beta = params["mu"], params["alpha"], params["beta"]
    L = []
    for N in nr_obs_partial_hawkes:
        if kalman:
            temp = np.load(pretext+"data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}_kalman.npy".format(N,mu,alpha,beta))
        else:
            temp = np.load(pretext+"data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}.npy".format(N,mu,alpha,beta))
        temp = temp[:,:,1:,:]
        L.append(reshape_for_auc(temp))
    return L


def main():

    # Load data from files
    #h = [150, 100, 50, 20, 10, 5, 2, 1]
    #AUC(3)_True = [0.74, 0.71, 0.64, 0.58, 0.56, 0.55, 0.54, 0.53]
    #AUC(3)_No_Contagion = [0.52,0.52, 0.52,0.52, 0.52, 0.52, 0.53]
    
    y_true = np.load("./data/y_label_mu_{}_alpha_{}_beta_{}.npy".format(mu,alpha,beta))  # True transitions
    y_true = y_true[:,:,1:,:]
    nr_defualts = np.sum(y_true[:,:,2,:])
    print("TOTAL DEFAULTS: ", nr_defualts)
    nr_units_for_inference = [1,2,5,10, 50,100,200, 500,1000]
    #nr_units_for_inference = [1,2,4,5,10,20,25, 40, 50,100,200,250, 500,1000]
    nr_obs_partial_hawkes = [1,2,5,10, 50,100,200, 500,1000] # 1, 5, ,10, 20, 100
    set_seq_vec = []
    idx =3
    if model_auc:
        for u in nr_units_for_inference:
            y_pred_probs_method3 = np.load("./data/all_set_seq_model_preds_mu_{}_alpha_{}_beta_{}_n_{}.npy".format(mu,alpha,beta,u)) # Set Seq model
            y_pred_probs_method3 = y_pred_probs_method3 + np.random.normal(0, 1e-6, y_pred_probs_method3.shape)
            y_pred_probs_method3 = y_pred_probs_method3[:,:,:-1,:]
            set_seq_vec.append(y_pred_probs_method3)
        y_pred_probs_method3 = set_seq_vec[-1]
    #y_pred_probs_method1 = np.load("./data/all_y_pred_mu_{}_alpha_{}_beta_{}.npy".format(mu,alpha,beta))  # Predicted probabilities using a marginal method
    #y_pred_probs_method1 = y_pred_probs_method1[:,:,1:,:]

    y_pred_probs_method2 = np.load("./data/y_true_probs_mu_{}_alpha_{}_beta_{}.npy".format(mu,alpha,beta))  # Ground Truth probabilities
    y_pred_probs_method2 = y_pred_probs_method2[:,:,1:,:]
    y_pred_probs_method1 = y_pred_probs_method2
    if partial_hawkes:
        true_probs_limited_data_kalman = []
        for N in nr_obs_partial_hawkes:
            temp = np.load("./data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}_kalman.npy".format(N,mu,alpha,beta))
            temp = temp[:,:,1:,:]
            # Partial True probabilities (corrupted)
            true_probs_limited_data_kalman.append(temp)  # Partial True probabilities (corrupted)
        y_pred_partial_true_probs =copy.deepcopy(true_probs_limited_data_kalman[idx])
        N = nr_obs_partial_hawkes[idx]
    
    if partial_hawkes_no_kalman:
        true_probs_limited_data = []
        for N in nr_obs_partial_hawkes:
            temp = np.load("./data/all_y_partial_true_probs_N_{}_mu_{}_alpha_{}_beta_{}_kalman.npy".format(N,mu,alpha,beta))
            temp = reshape_for_auc(temp[:,:,1:,:])
            # Partial True probabilities (corrupted)
            true_probs_limited_data.append(temp)
    
    #np.mean(y_pred_probs_method2[:,:,:,2], axis=(0,1))
    #np.mean(true_probs_limited_data_kalman[5][:,:,:,2], axis=(0,1))

    y_true_reshaped = reshape_for_auc(y_true)
    y_pred_probs_method1_reshaped = reshape_for_auc(y_pred_probs_method1)
    y_pred_probs_method2_reshaped = reshape_for_auc(y_pred_probs_method2)
    if model_auc:
        y_pred_probs_method3_reshaped = reshape_for_auc(y_pred_probs_method3)
        for i in range(len(set_seq_vec)):
            set_seq_vec[i] = reshape_for_auc(set_seq_vec[i])
    if partial_hawkes:
        for i in range(len(true_probs_limited_data_kalman)):
            true_probs_limited_data_kalman[i] = reshape_for_auc(true_probs_limited_data_kalman[i])
    if partial_hawkes:
        y_pred_partial_true_probs_reshaped = reshape_for_auc(y_pred_partial_true_probs)
    
    data = {
        "Kalman Filter": true_probs_limited_data_kalman,
        "Set-Seq Model": set_seq_vec
    }

    plot_kalman_model_comparison(
        y_true = y_true_reshaped,
        ground_truth_probs = y_pred_probs_method2_reshaped, 
        data = data,
        kalman_indices=nr_obs_partial_hawkes, 
        model_indices=nr_units_for_inference , 
        save_path = "./data/baseline_v_model_corr_abs_err_auc.jpg")

    # Convert one-hot labels to scalar labels
    y_true_scalar = np.argmax(y_true_reshaped, axis=-1)

    if model_auc and partial_hawkes:
        print("Computing AUC for Set Sequence Model")
        auc_vec = []
        true_probs_auc_vec = []
        for loss in set_seq_vec:
            auc1, auc_joint_ovr = compute_auc_by_start_state(y_true_scalar, loss, y_true_reshaped)
            auc_vec.append(auc_joint_ovr[2])
        for loss2 in true_probs_limited_data_kalman:
            auc1, auc_joint_ovr = compute_auc_by_start_state(y_true_scalar, loss2, y_true_reshaped)
            true_probs_auc_vec.append(auc_joint_ovr[2])
        print("AUC Set Sequence Model: ", auc_vec)
        print("AUC True Transition Probabilities Kalman: ", true_probs_auc_vec)
        plt.plot(nr_obs_partial_hawkes, true_probs_auc_vec, label="True Transition Probabilities with limited data")
        plt.plot(nr_units_for_inference, auc_vec, label = "Set Sequence Model inference on limited data")
        plt.xlabel("Number of Units for Inference / Number of Units to compute True Transition Probabilities \n Out of 1000 total units")
        plt.ylabel("AUC Default class vs Rest")
        plt.xscale("log")
        plt.legend()
        plt.savefig("./data/auc_vs_units_for_inference_mu_{}_alpha_{}_beta_{}_kalman.jpg".format(mu,alpha,beta))
        plt.close()
    

    if model_auc and partial_hawkes and partial_hawkes_no_kalman:
        print("Computing AUC for Set Sequence Model")
        true_probs_no_kalman_auc_vec = []
        for loss2 in true_probs_limited_data:
            auc1, auc_joint_ovr = compute_auc_by_start_state(y_true_scalar, loss2, y_true_reshaped)
            true_probs_no_kalman_auc_vec.append(auc_joint_ovr[2])
        print("AUC True Transition Probabilities no Kalman: ", true_probs_no_kalman_auc_vec)
        
        print("AUC Set Sequence Model: ", auc_vec)
        print("AUC True Transition Probabilities: ", true_probs_auc_vec)
        plt.plot(nr_obs_partial_hawkes, true_probs_auc_vec, label="Kalman Filter Transition Probabilities")
        plt.plot(nr_obs_partial_hawkes, true_probs_no_kalman_auc_vec, label="No Kalman Transition Probabilities")
        plt.plot(nr_units_for_inference, auc_vec, label = "Set-Sequence Model Inference")
        plt.title("AUC Default class vs Rest for Set Sequence model and true \n transition probabilities as a function of # units for inference")
        plt.xlabel("Number of Units for Inference / # Units \n to compute True Transition Probabilities")
        plt.ylabel("AUC Default class vs Rest")
        plt.ylim(bottom=0.5) 
        plt.xscale("log")
        plt.legend()
        plt.tight_layout()
        plt.savefig("./data/auc_vs_units_for_inference_mu_{}_alpha_{}_beta_{}_kalman_and_no_kalman.jpg".format(mu,alpha,beta))
        plt.close()
    # Compute AUC matrices for both methods
    auc_matrix_method2, auc_joint_ovr2 = compute_auc_by_start_state(y_true_scalar, y_pred_probs_method2_reshaped, y_true_reshaped) # True transition probabilities
    auc_matrix_method1, auc_joint_ovr1 = compute_auc_by_start_state(y_true_scalar, y_pred_probs_method1_reshaped, y_true_reshaped) # Non contagion modeling
    if model_auc:
        auc_matrix_method3, auc_joint_ovr3 = compute_auc_by_start_state(y_true_scalar, y_pred_probs_method3_reshaped, y_true_reshaped) # Set Sequence Model
        save_auc_matrix_with_values(auc_matrix_method3, "AUC For Set Sequence Model", "./data/auc_matrix_method3_scalar1.jpg")
    if partial_hawkes:
        auc_matrix_partial_true_probs, auc_joint_ovr_partial_true_probs = compute_auc_by_start_state(y_true_scalar, y_pred_partial_true_probs_reshaped, y_true_reshaped) # Partial True transition probabilities
        save_auc_matrix_with_values(auc_matrix_partial_true_probs, "AUC For Partial Transition Probabilities, N = {} samples observed".format(N), "./data/auc_matrix_partial_probs_mu_{}_alpha_{}_beta_{}.jpg".format(mu,alpha,beta))

    if model_auc and partial_hawkes:
        auc_joint_ovr = [auc_joint_ovr1, auc_joint_ovr2, auc_joint_ovr3, auc_joint_ovr_partial_true_probs]
    elif model_auc:
        auc_joint_ovr = [auc_joint_ovr1, auc_joint_ovr2, auc_joint_ovr3]
    elif partial_hawkes:
        auc_joint_ovr = [auc_joint_ovr1, auc_joint_ovr2, auc_joint_ovr_partial_true_probs]
    else:
        auc_joint_ovr = [auc_joint_ovr1, auc_joint_ovr2]
    # Save matrices as images

    save_auc_matrix_with_values(auc_matrix_method1, "AUC For Method Not Contagion Modeling", "./data/auc_matrix_non_contagion_mu_{}_alpha_{}_beta_{}.jpg".format(mu,alpha,beta))
    save_auc_matrix_with_values(auc_matrix_method2, "AUC For True Transition Probabilities", "./data/auc_matrix_true_probs_mu_{}_alpha_{}_beta_{}.jpg".format(mu,alpha,beta))
    save_joint_ovr(auc_joint_ovr, "./data/auc_ovr_mu_{}_alpha_{}_beta_{}_N_{}.jpg".format(mu,alpha,beta,N),N)

if __name__ == "__main__":
    import cProfile
    import pstats

    # Create a profiler
    profiler = cProfile.Profile()
    #profiler.enable()
    
    # Run your main function
    main()  
    
    #profiler.disable()
    
    # Sort and display the stats
    #stats = pstats.Stats(profiler)
    #stats.sort_stats("cumtime").print_stats(30)
