import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

flops_result = pd.read_csv('analysis/results/model_flops.csv')

our_settings   = {'cpu': 112e9, 'gpu': 21.952e12}
mobile_setting = {'cpu': 14.8e9,  'gpu': 1.9e9}
edge_setting   = {'cpu': 5.6e9,   'gpu': 0.0}

ML_models = ['HBOS', 'LODA', 'ABOD', 'PCA', 'LOF', 'Hotelling', 'IForest', 'HSTree', 'CBLOF']
DL_models = ['USAD', 'DAGMM', 'LUAD', 'lstmAE', 'lstmVAE', 'OmniAnomaly', 'DeepSVDD']
DEFAULT_MODEL_ORDER = ML_models + DL_models

if 'DL_per_flop' in flops_result.columns:
    dl_col = 'DL_per_flop'
elif 'DL_per_flops' in flops_result.columns:
    flops_result = flops_result.rename(columns={'DL_per_flops': 'DL_per_flop'})
    dl_col = 'DL_per_flop'
else:
    dl_col = None

for col in ['train_flops', 'inference_flops', 'total_flops']:
    if col in flops_result.columns:
        flops_result[col] = pd.to_numeric(flops_result[col], errors='coerce')

if dl_col is not None:
    flops_result[dl_col] = pd.to_numeric(flops_result[dl_col], errors='coerce')

flops_result['model_name'] = flops_result['model_name'].astype(str)
if 'dataset' in flops_result.columns:
    flops_result['dataset'] = flops_result['dataset'].astype(str)

is_gpu_model = flops_result['model_name'].isin(DL_models)
has_valid_dl = False
if dl_col is not None:
    has_valid_dl = np.any(np.isfinite(flops_result.loc[is_gpu_model, dl_col]))

if has_valid_dl:
    flops_result['train_flops_eff'] = np.where(is_gpu_model & np.isfinite(flops_result[dl_col]) & (flops_result[dl_col] > 0),
                                               flops_result[dl_col],
                                               flops_result['train_flops'])
else:
    flops_result['train_flops_eff'] = flops_result['train_flops']

flops_result['total_flops_eff'] = flops_result['train_flops_eff'] + flops_result['inference_flops']

def choose_rate(model_name: str, setting: dict) -> float:
    use_gpu = float(setting.get('gpu', 0.0)) > 0 and (model_name in DL_models)
    return float(setting.get('gpu', 0.0)) if use_gpu else float(setting.get('cpu', 0.0))

def add_time_columns(df: pd.DataFrame, setting: dict, tag: str) -> pd.DataFrame:
    rates = df['model_name'].map(lambda m: choose_rate(m, setting)).astype(float)

    df[f'train_time_{tag}']     = np.where(rates > 0, df['train_flops_eff']     / rates, np.nan)
    df[f'inference_time_{tag}'] = np.where(rates > 0, df['inference_flops']     / rates, np.nan)
    df[f'total_time_{tag}']     = np.where(rates > 0, df['total_flops_eff']     / rates, np.nan)
    return df

for tag, setting in [('Highly Resourced', our_settings),
                     ('Mobile', mobile_setting),
                     ('Edge', edge_setting)]:
    flops_result = add_time_columns(flops_result, setting, tag)

print(flops_result.filter(regex='^(dataset|model_name|.*_time_).*$').head())


def _mean_by_model_one(df: pd.DataFrame, col: str, datasets: list | None) -> pd.DataFrame:
    use = df if (datasets is None or 'dataset' not in df.columns) else df[df['dataset'].isin(datasets)]
    g = use.groupby('model_name', dropna=True)[[col]].mean(numeric_only=True)
    g = g.rename(columns={col: 'mean'}).reset_index()
    return g

def _collect_means_for_metric(df: pd.DataFrame, metric: str, datasets: list | None):
    cols = {sc: f'{metric}_time_{sc}' for sc in ('Highly Resourced', 'Mobile', 'Edge')}
    means = {}
    for sc, c in cols.items():
        means[sc] = _mean_by_model_one(df, c, datasets)
    return means

def _order_by_ours(means_dict) -> list[str]:
    ours_df = means_dict['Highly Resourced'].dropna(subset=['mean']).copy()
    ours_df = ours_df.sort_values('mean', ascending=True)
    return ours_df['model_name'].tolist()


