import argparse
import math
import os
from collections import Counter
from typing import Sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# TODO: move these common statements to a util file (don't want to import torch etc here)
from main_exp import get_task_sequence


def check_results(df_res: pd.DataFrame, name):
    iter_counts = Counter(df_res['iter'])
    all_values = Counter(iter_counts.values())

    print(f"{name:20s}: ", end='')
    for k in sorted(all_values):
        print(f"{all_values[k]} x {k}", end='\t')
    print()


def load_df(exp_id, exp_settings, check_result=True):
    exp_dir = os.path.join('results', str(exp_id))
    # exp_dir = os.path.join('guy_results', str(exp_id))
    exp_name = f"{exp_id} {exp_settings['exp_name']}"

    result_dirs = [rf for rf in os.listdir(exp_dir) if not (rf == 'models' or rf.startswith('ignore'))]

    results = []
    for rd in result_dirs:
        if not os.path.exists(os.path.join(exp_dir, rd, 'running.txt')):
            df = pd.read_csv(os.path.join(exp_dir, rd, 'results.csv'), dtype={'iter': int, 'num_classes_new': int})
            df['run_name'] = rd
            results.append(df)

    if len(results) == 0:
        return

    results = pd.concat(results, ignore_index=True)
    # I used -1 as a placeholder in the csv files. Not smart. Using np.nan would ignore those values, using 0
    # treats them as 0 accuracy. The -1 results from not testing, but that's usually because it wasn't trained on that
    # data, so in those cases 0 is correct. In theory, it could be that it is not tested but performance is not 0,
    # then nan would be better. I will keep putting in -1, if not I can never recover whether it is 0 performance
    # or not testing.
    results.replace(-1, 0, inplace=True)

    if check_result:
        check_results(results, exp_name)

    if 'task_iter' in results:
        results['iter'] = results['task_iter'] * int(exp_settings['max_iters']) + results['iter']

    return results


def calculate_means(df, tags, task_labels, rolling=1):
    all_names = []

    for i, tag in enumerate(tags):
        # Add mean of all classes
        y_name = f'mean_{tag}'
        all_names.append(y_name)
        df[y_name] = df.filter(regex=tag).mean(axis=1)
        if rolling > 1:
            # Calculate rolling mean per run and set back to df. Looks ugly, but works.
            df[y_name] = (df.groupby('run_name', sort=False)[y_name]
                          .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))

        # Add mean for each task
        for j, tl in enumerate(task_labels):
            y_name = f'mean_{tag}_T{j + 1}'
            all_names.append(y_name)
            df[y_name] = df.filter(regex=tag).iloc[:, tl].mean(axis=1)
            if rolling > 1:
                df[y_name] = (df.groupby('run_name', sort=False)[y_name]
                              .rolling(rolling, min_periods=1).mean().reset_index(0, drop=True))
    return df, all_names


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--acc', action='store_true',
                        help='If set print also the test and train accuracies of the model')
    parser.add_argument('--sep', action='store_true',
                        help='If set will also print levels per task ')

    args = parser.parse_args()

    # First experiment defines the task file to use.
    # base_exp_id = 157   # with l2-init
    base_exp_id = 174   # without l2-init
    # base_exp_id = 198

    exp_ids = [
        157, 160,
        182, 186, 190,
        179, 183, 187,
        180, 184, 188,
        181, 185, 189,
    ]
    #
    # exp_ids = [174, 175, 176, 164, 421, 170, 177, 161, 156, 179, 187]  # Continuous Ablation table
    exp_ids = [174, 230, 157, 234, 231, 232, 234, 235, 236, 237]

    # exp_ids = [198, 199, 200]

    metric = 'acc'
    rolling = 5
    limits = [0.95, 0.99, 1.00]
    tags = [f'train_{metric}', f'test_{metric}']

    exp_details = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})

    all_results = []
    all_exp_names = []

    for exp_id in exp_ids:
        _, task_labels = get_task_sequence(exp_details.loc[exp_id]['tasks_file'], exp_details.loc[exp_id]['seed'])

        if base_exp_id is not None:
            base_results = load_df(base_exp_id, exp_details.loc[base_exp_id])
            base_results, all_names = calculate_means(base_results, tags, task_labels, rolling=rolling)
            base_results = base_results.groupby('iter')[all_names].agg(['mean', 'sem'])
        else:
            base_results = None

        exp_settings = exp_details.loc[exp_id]
        exp_name = f"{exp_id} {exp_settings['exp_name']}"

        results = load_df(exp_id, exp_settings)
        if results is None:
            continue

        all_exp_names.append(exp_name)
        results, all_names = calculate_means(results, tags, task_labels, rolling=rolling)

        groups = results.groupby('iter')
        avg_results = groups[all_names].agg(['mean', 'sem'])

        exp_results = []
        if args.acc:
            for y_name in all_names:
                exp_results.append(avg_results[y_name].iloc[-1].values)
        else:
            exp_results.append(100 * avg_results['mean_test_acc'].iloc[-1].values)

        names = all_names if args.sep else ['mean_test_acc']
        for y_name in names:
            if 'test' in y_name and base_results is not None:
                for lim in limits:
                    base = base_results[y_name]['mean']
                    other = avg_results[y_name]['mean']

                    # This was only important with multistep, could just as well remove this.
                    base_final_iter = get_last_iteration_before_scheduler(exp_details.loc[base_exp_id], base)
                    other_final_iter = get_last_iteration_before_scheduler(exp_details.loc[exp_id], other)

                    res, max_iter = get_time_to_x(base[(100 < base.index) & (base.index <= base_final_iter)],
                                        other[(100 < other.index) & (other.index <= other_final_iter)], lim)

                    # flip for percentages
                    res = max_iter / res
                    exp_results.append([res])

        all_results.append(exp_results)


    if args.acc:
        print(',', end='')
        for n in all_names:
            print(f'{n},,', end='')
        print()

    for en, res in zip(all_exp_names, all_results):
        print(en, end=',')
        for r in res:
            if len(r) == 2:
                # print(f"{r[0]:.2f},{r[1]:.2f}", end=',')
                print(f"${r[0]:.2f}$", end=',')
            else:
                print(f"$\\times {r[0]:.2f}$", end=',')
        print()


def get_last_iteration_before_scheduler(exp_setting, result):
    if exp_setting['sched'] == 'multistep':
        num_samples_per_class = 500  # This is only true for Cifar100.
        num_new_classes = int(exp_setting['num_classes_new'])  # This will not work with the multi-task setting.
        batches_per_epoch = num_new_classes * num_samples_per_class / exp_setting['batch_size_new']

        if exp_setting['epochs'] == 'na':
            # TODO: double check this, not 100% sure this always gives the correct result.
            len_train_set = num_samples_per_class * num_new_classes
            epochs = math.ceil(int(exp_setting['max_iters']) / math.ceil(len_train_set / exp_setting['batch_size_new']))
        else:
            epochs = int(exp_setting['epochs'])

        final_iter = 0.6 * epochs * batches_per_epoch  # This may in the future no longer be true
        # print(final_iter)
        # print(result[abs(result.index - final_iter) < 1000])
    else:
        final_iter = max(result.index) + 1
    return final_iter


def get_time_to_x(base: pd.Series, other: pd.Series, alpha: float):
    limit = alpha * base.max()
    reached = (other >= limit).idxmax()

    if other[reached] < limit:
        return -1.0, 1.0
    else:
        return reached, base.idxmax()


if __name__ == '__main__':
    main()
