import numpy as np


def expert_statistics(data, num_quantiles=5):
    """Returns a (batch_size, 15) array of expert statistics."""

    # Choose quantiles
    qs = np.linspace(0.025, 0.975, num_quantiles)

    # RT summaries
    rt_quantiles = np.quantile(data[..., 0], qs, axis=-1).T
    rt_mean = np.mean(data[..., 0], axis=1)[..., None]
    rt_std = np.std(data[..., 0], axis=1)[..., None]

    # EEG summaries
    eeg_quantiles = np.quantile(data[..., 2], qs, axis=-1).T
    eeg_mean = np.mean(data[..., 2], axis=1)[..., None]
    eeg_std = np.std(data[..., 2], axis=1)[..., None]

    # Accuracy
    accuracy = np.mean(data[..., 1], axis=1)[..., None]
    return np.c_[rt_quantiles, rt_mean, rt_std, eeg_quantiles, eeg_mean, eeg_std, accuracy].astype(np.float32)


def configurator(input_dict, mode='direct_hybrid'):
    """Configures automatic and human summary statistics."""

    assert mode in [
        'expert',
        'learner',
        'direct_hybrid',
        'mmd_hybrid',
        'generative_hybrid'
    ]

    out_dict = {}

    # Case expert hand-crafted summary statistics only
    match mode:
        case 'expert':
            out_dict['direct_conditions'] = expert_statistics(input_dict['sim_data'])

        # Case learned summary statistics only
        case 'learner':
            out_dict['summary_conditions'] = input_dict['sim_data'].astype(np.float32)

        # Case learned and hand-crafted together, no optimization
        case 'direct_hybrid':
            out_dict['direct_conditions'] = expert_statistics(input_dict['sim_data'])
            out_dict['summary_conditions'] = input_dict['sim_data'].astype(np.float32)

        # Case learned and hand-crafted together, joint embedding
        case 'mmd_hybrid':
            expert_stats = expert_statistics(input_dict['sim_data'])
            raw_data = input_dict['sim_data'].astype(np.float32)
            out_dict['summary_conditions'] = (raw_data, expert_stats)

        # Case learned summaries are conditioned on expert statistics
        case 'generative_hybrid':
            out_dict['direct_conditions'] = expert_statistics(input_dict['sim_data'])
            out_dict['summary_conditions'] = input_dict['sim_data'].astype(np.float32)

        case _:
            raise NotImplementedError(f'No mode {mode} exists.')

    out_dict['parameters'] = input_dict['prior_draws'].astype(np.float32)
    return out_dict


def configurator_model_comparison(input_dict, base_config, mode='hybrid'):
    """Configures automatic and human summary statistics."""

    assert mode in [
        'expert',
        'learner',
        'hybrid',
    ]

    # Collect model inputs
    out_dict = base_config(input_dict)

    # Case expert hand-crafted summary statistics only
    match mode:
        case 'expert':
            out_dict['direct_conditions'] = expert_statistics(out_dict['summary_conditions'])
            del out_dict['summary_conditions']

        # Case learned summary statistics only - covered by default
        case 'learner':
            pass

        # Case learned and hand-crafted together, no optimization
        case 'hybrid':
            out_dict['direct_conditions'] = expert_statistics(out_dict['summary_conditions'])

        case _:
            raise NotImplementedError(f'No mode {mode} exists.')

    return out_dict
