# Naïve NAS algorithm for 3-block DenseNet-BC

import csv
import os
import re
import time
from collections import deque
from datetime import datetime, timedelta


def to_time(time_string):
    """
    Extract a number of seconds from a string representing time.

    Args:
        time_string (str) - the string to process.
    """
    # Count days separately.
    days = 0
    if " days, " in time_string:
        days, time_string = time_string.split(" days, ", 1)
    elif " day, " in time_string:
        days, time_string = time_string.split(" day, ", 1)
    # Process the rest of the string.
    hours, minutes, seconds = time_string.split(":")
    hours = int(hours) + 24*int(days)
    minutes = int(minutes) + 60*hours
    seconds = float(seconds.replace(",", ".")) + 60*minutes
    return seconds


def get_score_data(file_path, eeu=100):
    """
    Parse a given ft-log to get the accuracy and loss for the validation and
    test set, the training and testing time, and the parameter counts.

    Args:
        file_path (str) - path to the ft-log to be parsed.
        eeu (int) - corresponds to the 'every_epoch_until' parameter
            (default 0).
    """
    # Only the last 16 rows before the architecture is specified are important
    # (hence the deque of maxlen=16).
    ft_log_rows_FIFO = deque(maxlen=16)
    # For prebuilt, the state with best loss is final.
    ft_log_best_loss_row = None
    ft_log_best_loss_row_eeu = None
    # This is the data to extract from the ft-log.
    score_data = {}

    # Open the feature log file and traverse the rows until the end.
    with open(file_path, "r") as ft_log_file:
        ft_log_reader = csv.reader(ft_log_file, delimiter=';')
        for r in ft_log_reader:
            # Append rows.
            ft_log_rows_FIFO.append(r)
            # Keep track of the best loss row.
            if len(r) > 6 and re.match("[0-9]+", r[0]):
                if int(r[0]) < eeu:
                    ft_log_best_loss_row_eeu = r
                elif (ft_log_best_loss_row_eeu is None or (float(
                        ft_log_best_loss_row_eeu[2].replace(",", ".")
                        ) > float(r[2].replace(",", ".")))):
                    ft_log_best_loss_row_eeu = r
                if (ft_log_best_loss_row is None or (float(
                        ft_log_best_loss_row[2].replace(",", ".")
                        ) > float(r[2].replace(",", ".")))):
                    ft_log_best_loss_row = r

    # Build the table row.
    if ft_log_rows_FIFO[-1][0] == "fully connected (in use)":
        # Sparsification.
        score_data.update({
            "end epoch": int(ft_log_rows_FIFO[-13][0])+1,
            "train time": to_time(ft_log_rows_FIFO[-8][1]),
            "test time": to_time(ft_log_rows_FIFO[-5][1]),
            "total params": int(ft_log_rows_FIFO[-4][1]),
            "conv params": int(ft_log_rows_FIFO[-2][1]),
            "fc params": int(ft_log_rows_FIFO[-1][1]),
            "valid acc": float(
                ft_log_best_loss_row_eeu[3].replace(",", "."))*100,
            "valid loss": float(ft_log_best_loss_row_eeu[2].replace(",", ".")),
            "test acc": float(
                ft_log_rows_FIFO[-12][3].replace(",", "."))*100,
            "test loss": float(ft_log_rows_FIFO[-12][2].replace(",", "."))
            })
    elif ft_log_rows_FIFO[-1][0] == "fully connected":
        # No-sparsification.
        score_data.update({
            "end epoch": int(ft_log_rows_FIFO[-12][0])+1,
            "train time": to_time(ft_log_rows_FIFO[-7][1]),
            "test time": to_time(ft_log_rows_FIFO[-4][1]),
            "total params": int(ft_log_rows_FIFO[-3][1]),
            "conv params": int(ft_log_rows_FIFO[-2][1]),
            "fc params": int(ft_log_rows_FIFO[-1][1]),
            "valid acc": float(
                ft_log_best_loss_row[3].replace(",", "."))*100,
            "valid loss": float(ft_log_best_loss_row[2].replace(",", ".")),
            "test acc": float(
                ft_log_rows_FIFO[-11][3].replace(",", "."))*100,
            "test loss": float(ft_log_rows_FIFO[-11][2].replace(",", "."))
            })
    else:
        print("Format unrecognized. Probably CUT_SHORT or FAILED.")

    return score_data


