import argparse
import os
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging

# TODO: move these common statements to a util file (don't want to import torch etc here)
from main_exp import get_task_sequence, get_classes_per_task
from main_table import load_df, calculate_means

plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams["font.family"] = "Yrsa"
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)

# Surpress annoying pandas warnings about stuff I'm not even doing
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)


def check_results(df_res: pd.DataFrame, name):
    iter_counts = Counter(df_res['iter'])
    all_values = Counter(iter_counts.values())

    print(f"{name:20s}: ", end='')
    for k in sorted(all_values):
        print(f"{all_values[k]} x {k}", end='\t')
    print()


def get_intersections(data, levels, ignore_first=100):
    eps = 1e-8  # Had some floating point errors, this is better.
    intersections = []
    for level in levels:
        reached = (data >= level - eps)[ignore_first:].idxmax()
        if data[reached] < level - eps:
            intersections.append(None)
        else:
            intersections.append(reached)
    return np.array(intersections)


def get_colors_and_markers(exp_ids):
    if exp_ids is None:
        colors, markers = None, None
    elif len(exp_ids) == 5:
        colors = ["#94c4dfff", "#4a98c9ff", "#1764abff", "#08306bff", "#f08b30ff"]
        markers = ['o', 's', 'd', 'v', 'p']
    elif len(exp_ids) == 4:
        colors = ["#6aaed6ff", "#2070b4ff", "#08306bff", "#f08b30ff"]
        markers = ['o', 's', 'd', 'p']
    elif len(exp_ids) == 3:
        colors = ["#6aaed6ff", "#08306bff", "#f08b30ff"]
        markers = ['o', 'd', 'p']
    elif len(exp_ids) == 2:
        colors = ["#6aaed6ff", "#f08b30ff"]
        markers = ['o', 'p']
    else:
        colors, markers = None, None
    return colors, markers


def get_max_iter(exp_settings, exp_ids, cut_off=1.0):
    """
    Get the final interation of all experiments in exp_ids.
    """
    all_max_iters = []
    for e_id in exp_ids:
        c_exp = exp_settings.loc[e_id]
        try:
            if c_exp['task'] == 'na':
                num_tasks = len(c_exp['num_classes_new'].split(','))
            else:
                num_tasks = 1
            max_iter = cut_off * int(c_exp['max_iters']) * num_tasks
        except ValueError:
            max_iter = cut_off * 19550
        all_max_iters.append(max_iter)
    return max(all_max_iters)


