'''
Runs the entire pipeline of training on bodies, evaluating on bodies and objects and running parallel backprop.
The input is the ID of the recording day, as given in the example run statement below.

RUN:
python train_and_evaluate.py -d day_05_03_24
'''

import argparse
import os
import pickle

import torch

from evaluate.build_model import build_model
from evaluate.evaluate_online_fit import evaluate
from evaluate_on_objects import evaluate_on_objects
from evaluate.train_online import run_training
from evaluate.utils import *
from parallel_backpropagation import parallel_backprop



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--recording_day',
                        help='name of the np file containing data of the recoding day', type=str)

    args = parser.parse_args()

    base_path = os.getcwd()
    print('Current base path:', base_path)

    # Load configs and paths
    config_path = join(base_path, 'evaluate/configs', 'training.json')
    config = load_config(config_path=config_path)
    # Set correct paths for this machine
    config = set_config(config, base_path)

    # Save paths to config
    with open(join(base_path, 'evaluate/configs/training.json'), 'w') as f:
        json.dump(config, f)

    # Load the spike matrix of first experimental phase
    spike_array_path = join(base_path, 'submission_data', 'spike_data', args.recording_day + '.npy') # (475,n_neurons)

    ### Set Save path for this session
    config['save_path'] = join(base_path, 'submission_data/saved_models', args.recording_day)
    config['recording_day'] = args.recording_day

    if not os.path.exists(config['save_path']):
        os.makedirs(config['save_path'])


    ### Initialize Model
    model = build_model(config, spike_array_path)

    ### Fit Model ###
    train_result = run_training(model, config, spike_array_path)
    model = train_result['Model']
    idx_from_training = train_result['idx_dict']
    model.readout.eval()
    model.eval()

    ### Evaluate Model ###
    metrics = evaluate(model, config, spike_array_path, idx_from_training)

    ### Save model ###
    # Model
    torch.save(model.readout.state_dict(), os.path.join(config['save_path'], 'model'))

    # Predictions on Objects and Bodies
    out_dict = {
        'Model Metrics': metrics,
        'Spike Matrix': np.load(spike_array_path),
        'Used Indices': idx_from_training
    }

    with open(join(config['save_path'], 'dict.pickle'), 'wb') as handle:
        pickle.dump(out_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

    ### Evaluate ###
    evaluate_on_objects(config, args.recording_day, model, used_indices=out_dict['Used Indices'])

    ### Compute parallel backpropagation ###
    parallel_backprop(args.recording_day, model)

if __name__ == '__main__':
    main()