'''
Script to generate plots from the paper showing model perfomance on bodies and objects

RUN:
python generate_correlation_plots.py
'''
import os
import pickle

import matplotlib.pyplot as plt
import torch
from scipy.stats import pearsonr, mannwhitneyu
from tqdm import tqdm

from evaluate.build_model import build_model
from evaluate.post_analysis.PostData import PostData
from evaluate.post_analysis.global_analysis.correlation_statistics import compute_conf_interval
from evaluate_on_objects import compute_stability
from evaluate.utils import *
from evaluate.utils import load_config
from model.Dataset import ImageDataset


recording_days = ['day_09_02_24', 'day_12_02_24', 'day_23_02_24', 'day_27_02_24', 'day_29_02_24', 'day_01_03_24',
                  'day_05_03_24', 'day_06_03_24', 'day_18_04_24', 'day_23_04_24', 'day_24_04_24', 'day_29_04_24',
                  'day_30_04_24', 'day_02_05_24', 'day_03_05_24', 'day_06_05_24', 'day_07_05_24', 'day_08_05_24']


def evaluate(model, dataset, spike_matrix=None):
    '''
    Evaluate the model on the dataset
    Args:
        model: model of class GaussianReadoutModel
        dataset: dataset of clas ImageDataset
        spike_matrix: array of shape (n_img, n_neuron)

    Returns: list containing output of scipy.stats.pearsonr for each channel

    '''
    loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)
    yhat = []
    with torch.no_grad():
        for x in loader:
            yhat.append(model(x.to('cuda')).to('cpu'))
    yhat = torch.cat(yhat)
    ros = []
    for n in range(spike_matrix.shape[-1]):
        ros.append(pearsonr(yhat[:, n], spike_matrix[:, n]))
    return ros


