#!/usr/bin/env python
# coding: utf-8
import os
import argparse
from os.path import join
from pathlib import Path

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from mnist_auto_aug.classwise_transforms import KEY_TO_TRANSFORM


def get_trials(path, trials_per_fold=None):
    """ Loads trials.pkl results file for each parallel job folder in path,
    concatenate complete trials and sorts them by validation accuracy.

    Parameters
    ----------

    path : str
        Path to main training folder.

    trials_per_fold : int, optional
        Number of trials to keep after sorting (e.g. 5 -> keep 5 best trials).
        Defaults to None (keep all trials).

    Returns
    -------
    pandas.DataFrame
    """
    max_len = 0
    for job in os.listdir(path):
        job_folder = join(path, job)
        if os.path.isdir(job_folder) and "trials.pkl" in os.listdir(job_folder):
            file_path = join(job_folder, "trials.pkl")
            res = pd.read_pickle(file_path)
            if res.shape[0] > max_len:
                trials_df = res
                max_len = trials_df.shape[0]
    sorted_trials = trials_df[trials_df['state']=="COMPLETE"].sort_values(by="value", ascending=False)
    if trials_per_fold is not None:
        sorted_trials = sorted_trials[:trials_per_fold]
    return sorted_trials.reset_index()


def get_selected_transforms(sorted_trials, classwise=True):
    """ Extracts selected operations from the sorted trials DataFrame created with get_trials"""
    subpolicies = list()
    if classwise:
        classes = [int(col.split("_")[-1]) for col in sorted_trials.columns if 'operation_' in col]
        for c in classes:
            for rank in range(sorted_trials.shape[0]):
                proba = sorted_trials.loc[rank, f"params_probability_{c}"]
                operation = sorted_trials.loc[rank, f"params_operation_{c}"]
                subpolicies.append({
                    'class': c,
                    'rank': rank,
                    'operation': operation,
                    'probability': proba,
                })
    else:
        for rank in range(sorted_trials.shape[0]):
            proba = sorted_trials.loc[rank, f"params_probability"]
            operation = sorted_trials.loc[rank, f"params_operation"]
            subpolicies.append({
                'class': 'all_classes',
                'rank': rank,
                'operation': operation,
                'probability': proba,
            })
    subpolicies = pd.DataFrame(subpolicies)
    return subpolicies


def make_proba_matrix(selected_tf, operations=None):
    """ Builds probability matrix of selected transforms from DataFrame created with
    get_selected_transforms

    Parameters
    ----------

    selected_tf : pandas.DataFrame
        Result of get_selected_transforms.
    operations : list, optional
        List of possible operations (other than the Identity) to use as matrix
        columns. Default to None (i.e. only transforms present in `selected_tf`
        will be used).
        
    """
    if operations is None:
        operations = selected_tf['operation'].unique()
    classes = selected_tf['class'].unique()
    max_rank = selected_tf['rank'].max()

    proba_matrix = pd.DataFrame(
        np.zeros((len(classes), len(operations) + 1)),
        index=list(classes),
        columns=list(operations) + ['Identity']
    )

    for r in range(selected_tf.shape[0]):
        c = selected_tf.loc[r, 'class']
        op = selected_tf.loc[r, 'operation']
        p = selected_tf.loc[r, 'probability']
        proba_matrix.loc[c, op] += p
        proba_matrix.loc[c, 'Identity'] += (1 - p)

    proba_matrix /= (max_rank + 1)
    return proba_matrix


def get_perf_per_trial(path, sort=True, sort_by="valid acc", ascending=False):
    """ Will load the test and validation results per trial for each job folder in path

    Parameters
    ----------

    path : str
        Path to main training folder.
    sort : bool, optional
        Whether to sort results. Defaults to True.
    sort_by : str, optional
        Key used to sort values. Default to "valid_acc".
    ascending : bool, optional
        Defaults to False.

    Returns
    -------
    pandas.DataFrame
    """
    res = list()
    for job in os.listdir(path):
        job_folder = join(path, job)
        if os.path.isdir(job_folder) and "perf_per_tial.csv" in os.listdir(job_folder):
            file_path = join(job_folder, "perf_per_tial.csv")
            df = pd.read_csv(file_path, names=['trial', 'test loss', 'test acc', 'valid loss', 'valid acc'])
            res.append(df)
    trials_df = pd.concat(res, ignore_index=True)
    if sort:
        trials_df = trials_df.sort_values(by=sort_by, ascending=ascending).reset_index(drop=True)
    else:
        trials_df = trials_df.sort_values(by='trial', ascending=True).reset_index(drop=True)
    return trials_df


