from os.path import join
import numpy as np

from model.GaussianReadout import FullGaussian2d
from model.GaussianReadoutModel import GaussianReadoutModel
from model.CoreModel import CoreModel


def build_model(config, spike_array_path):
    '''
    Builds the model used in the paper.
    Args:
        config: config file containing parameters used in the paper
        spike_array_path: path to the spike data

    Returns: instance of class GaussianReadoutModel

    '''
    base_path = config['base_path']
    resnet_weight_path = config['resnet_weight_path']
    readout_layer = config['readout_layer']
    n_features = config['n_features']
    feat_map_size = config['feat_map_size']

    spike_matrix = np.load(spike_array_path)
    n_neurons = spike_matrix.shape[1]

    # Build backbone CNN
    core_model = CoreModel(weights=join(base_path, resnet_weight_path),
                           readout_layer=readout_layer)

    # Build readout layer
    readout = FullGaussian2d(in_shape=(n_features, feat_map_size, feat_map_size), outdims=n_neurons, bias=False,
                             gauss_type='isotropic',
                             init_sigma=0.5)
    # build final model
    model = GaussianReadoutModel(core_model=core_model, readout=readout)

    return model