'''
Evaluates model for a single recording day as a reference for visualizations

RUN:
python evaluate_on_objects.py -d day_05_03_24
'''

import matplotlib.pyplot as plt
import torch
from scipy.stats import pearsonr
from tqdm import tqdm
import os
import argparse
import pickle

from evaluate.post_analysis.PostData import PostData
from evaluate.build_model import build_model
from evaluate.utils import *
from model.Dataset import ImageDataset


def evaluate_on_objects(config, recording_day, model, used_indices, plot=True):
    '''
    Evaluates model fit for a specific recording day
    Args:
        config: dict containing parameters
        recording_day: str indicating the recording day. see generate_correlation_plots.py for possible values
        model:instance of class GaussianReadoutModel
        used_indices: train/val/test idx for the trained model
        plot: whether to plot the results

    Returns: None

    '''
    print('Evaluating...')
    post_data = PostData(config, recording_day)

    base_path = config['base_path']
    spike_array = post_data.spike_matrix


    object_set = ImageDataset(stimulus_path=join(base_path, config['object_path']))
    object_set.set_subset([name for name in post_data.filenames if name.startswith('Object_')])

    test_idx = used_indices['test'] # these are the test image indices
    # Get test target indices
    val_idx = used_indices['val']
    val_test_idx = sorted([*val_idx, *test_idx])
    test_target_idx = [] # these are the test indices within the 75 val/test indices
    for i, value in enumerate(val_test_idx):
        if value in test_idx:
            test_target_idx.append(i)
    test_target_idx = [post_data.avatar_rows[i] for i in test_target_idx] # rows of test images in the post data spike matrix

    avatar_set = ImageDataset(stimulus_path=join(base_path, config['stimulus_path']), idx=test_idx)

    model.eval()
    model.readout.eval()
    model = model.to('cuda')

    def evaluate(model, dataset, spike_matrix=None, plot=False, domain=None, avatar_means=None):
        loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)
        yhat = []
        with torch.no_grad():
            for x in loader:
                if domain == 'avatar':
                    x = x[0]
                    yhat.append(model(x.to('cuda')).to('cpu'))
                elif domain == 'Object':
                    y = model(x.to('cuda')).to('cpu')
                    overshoot = y - avatar_means
                    y = y - 3 * torch.nn.functional.relu(overshoot)
                    yhat.append(y)
                else:
                    yhat.append(model(x.to('cuda')).to('cpu'))
        yhat = torch.cat(yhat)
        ros = []
        for n in range(spike_matrix.shape[-1]):
            ros.append(pearsonr(yhat[:, n], spike_matrix[:, n])[0])
            if plot:
                plt.scatter(yhat[:, n], spike_matrix[:, n])
                plt.xlabel('Predicted Responses')
                plt.ylabel('Recorded Responses')
                plt.title(str(n))
                plt.show()
        return ros

    object_corr = evaluate(model, object_set, spike_matrix=post_data.spike_matrix[post_data.object_rows, :])
    avatar_corr = evaluate(model, avatar_set, spike_matrix=post_data.spike_matrix[test_target_idx, :])
    print('Mean Correlation on avatar Test Set:', np.mean(avatar_corr))
    print('Channel Correlation on avatar Test Set:', (avatar_corr))
    print()
    print('Mean Correlation on Objects:', np.mean(object_corr))
    print('Channel Correlation on Objects:', (object_corr))
    print()

    # Stability
    pre_spike_array_path = join(config['base_path'], 'submission_data/spike_data', recording_day) + '.npy'
    pre_spike_array = np.load(pre_spike_array_path)
    pre_val_test = sorted([*used_indices['val_target'], *used_indices['test_target']])
    stability = []
    for n in range(pre_spike_array.shape[-1]):
        ro = pearsonr(pre_spike_array[pre_val_test, n], spike_array[post_data.avatar_rows, n])[0]
        stability.append(ro)
    print('Channel recording stability:', stability)
    print()


    if plot:
        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 6))
        ax.scatter(range(post_data.spike_matrix.shape[1]), avatar_corr, label='avatar Test Set')
        ax.scatter(range(post_data.spike_matrix.shape[1]), object_corr, label='Objects')
        ax.scatter(range(post_data.spike_matrix.shape[1]), stability, label='Stability', c='tab:brown')
        ax.set_ylim(-0.5, 1)
        ax.set_ylabel('Correlation')
        ax.set_xlabel('Channel')
        ax.legend()
        # ax[2].legend()
        ax.hlines([0, np.mean(avatar_corr), np.mean(object_corr)], 0, 15, colors=['black', 'tab:blue', 'tab:orange'], linestyles='dashed')
        fig.suptitle('Correlation after training on the avatar stimuli')
        print('Saving to:', join(base_path, 'plots/corr_' + recording_day + '.pdf'))
        plt.savefig(join(base_path, 'plots/corr_' + recording_day + '.pdf'), bbox_inches='tight')
        plt.show()
    return None