def get_perf_at_max_trial(perf_per_trial, trial, result_set='test', sorting_set='valid'):
    """ Crops performance results up to a certain number of trials and sorts them
    
    Parameters
    ----------

    perf_per_trial : pandas.DataFrame
        Complete results from get_perf_per_trial.
    trial : int
        Maximum number of trials before cropping.
    result_set : str, optional
        What set output results for. Defaults to 'test'.
    sorting_set : str, optional
        What set to use to sort results. Defaults to 'valid'.

    Returns
    -------
    pandas.Series : Best accuracy on result_set, based on values from sorting_set.
    pandas.Series : Best loss (cross-entropy) on result_set, based on values from sorting_set.
    """
    cropped_df = perf_per_trial.loc[perf_per_trial['trial'] <= trial]
    sorted_df = cropped_df.sort_values(by=f'{sorting_set} acc', ascending=False).reset_index(drop=True)
    best_row = sorted_df.iloc[0, :]
    return best_row.loc[f'{result_set} acc'], best_row.loc[f'{result_set} loss']



if __name__ == "__main__":
    
    # Training settings
    parser = argparse.ArgumentParser(
        description='PyTorch MNIST Example - plotting script'
    )
    parser.add_argument(
        'experiment_path',
        help='Path to folder where to save the search results.'
    )
    args = parser.parse_args()
    
        
    experiment_path = Path(args.experiment_path)
    baseline_path = experiment_path / 'baseline'
    no_classwise_path = experiment_path / 'single'
    classwise_path = experiment_path / 'classwise'
    figures_path = experiment_path / 'figures'
    figures_path.mkdir(exist_ok=True)


    plt.rc("text", usetex=False)
    plt.rc('font', size=8)
    plt.rc('xtick', labelsize=8)
    plt.rc('ytick', labelsize=8)
    plt.rc('savefig', bbox='tight')


    transforms_mapping = {
        "hflip": "Hor. Flip",
        "vflip": "Ver. Flip",
        "90rot": "90 Rot.",
        "180rot": "180 Rot.",
    }


    # Selected transforms not classwise
    sorted_trials_noncw = get_trials(str(no_classwise_path), trials_per_fold=5)
    long_selected_tr_df_noncw = get_selected_transforms(
        sorted_trials_noncw, classwise=False
    ).replace(transforms_mapping)

    proba_matrix_noncw = make_proba_matrix(
        long_selected_tr_df_noncw, list(transforms_mapping.values())
    )
    proba_matrix_noncw.index = ['All classes']

    
    # Selected transforms classwise
    sorted_trials_cw = get_trials(classwise_path, trials_per_fold=5)
    long_selected_tr_df_cw = get_selected_transforms(
        sorted_trials_cw
    ).replace(transforms_mapping)
    
    proba_matrix_cw = make_proba_matrix(
        long_selected_tr_df_cw, list(transforms_mapping.values())
    )

    # Merge the two dataframes
    grouped_proba_matrix = pd.concat([proba_matrix_noncw, proba_matrix_cw])

    # Reproduce Fig.2 right
    sns.set(font_scale=0.8)
    plt.subplots(figsize=(2.2,2.2))
    sns.heatmap(
        grouped_proba_matrix, annot=True, 
        cmap=sns.light_palette("seagreen", as_cmap=True),
        annot_kws={"fontsize":8}
    )
    plt.ylabel("Class")
    plt.xticks(rotation=60)
    plt.savefig(figures_path / 'MNIST_transforms.pdf')


    # Performance curve

    # Load non-classwise perf
    no_classwise_perf = get_perf_per_trial(no_classwise_path, sort=False)
    no_classwise_trials_line = np.arange(no_classwise_perf.shape[0])
    no_classwise_perf = np.stack([
        get_perf_at_max_trial(no_classwise_perf, trial)
        for trial in no_classwise_trials_line
    ])

    # Load classwise perf
    classwise_perf = get_perf_per_trial(classwise_path, sort=False)
    classwise_trials_line = np.arange(classwise_perf.shape[0])
    classwise_perf = np.stack([
        get_perf_at_max_trial(classwise_perf, trial)
        for trial in classwise_trials_line
    ])

    # Load baseline perf
    baseline_perf = get_perf_per_trial(baseline_path, sort=False)
    baseline_test_acc = baseline_perf.loc[0, 'test acc']

    
    # Reproduce Figure.2 right
    plt.subplots(figsize=(3.3,2.2))
    p2 = plt.plot(
        no_classwise_trials_line, no_classwise_perf[:, 0],
        label="Standard setting"
    )
    plt.hlines(
        no_classwise_perf[-1, 0], no_classwise_trials_line[-1],
        classwise_trials_line[-1]
    )
    p3 = plt.plot(
        classwise_trials_line, classwise_perf[:, 0],
        label="Class-wise setting"
    )
    p1 = plt.hlines(
        baseline_test_acc, *plt.xlim(), linestyle="--",
        color='k', label="No augmentation"
    )
    plt.ylabel("Test accuracy")
    plt.xlabel("Number of subpolicies sampled")
    plt.grid()
    plt.legend(loc=0, fontsize='small')
    plt.xscale("log")
    plt.tight_layout()
    plt.savefig(figures_path / 'MNIST_learnig_curve.pdf')
    
    print(f"Figures generated in {figures_path} folder")