def plot_estimated_time(data, order, ylim_sets, data_type='train', dataset=None, save_path='./analysis/figures/FLOPS_scenarios'):
    
    nrows = len(ylim_sets)
    ylim_sets = list(reversed(ylim_sets))
    fig, axes = plt.subplots(nrows, 1, sharex=False, figsize=(18, 10))
    fig.subplots_adjust(hspace=0.05)

    for ax in axes:
        ax.plot(data['Highly Resourced'].set_index('model_name').reindex(order)['mean'],
                marker='o', linestyle='-', label='Highly Resourced', color='black', markersize=8)
        ax.plot(data['Mobile'].set_index('model_name').reindex(order)['mean'],
                marker='s', linestyle='--', label='Mobile', color='blue', markersize=8)
        ax.plot(data['Edge'].set_index('model_name').reindex(order)['mean'],
                marker='^', linestyle='-.', label='Edge', color='red', markersize=8)

    for i, name in enumerate(order):
        if name in DL_models:
            for ax in axes:
                ax.axvspan(i - 0.5, i + 0.5, color='gray', alpha=0.1, zorder=0)

    for ax, ylim in zip(axes, ylim_sets):
        ax.set_ylim(ylim)

    for ax in axes[:-1]:
        ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    axes[-1].xaxis.tick_bottom()

    d = 0.3
    kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
                  linestyle="none", color='k', mec='k', mew=1, clip_on=False)
    for i, ax in enumerate (axes):
        if i == 0:
            ax.plot([0, 1], [0, 0], transform=ax.transAxes, **kwargs)
        elif i == nrows - 1:
            ax.plot([0, 1], [1, 1], transform=ax.transAxes, **kwargs)
        else:
            ax.plot([0, 1], [1, 1], transform=ax.transAxes, **kwargs)
            ax.plot([0, 1], [0, 0], transform=ax.transAxes, **kwargs)

    if dataset is not None:
        pass
    else:
        if data_type == 'train':
            axes[0].set_yticks([15000, 30000, 45000])
            axes[1].set_yticks([1500, 3000, 4500])
            axes[2].set_yticks([150, 300, 450])
            axes[-1].set_yticks([0, 4, 8])
            fig.supylabel("Estimated Time (s)", fontsize=24, x=0.3)
        if data_type == 'inference':
            axes[0].set_yticks([500, 1500, 2500])
    for ax in axes:
        ax.tick_params(axis='x', which='major', length=10, width=1.5, labelsize=18)
        ax.tick_params(axis='y', which='major', length=10, width=1.5, labelsize=18)
    fig.supylabel("Estimated Time (s)", fontsize=24, x=0.06)
    plt.xticks(ticks=np.arange(len(order)), labels=order, rotation=30, ha='right', fontsize=18)
    axes[0].legend(loc='upper left', ncol=3, fontsize=20)

    os.makedirs(save_path, exist_ok=True)
    if dataset is not None:
        plt.savefig(f"{save_path}/Estimated_time_{dataset}_{data_type}.pdf", dpi=600, bbox_inches='tight', format='pdf')
    else:
        plt.savefig(f"{save_path}/Estimated_time_mean_{data_type}.pdf", dpi=600, bbox_inches='tight', format='pdf')
    plt.close()

train_means = _collect_means_for_metric(flops_result, metric='total', datasets=None)
infer_means = _collect_means_for_metric(flops_result, metric='inference', datasets=None)
order = _order_by_ours(train_means)

ylim_sets = {'mean': {'train_ylim': [[0, 10], [30, 500], [520, 5000], [5500, 45000]],
                      'infer_ylim': [[0, 2.5], [2.6, 16], [40, 350], [400, 2500]]},
             'PSM':  {'train_ylim': [[0, 15], [15, 150], [500, 3500], [3500, 25000]],
                      'infer_ylim': [[0, 2.5], [2.6, 15], [20, 250]]},
             'MSL': {'train_ylim': [[0, 20], [20, 200], [200, 1500]],
                     'infer_ylim': [[0, 2.5], [2.6, 15], [20, 150]]},
             'SMAP': {'train_ylim': [[0, 25], [28, 260], [280, 3500]],
                      'infer_ylim': [[0, 2.5], [3, 40], [50, 400], [500, 2300]]},
             'SMD': {'train_ylim': [[0, 15], [50, 1700], [2800, 15000], [16000, 120000]],
                     'infer_ylim': [[0, 10], [10, 35], [80, 1200], [2000, 8100]]},
             'SWaT': {'train_ylim': [[0, 70], [80, 1050], [1500, 10000], [11000, 18000]],
                      'infer_ylim': [[0, 1.3], [1.3, 10], [15, 190], [800, 5000]]},
             'WADI': {'train_ylim': [[0, 20], [80, 1400], [1700, 15000], [17000, 130000]],
                      'infer_ylim': [[0, 1.3], [1.4, 12], [15, 190], [200, 1400]]}}
plot_estimated_time(train_means, order, ylim_sets['mean']['train_ylim'], data_type='train')
plot_estimated_time(infer_means, order, ylim_sets['mean']['infer_ylim'], data_type='inference')


for dataset in ['PSM', 'MSL', 'SMAP', 'SMD', 'SWaT', 'WADI']:
    train_means_ds = _collect_means_for_metric(flops_result, metric='total', datasets=[dataset])
    infer_means_ds = _collect_means_for_metric(flops_result, metric='inference', datasets=[dataset])
    plot_estimated_time(train_means_ds, order, ylim_sets[dataset]['train_ylim'], data_type='train', dataset=dataset)
    plot_estimated_time(infer_means_ds, order, ylim_sets[dataset]['infer_ylim'], data_type='inference', dataset=dataset)