import pandas as pd
import os
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

save_path = "Visualize_plots"
os.makedirs(save_path, exist_ok=True)

for DATA_NAME in [
    "suzuki_50",
    "arylation",
    "buchwald_Cc1ccc(Nc2ccccn2)cc1.csv",
    "buchwald_CCc1ccc(Nc2ccc(C)cc2)cc1.csv",
    "buchwald_COc1ccc(Nc2ccc(C)cc2)cc1.csv"
]:
    csv_path = os.path.join(DATA_NAME, "handled_data.csv")
    yields = pd.read_csv(csv_path)['output']

    print(f"=== 基础统计信息 {DATA_NAME} ===")
    print(f"总数据量: {len(yields)}")
    print(f"最大值: {yields.max()}")
    print(f"最小值: {yields.min()}")
    print(f"均值: {yields.mean():.2f}")
    print(f"中位数: {yields.median():.2f}")
    print(f"标准差: {yields.std():.2f}")
    print(f"25% 分位数: {yields.quantile(0.25):.2f}")
    print(f"75% 分位数: {yields.quantile(0.75):.2f}")


    # 归一化

    scaler = MinMaxScaler()
    yields_normalized = scaler.fit_transform(yields.values.reshape(-1, 1)).flatten()

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 原始数据
    axes[0].hist(yields, bins=30, color='skyblue', edgecolor='black', range=(0, 100))
    axes[0].set_title("Original")
    axes[0].set_xlabel("Yield")
    axes[0].set_ylabel("Frequency")
    axes[0].set_xlim(0, 100)

    # 归一化后
    axes[1].hist(yields_normalized, bins=30, color='salmon', edgecolor='black')
    axes[1].set_title("Normalized")
    axes[1].set_xlabel("Normalized Yield")
    axes[1].set_ylabel("Frequency")

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"Norm_{DATA_NAME}.png"))
    plt.close()

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 箱线图
    sns.boxplot(y=yields, ax=axes[0], color='lightblue')
    axes[0].set_title("BOX")
    axes[0].set_ylabel("Yield")

    # 小提琴图
    sns.violinplot(y=yields, ax=axes[1], color='lightcoral')
    axes[1].set_title("Violin")
    axes[1].set_ylabel("Yield")

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"BV_{DATA_NAME}.png"))
    plt.close()

    # KDE

    plt.figure(figsize=(8, 5))
    sns.kdeplot(yields, shade=True, color='purple')
    plt.title("KDE")
    plt.xlabel("Yield")
    plt.ylabel("Density")
    plt.xlim(0, 100)
    plt.savefig(os.path.join(save_path, f"KDE_{DATA_NAME}.png"))
    plt.close()