def full_plot(exp_ids, args, rolling=3, metric='acc'):
    exp_details = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})
    _, task_labels = get_task_sequence(exp_details.loc[exp_ids[0]]['tasks_file'], exp_details.loc[exp_ids[0]]['seed'])

    axes_tags = [f'train_{metric}', f'test_{metric}']

    fig, axes = plt.subplots(len(task_labels) + 1, len(axes_tags), figsize=(10, 10), squeeze=False)
    sep_runs = args.sep

    for exp_id in exp_ids:
        exp_settings = exp_details.loc[exp_id]
        exp_dir = os.path.join('results', str(exp_id))
        result_dirs = [rf for rf in os.listdir(exp_dir) if not (rf == 'models' or rf.startswith('ignore'))]

        # These bits should be replaced by the common methods in main_table to be good.
        results = []
        for rd in result_dirs:
            if not os.path.exists(os.path.join(exp_dir, rd, 'running.txt')):
                df = pd.read_csv(os.path.join(exp_dir, rd, 'results.csv'), dtype={'iter': int})
                df['run_name'] = rd
                results.append(df)

        if len(results) > 0:
            results = pd.concat(results, ignore_index=True)
            results.replace(-1, 0, inplace=True)  # I used -1 as a placeholder in the csv files. Not smart.
            check_results(results, f"{exp_id} {exp_settings['exp_name']}")
        else:
            continue

        if 'task_iter' in results:
            results['iter'] = results['task_iter'] * int(exp_settings['max_iters']) + results['iter']

        for i, tag in enumerate(axes_tags):
            y_name = f'mean_{tag}'
            results[y_name] = results.filter(regex=tag).mean(axis=1)
            results[y_name] = (results.groupby('run_name', sort=False)[y_name]
                               .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))
            if not sep_runs:
                sns.lineplot(results, x='iter', y=y_name, ax=axes[0, i], label=f"{exp_id} {exp_settings['exp_name']}",
                             legend=(i == 0))
            else:
                sns.lineplot(results, x='iter', y=y_name, ax=axes[0, i], hue='run_name', legend=(i == 0))

            for j, tl in enumerate(task_labels):
                y_name = f'mean_{tag}_T{j + 1}'
                # Are they really always ordered? I would not trust this
                results[y_name] = results.filter(regex=tag).iloc[:, tl].mean(axis=1)
                results[y_name] = (results.groupby('run_name', sort=False)[y_name]
                                   .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))
                # I'm not 100% sure that these will always be ordered
                if not sep_runs:
                    sns.lineplot(results, x='iter', y=y_name, ax=axes[j + 1, i], label=exp_settings['exp_name'],
                                 legend=False, errorbar='sd')
                else:
                    sns.lineplot(results, x='iter', y=y_name, ax=axes[j + 1, i], hue='run_name', legend=False)

    for ax in axes.flatten():
        axes[0, 1].legend()
        if args.logx:
            ax.set_xscale('log')
        ax.set_ylim((0, 1.0) if metric == 'acc' else (0.0, 3.5))
        ax.grid()

    if args.save is not None:
        plt.savefig(os.path.join('graphs', f"{args.save}.jpg"), dpi=300)
        plt.savefig(os.path.join('graphs', f"{args.save}.svg"))

    plt.show()


