import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

model_type_map = {"HS-Tree": "One-class",
                  "CBLOF": "One-class",
                  "LOF": "One-class",
                  "IForest": "One-class",
                  "DeepSVDD": "One-class",

                  "LUAD": "Reconstruction",
                  "USAD": "Reconstruction",
                  "PCA": "Reconstruction",
                  "OmniAnomaly": "Reconstruction",
                  "LSTM-VAE": "Reconstruction",
                  "LSTM-AE": "Reconstruction",

                  "ABOD": "Statistical",
                  "DAGMM": "Statistical",
                  "HBOS": "Statistical",
                  "LODA": "Statistical",
                  "Hotelling": "Statistical"}

model_flops = pd.read_csv('./analysis/results/model_flops.csv')
model_performance = pd.read_csv('./analysis/results/model_performance.csv')

performance_stats = (model_performance.groupby(['dataset', 'model_name']).agg(auroc_mean=('auroc', 'mean'),
                                                                              auroc_std=('auroc', 'std'),
                                                                              aucpr_mean=('aucpr', 'mean'),
                                                                              aucpr_std=('aucpr', 'std')).reset_index())

performance_stats['auroc_mean_str'] = performance_stats['auroc_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['auroc_std_str'] = performance_stats['auroc_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['aucpr_mean_str'] = performance_stats['aucpr_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['aucpr_std_str'] = performance_stats['aucpr_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")

merged_result = pd.merge(performance_stats, model_flops, on=['dataset', 'model_name'], how='left')
rename_map = {'lstmAE': 'LSTM-AE', 'lstmVAE': 'LSTM-VAE', 'HSTree': 'HS-Tree'}
merged_result['model_name'] = merged_result['model_name'].replace(rename_map)
merged_result['type'] = merged_result['model_name'].map(model_type_map)

UNIT = 'GFLOPs'
SCALE = {'kFLOPs': 1e3, 'MFLOPs': 1e6, 'GFLOPs': 1e9}[UNIT]
merged_result['train_' + UNIT]     = pd.to_numeric(merged_result['train_flops'], errors='coerce')     / SCALE
merged_result['inference_' + UNIT] = pd.to_numeric(merged_result['inference_flops'], errors='coerce') / SCALE
merged_result['DL_per_flops_' + UNIT] = pd.to_numeric(merged_result['DL_per_flops'], errors='coerce') / SCALE

merged_result['train_' + UNIT + '_str']     = merged_result['train_' + UNIT].map(lambda x: "" if pd.isna(x) else f"{x:.2f}")
merged_result['inference_' + UNIT + '_str'] = merged_result['inference_' + UNIT].map(lambda x: "" if pd.isna(x) else f"{x:.2f}")
merged_result['DL_per_flops_' + UNIT + '_str'] = merged_result['DL_per_flops_' + UNIT].map(lambda x: np.nan if pd.isna(x) else f"{x:.2f}")

def min_max_scaling(series, invert=False):
    if invert:
        return (series.max() - series) / (series.max() - series.min())
    else:
        return (series - series.min()) / (series.max() - series.min())

def compute_scores(df, w=0.5):
    results = []
    for dataset, group in df.groupby("dataset"):
        A_scaled = min_max_scaling(group["auroc_mean"], invert=False)
        F_scaled = min_max_scaling(group["train_flops"], invert=True)

        weighted_sum   = w * A_scaled + (1 - w) * F_scaled
        geometric_mean = (A_scaled ** w) * (F_scaled ** (1 - w))
        harmonic_mean  = 1 / ((w / A_scaled) + ((1 - w) / F_scaled))
        efficiency_ratio = A_scaled / F_scaled
        efficiency_ratio_log = group["auroc_mean"] / np.log(group["total_flops"])
        efficiency_ratio_log_tr = group["auroc_mean"] / np.log(group["train_flops"])
        efficiency_ratio_log_inf = group["auroc_mean"] / np.log(group["inference_flops"])

        tmp = group.copy()
        tmp["A_scaled"] = A_scaled
        tmp["F_scaled"] = F_scaled
        tmp["weighted_sum"] = weighted_sum
        tmp["geometric_mean"] = geometric_mean
        tmp["harmonic_mean"] = harmonic_mean
        tmp["efficiency_ratio"] = efficiency_ratio
        tmp["efficiency_ratio_log"] = efficiency_ratio_log
        tmp["efficiency_ratio_log_tr"] = efficiency_ratio_log_tr
        tmp["efficiency_ratio_log_inf"] = efficiency_ratio_log_inf

        results.append(tmp)
    return pd.concat(results)

merged_result = compute_scores(merged_result, w=0.5)
merged_result.to_csv('./analysis/results/model_merged_results.csv', index=False)

def plot_FLOPs(dataframe, dataset_name, save_path='./analysis/figures/FLOPs'):
    data = dataframe[dataframe['dataset'] == dataset_name].sort_values('train_flops', ascending=False)
    plt.figure(figsize=(10, 0.3*len(data)+4))
    plt.grid(axis='x', zorder=0)
    plt.barh(data['model_name'], data['train_flops'], color='tab:blue', zorder=3)

    plt.title(f'Train FLOPs (descending) — {dataset_name}', fontsize=24)
    plt.xlabel('Train FLOPs', fontsize=20)
    plt.xscale('log')
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    plt.xlim(1, dataframe['train_flops'].max())

    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f'{save_path}/train_flops_{dataset_name}.pdf', dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

def plot_auroc(dataframe, dataset_name, save_path='./analysis/figures/AUROC'):
    data = dataframe[dataframe['dataset'] == dataset_name].sort_values('auroc_mean', ascending=True)
    plt.figure(figsize=(10, 0.3*len(data)+4))
    plt.grid(axis='x', zorder=0)

    plt.barh(data['model_name'], data['auroc_mean'], color='tab:blue', zorder=3)

    plt.title(f'AUROC (ascending) — {dataset_name}', fontsize=24)
    plt.xlabel('AUROC', fontsize=20)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlim(0, 1)

    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f'{save_path}/auroc_{dataset_name}.pdf', dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

def plot_scatter(dataframe, dataset_name, x_col='total_flops', y_col='auroc_mean', save_path='./analysis/figures/scatter'):

    data = dataframe[dataframe['dataset'] == dataset_name]

    top5 = data.nlargest(5, y_col)['model_name'].tolist()

    offsets_data = {'PSM':  {'ABOD': (-70, 30), 'CBLOF': (-70, 70), 'DAGMM': (-50, -30), 'DeepSVDD': (-70, 120), 'HBOS': (-25, 50),
                             'HS-Tree': (20, -40), 'Hotelling': (10, 70), 'IForest': (25, -60), 'LODA': (-30, -50), 'LOF': (-50, 0),
                             'LUAD': (-60, -40), 'OmniAnomaly': (0, -150), 'PCA': (-40, 40), 'USAD': (0,-50), 'LSTM-AE': (0, 100), 'LSTM-VAE': (-60, -30)},
                    'MSL':  {'ABOD': (-70, 30), 'CBLOF': (-70, 70), 'DAGMM': (-50, 40), 'DeepSVDD': (-70, -60), 'HBOS': (-15, 60),
                             'HS-Tree': (-20, 80), 'Hotelling': (70, -25), 'IForest': (25, -50), 'LODA': (65, -30), 'LOF': (-50, 0),
                             'LUAD': (-60, -40), 'OmniAnomaly': (-30, -140), 'PCA': (-20, 40), 'USAD': (0,-50), 'LSTM-AE': (-45, 50), 'LSTM-VAE': (-40, -60)},
                    'SMAP': {'ABOD': (-30, 50), 'CBLOF': (-8, 70), 'DAGMM': (-50, 30), 'DeepSVDD': (-70, 60), 'HBOS': (5, 50),
                             'HS-Tree': (-10, 60), 'Hotelling': (10, -60), 'IForest': (10, 50), 'LODA': (60, 30), 'LOF': (-50, 0),
                             'LUAD': (-60, 0), 'OmniAnomaly': (-150, -30), 'PCA': (70, 0), 'USAD': (-50,-50), 'LSTM-AE': (10, 200), 'LSTM-VAE': (-70, -40)},
                    'SMD':  {'ABOD': (-15, 40), 'CBLOF': (-15, 90), 'DAGMM': (-50, -30), 'DeepSVDD': (45, -110), 'HBOS': (30, 50),
                             'HS-Tree': (60, -70), 'Hotelling': (10, 70), 'IForest': (25, 50), 'LODA': (10, -80), 'LOF': (-20, -40),
                             'LUAD': (-10, 60), 'OmniAnomaly': (-120, 90), 'PCA': (0, -60), 'USAD': (30,-50), 'LSTM-AE': (-5, 100), 'LSTM-VAE': (-65, -30)},
                    'SWaT': {'ABOD': (-15, -40), 'CBLOF': (-70, 70), 'DAGMM': (-50, -30), 'DeepSVDD': (-90, 10), 'HBOS': (40, -50),
                             'HS-Tree': (20, -35), 'Hotelling': (10, -70), 'IForest': (30, 60), 'LODA': (40, -60), 'LOF': (-10, 40),
                             'LUAD': (-60, -40), 'OmniAnomaly': (30, 95), 'PCA': (60, 50), 'USAD': (0,-50), 'LSTM-AE': (5, -120), 'LSTM-VAE': (-40, -90)},
                    'WADI': {'ABOD': (10, -40), 'CBLOF': (-70, -30), 'DAGMM': (-50, -30), 'DeepSVDD': (-70, -50), 'HBOS': (-10, 80),
                             'HS-Tree': (20, -50), 'Hotelling': (70, 30), 'IForest': (50, 50), 'LODA': (-30, 50), 'LOF': (-10, 40),
                             'LUAD': (0, 40), 'OmniAnomaly': (-30, 140), 'PCA': (50, -50), 'USAD': (35,-60), 'LSTM-AE': (40, 90), 'LSTM-VAE': (-60, 90)},}
    
    offsets = offsets_data[dataset_name]

    type_marker = {'One-class': 'o',        # circle
                   'Reconstruction': 's',   # square
                   'Statistical': '^'}      # triangle

    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    # annotate model names
    for _, row in data.iterrows():
        color = 'tab:orange' if row['model_name'] in top5 else 'tab:blue'
        marker = type_marker.get(row['type'], 'o')

        ax.scatter(row[x_col], row[y_col], 
                    s=450, c=color, edgecolors='k',
                    marker=marker, label=row['type'], zorder=30)
        dx, dy = offsets[row['model_name']]
        ax.annotate(row['model_name'],
                    xy=(row[x_col], row[y_col]),
                    xytext=(dx, dy),            
                    textcoords='offset points', 
                    ha='center', va='center',
                    fontsize=22,
                    arrowprops=dict(arrowstyle='->',
                                    shrinkA=0, shrinkB=8,               
                                    lw=2.5,
                                    color='0.1',
                                    relpos=(0.5, 0.8)),
                    zorder=40)

    plt.ylim(0.3, 1)
    plt.xlim(10**14.5, 10**6.5)
    plt.xscale('log')
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.tick_params(axis='x', which='minor', bottom=False, top=False)

    plt.xlabel('Total FLOPs', fontdict={'fontsize': 24})
    plt.ylabel('AUROC', fontdict={'fontsize': 24})
    plt.axvline(x = 10**10.5, c = 'k', zorder=10)
    plt.axhline(y = 0.65, c = 'k', zorder=10)
    plt.grid(True, zorder=0)

    
    legend_proxies = [mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Statistical'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Statistical'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['One-class'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='One-class'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Reconstruction'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Reconstruction')]

    plt.legend(handles=legend_proxies, loc='upper left', fontsize=20)
    
    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f"{save_path}/scatter_{dataset_name}.pdf", 
                    dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

for dataset in merged_result['dataset'].unique():
    plot_scatter(merged_result, dataset) #, save_path=None)

for dataset in merged_result['dataset'].unique():
    plot_FLOPs(merged_result, dataset) #, save_path=None)
    plot_auroc(merged_result, dataset) #, save_path=None)