# -*- coding: utf-8 -*-
import argparse
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

COLORS = {'aliceblue': '#F0F8FF','aqua': '#00FFFF','aquamarine': '#7FFFD4','azure': '#F0FFFF','beige': '#F5F5DC','bisque': '#FFE4C4','black': '#000000','blanchedalmond': '#FFEBCD','blue': '#0000FF','blueviolet': '#8A2BE2','brown': '#A52A2A','burlywood': '#DEB887','cadetblue': '#5F9EA0','chartreuse': '#7FFF00','chocolate': '#D2691E','coral': '#FF7F50','cornflowerblue': '#6495ED','cornsilk': '#FFF8DC','crimson': '#DC143C','cyan': '#00FFFF','darkblue': '#00008B','darkcyan': '#008B8B','darkgoldenrod': '#B8860B','darkgray': '#A9A9A9','darkgreen': '#006400','darkkhaki': '#BDB76B','darkmagenta': '#8B008B','darkolivegreen': '#556B2F','darkorange': '#FF8C00','darkorchid': '#9932CC','darkred': '#8B0000','darksalmon': '#E9967A','darkseagreen': '#8FBC8F','darkslateblue': '#483D8B','darkslategray': '#2F4F4F','darkturquoise': '#00CED1','darkviolet': '#9400D3','deeppink': '#FF1493','deepskyblue': '#00BFFF','dimgray': '#696969','dodgerblue': '#1E90FF','firebrick': '#B22222','forestgreen': '#228B22','fuchsia': '#FF00FF','gainsboro': '#DCDCDC','gold': '#FFD700','goldenrod': '#DAA520','gray': '#808080','green': '#008000','greenyellow': '#ADFF2F','honeydew': '#F0FFF0','hotpink': '#FF69B4','indianred': '#CD5C5C','indigo': '#4B0082','ivory': '#FFFFF0','khaki': '#F0E68C','lavender': '#E6E6FA','lavenderblush': '#FFF0F5','lawngreen': '#7CFC00','lemonchiffon': '#FFFACD','lightblue': '#ADD8E6','lightcoral': '#F08080','lightcyan': '#E0FFFF','lightgoldenrodyellow': '#FAFAD2','lightgreen': '#90EE90','lightgray': '#D3D3D3','lightpink': '#FFB6C1','lightsalmon': '#FFA07A','lightseagreen': '#20B2AA','lightskyblue': '#87CEFA','lightslategray': '#778899','lightsteelblue': '#B0C4DE','lightyellow': '#FFFFE0','lime': '#00FF00','limegreen': '#32CD32','linen': '#FAF0E6','magenta': '#FF00FF','maroon': '#800000','mediumaquamarine': '#66CDAA','mediumblue': '#0000CD','mediumorchid': '#BA55D3','mediumpurple': '#9370DB','mediumseagreen': '#3CB371','mediumslateblue': '#7B68EE','mediumspringgreen': '#00FA9A','mediumturquoise': '#48D1CC','mediumvioletred': '#C71585','midnightblue': '#191970','mintcream': '#F5FFFA','mistyrose': '#FFE4E1','moccasin': '#FFE4B5','navy': '#000080','oldlace': '#FDF5E6','olive': '#808000','olivedrab': '#6B8E23','orange': '#FFA500','orangered': '#FF4500','orchid': '#DA70D6','palegoldenrod': '#EEE8AA','palegreen': '#98FB98','paleturquoise': '#AFEEEE','palevioletred': '#DB7093','papayawhip': '#FFEFD5','peachpuff': '#FFDAB9','peru': '#CD853F','pink': '#FFC0CB','plum': '#DDA0DD','powderblue': '#B0E0E6','purple': '#800080','red': '#FF0000','rosybrown': '#BC8F8F','royalblue': '#4169E1','saddlebrown': '#8B4513','salmon': '#FA8072','sandybrown': '#FAA460','seagreen': '#2E8B57','seashell': '#FFF5EE','sienna': '#A0522D','silver': '#C0C0C0','skyblue': '#87CEEB','slateblue': '#6A5ACD','slategray': '#708090','snow': '#FFFAFA','springgreen': '#00FF7F','steelblue': '#4682B4','tan': '#D2B48C','teal': '#008080','thistle': '#D8BFD8','tomato': '#FF6347','turquoise': '#40E0D0','violet': '#EE82EE','wheat': '#F5DEB3','yellow': '#FFFF00','yellowgreen': '#9ACD32'}

pd.set_option('display.max_rows', None)

def interp(x, y, num=1000):
    func = interp1d(x, y, kind='linear')
    x_min, x_max = min(x), max(x)
    new_x = np.linspace(x_min, x_max, num=num)
    return new_x, func(new_x)

def is_numeric(s):
    return s.lstrip('-').replace('.', '', 1).isdigit()