def eye_catcher_plot(exp_ids, rolling=3, tag='test_acc', colors=None, markers=None, args=None):
    exp_settings = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})
    # exp_settings = pd.read_csv('guy_experiments.csv', index_col='exp_id', dtype={'balance': bool})
    _, task_labels = get_task_sequence(exp_settings.loc[exp_ids[0]]['tasks_file'], exp_settings.loc[exp_ids[0]]['seed'])

    max_iter = get_max_iter(exp_settings, exp_ids, cut_off=args.cut_off)

    results = load_df(exp_ids[-1], exp_settings.loc[exp_ids[-1]])
    results = results[results['iter'] <= max_iter]
    results, all_names = calculate_means(results, [tag], task_labels)

    groups = results.groupby('iter')
    avg_results = groups[all_names].agg(['mean', 'sem'])
    max_acc = max(avg_results[f'mean_{tag}']['mean'].rolling(rolling, min_periods=0).mean())

    fig, axes = plt.subplots(1, 1, figsize=args.figsize)
    # fig, axes = plt.subplots(1, 1, figsize=(6.2, 2))

    for i, exp_id in enumerate(exp_ids):
        this_exp = exp_settings.loc[exp_id]
        results = load_df(exp_id, exp_settings.loc[exp_id])

        if results is None:
            continue

        y_name = f'mean_{tag}'
        results[y_name] = results.filter(regex=tag).mean(axis=1)
        # Rolling at the level of the experiment, need to reset index to get original order back.
        results[y_name] = (results.groupby('run_name', sort=False)[y_name]
                           .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))

        if max_iter is not None:
            results = results[results['iter'] <= max_iter]

        if not args.sep:
            color = colors[i] if colors is not None else None
            zorder = 10 if i == (len(exp_ids) - 1) else 15
            sns.lineplot(results, x='iter', y=y_name, ax=axes, label=f"{exp_id} {this_exp['exp_name']}",
                         legend=True, zorder=zorder, color=color)

            acc_levels = [f * max_acc for f in [0.95, 0.99, 1.0]]
            marker = markers[i] if markers is not None else 'o'
            marker_c = '#ffff00' if i == (len(exp_ids) - 1) else 'r'
            mean_acc = results.groupby('iter')[y_name].mean()
            intersections = get_intersections(mean_acc, acc_levels)

            # print(intersections, [f"{inter / max_iter * 100:.1f}%" if inter is not None else 'na' for inter in intersections])
            print(",".join([str(i) for i in [max(mean_acc), *intersections]]))
            plt.plot(intersections, acc_levels, marker, c=marker_c, zorder=20, markersize=4.0,
                     markeredgecolor='k', markeredgewidth=0.5)
        else:
            sns.lineplot(results, x='iter', y=y_name, ax=axes, legend=True, marker='o', hue='run_name')

    axes_lim = 10_000 * (((max_iter - 1) // 10_000) + 1)
    axes.hlines([f * max_acc for f in [0.95, 0.99, 1.0]], -500, axes_lim + 500, linestyles='--', colors='k', zorder=5,
                linewidth=0.5)

    if args.logx:
        axes.set_xscale('log')

    setup_axes(axes, axes_lim, args)

    if args.save:
        plt.savefig(f'./graphs/{args.save}.svg')
    plt.tight_layout()

    plt.show()


def levels_plot(exp_ids, rolling=3, tag='test_acc', colors=None, markers=None, args=None, levels=(0.95, 0.99, 1.0)):
    exp_settings = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})
    _, task_labels = get_task_sequence(exp_settings.loc[exp_ids[0]]['tasks_file'], exp_settings.loc[exp_ids[0]]['seed'])

    try:
        max_iter = args.cut_off * int(exp_settings.loc[exp_ids[-1]]['max_iters'])
    except ValueError:
        max_iter = args.cut_off * 19550

    results = load_df(exp_ids[-1], exp_settings.loc[exp_ids[-1]])
    results = results[results['iter'] < max_iter]
    results, all_names = calculate_means(results, [tag], task_labels)

    groups = results.groupby('iter')
    avg_results = groups[all_names].agg(['mean', 'sem'])
    max_acc = max(avg_results[f'mean_{tag}']['mean'].rolling(rolling, min_periods=0).mean())

    fig, axes = plt.subplots(1, 1, figsize=(0.8, 2))
    # fig, axes = plt.subplots(1, 1, figsize=(2.4, 6))

    for i, exp_id in enumerate(exp_ids):
        results = load_df(exp_id, exp_settings.loc[exp_id])

        if results is None:
            continue

        y_name = f'mean_{tag}'
        results[y_name] = results.filter(regex=tag).mean(axis=1)
        # Rolling at the level of the experiment, need to reset index to get original order back.
        results[y_name] = (results.groupby('run_name', sort=False)[y_name]
                           .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))

        if max_iter is not None:
            results = results[results['iter'] <= max_iter]

        acc_levels = [f * max_acc for f in levels]
        marker = markers[i] if markers is not None else 'o'
        marker_c = '#ffff00' if i == (len(exp_ids) - 1) else 'r'
        intersections = get_intersections(results.groupby('iter')[y_name].mean(), acc_levels)

        print(intersections)
        plt.plot(np.arange(len(acc_levels)), intersections, ls=(1, (2, 1)), lw=0.25, marker=marker, color='k',
                 markerfacecolor=marker_c, zorder=20, markersize=4.0, markeredgecolor='k', markeredgewidth=0.5)

    axes_lim = 10_000 * (((max_iter - 1) // 10_000) + 1)

    axes.spines[['bottom', 'top', 'right', 'left']].set_visible(False)
    axes.tick_params(length=0)

    y_ticks = axes.get_yticks()
    axes.set_yticks(y_ticks, y_ticks / 10_000)
    axes.set_xticks(np.arange(len(levels)), [f"{l*100:.0f}%" for l in levels], rotation=90)

    axes.set_xlim((-0.2, len(levels) + 0.2))
    axes.set_ylim((-500, axes_lim + 500))

    axes.grid(axis='both')

    if args.save:
        plt.savefig(f'./graphs/{args.save}.svg')
    plt.tight_layout()

    plt.show()


def special_multitask_plot(multi_exp, scratch_exps, rolling=3, args=None):

    exp_settings = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})
    _, task_labels = get_task_sequence(exp_settings.loc[multi_exp]['tasks_file'], exp_settings.loc[multi_exp]['seed'])

    iters_per_task = int(exp_settings.loc[multi_exp]['max_iters'])
    max_iter = get_max_iter(exp_settings, [multi_exp])
    print(max_iter)

    results = load_df(multi_exp, exp_settings.loc[multi_exp])
    results, _ = calculate_means(results, ['test_acc'], task_labels)

    fig, axes = plt.subplots(1, 1, figsize=(3.6, 2))
    label = f"{multi_exp} {exp_settings.loc[multi_exp]['exp_name']}"
    sns.lineplot(results, x='iter', y='mean_test_acc', ax=axes, label=label, legend=True, zorder=15, color='firebrick')

    cls_per_task = get_classes_per_task(exp_settings.loc[multi_exp])
    max_levels = np.cumsum(cls_per_task) / np.sum(cls_per_task)
    x_stairs = np.arange(0, iters_per_task * (len(cls_per_task)), iters_per_task)
    axes.stairs(max_levels[1:], x_stairs, color='k')

    for i, (exp_id) in enumerate(scratch_exps):
        if exp_id is None:
            continue

        start, stop = i * iters_per_task, (i + 1) * iters_per_task

        s_results = load_df(exp_id, exp_settings.loc[exp_id])
        s_results, _ = calculate_means(s_results, ['test_acc'], task_labels)

        max_acc = max(s_results.groupby('iter')['mean_test_acc'].mean().rolling(rolling, min_periods=0).mean())
        acc_levels = [f * max_acc for f in [0.95, 0.99, 1.0]]
        axes.hlines(acc_levels, start, stop, linestyles='--', colors='k', zorder=5, linewidth=0.5)

        multi_part = results[(results['iter'] >= start + 10) & (results['iter'] < stop)]
        intersections = get_intersections(multi_part.groupby('iter')['mean_test_acc'].mean(), acc_levels)
        print(intersections)

        plt.plot(intersections, acc_levels, 'o', c='yellow', zorder=20, markersize=4.0,
                 markeredgecolor='k', markeredgewidth=0.5)


    setup_axes(axes, max_iter, args)
    axes.set_xticks(x_stairs, [f"T{i}" for i in range(len(x_stairs))])

    if args.save:
        plt.savefig(f'./graphs/{args.save}.svg')
    plt.tight_layout()

    plt.show()
    plt.show()


