import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

model_type_map = {"HS-Tree": "One-class",
                  "CBLOF": "One-class",
                  "LOF": "One-class",
                  "IForest": "One-class",
                  "DeepSVDD": "One-class",

                  "LUAD": "Reconstruction",
                  "USAD": "Reconstruction",
                  "PCA": "Reconstruction",
                  "OmniAnomaly": "Reconstruction",
                  "LSTM-VAE": "Reconstruction",
                  "LSTM-AE": "Reconstruction",
                  "AnomalyTransformer": "Reconstruction",
                  "TimesNet": "Reconstruction",

                  "ABOD": "Statistical",
                  "DAGMM": "Statistical",
                  "HBOS": "Statistical",
                  "LODA": "Statistical",
                  "Hotelling": "Statistical"}

model_flops = pd.read_csv('./analysis/results/model_flops.csv')
model_performance = pd.read_csv('./analysis/results/model_performance.csv')

performance_stats = (model_performance.groupby(['dataset', 'model_name']).agg(auroc_mean=('auroc', 'mean'),
                                                                              auroc_std=('auroc', 'std'),
                                                                              auprc_mean=('auprc', 'mean'),
                                                                              auprc_std=('auprc', 'std'),
                                                                              vus_roc_mean=('vus_roc', 'mean'),
                                                                              vus_roc_std=('vus_roc', 'std'),
                                                                              vus_pr_mean=('vus_pr', 'mean'),
                                                                              vus_pr_std=('vus_pr', 'std')).reset_index())