def hyper_parameter_split(hps_str):
    hps = re.findall(r"#.*?#", hps_str)
    hp_map = dict(map(lambda s: s.strip('#').split(','), hps))
    for k, v in hp_map.items():
        if isinstance(v, str) and is_numeric(v):
            hp_map[k] = float(v)
    return pd.Series(hp_map)

def excel_show(args, target_col='accuracy'):
    dfs = []
    excel_names = []
    val_cols = None
    for excel in args.excels:
        if os.path.isdir(excel):
            excel = os.path.join(excel, "metric_scores.xlsx")
        df = pd.read_excel(excel, engine='openpyxl')
        if 'hyper-parameters' not in df.columns:
            df['hyper-parameters'] = 'agnostic'
        if val_cols is None:
            val_cols = set(df.columns)
        else:
            val_cols = set(df.columns) & val_cols
        if args.last_epoch:
            df = df[(df["train_epoch"]==9)]
        excel_name = excel.split("/")[-2]
        df['method'] = excel_name.split("_")[1]
        dfs.append(df)
        excel_names.append(excel_name)
    val_cols.remove('hyper-parameters')
    if 'class accuracy' in val_cols:
        val_cols.remove('class accuracy')
    val_cols = sorted(list(val_cols))
    general_metrics = ['accuracy', 'source_accuracy']
    discrepancy_metrics = ['a_distance', 'MCD', 'MDD'] 
    assign_cost_metrics = ['entropy', 'MI', 'SND', 'ClassAMI', 'ISM'] 
    image_metrics = ['ACM']
    statistic_metrics = ['IWCV', 'DEV', 'DEVN', 'BNM']
    all_metrics = general_metrics + discrepancy_metrics + assign_cost_metrics + image_metrics + statistic_metrics
    val_cols = [col for col in all_metrics if col in val_cols]
    print(val_cols)
    
    if args.concat:
        dataset_transfers = {}
        for df, excel_name in zip(dfs, excel_names):
            name = excel_name.split("_")[0] + "_" + excel_name.split("_")[2]
            if name in dataset_transfers:
                dataset_transfers[name].append(df)
            else:
                dataset_transfers[name] = [df]
        dfs = [pd.concat(df_list, axis=0) for df_list in dataset_transfers.values()]
        excel_names = [excel_name for excel_name in dataset_transfers]
        print(excel_names)

    mode = input(f"input the mode:\n0): {target_col} as x-tick;\n1): hyper-parameter as x-tick;\nelse): quit;\n")
    
    if mode == '0':
        for df in dfs:
            df[target_col] = 100 * df[target_col]

    hps = set()
    for df in dfs:
        hps = hps | set(df['hyper-parameters'])
    hps = sorted(list(hps))

    hp_info = ''.join(f"{i}): {hp};\n" for i, hp in enumerate(hps))
    hp_info += 'a): all;\n'

    input_hps = input("input the hyper-parameter(split with ',' if you have multiple hyper-parameters):\n" + hp_info)
    selected_hps = hps
    if input_hps != 'a':
        selected_hps = [hps[int(hp)] for hp in input_hps.split(',')]
        # selected = df[df['hyper-parameters'] == hps[int(hp)]][val_cols]

    if mode == '0':
        all_selected_with_name = [(df[df['hyper-parameters'].isin(selected_hps)], name) for name, df in
                                    zip(excel_names, dfs) \
                                    if set(selected_hps) & set(df['hyper-parameters'])]
        all_selected = [ele[0] for ele in all_selected_with_name]
        groups = [ele[1] for ele in all_selected_with_name]

        acc_infos = []
        max_accs = []
        best_accs = []
        for group, selected in zip(groups, all_selected):
            x = selected[target_col].values
            df_hp = selected['hyper-parameters'].values
            df_epoch = selected['train_epoch'].values
            if args.per_class_eval:
                df_class = selected['class accuracy'].values
            selected = selected[val_cols]

            corr = selected.corr()[target_col]
            min_acc, max_acc = x.min(), x.max()
            best_idx = selected.apply(lambda _col: _col.argmax())
            worst_idx = selected.apply(lambda _col: _col.argmin())
            best_acc = best_idx.map(lambda idx: x[idx])
            worst_acc = worst_idx.map(lambda idx: x[idx])
            acc_info = pd.concat([corr, max_acc - best_acc, worst_acc - min_acc], axis=1,
                                    keys=['corr', 'diff_best', 'diff_worst'])
            print(f"hyper-parameter or excel-name '{group}':")
            print(f"min_acc: {min_acc}, max_acc: {max_acc}")
            print(f"the relations of the {target_col} and other metric:\n{acc_info}")
            print("==================================================")
            acc_infos.append(acc_info)
            max_accs.append(max_acc)
            print(df_hp[best_idx[target_col]], df_epoch[best_idx[target_col]])
            if args.per_class_eval:
                print(df_class[best_idx[target_col]])
            if 'ACM' in val_cols:
                best_accs.append(best_acc['ACM'])
                print(df_hp[best_idx['ACM']], df_epoch[best_idx['ACM']])
                if args.per_class_eval:
                    print(df_class[best_idx['ACM']])

        print("max_acc list:", max_accs)
        print("best_acc list:", best_accs)
        print(f"avg max_acc: {sum(max_accs) / len(max_accs)}")
        acc_infos = sum(acc_infos) / len(acc_infos)
        print(f"the relations of the {target_col} and other metric by average:\n{acc_infos}")
        print([-np.round(acc_infos['corr'][row], 2) for row in acc_infos.index])
        print([np.round(acc_infos['diff_best'][row], 2) for row in acc_infos.index])

    elif mode == '1':
        dfs = [df[df['hyper-parameters'].isin(selected_hps)] for df in dfs]
        for i, df in enumerate(dfs):
            df_hps = df['hyper-parameters'].apply(hyper_parameter_split)
            dfs[i] = pd.concat([df, df_hps], axis=1)
        all_selected = dfs
        groups = excel_names
        target_col = input(f"the default x-tick '{target_col}' should be replaced, input the new x-tick from {list(df_hps.columns)+['train_epoch']+['method']}:\n")

    print("all cols:", val_cols)
    colors = COLORS.copy()
    y_names = input("input the y-tick(split with ',' if you have multiple y-ticks) or r to reset x-tick:\n")
    if y_names == 'a':
        y_names = val_cols
    else:
        y_names = y_names.split(',')
    x_all, y_all = defaultdict(list), defaultdict(list)
    for group, selected in zip(groups, all_selected):
        if mode == '1':
            x = selected[target_col].unique()
        else:
            x = selected[target_col].values
        x_sort = np.sort(x)
        x_argsort = np.argsort(x)
        y_names = [y_name.strip() for y_name in y_names if y_name.strip() in selected]
        for y_name in y_names:
            if mode == '1':
                if args.max_col:
                    y_sort = np.array([selected[selected[target_col] == v][y_name].max() for v in x_sort])
                else:
                    y_sort = np.array([selected[selected[target_col] == v][y_name].mean() for v in x_sort])
            else:
                y = selected[y_name].values
                y_sort = [y[i] for i in x_argsort]
            #plt.plot(x_sort, y_sort, c=colors.popitem()[0])
            x_all[y_name].append(x_sort)
            y_all[y_name].append(y_sort)
        #plt.legend(y_names)
        #plt.xlabel(f"{target_col} for {group}")
        #plt.show()

    for y_name in y_names:
        x_values = x_all[y_name]
        y_values = y_all[y_name]
        if mode == '1':
            if target_col == 'method':
                assert not args.concat
                xs = []
                ys = []
                for x in x_values:
                    if x in xs: 
                        continue
                    xs.append(x)
                    y = np.max([y_values[i] for i in range(len(x_values)) if x_values[i]==x])
                    ys.append(y)
                xs = np.concatenate(xs, axis=0)
                ys = np.array(ys)
                ys = (ys-ys.min())/(ys.max()-ys.min()) if ys.max()>ys.min() else 0*ys
            else:
                ys = np.mean(y_values, axis=0)
                print(y_name, ys.max(), ys.min())
                ys = (ys-ys.min())/(ys.max()-ys.min())
                if target_col not in ["train_epoch", "temperature", "margin"]:
                    plt.xscale("log")
                xs = np.mean(x_values, axis=0)
            plt.plot(xs, ys, c=colors.popitem()[0], marker="^")
        else:
            x_values_interp, y_values_interp = [], []
            for x, y in zip(x_values, y_values):
                x_interp, y_interp = interp(x, y)
                x_values_interp.append(x_interp)
                y_values_interp.append(y_interp)
            ys = np.mean(y_values_interp, axis=0)
            plt.plot(np.mean(x_values_interp, axis=0), ys, c=colors.popitem()[0])
    plt.legend(y_names, loc="lower right", fontsize=13.5)
    if target_col == "lr":
        plt.xlabel("learning_rate", fontsize=13.5)
    else:
        plt.xlabel(target_col, fontsize=13.5)
    plt.ylabel(f"metric score", fontsize=13.5)
    plt.show()
    plt.savefig(args.fig_path)
    print(f"save figure to {args.fig_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='draw excel')
    parser.add_argument('--excels', default=None, nargs='+', help='specify the path of metric excels')
    parser.add_argument('--concat', action='store_true', default=False,
                        help='concat the excels of all methods on a dataset')
    parser.add_argument('--last_epoch', action='store_true', default=False,
                        help='only evaluate models of the last_epoch')
    parser.add_argument('--per-class-eval', action='store_true',
                        help='whether output per-class accuracy during evaluation')
    parser.add_argument('--fig_path', type=str, default="/home/username/DAmetric_logs/visual.png",
                        help='output path of the plotted figure')
    args = parser.parse_args()

    excel_show(args)