def setup_axes(axes, axes_lim, args):
    axes.spines[['bottom', 'top', 'right', 'left']].set_visible(False)
    axes.tick_params(length=0)

    x_ticks = axes.get_xticks()
    axes.set_xticks(x_ticks, x_ticks / 10_000)
    axes.set_xlabel("Iterations")

    if args.xlim is None:
        axes.set_xlim((-500, axes_lim + 500))
    else:
        axes.set_xlim(args.xlim)

    axes.set_ylabel("Test Accuracy")
    axes.set_ylim(args.ylim)
    # axes.set_ylim((0.2, 0.8))

    axes.grid()



def main():
    """
    This makes the plots, using the exp_ids. Using 'eye' will only show the full test accuracy, without 'eye' it also
    plots the training accuracy and accuracies for each task (and also loss if you change the metric parameter to 'loss').
    cut_off was used to not show until the end of training, but that shouldn't be used anymore with cosine scheduling

    The codes automatically groups the different runs that are finished (so those that no longer have a running.txt file
    in their result folder). It will calculate mean and standard errors using those runs. It prints a simple statistic
    of how many point there are with how many results (e.g. 195 x 5 means that there are 195 iterations where there
    are 5 results).
    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--eye', action='store_true')
    parser.add_argument('--level', action='store_true')
    parser.add_argument('--multi', action='store_true')
    parser.add_argument('--cut_off', type=float, default=1.0)
    parser.add_argument('--sep', action='store_true')
    parser.add_argument('--logx', action='store_true')
    parser.add_argument('--tag', default='test_acc',
                        help='the tag to match in the results file, probably test_acc, test_loss, train_acc or '
                             'train_loss for --eye plots and acc or loss for full plot')
    parser.add_argument('--exps', nargs='*',
                        help='The ID of the experiments, if they are not overwritten below.')
    parser.add_argument('--figsize', nargs=2, default=(3.6, 2))
    parser.add_argument('--ylim', nargs=2, default=(0.545, 0.705))
    parser.add_argument('--xlim', nargs=2, default=None)
    parser.add_argument('--save')
    args = parser.parse_args()
    
    guy = True

    args.figsize = (float(args.figsize[0]), float(args.figsize[1]))
    args.ylim = (float(args.ylim[0]), float(args.ylim[1]))
    args.xlim = (float(args.xlim[0]), float(args.xlim[1])) if args.xlim is not None else None

    # The first experiment will determine the task file that is used! Mostly important for the full plots.
    # !! The final experiment is the one that is used to draw the lines at 100, 99 and 95% accuracy.

    exp_ids = [int(e) for e in args.exps] if args.exps is not None else None

    # exp_ids = [170, 171, 172, 175, 174]         # Scheduling at different points
    # exp_ids = [168, 169, 175, 174]              # Different ratios of old and new data.
    # exp_ids = [167, 166, 157]                   # Eye-catcher figure
    # exp_ids = [59, 62, 34]                      # Too short 70 + 30
    # exp_ids = [173, 175, 176, 174]              # Static solutions
    # exp_ids = [164, 166, 177, 157]              # Dynamic solutions (add 177 for S&P with L2-init)
    # exp_ids = [132, 130]                        # Loss plateau
    # exp_ids = [156, 164, 157]                   # L2-init + sample
    # exp_ids = [158, 176, 178, 174]              # S&P and joint + sample
    # exp_ids = [194, 195, 196, 197, 179, 157]    # Sample with zero probability
    # exp_ids = [194, 179, 177, 165, 157]         # Sample with 0%, 10% and joint.
    # exp_ids = [182, 186, 190, 157]              # 90 + 10 combine
    # exp_ids = [179, 183, 187, 157]              # 70 + 30 combine
    # exp_ids = [180, 184, 188, 157]              # 50 + 50 combine
    # exp_ids = [181, 185, 189, 157]              # 30 + 70 combine
    # exp_ids = [205, 206, 207, 198]
    # exp_ids = [202, 203, 204, 198]
    # exp_ids = [421, 169, 175, 174]
    # exp_ids = [682, 683, 684]
    exp_ids = [410, 411, 412, 413, 414, 409, 506]
    exp_ids = [454, 456, 457, 409, 506]
    exp_ids = [460, 463, 465, 466, 409, 506]
    exp_ids = [95, 59, 22]

    rolling = 5
    # colors, markers = get_colors_and_markers(exp_ids)
    colors, markers = None, None

    if args.eye:
        eye_catcher_plot(exp_ids, rolling=rolling, colors=colors, markers=markers, args=args, tag=args.tag)
    elif args.level:
        levels_plot(exp_ids, rolling=rolling, args=args, markers=markers, tag='test_acc')
    elif args.multi:
        # special_multitask_plot(192, [215, 216, 217, 218, 219], args=args)
        special_multitask_plot(193, [220, 221, 222, 223, 224, 225, 226, 227, 228, 229], args=args)
    else:
        full_plot(exp_ids, args, rolling=rolling, metric='acc')


if __name__ == '__main__':
    main()
