"""
Main function for EEGnet.
"""

from functools import reduce
import torch as t
import numpy as np
from tqdm import tqdm
from utils.eegnet_controller import train_subject_specific

QUANTIZED = False
DO_CV = False
N_EPOCHS = 100

BENCHMARK = False
N_TRIALS = 20

GRID_SEARCH = False

TEST_KERAS_MODEL = False


def run(do_cv=False, epochs=N_EPOCHS, quantized=QUANTIZED, export=True, silent=False):
    """
    Does one complete run over all data
    """
    metrics = t.zeros((9, 4))
    loss_history = t.zeros((9, 2, epochs))
    acc_history = t.zeros((9, 2, epochs))
    for subject in range(1, 10):
        if do_cv:
            _, _, best_epoch = train_subject_specific_cv(subject, epochs=epochs, silent=silent,
                                                         plot=False)
            epochs = best_epoch

        if quantized:
            _model, subject_metrics, history = \
                train_subject_specific_quant(subject, epochs=epochs, silent=silent, plot=export)
        else:
            _model, subject_metrics, history = \
                train_subject_specific(subject, epochs=epochs, batch_size=32, lr= 0.001, silent=silent, plot=export)
        loss, acc = history
        loss_history[subject-1, :, :] = loss
        acc_history[subject-1, :, :] = acc
        metrics[subject-1, :] = subject_metrics[0, :]
        

    if export:
        metrics_to_csv(metrics)
        
    return metrics, (loss_history, acc_history)

def main():
    """
    Main function used for testing
    """
    
    # normal procedure
    metrics, _ = run(do_cv=DO_CV, epochs=N_EPOCHS, export=False, silent=False)
    print(f"\nAverage Accuracy: {metrics[:,0].mean()}")

if __name__ == "__main__":
    main()
