import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import wandb
from typing import List, Union, Optional
def plot_score_dis(
    trans_score: Union[List[float], np.ndarray],
    emb_2_score: Union[List[float], np.ndarray],
    emb_1_score: Union[List[float], np.ndarray],
    wandb_log: bool = True,
    wandb_key: Optional[str] = "score_distribution"
) -> None:
    trans_score = np.array(trans_score).flatten()
    emb_2_score = np.array(emb_2_score).flatten()
    emb_1_score = np.array(emb_1_score).flatten()
    plt.figure(figsize=(10, 6))
    sns.histplot(trans_score, kde=True, label='Query[E2] -> Transform Target[E1]', color='blue', alpha=0.6)
    sns.histplot(emb_2_score, kde=True, label='Query[E2] -> Target[E2]', color='red', alpha=0.6)
    sns.histplot(emb_1_score, kde=True, label='Query[E2] -> Target[E1]', color='green', alpha=0.6)
    plt.xlabel('Similarity Score')
    plt.ylabel('Frequency')
    plt.title('Distribution of Similarity Scores')
    plt.legend()
    if wandb_log:
        wandb.log({wandb_key: wandb.Image(plt)})
    plt.show()
    plt.close()
