import torch
import random
import numpy as np
import transformers
import os

from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_recall_curve
import matplotlib.pyplot as plt
import pickle
from pathlib import Path
import shap


def sanity_check():
    hf_home = os.getenv('HF_HOME')
    api_token = os.getenv('WANDB_API_TOKEN')
    if any(v is None for v in [hf_home, api_token]):
        raise ValueError("Please set HF_HOME and WANDB_API_TOKEN in your .env file")
    

def save_results_to_pickle(save_path, results):
    """
    Save the given results to a pickle file.
    """
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    with open(save_path, "wb") as f:
        pickle.dump(results, f)
    print(f"Results saved to {save_path}")


def plot_shap_values(
    shap_values, X_data, feature_names, title=None, out_file=None
):
    plt.figure(figsize=(8, 10))
    shap.summary_plot(
        shap_values, X_data, show=False, plot_type="bar", feature_names=feature_names
    )
    if title:
        plt.title(title)
    plt.tight_layout()

    # Save or display the plot
    if out_file:
        p = Path(out_file)
        p.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(p)
    else:
        plt.show()

    plt.close()


def get_tpr_metric(y_true, y_pred_proba, fpr_budget):
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba) 
    tpr_at_low_fpr = np.interp(fpr_budget/100, fpr,tpr)
    return tpr_at_low_fpr


def plot_tpr_fpr_curve(roc_data, fpr_budget, title):
    plt.figure(figsize=(8, 6))
    for i, data in enumerate(roc_data):
        fpr = data['fpr']
        tpr = data['tpr']
        roc_auc = data['auc']
        plt.plot(fpr, tpr, label=f'Fold {i+1} (AUC = {roc_auc:.2f})')
        tpr_at_low_fpr = np.interp(fpr_budget / 100, fpr, tpr)
        plt.plot(fpr_budget / 100, tpr_at_low_fpr, marker="x", markersize=10, markeredgecolor="red", markerfacecolor="green")
        plt.plot(fpr_budget / 100, 0, marker="o", markersize=5, markerfacecolor="red", markeredgecolor="red")
        plt.vlines(fpr_budget / 100, 0, tpr_at_low_fpr, color='r', linestyles='dashed')
        plt.hlines(tpr_at_low_fpr, 0, fpr_budget / 100, color='r', linestyles='dashed')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random Guess')
    plt.title(title)
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TPR)')
    plt.legend(loc='lower right')
    plt.show()

def get_confusion_matrix(y_true, y_pred):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)    
    return confusion_matrix(y_true, (y_pred >=  thresholds[np.argmax(f1_scores)]).astype(int))


def get_roc_auc(y_true, y_pred_proba):
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba) 
    return auc(fpr, tpr), fpr, tpr

def set_seed(seed: int):
    """
    Fix PRNG seed for reproducable experiments.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    transformers.set_seed(seed)
    