from os.path import join

import numpy as np
import torch.utils.data
import torchvision
from scipy.stats import pearsonr

from model.Dataset import FullDataset


# Make val and test set
def build_evaluation_datasets(config, spike_array_path, idx_from_training):
    '''
    Builds the evaluation datasets for the body images
    The difference to training datasets is that train/val/test idx from training are reused.
    Args:
        config:
        spike_array_path:
        idx_from_training: Index dict output by the training scripts. Since the distribution of val/test idx is
        randomized across runs, we must keep track of these indices for each model.

    Returns:

    '''
    # Make val and test set
    base_path = config['base_path']

    val_set = FullDataset(stimulus_path=join(base_path, config['stimulus_path']),
                                   name_path=None, transform=torchvision.transforms.ToTensor(),
                                    spike_array_path=spike_array_path, idx=idx_from_training['val'], target_idx=idx_from_training['val_target'])


    test_set = FullDataset(stimulus_path=join(base_path, config['stimulus_path']),
                                 name_path=None, transform=torchvision.transforms.ToTensor(),
                                 spike_array_path=spike_array_path, idx=idx_from_training['test'], target_idx=idx_from_training['test_target'])

    # Get output of core model because features are precomputed for the train images
    train_set = FullDataset(stimulus_path=join(base_path, config['stimulus_path']),
                          name_path=None, transform=torchvision.transforms.ToTensor(),
                          spike_array_path=spike_array_path, idx=idx_from_training['train'], target_idx=idx_from_training['train_target'])

    return train_set, val_set, test_set


def predict(model, dataset):
    '''
    Predict model responses on dataset
    Args:
        model: instance of class GaussianReadoutModel
        dataset: instance of clas Dataset

    Returns: tuple containing recorded and predicted responses to the dataset
    '''
    model.readout.eval()
    model = model.to('cuda')

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=32)
    y, yhat = [], []
    for i, (x, target) in enumerate(data_loader):
        x, target = x.to('cuda'), target.to('cuda')
        output = model(x)
        y.append(target.detach())
        yhat.append(output.detach())
    y = torch.cat(y).cpu().numpy()
    yhat = torch.cat(yhat).cpu().numpy()

    return y, yhat


def compute_ro(y, yhat):
    '''
    Get correlation between predicted and recorded responses
    Returns: array of shape (n_neurons,) containing the correlations for all channels
    '''
    ros = []
    for n in range(y.shape[1]):
        ros.append(pearsonr(y[:, n], yhat[:, n])[0])
    return np.array(ros)

def evaluate(model, config, spike_array_path, idx_from_training):
    '''
    Evaluates the model and returns results
    Args:
        model:
        config:
        spike_array_path:
        idx_from_training: dictionary containg train/val/test indices
    Returns: Dict containing the correlations

    '''
    # print('Evaluating on bodies...')
    train_set, val_set, test_set = build_evaluation_datasets(config, spike_array_path, idx_from_training)

    datasets, names = [train_set, val_set, test_set], ['Train', 'Val', 'Test']
    out_dict = {}
    for dataset, name in zip(datasets, names):
        y, yhat = predict(model, dataset)
        ros = compute_ro(y, yhat)
        out_dict[name] = {'Correlation': ros}
    return out_dict