performance_stats['auroc_mean_str'] = performance_stats['auroc_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['auroc_std_str'] = performance_stats['auroc_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['auprc_mean_str'] = performance_stats['auprc_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['auprc_std_str'] = performance_stats['auprc_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['vus_roc_mean_str'] = performance_stats['vus_roc_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['vus_roc_std_str'] = performance_stats['vus_roc_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['vus_pr_mean_str'] = performance_stats['vus_pr_mean'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
performance_stats['vus_pr_std_str'] = performance_stats['vus_pr_std'].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")

merged_result = pd.merge(performance_stats, model_flops, on=['dataset', 'model_name'], how='left')
rename_map = {'lstmAE': 'LSTM-AE', 'lstmVAE': 'LSTM-VAE', 'HSTree': 'HS-Tree'}
merged_result['model_name'] = merged_result['model_name'].replace(rename_map)
merged_result['type'] = merged_result['model_name'].map(model_type_map)

UNIT = 'GFLOPs'
SCALE = {'kFLOPs': 1e3, 'MFLOPs': 1e6, 'GFLOPs': 1e9}[UNIT]
merged_result['train_' + UNIT]     = pd.to_numeric(merged_result['train_flops'], errors='coerce')     / SCALE
merged_result['inference_' + UNIT] = pd.to_numeric(merged_result['inference_flops'], errors='coerce') / SCALE
merged_result['DL_per_flops_' + UNIT] = pd.to_numeric(merged_result['DL_per_flops'], errors='coerce') / SCALE

merged_result['train_' + UNIT + '_str']     = merged_result['train_' + UNIT].map(lambda x: "" if pd.isna(x) else f"{x:.2f}")
merged_result['inference_' + UNIT + '_str'] = merged_result['inference_' + UNIT].map(lambda x: "" if pd.isna(x) else f"{x:.2f}")
merged_result['DL_per_flops_' + UNIT + '_str'] = merged_result['DL_per_flops_' + UNIT].map(lambda x: np.nan if pd.isna(x) else f"{x:.2f}")

def min_max_scaling(series, invert=False):
    if invert:
        return (series.max() - series) / (series.max() - series.min())
    else:
        return (series - series.min()) / (series.max() - series.min())

def compute_scores(df, w=0.5):
    results = []
    for dataset, group in df.groupby("dataset"):
        A_scaled = min_max_scaling(group["auroc_mean"], invert=False)
        F_scaled = min_max_scaling(group["train_flops"], invert=True)

        weighted_sum   = w * A_scaled + (1 - w) * F_scaled
        geometric_mean = (A_scaled ** w) * (F_scaled ** (1 - w))
        harmonic_mean  = 1 / ((w / A_scaled) + ((1 - w) / F_scaled))
        efficiency_ratio = A_scaled / F_scaled
        efficiency_ratio_log = group["auroc_mean"] / np.log(group["total_flops"])
        efficiency_ratio_log_tr = group["auroc_mean"] / np.log(group["train_flops"])
        efficiency_ratio_log_inf = group["auroc_mean"] / np.log(group["inference_flops"])

        tmp = group.copy()
        tmp["A_scaled"] = A_scaled
        tmp["F_scaled"] = F_scaled
        tmp["weighted_sum"] = weighted_sum
        tmp["geometric_mean"] = geometric_mean
        tmp["harmonic_mean"] = harmonic_mean
        tmp["efficiency_ratio"] = efficiency_ratio
        tmp["efficiency_ratio_log"] = efficiency_ratio_log
        tmp["efficiency_ratio_log_tr"] = efficiency_ratio_log_tr
        tmp["efficiency_ratio_log_inf"] = efficiency_ratio_log_inf

        results.append(tmp)
    return pd.concat(results)

merged_result = compute_scores(merged_result, w=0.5)
merged_result.to_csv('./analysis/results/model_merged_results.csv', index=False)

def plot_FLOPs(dataframe, dataset_name, save_path='./analysis/figures/FLOPs'):
    data = dataframe[dataframe['dataset'] == dataset_name].sort_values('train_flops', ascending=False)
    plt.figure(figsize=(10, 0.3*len(data)+4))
    plt.grid(axis='x', zorder=0)
    plt.barh(data['model_name'], data['train_flops'], color='tab:blue', zorder=3)

    plt.title(f'Train FLOPs (descending) — {dataset_name}', fontsize=24)
    plt.xlabel('Train FLOPs', fontsize=20)
    plt.xscale('log')
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)

    plt.xlim(1, dataframe['train_flops'].max())

    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f'{save_path}/train_flops_{dataset_name}.pdf', dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

def plot_auroc(dataframe, dataset_name, save_path='./analysis/figures/AUROC'):
    data = dataframe[dataframe['dataset'] == dataset_name].sort_values('auroc_mean', ascending=True)
    plt.figure(figsize=(10, 0.3*len(data)+4))
    plt.grid(axis='x', zorder=0)

    plt.barh(data['model_name'], data['auroc_mean'], color='tab:blue', zorder=3)

    plt.title(f'AUROC (ascending) — {dataset_name}', fontsize=24)
    plt.xlabel('AUROC', fontsize=20)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlim(0, 1)

    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f'{save_path}/auroc_{dataset_name}.pdf', dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

def plot_auroc_scatter(dataframe, dataset_name, x_col='total_flops', y_col='auroc_mean', save_path='./analysis/figures/scatter'):

    data = dataframe[dataframe['dataset'] == dataset_name]

    top5 = data.nlargest(5, y_col)['model_name'].tolist()

    offsets_data = {'PSM':  {'ABOD': (-70, 30), 'CBLOF': (-70, 70), 'DAGMM': (0, -70), 'DeepSVDD': (0, 100), 'HBOS': (-25, 50), 'HS-Tree': (20, -40),
                             'Hotelling': (30, 100), 'IForest': (40, -80), 'LODA': (-30, -50), 'LOF': (-50, 0), 'LUAD': (-80, 30), 'OmniAnomaly': (-25, -35),
                             'PCA': (0, 40), 'USAD': (50,-50), 'LSTM-AE': (0, 100), 'LSTM-VAE': (0, -90), 'AnomalyTransformer': (-20, -110), 'TimesNet': (-120, -40)},
                    'MSL':  {'ABOD': (-20, 90), 'CBLOF': (10, 70), 'DAGMM': (0, 50), 'DeepSVDD': (100, -100), 'HBOS': (15, 60), 'HS-Tree': (10, 80),
                             'Hotelling': (70, -25), 'IForest': (25, -50), 'LODA': (65, -30), 'LOF': (-50, 0), 'LUAD': (-60, -40), 'OmniAnomaly': (-35, -140),
                             'PCA': (20, 100), 'USAD': (0,-50), 'LSTM-AE': (-45, 100), 'LSTM-VAE': (-40, -60), 'AnomalyTransformer': (-40, -90), 'TimesNet': (-10, 50)},
                    'SMAP': {'ABOD': (-30, 50), 'CBLOF': (-10, 120), 'DAGMM': (30, -30), 'DeepSVDD': (35, 80), 'HBOS': (5, 70), 'HS-Tree': (30, 80),
                             'Hotelling': (10, -60), 'IForest': (45, 90), 'LODA': (60, 30), 'LOF': (-60, 0), 'LUAD': (-60, 0), 'OmniAnomaly': (-130, -70),
                             'PCA': (70, -20), 'USAD': (0,-50), 'LSTM-AE': (-45, 150), 'LSTM-VAE': (-30, -70), 'AnomalyTransformer': (-80, 160), 'TimesNet': (-90, -40)},
                    'SMD':  {'ABOD': (-60, -50), 'CBLOF': (50, 100), 'DAGMM': (45, -65), 'DeepSVDD': (70, -95), 'HBOS': (30, 50), 'HS-Tree': (80, -60),
                             'Hotelling': (20, 70), 'IForest': (70, 90), 'LODA': (50, -70), 'LOF': (-40, -50), 'LUAD': (10, -60), 'OmniAnomaly': (-60, 70),
                             'PCA': (50, 70), 'USAD': (30,-50), 'LSTM-AE': (25, 120), 'LSTM-VAE': (20, -70), 'AnomalyTransformer': (-20, -100), 'TimesNet': (-80, 15)},
                    'SWaT': {'ABOD': (-15, -40), 'CBLOF': (30, -75), 'DAGMM': (-50, -30), 'DeepSVDD': (20, 70), 'HBOS': (40, -50), 'HS-Tree': (20, -35),
                             'Hotelling': (10, -70), 'IForest': (40, 70), 'LODA': (40, -60), 'LOF': (-50, -40), 'LUAD': (-70, 30), 'OmniAnomaly': (30, 85),
                             'PCA': (60, 45), 'USAD': (10,-50), 'LSTM-AE': (5, -120), 'LSTM-VAE': (-40, -90), 'AnomalyTransformer': (75, -70), 'TimesNet': (70, -30)},
                    'WADI': {'ABOD': (10, -40), 'CBLOF': (40, 60), 'DAGMM': (-50, -60), 'DeepSVDD': (70, -50), 'HBOS': (50, 50), 'HS-Tree': (20, -50),
                             'Hotelling': (70, 30), 'IForest': (-10, 80), 'LODA': (-30, 50), 'LOF': (-30, 45), 'LUAD': (10, 45), 'OmniAnomaly': (-25, 130),
                             'PCA': (50, -30), 'USAD': (45,-90), 'LSTM-AE': (-25, -50), 'LSTM-VAE': (-55, 120), 'AnomalyTransformer': (-5, -160), 'TimesNet': (-40, 80)},}
    
    offsets = offsets_data[dataset_name]

    type_marker = {'One-class': 'o',        # circle
                   'Reconstruction': 's',   # square
                   'Statistical': '^'}      # triangle

    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    # annotate model names
    for _, row in data.iterrows():
        color = 'tab:orange' if row['model_name'] in top5 else 'tab:blue'
        marker = type_marker.get(row['type'], 'o')

        ax.scatter(row[x_col], row[y_col], 
                    s=450, c=color, edgecolors='k',
                    marker=marker, label=row['type'], zorder=30)
        dx, dy = offsets[row['model_name']]
        ax.annotate(row['model_name'],
                    xy=(row[x_col], row[y_col]),
                    xytext=(dx, dy),            
                    textcoords='offset points', 
                    ha='center', va='center',
                    fontsize=22,
                    arrowprops=dict(arrowstyle='->',
                                    shrinkA=0, shrinkB=8,               
                                    lw=2.5,
                                    color='0.1',
                                    relpos=(0.5, 0.8)),
                    zorder=40)

    plt.ylim(0.2, 1)
    plt.xlim(10**15.5, 10**6.5)
    plt.xscale('log')
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.tick_params(axis='x', which='minor', bottom=False, top=False)

    plt.xlabel('Total FLOPs', fontdict={'fontsize': 24})
    plt.ylabel('AUROC', fontdict={'fontsize': 24})
    plt.axvline(x = 10**11, c = 'k', zorder=10)
    plt.axhline(y = 0.6, c = 'k', zorder=10)
    plt.grid(True, zorder=0)

    
    legend_proxies = [mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Statistical'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Statistical'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['One-class'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='One-class'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Reconstruction'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Reconstruction')]

    plt.legend(handles=legend_proxies, loc='upper left', fontsize=20)
    
    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f"{save_path}/auroc_scatter_{dataset_name}.pdf", 
                    dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

def plot_auprc_scatter(dataframe, dataset_name, x_col='total_flops', y_col='auprc_mean', save_path='./analysis/figures/scatter'):

    data = dataframe[dataframe['dataset'] == dataset_name]

    top5 = data.nlargest(5, y_col)['model_name'].tolist()

    offsets_data = {'PSM':  {'ABOD': (-20, 100), 'CBLOF': (65, 60), 'DAGMM': (0, -50), 'DeepSVDD': (90, -120), 'HBOS': (10, 50), 'HS-Tree': (30, -40),
                             'Hotelling': (20, 110), 'IForest': (35, -60), 'LODA': (-15, -50), 'LOF': (-70, 0), 'LUAD': (-60, 40), 'OmniAnomaly': (-60, -190),
                             'PCA': (-15, 70), 'USAD': (25, -50), 'LSTM-AE': (-10, 100), 'LSTM-VAE': (30, 50), 'AnomalyTransformer': (-60, -150), 'TimesNet': (10, -90)},
                    'MSL':  {'ABOD': (-10, 90), 'CBLOF': (-30, 150), 'DAGMM': (50, -50), 'DeepSVDD': (-70, -65), 'HBOS': (20, 90), 'HS-Tree': (50, -40),
                             'Hotelling': (90, 190), 'IForest': (35, 60), 'LODA': (60, -50), 'LOF': (-80, -50), 'LUAD': (-100, -40), 'OmniAnomaly': (-55, 250),
                             'PCA': (50, 220), 'USAD': (25, 120), 'LSTM-AE': (-45, 50), 'LSTM-VAE': (-45, -40), 'AnomalyTransformer': (-120, 200), 'TimesNet': (-80, 70)},
                    'SMAP': {'ABOD': (-50, 70), 'CBLOF': (-15, 70), 'DAGMM': (50, -40), 'DeepSVDD': (-5, 170), 'HBOS': (20, 50), 'HS-Tree': (-10, 90),
                             'Hotelling': (40, -50), 'IForest': (80, 120), 'LODA': (90, -10), 'LOF': (-80, 0), 'LUAD': (-80, 0), 'OmniAnomaly': (-170, -40),
                             'PCA': (100, -30), 'USAD': (-50,-50), 'LSTM-AE': (-40, 40), 'LSTM-VAE': (-80, -55), 'AnomalyTransformer': (-40, 240), 'TimesNet': (-50, 150)},
                    'SMD':  {'ABOD': (-50, 40), 'CBLOF': (15, 200), 'DAGMM': (40, 150), 'DeepSVDD': (5, 270), 'HBOS': (50, 50), 'HS-Tree': (60, -30),
                             'Hotelling': (90, 170), 'IForest': (80, 120), 'LODA': (100, -30), 'LOF': (-50, -30), 'LUAD': (-40, 60), 'OmniAnomaly': (-50, -50),
                             'PCA': (20, -55), 'USAD': (-5, 50), 'LSTM-AE': (-30, 90), 'LSTM-VAE': (-20, 230), 'AnomalyTransformer': (-40, 300), 'TimesNet': (-20, 90)},
                    'SWaT': {'ABOD': (-70, 10), 'CBLOF': (30, -120), 'DAGMM': (10, 55), 'DeepSVDD': (-50, 30), 'HBOS': (70, -50), 'HS-Tree': (20, -35),
                             'Hotelling': (50, 70), 'IForest': (50, -90), 'LODA': (40, -90), 'LOF': (-50, 0), 'LUAD': (60, 40), 'OmniAnomaly': (-55, -160),
                             'PCA': (60, 30), 'USAD': (20,-50), 'LSTM-AE': (-30, -60), 'LSTM-VAE': (-80, -40), 'AnomalyTransformer': (95, -90), 'TimesNet': (-10, -45)},
                    'WADI': {'ABOD': (-30, 30), 'CBLOF': (40, 70), 'DAGMM': (-60, 40), 'DeepSVDD': (70, -20), 'HBOS': (80, -30), 'HS-Tree': (50, -45),
                             'Hotelling': (-5, 40), 'IForest': (60, 50), 'LODA': (30, 60), 'LOF': (-40, -20), 'LUAD': (-40, 40), 'OmniAnomaly': (80, 140),
                             'PCA': (70, 30), 'USAD': (80,-10), 'LSTM-AE': (-10, 100), 'LSTM-VAE': (-50, 70), 'AnomalyTransformer': (-10, -40), 'TimesNet': (-40, 140)},}
    
    offsets = offsets_data[dataset_name]

    type_marker = {'One-class': 'o',        # circle
                   'Reconstruction': 's',   # square
                   'Statistical': '^'}      # triangle

    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    # annotate model names
    for _, row in data.iterrows():
        color = 'tab:orange' if row['model_name'] in top5 else 'tab:blue'
        marker = type_marker.get(row['type'], 'o')

        ax.scatter(row[x_col], row[y_col], 
                    s=450, c=color, edgecolors='k',
                    marker=marker, label=row['type'], zorder=30)
        dx, dy = offsets[row['model_name']]
        ax.annotate(row['model_name'],
                    xy=(row[x_col], row[y_col]),
                    xytext=(dx, dy),            
                    textcoords='offset points', 
                    ha='center', va='center',
                    fontsize=22,
                    arrowprops=dict(arrowstyle='->',
                                    shrinkA=0, shrinkB=8,               
                                    lw=2.5,
                                    color='0.1',
                                    relpos=(0.5, 0.8)),
                    zorder=40)

    plt.ylim(0, 0.8)
    plt.xlim(10**15.5, 10**6.5)
    plt.xscale('log')
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.tick_params(axis='x', which='minor', bottom=False, top=False)

    plt.xlabel('Total FLOPs', fontdict={'fontsize': 24})
    plt.ylabel('AUPRC', fontdict={'fontsize': 24})
    plt.axvline(x = 10**11, c = 'k', zorder=10)
    plt.axhline(y = 0.4, c = 'k', zorder=10)
    plt.grid(True, zorder=0)

    
    legend_proxies = [mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Statistical'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Statistical'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['One-class'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='One-class'),
                      mlines.Line2D([], [], linestyle='None',
                                    marker=type_marker['Reconstruction'],
                                    markerfacecolor='none', markeredgecolor='k',
                                    markersize=10, label='Reconstruction')]

    plt.legend(handles=legend_proxies, loc='upper left', fontsize=20)
    
    plt.tight_layout()
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        plt.savefig(f"{save_path}/auprc_scatter_{dataset_name}.pdf", 
                    dpi=600, bbox_inches='tight', format='pdf')
        plt.close()
    else:
        plt.show()

for dataset in merged_result['dataset'].unique():
    plot_auroc_scatter(merged_result, dataset) #, save_path=None)
    plot_auprc_scatter(merged_result, dataset) #, save_path=None)

for dataset in merged_result['dataset'].unique():
    plot_FLOPs(merged_result, dataset) #, save_path=None)
    plot_auroc(merged_result, dataset) #, save_path=None)