def naive_NAS(dataset, save_dir, training_config, eeu=0, n_retrains=1,
              impr_thresh=0.01, max_layer_count=12, max_param_count=None):
    """
    Naïve NAS algorithm for 3-block DenseNet-BC with k=12:
    Train and test various architectures with the same number of layers in each
    block (layer_count), progressively increasing that number of layers.
    Stop with the following stopping criteria (if active):
        - no significant improvement w.r.t. the previous architecture,
        - the architecture exceeds max_layer_count layers per block.
        - the architecture exceeds max_param_count trainable parameters.

    Args:
        dataset (str) - the dataset for which to create the networks;
        save_dir (str) - dir where logs, models, hypers should be saved;
        training_config (str) - configuration for naïve training to be passed
            to the initializer;
        eeu (int) - corresponds to the 'every_epoch_until' parameter
            (default 0);
        n_retrains (int) - number of times that the network is trained from
            scratch to find a better validation performance (default 1);
        impr_thresh (float or None) - minimum accuracy difference between the
            last and current architecture (default 0.01);
        max_layer_count (int or None) - maximum number of layers per block in
            the network (default 12);
        max_param_count (int or None) - maximum number of trainable parameters
            in the network (default None).
    """
    base_config = "-m DenseNet-BC -lnl {0} -k 12"
    l_count = 1  # layer count.
    real_time_count = 0  # real execution time (wall).
    train_time_count = 0  # sum of GPU training times.
    test_time_count = 0  # sum of GPU test times.
    stop_criterion = ""  # why the algorithm stopped.

    # Best run before the current architecture.
    best_file_name_last = ""
    best_score_data_last = {}

    # Create logs and write headers.
    run_id = "naive_NAS_{0}_{1}".format(
        dataset, datetime.now().strftime("%Y_%m_%d_%H%M%S"))
    logs_dir = os.path.join(save_dir, run_id)
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)
    main_log_fname = run_id + '.csv'
    with open(os.path.join(save_dir, main_log_fname), 'w+') as f:
        f.write('l count;total real time;total train time;total test time;'
                'total train + last test time;'
                'best file name;end epoch;train time;test time;'
                'total params;conv params;fc params;'
                'valid acc;valid loss;test acc;test loss\n')
    score_keys_list = ['end epoch', 'train time', 'test time', 'total params',
                       'conv params', 'fc params', 'valid acc', 'valid loss',
                       'test acc', 'test loss']

    # Set a reference for real execution time.
    real_time_start = time.perf_counter()

    try:
        # While the stopping criteria are not met.
        while True:
            # Best run for the current architecture.
            best_file_name_curr = ""
            best_score_data_curr = {}

            # Train the candidate network n_retrains times
            for r in range(n_retrains):
                # print("python run_DensEMANN.py --train --test"
                #       " -ds {0} --prebuilt -prlr DensEMANN {3}"
                #       # " --data /dev/shm"
                #       " --save \"{1}\""
                #       " {2}".format(
                #         dataset, logs_dir, training_config,
                #         base_config.format(l_count)))
                os.system("python run_DensEMANN.py --train --test"
                          " -ds {0} --prebuilt -prlr DensEMANN {3}"
                          " --data /dev/shm"
                          " --save \"{1}\""
                          " {2}".format(
                            dataset, logs_dir, training_config,
                            base_config.format(l_count)))

            # Read the results for all the retrains and find the best one.
            for file_name in os.listdir(logs_dir):
                if file_name.startswith(
                    'DenseNet-BC_{0}_prebuilt_k=12_lnl={1}_'.format(
                        dataset, l_count)) and file_name.endswith('csv'):
                    # Read the score data.
                    file_path = os.path.join(logs_dir, file_name)
                    score_data = get_score_data(file_path, eeu=eeu)
                    # print(file_path)
                    # print(score_data)
                    # STOPPING CRITERION: max param count.
                    if max_param_count is not None and (
                            score_data["total params"] > max_param_count):
                        # Do not update real_time_count, as we didn't need to
                        # train the network to know the param count.
                        stop_criterion = "max_param_count"
                        raise StopIteration
                    # Save the score data if it is the best so far.
                    if len(best_score_data_curr) == 0 or (
                            score_data["valid loss"] < best_score_data_curr[
                            "valid loss"]):
                        best_file_name_curr = file_name
                        best_score_data_curr = score_data
                    # Add up the training and test times.
                    train_time_count += score_data["train time"]
                    test_time_count += score_data["test time"]

            # STOPPING CRITERION: improvement threshold on accuracy.
            if len(best_score_data_last) == 0 or (
                    best_score_data_curr["valid acc"] -
                    best_score_data_last["valid acc"]) >= impr_thresh:
                best_file_name_last = best_file_name_curr
                best_score_data_last = best_score_data_curr
            else:
                real_time_count = time.perf_counter() - real_time_start
                stop_criterion = "impr_thresh"
                raise StopIteration

            # STOPPING CRITERION: max layer count.
            l_count += 1
            if max_layer_count is not None and l_count > max_layer_count:
                real_time_count = time.perf_counter() - real_time_start
                stop_criterion = "max_layer_count"
                raise StopIteration

            # Add up the real execution time.
            real_time_count = time.perf_counter() - real_time_start
            # Write logs.
            log_row = '{0};{1};{2};{3};{4};{5}'.format(
                l_count-1,
                str(timedelta(seconds=real_time_count)),
                str(timedelta(seconds=train_time_count)),
                str(timedelta(seconds=test_time_count)),
                str(timedelta(seconds=train_time_count
                              + best_score_data_last["test time"])),
                best_file_name_last)
            for k in score_keys_list:
                if k.endswith('time'):
                    log_row += ';{}'.format(str(
                        timedelta(seconds=best_score_data_last[k])))
                else:
                    log_row += ';{}'.format(best_score_data_last[k])
            log_row += '\n'
            with open(os.path.join(save_dir, main_log_fname), 'a') as f:
                f.write(log_row.replace('.', ','))

    except StopIteration:
        log_row = '{0};{1};{2};{3};{4};{5}'.format(
            l_count-1,
            str(timedelta(seconds=real_time_count)),
            str(timedelta(seconds=train_time_count)),
            str(timedelta(seconds=test_time_count)),
            str(timedelta(seconds=train_time_count
                          + best_score_data_last["test time"])),
            best_file_name_last)
        for k in score_keys_list:
            if k.endswith('time'):
                log_row += ';{}'.format(str(
                    timedelta(seconds=best_score_data_last[k])))
            else:
                log_row += ';{}'.format(best_score_data_last[k])
        log_row += '\n'
        with open(os.path.join(save_dir, main_log_fname), 'a') as f:
            f.write(log_row.replace('.', ','))
        print("Stopping because of " + stop_criterion)


