import glob
import os

import matplotlib
import numpy as np
import pandas as pd
import tensorboard as tb
from IPython.display import display
from matplotlib import pyplot as plt
from packaging import version
from scipy import stats
from tensorboard.backend.event_processing import event_accumulator


class Debug():
    def __init__(self):
        obj = 3


debug = Debug()

model_order = [
    ['BetaVAE', 'Ada-GVAE', 'SlowVAE', 'PCL'],
    ['MLP', 'CNN', 'CoordConv', 'Coordinate Based', 'Rotation-EQ',
     'Rotation-EQ-big', 'Spatial Transformer', 'RN50', 'RN101', 'DenseNet'],
    ['RN50 (ImageNet-21k)', 'RN101 (ImageNet-21k)', 'DenseNet (ImageNet-1k)']
]

def read_tb(tb_path):
    ea = event_accumulator.EventAccumulator(tb_path,
                                            size_guidance={
                                                # see below regarding this argument
                                                event_accumulator.COMPRESSED_HISTOGRAMS: 1,
                                                event_accumulator.IMAGES: 1,
                                                event_accumulator.AUDIO: 1,
                                                event_accumulator.SCALARS: 0,
                                                # 0 = all events
                                                event_accumulator.HISTOGRAMS: 1})
    ea.Reload()
    debug.obj = ea
    return ea


def filter_nets(name):
    if 'implicit' in name:
        return True
        if not '0.01' in name:
            return True
    else:
        return False


def get_model_name(dirname):
    name_map = {'mlp': 'MLP',
                'deeper_cnn': 'Deeper CNN',
                'implicit': 'Coordinate Based',
                'rotation_frf': 'Rotation-EQ-big',
                'rotation': 'Rotation-EQ',
                'transformer': 'Spatial Transformer',
                'coordconv_pooling': 'CoordConv_MP',
                'coordconv': 'CoordConv',
                'betavae': "BetaVAE",
                'slowvae': 'SlowVAE',
                'pcl': 'PCL',
                'vanilla': 'CNN',
                'rn101': 'RN101',
                'rn50': 'RN50',
                'adagvae': 'Ada-GVAE',
                'densenet': 'DenseNet'
                }
    for name in name_map.keys():
        if name in dirname:
            suffix = ''
            if 'lr' in dirname:
                suffix += '_lr'
            if 'nl' in dirname:
                suffix += '_nl'
            if 'last' in dirname:
                suffix += '_last'
            if 'pretrained' in dirname:
                if 'dense' in dirname:
                    suffix += ' (ImageNet-1k)'
                elif ('rn50' in dirname) or ('rn101' in dirname):
                    suffix += ' (ImageNet-21k)'
                else:
                    suffix += 'pretrained'
            return name_map[name] + suffix
    print('Name not found, just naming', dirname)
    #     return 'NoName'
    return dirname


def highlight_max(data, color='yellow'):
    '''
    highlight the maximum in a Series or DataFrame
    '''
    attr = 'font-weight: bold'.format(color)
    # remove % and cast to float
    if not isinstance(data[0], str):
        data = data.replace('%', '', regex=True).astype(float)
    if data.ndim == 1:  # Series from .apply(axis=0) or axis=1
        is_max = data == data.max()
        if isinstance(data[0], str):
            return ['' for _ in is_max]
        return [attr if v else '' for v in is_max] + []
    else:  # from .apply(axis=None)
        is_max = data == data.max().max()
        return pd.DataFrame(np.where(is_max, attr, ''),
                            index=data.index, columns=data.columns)


def print_f(dataset, df):
    #     print(df)
    display(df)
    print()
    print()


# get one df per dataset with all results
def get_dfs(path, datasets, ood_types):
    dfs = {}
    j = -1
    for dataset in datasets:
        rows = []

        accuracies = []
        names = []
        for ood_type in ood_types:
            files = sorted(glob.glob(f'{path}/{dataset}_{ood_type}*'))

            for file in files:

                assert len(
                    files) > 0, f'no file sin {path}/{dataset}_{ood_type}*'

                model_paths = glob.glob(f'{file}/writer*/*')
                dirname = os.path.basename(file)
                model_name = get_model_name(dirname)

                assert len(model_paths) <= 4, model_paths

                for model_path in model_paths:

                    j += 1

                    ea = read_tb(model_path)
                    new_row = {'models': model_name}

                    if len(ea.Tags()['tensors']) == 0:
                        print('emtpy tb', model_path)
                        continue

                    for tag in ea.Tags()['tensors']:
                        a = ea.Tensors(tag)
                        a = a[-1][-1].string_val[0].decode(
                            "utf-8")  # wow, event accumulators are annoying
                        new_row[tag[:-13]] = a  # remove text_summary suffix

                    for tag in ea.Tags()['scalars']:
                        a = ea.Scalars(tag)
                        new_row[tag] = a[-1].value

                    # modify names
                    if 'only_train_last_layer' in new_row.keys():
                        if new_row['only_train_last_layer'] == 'True':
                            new_row['models'] = new_row['models'] + '_last'
                    rows.append(new_row)

        dfs[dataset] = pd.DataFrame(rows)
    return dfs