def compute_activations(model, dataset, avatar_set=False, return_predictions=False):
    '''
    Compute output of core model
    '''
    loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)
    activations = []
    predictions = []
    with torch.no_grad():
        for x in tqdm(loader):
            if avatar_set:
                x = x[0]
            y, W, a = model(x.to('cuda'), return_features=True)
            activations.append(a.to('cpu'))
            predictions.append(y.to('cpu'))
    activations = torch.cat(activations)
    predictions = torch.cat(predictions)

    if not return_predictions:
        return activations.squeeze()
    else:
        return activations.squeeze(), predictions.squeeze()


def compute_stability(pre_matrix, post_data, idx_dict, verbose=False):
    '''
    Test/retest stability of recording electrodes
    Args:
        pre_matrix: array with data from first experimental phase
        post_data: data from 2nd phase, instance of class PostData
        idx_dict: train/val idx
    Returns: list containing per-channel stability

    '''
    train_idx = sorted(idx_dict['train'])
    val_idx = sorted(idx_dict['val'])

    # Align the rows of the spike matrix with the train/test/val set
    all_idx = train_idx.copy()
    all_idx.extend(val_idx)

    val_target_idx = []

    for i, id in enumerate(sorted(all_idx)):
        if id in val_idx:
            val_target_idx.append(i)
    val_target_idx = np.array(val_target_idx)

    pre_array = pre_matrix[val_target_idx, :]
    post_array = post_data.spike_matrix[post_data.avatar_rows, :]



    from scipy.stats import pearsonr
    if verbose:
        print('Correlation between Pre and Post recordings')
    channels, pre, post, stability = [], [], [], []
    for n in range(pre_array.shape[-1]):
        if verbose:
            print(' Channel:', n)
            print('Pre/Post:', np.mean(pre_array[:, n]), np.mean(post_array[:, n]))
            print('Correlation:', pearsonr(pre_array[:, n], post_array[:, n])[0])
        channels.append(n + 1)
        pre.append(np.mean(pre_array[:, n]))
        post.append(np.mean(post_array[:, n]))
        stability.append(pearsonr(pre_array[:, n], post_array[:, n])[0])
    if verbose:
        print('Val set mean phase 1:', np.mean(pre_array, axis=0))
        print('Val set mean phase 2:', np.mean(post_array, axis=0))

    return stability



def main():
    base_path = os.getcwd()
    config_path = join(base_path, 'evaluate/configs', 'training.json')
    config = load_config(config_path=config_path)
    config = set_config(config, base_path)

    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--recording_day',
                        help='name of the np file containing data of the recoding day', type=str,
                        default='day_05_03_24')
    parser.add_argument('-m', '--model_checkpoint',
                        help='path to the model checkpoint containing the readout weights', type=str,
                        default=None)

    args = parser.parse_args()

    # set default model
    if args.model_checkpoint is None:
        args.model_checkpoint = join(base_path, 'submission_data',
                                     'pretrained_models', args.recording_day + '_0.1', 'model')

    # Get train/val/test split for this run
    with open(join(base_path, 'submission_data/pretrained_models',
                   args.recording_day + '_0.1', 'dict.pickle'), 'rb') as handle:
        d = pickle.load(handle)

    used_indices = d['Used Indices']  # these are the test image indices

    model = build_model(config,
                        join(base_path, 'submission_data/post_spike_data', args.recording_day) + '_2ndphase.npy')
    # Load pretrained model
    checkpoint = torch.load(args.model_checkpoint)
    model.readout.load_state_dict(checkpoint)

    model.eval()
    model.readout.eval()
    model = model.to('cuda')

    evaluate_on_objects(config, args.recording_day, model, used_indices, plot=True)


if __name__ == '__main__':
    main()