def get_stats(monkey, region):
    '''
    Computes descriptive stats for monkey/region, as given in the paper
    Args:
        monkey: str, 'monkeyG' or 'monkeyT'
        region: str, 'ASB' or 'MSB

    Returns: tuple containing relevant stats

    '''
    base_path = os.getcwd()
    config_path = join(base_path, 'evaluate/configs', 'training.json')
    config = load_config(config_path=config_path)
    config = set_config(config, base_path)

    # initialize outputs
    object_corr, avatar_corr = [], []
    stability = []
    idx_list = []
    object_responses, body_responses = [], []
    min_number_objects = np.inf

    print('Gathering recording days from monkey {} and region {}...'.format(monkey, region))
    for recording_day in tqdm(recording_days):
        post_data = PostData(config, recording_day)

        # spike matrix from 2nd phase
        spike_array = post_data.spike_matrix

        # spike matrix from first phase
        pre_spike_array = np.load(join(base_path, 'submission_data/spike_data', recording_day + '.npy'))

        # Get channel indices corresponding to this subject and region
        with open(join(base_path, 'evaluate/data_information/asb_msb_idx.json'), "r") as jsonfile:
            asb_msb_idx = json.load(jsonfile)
        with open(config['idx_path'], 'rb') as fp:
            stimulus_idx_dict = pickle.load(fp)

        # Select ASB/MSB subset
        asb_msb_idx = asb_msb_idx[monkey]
        if not recording_day in asb_msb_idx:
            # This day contains no recordings from this monkey
            continue
        asb_idx = asb_msb_idx[recording_day]['anterior']
        msb_idx = asb_msb_idx[recording_day]['posterior']
        if region == 'all':
            idx = sorted([*asb_idx, *msb_idx])
        elif region == 'ASB':
            idx = asb_idx
        elif region == 'MSB':
            idx = msb_idx
        if len(idx) == 0:
            continue # skip iteration if the speficied region wasn't recorded that day

        # keep track of which neuron comes from which day
        channel_names = []
        for i in range(spike_array.shape[1]):
            id = (4 - len(str(i))) * '0' + str(i)
            channel_names.append(recording_day + '_unit' + id)
        idx_list.append(channel_names)

        # Load object dataset
        object_set = ImageDataset(stimulus_path=join(base_path, config['object_path']))
        # Select objects that were shown during this recording
        object_set.set_subset([name for name in post_data.filenames if name.startswith('Object_')])

        # Keep track of smallest number of objects used across recording days
        if len(object_set) < min_number_objects:
            min_number_objects = len(object_set)

        # Get train/val/test split for this run
        with open(join(base_path, 'submission_data/pretrained_models',
                       recording_day + '_0.1', 'dict.pickle'), 'rb') as handle:
            d = pickle.load(handle)

        # Need to wrangle some indices to see which row in the spike matrix corresponds to which image
        test_idx = d['Used Indices']['test']  # these are the test image indices
        # Get test target indices
        val_idx = d['Used Indices']['val']
        val_test_idx = sorted([*val_idx, *test_idx])
        test_target_idx = []  # these are the test indices within the 75 val/test indices
        for i, value in enumerate(val_test_idx):
            if value in test_idx:
                test_target_idx.append(i)
        test_target_idx = [post_data.avatar_rows[i] for i in
                           test_target_idx]  # rows of test images in the post data spike matrix

        # Load avatar set (held-out test imgs)
        avatar_set = ImageDataset(stimulus_path=join(base_path, config['stimulus_path']), idx=test_idx)

        model = build_model(config,
                            join(base_path, 'submission_data/post_spike_data', recording_day) + '_2ndphase.npy')
        # Load pretrained model
        checkpoint = torch.load(join(base_path, 'submission_data/pretrained_models',
                                     recording_day + '_0.1', 'model'))
        model.readout.load_state_dict(checkpoint)

        model.eval()
        model.readout.eval()
        model = model.to('cuda')

        # Evaluate correlations
        object_corr.append(evaluate(model, object_set, spike_matrix=post_data.spike_matrix[post_data.object_rows, :]))
        stability.append(compute_stability(pre_matrix=pre_spike_array, post_data=post_data, idx_dict=stimulus_idx_dict))
        avatar_corr.append(evaluate(model, avatar_set, spike_matrix=post_data.spike_matrix[test_target_idx, :]))

        # Keep track of responses to bodies and objects. Will be needed to test whether channel respons more to
        # bodies than objects.
        for i, n in enumerate(idx):
            o = np.concatenate((post_data.spike_matrix[post_data.good_object_rows, n],
                                post_data.spike_matrix[post_data.bad_object_rows, n]))
            object_responses.append(o)
            b = np.concatenate((post_data.spike_matrix[post_data.good_body_rows, n],
                                post_data.spike_matrix[post_data.bad_body_rows, n]))
            body_responses.append(b)

        # Append resuls from this recording day
        # Keep only channels from specified region
        avatar_corr[-1] = [avatar_corr[-1][i] for i in idx]
        object_corr[-1] = [object_corr[-1][i] for i in idx]
        stability[-1] = [stability[-1][i] for i in idx]
        idx_list[-1] = [idx_list[-1][i] for i in idx]


    from itertools import chain

    # Unpack recording days into one list
    object_corr = np.array(list(chain(*object_corr)))
    avatar_corr = np.array(list(chain(*avatar_corr)))
    stability = np.array(list(chain(*stability)))
    idx_list = np.array(list(chain(*idx_list)))


    # Remove channels where test/retest stability < 0.6
    object_corr = np.array([object_corr[i] for i in range(len(stability)) if stability[i] > 0.6])
    avatar_corr = np.array([avatar_corr[i] for i in range(len(stability)) if stability[i] > 0.6])
    idx_list = np.array([idx_list[i] for i in range(len(stability)) if stability[i] > 0.6])
    object_responses = np.array([object_responses[i] for i in range(len(stability)) if stability[i] > 0.6], dtype='object')
    body_responses = np.array([body_responses[i] for i in range(len(stability)) if stability[i] > 0.6], dtype='object')
    # Remove stability elements last since it is the removal criterion
    stability = np.array([stability[i] for i in range(len(stability)) if stability[i] > 0.6])

    print('Remaining channels after stability:', idx_list)

    # Remove non body-category selective channels
    selective_idx = [] # False at position i if channel i is non-selective, otherwise True
    for i in range(len(stability)):
        o = np.array(object_responses[i], dtype=float)
        b = np.array(body_responses[i], dtype=float)
        stat, p = mannwhitneyu(o, b, alternative='less')
        if p > 0.01:
            selective_idx.append(False)
        elif p <= 0.01:
            selective_idx.append(True)

    # for i in range(len(selective_idx)):
    #     print('Index: {}. Channel: {}. Selective: {}'.format(i, idx_list[i], selective_idx[i]))

    object_corr = np.array([object_corr[i] for i in range(len(selective_idx)) if selective_idx[i] == True])
    avatar_corr = np.array([avatar_corr[i] for i in range(len(selective_idx)) if selective_idx[i] == True])
    idx_list = np.array([idx_list[i] for i in range(len(selective_idx)) if selective_idx[i] == True])
    stability = np.array([stability[i] for i in range(len(selective_idx)) if selective_idx[i] == True])
    object_responses = np.array([object_responses[i] for i in range(len(selective_idx)) if selective_idx[i] == True], dtype='object')
    body_responses = np.array([body_responses[i] for i in range(len(selective_idx)) if selective_idx[i] == True], dtype='object')

    print('Remaining channels after selectivity:', idx_list)


    assert len(object_corr) == len(avatar_corr)

    # Order by correlation on bodies
    order = list(np.argsort(avatar_corr[:, 0]))[::-1]
    object_corr = object_corr[order]
    avatar_corr = avatar_corr[order]
    stability = stability[order]

    return order, object_corr, avatar_corr, stability, min_number_objects