if __name__ == '__main__':
    # Setting the variables.
    dataset = "C10+"
    training_config_dict = {
        "no_DSD": "-nep {1} --no-sparsify",
        # "DSD": "-nep {1} --sparsify --spars_sched_func sched_dsd"
        # " --granularity weight --end_sparsity 0 --dsd_middle {2}"
        # " --dsd_pattern {3} --spars_end_epoch {0} --rlr_start_epoch {0}"
        # " --every_epoch_until {0}",
        # "no_DSD_cutout": "-nep {1} --no-sparsify --cutout",
        # "DSD_cutout": "-nep {1} --sparsify --spars_sched_func sched_dsd"
        # " --granularity weight --end_sparsity 0 --dsd_middle {2}"
        # " --dsd_pattern {3} --spars_end_epoch {0} --rlr_start_epoch {0}"
        # " --every_epoch_until {0} --cutout"
    }
    epoch_count_list = [(200, 100), (400, 200)]
    prlr = "DensEMANN"
    dsd_middle = 80
    dsd_pattern = 'square'

    # Experiment battery.
    # naive_NAS(
    #     dataset, os.path.join(os.getcwd(), "ft-logs", "RANDOM_TEST"),
    #     training_config_dict["no_DSD"].format(0, 1),
    #     eeu=0, n_retrains=2, max_layer_count=4, max_param_count=None)
    for epoch_count in epoch_count_list:
        for training_config in training_config_dict:
            for i in range(5):
                naive_NAS(
                    dataset,
                    "./ft-logs/Naive_NAS_one_block/{}/"
                    "C10+_same-k".format(
                        training_config),
                    training_config_dict[training_config].format(
                        epoch_count[0], sum(epoch_count),
                        dsd_middle, dsd_pattern),
                    eeu=(epoch_count[0] if training_config.startswith("DSD")
                         else 0),
                    n_retrains=1, max_layer_count=None, max_param_count=None)
