from typing import Tuple
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
from typing import Union
from pathlib import Path
import re
import seaborn as sns


def extract_values_from_string(s):

    adaptation = None
    if 'dp_' in s:
        adaptation = s.split('dp_')[1].split('_')[0]  # After dp_
    else:
        adaptation = s.split('_')[0] 

    lr_match = re.search(r'lr_([\de\.-]+)', s)
    lr = lr_match.group(1) if lr_match else None

    eps_match = re.search(r'eps_([\de\.-]+)', s)
    eps = eps_match.group(1) if eps_match else None

    epochs_match = re.search(r'epochs_([\d]+)', s)
    epochs = epochs_match.group(1) if epochs_match else None

    model_match = re.search(r'model_([\w-]+)', s)
    model = model_match.group(1) if model_match else None

    return {
        'id': s, 
        'adaptation': adaptation,
        'lr': lr,
        'eps': eps,
        'epochs': epochs,
        'model': model,
        "dp": "dp" in s
    }
    

def plot_points_with_labels(
    train_points, 
    test_points, 
    out_dir: Union[str, Path],
    name: str    
    ):
    
    plt.clf()
    # Create a scatter plot
    plt.figure(figsize=(10, 2))
    
    points = np.concatenate([train_points, test_points])
    labels = np.concatenate([np.zeros_like(train_points), np.ones_like(test_points)])
    # Plot points on x-axis (y=0), colorizing based on the label (hue)
    sns.scatterplot(x=points, y=np.zeros_like(points), hue=labels, palette='coolwarm', s=100, legend=False)
    
    # Set title and axis labels
    plt.title("Ratios")
    plt.xlabel("Ratio")
    plt.yticks([])  # Remove y-axis ticks since y is constant
    
    plt.savefig(Path(out_dir) / f"points_{name}.pdf")
    
    print(f"Points plot saved at {Path(out_dir) / f'{name}_points.pdf'}")
    plt.show()
    
    
def calculate_auc(
    thresh_train, 
    thresh_test):
    auc = metrics.roc_auc_score(
        np.concatenate([np.ones_like(thresh_train), np.zeros_like(thresh_test)]),
        np.concatenate([thresh_train, thresh_test])
    )
    
    return auc


def get_optimal_alpha_gamma_auc(auc_matrix, alpha_values, gamma_values)-> Tuple[float, float, float]:
    max_j, max_i = np.unravel_index(np.argmax(auc_matrix), auc_matrix.shape)
    return alpha_values[max_i], gamma_values[max_j], auc_matrix[max_j, max_i]
   
   
def calculate_confusion_matrix(
    thresh_train, 
    thresh_test
)-> Tuple[int, int, int, int]:
    # Use the threshold to make binary predictions
    y_true = np.concatenate([np.ones_like(thresh_train), np.zeros_like(thresh_test)])
    y_scores = np.concatenate([thresh_train, thresh_test])
    fpr, tpr ,thrs = metrics.roc_curve(y_true=y_true, y_score=y_scores)
    optimal_idx = (tpr - fpr).argmax()
    threshold = thrs[optimal_idx]
    y_pred = (y_scores >= threshold).astype(int)
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    return tn, fp, fn, tp


def plot_roc_thresholding(
    thresh_train, 
    thresh_test, 
    out_dir: Union[str, Path],
    name: str):
    fpr, tpr, _ = metrics.roc_curve(
        np.concatenate([np.ones_like(thresh_train), np.zeros_like(thresh_test)]), 
        np.concatenate([thresh_train, thresh_test]))
    auc = calculate_auc(thresh_train, thresh_test)
    plt.clf()
    plt.plot([0,1], [0,1])
    plt.plot(fpr,tpr)
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')


    plt.savefig(Path(out_dir) / f"thresholding_plot_test_{name}.pdf")
    print(f"ROC plot saved at {Path(out_dir) / f'{name}_roc_plot.pdf'}")
    return auc