'''
Class that performs data manipulation for the data recorded AFTER  model fitting. This class
operates on data from a single record, i.e. all neurons have seen all stimuli in the set.
'''

import numpy as np
from os.path import join
import pickle

class PostData():
    def __init__(self, config, recording_day):
        self.config = config
        self.recording_day = recording_day
        self.pre_spike_matrix = np.load(join(config['base_path'], 'submission_data/spike_data', recording_day) + '.npy')
        self.spike_matrix = np.load(join(config['base_path'], 'submission_data/post_spike_data', recording_day) + '_2ndphase.npy')

        with open(join(config['base_path'], 'submission_data/pretrained_models' , recording_day + '_0.1', 'dict.pickle'), 'rb') as handle:
            self.dict = pickle.load(handle)
        self.filenames = list(np.loadtxt(join(config['base_path'], 'submission_data/pretrained_models' , recording_day + '_0.1', 'filenames.txt'),
                                    dtype=str))

        self.get_row_idx()

    def get_row_idx(self):
        '''
        Returns: The indices in the spike matrix corresponding to bodies/objects/avatars etc.
        '''
        self.avatar_rows = []
        for filename in self.filenames:
            if not filename.startswith('avatar_'):
                continue
            self.avatar_rows.append(self.filenames.index(filename))

        self.body_rows = []
        for filename in self.filenames:
            if not filename.startswith('Body_'):
                continue
            self.body_rows.append(self.filenames.index(filename))

        self.object_rows = []
        for filename in self.filenames:
            if not filename.startswith('Object_'):
                continue
            self.object_rows.append(self.filenames.index(filename))

        self.good_body_rows = []
        for id in self.dict['Predictions']['Body']['best_idx'][0]:
            id = id.item()
            filename = 'Body_' + (7 - len(str(id))) * '0' + str(id) + '.png'
            self.good_body_rows.append(self.filenames.index(filename))

        self.bad_body_rows = []
        for id in self.dict['Predictions']['Body']['worst_idx'][0]:
            id = id.item()
            filename = 'Body_' + (7 - len(str(id))) * '0' + str(id) + '.png'
            self.bad_body_rows.append(self.filenames.index(filename))

        self.good_object_rows = []
        for id in self.dict['Predictions']['Object']['best_idx'][0]:
            id = id.item()
            filename = 'Object_' + (7 - len(str(id))) * '0' + str(id) + '.png'
            self.good_object_rows.append(self.filenames.index(filename))

        self.bad_object_rows = []
        for id in self.dict['Predictions']['Object']['worst_idx'][0]:
            id = id.item()
            filename = 'Object_' + (7 - len(str(id))) * '0' + str(id) + '.png'
            self.bad_object_rows.append(self.filenames.index(filename))