def main():
    base_path = os.getcwd()
    for region in ['ASB', 'MSB']:
        monkeys = ['monkeyG', 'monkeyT']
        output = {}
        for monkey in monkeys:
            order, object_corr, avatar_corr, stability, min_number_objects = get_stats(monkey, region)
            output[monkey] = [order, object_corr, avatar_corr, stability, min_number_objects]
            print('--- Results ---')
            print(monkey, region)
            print('Avatar correlation:', np.mean(avatar_corr[:, 0]))
            print('Object correlation:', np.mean(object_corr[:, 0]))
            print('Total number of selective channels:', len(object_corr))
            significant_channels = [channel for channel in object_corr if channel[1] < .05]
            print('Number of channels with significant object correlation:', len(significant_channels))

        plt.rcParams.update({'font.size': 7})
        plt.rcParams['pdf.fonttype'] = 42
        marker_size = 2


        x_val = 0
        # make plot
        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7 / 2.54, 3 / 2.54))
        for i, monkey in enumerate(monkeys):
            order, object_corr, avatar_corr, stability, min_number_objects = output[monkey]
            ax.scatter(range(x_val, x_val + len(object_corr)), avatar_corr[:, 0],
                          label='Monkey body test set', s=marker_size, c='#32bed7')
            ax.scatter(range(x_val, x_val + len(object_corr)), object_corr[:, 0], label='Objects', s=marker_size, c='#fbad27')
            ax.scatter(range(x_val, x_val + len(object_corr)), stability, label='Recording stability', s=marker_size, c='#9b8476')
            ax.set_xlabel('Recording Channel')
            ax.set_ylabel('Correlation')

            # Correlation confidence interval
            conf = compute_conf_interval(min_number_objects, 0.95)

            ax.hlines([conf, np.mean(object_corr[:, 0]), np.mean(avatar_corr[:, 0])], x_val, x_val + len(object_corr),
                      colors=['black', '#fbad27', '#32bed7'], linestyles='dashed', linewidth=1)
            ax.hlines([0], x_val, x_val + len(object_corr), colors='black', linewidth=1)
            if i == 0:
                ax.legend()
            if i == 1:
                ax.vlines([x_val + 0.5], 0, 1, linewidth=1, colors='black')
            x_val += len(object_corr)
        ax.set_yticks(ticks=[-0.3, 0, 0.2, 0.4, 0.6, 0.8, 1.0])
        fig.suptitle(str(region))
        plt.savefig(join(base_path, 'plots/corr_' + region + '.pdf'), bbox_inches='tight')
        plt.show()

if __name__ == '__main__':
    main()