import numpy as np


def summarize_expert(x):
    """Custom summary statistics from Papamakarios et al., (2016)."""

    n_points = x.shape[1]
    lag1 = int(0.2 * n_points)
    lag2 = int(0.4 * n_points)

    # Means of the two time series
    x_means = np.mean(x, axis=1)

    # Log-variances of the two time-series
    x_logvars = np.log1p(np.var(x, axis=1))

    # Auto-correlations at lag 0.2 and 0.4 time units
    x_auto = np.array([np.corrcoef(np.c_[x[i][:-lag1], x[i][lag1:]], rowvar=False) for i in range(x.shape[0])])
    x_auto11_1 = x_auto[:, 0, 2]
    x_auto12_2 = x_auto[:, 1, 3]
    x_auto = np.array([np.corrcoef(np.c_[x[i][:-lag2], x[i][lag2:]], rowvar=False) for i in range(x.shape[0])])
    x_auto21_1 = x_auto[:, 0, 2]
    x_auto22_2 = x_auto[:, 1, 3]

    # Cross-correlation
    x[:, :, 0] = (x[:, :, 0] - np.mean(x[:, :, 0], axis=1)[:, np.newaxis]) / \
                 (np.std(x[:, :, 0], axis=1)[:, np.newaxis] * x.shape[1])
    x[:, :, 1] = (x[:, :, 1] - np.mean(x[:, :, 1], axis=1)[:, np.newaxis]) / \
                 (np.std(x[:, :, 1], axis=1)[:, np.newaxis])
    x_cross = np.array([np.correlate(x[i, :, 0], x[i, :, 1]) for i in range(x.shape[0])])
    stats = np.c_[x_means, x_logvars, x_auto11_1, x_auto12_2, x_auto21_1, x_auto22_2, x_cross]
    return stats.astype(np.float32)


def match_type(mode, out_dict, raw_data, summary_conditions, direct_conditions=None):
    """Caution: works in-place!"""

    match mode:

        # Case learned summary statistics only
        case 'learner':
            out_dict['summary_conditions'] = summary_conditions

        # Case learned and hand-crafted together, no optimization
        case 'direct_hybrid':
            if direct_conditions is None:
                out_dict['direct_conditions'] = summarize_expert(raw_data)
            else:
                out_dict['direct_conditions'] = direct_conditions
            out_dict['summary_conditions'] = summary_conditions

        # Case learned and hand-crafted together, joint embedding
        case 'mmd_hybrid':
            if direct_conditions is None:
                expert_stats = summarize_expert(raw_data)
            else:
                expert_stats = direct_conditions
            out_dict['summary_conditions'] = (summary_conditions, expert_stats)

        # Case learned summaries are conditioned on expert statistics
        case 'generative_hybrid':
            if direct_conditions is None:
                out_dict['direct_conditions'] = summarize_expert(raw_data)
            else:
                out_dict['direct_conditions'] = direct_conditions
            out_dict['summary_conditions'] = summary_conditions

        case _:
            raise NotImplementedError(f'No mode {mode} exists.')


def configurator(input_dict, mode='hybrid', summary_type='transformer', T=20, scale=1000, ai_expert=False):
    """Configures automatic and expert summary statistics."""

    assert mode in [
        'expert',
        'learner',
        'direct_hybrid',
        'mmd_hybrid',
        'generative_hybrid'
    ]

    out_dict = {}
    time_series = input_dict['sim_data'] / scale
    batch_size = time_series.shape[0]

    # Case expert
    if mode == 'expert':
        if not ai_expert:
            out_dict['direct_conditions'] = summarize_expert(time_series).astype(np.float32)
        else:
            out_dict['direct_conditions'] = input_dict['summaries']

    # Otherwise summary statistics will exist
    else:
        # Case transformer summary, add positional (standardized temporal encodings)
        if summary_type == 'transformer':
            time_encodings = np.linspace(0, T, T, dtype=np.float32) / T
            time_encodings = np.stack([time_encodings] * batch_size, axis=0)
            time_encodings = time_encodings[..., None]
            summary_conditions = np.concatenate([time_series, time_encodings], axis=-1)
            summary_conditions = summary_conditions.astype(np.float32)
        else:
            summary_conditions = time_series.astype(np.float32)

        # Determine dictionary in-place
        if ai_expert:
            match_type(mode, out_dict, time_series, summary_conditions, direct_conditions=input_dict['summaries'])
        else:
            match_type(mode, out_dict, time_series, summary_conditions)

    # Add in parameters and return
    out_dict['parameters'] = input_dict['prior_draws'].astype(np.float32)
    return out_dict