def do_violin_plot(dfs, metric, all_models=None):
    for dataset, df in dfs.items():
        if all_models is None:
            all_models = np.unique(df['models'])
        ood_settings = np.unique(df['modification'])
        print('mod3els', all_models)
        print(dataset, metric + "01")

        fig, axs = plt.subplots(1, len(ood_types), figsize=(30, 3))
        for i, ood_setting in enumerate(ood_types):
            for j, model in enumerate(all_models):
                ax = axs[i]
                data = df[
                    (df['models'] == model) & (
                            df['modification'] == ood_setting)]
                if data.shape[0] == 0:
                    print('missing values for model', model, ood_setting,
                          dataset)
                    print('data shape', data.shape)
                    continue
                data = data[metric].to_numpy() * 100
                ax.violinplot(np.clip(data, 0, 1000), positions=[j]);

            ax = axs[i]
            ax.set_ylim(-10, 110)
            ax.set_title(ood_setting)
            ax.set_xticks(range(len(all_models)))
            model_names = [n[:8] for n in all_models]
            ax.set_xticklabels(model_names, rotation=45)
    return fig, ax


def plot_bars(dataset,
              ood_types,
              model_dicts,
              metric,
              colors=None,
              model_names=None,
              figsize=(12, 1.2),
              pairs=False):
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    i = -5
    tick_positions = []
    n_models = 0
    for ood_int, ood_type in enumerate(ood_types):
        i += 4
        tick_positions.append(i + 3)
        if model_names is None:
            model_names = [np.unique(df['models'])]
        for _, (dfs, c, models) in enumerate(
                zip(model_dicts, colors, model_names)):
            # print('dfs', dfs)
            df = dfs[dataset]
            for col, model in enumerate(models):
                # print('dataset', dataset, model)

                i += 1

                # color
                cmap = matplotlib.cm.get_cmap(c)
                color = cmap((col + 1) / len(models))
                data = df[
                    (df['models'] == model) & (df['modification'] == ood_type)]
                data = data[metric].to_numpy()
                data = np.clip(data, 0.01, 1000)
                if np.isnan(np.mean(data)):
                    print('model', model, 'mod', ood_type, 'mean', np.mean(data), 'nan-mean', np.nanmean(data), data)
                ax.bar(i, np.clip(np.mean(data), 0.01, 1000), yerr=np.std(data),
                       color=color, width=1.0,
                       label=model,
                       capsize=3.);
                # print('plot')
                if ood_int == 0:
                    n_models += 1
                if pairs and i% 2 == 1:
                    i += 2

    ax.set_xticks(np.array(tick_positions))

    # Hide the right and top spines
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    return fig, ax, tick_positions, n_models


# model selection
def model_selection(df,
                    model_name,
                    modification='extrapolation',
                    score='dis_lib_full/dci'):
    model_to_factors = {'BetaVAE': 'vae_beta',
                        'SlowVAE': 'slowvae_gamma',
                        'Ada-GVAE': 'vae_beta',
                        }
    hyper_param_name = model_to_factors[model_name]
    hyper_param_values = np.unique(
        df[df['models'] == model_name][hyper_param_name])
    score_per_hyper_param = []
    for hyper_param in hyper_param_values:
        #         print('hyper_param', hyper_param)
        df_i = df[(df['models'] == model_name) & (
                df[hyper_param_name] == hyper_param) & (
                          df['modification'] == modification)]
        score_per_hyper_param.append(np.mean(df_i[score]))
    best_hyper_param = hyper_param_values[np.argmax(score_per_hyper_param)]
    return df[(df['models'] == model_name) & (
            df[hyper_param_name] == best_hyper_param) & (
                      df['modification'] == modification)]
