from collections import OrderedDict
from typing import Sequence

import arviz
import numpy as np
import pandas as pd


def summarize(samples: np.ndarray, names: Sequence, should_print: bool=True) -> pd.DataFrame:
    """Summarization code inspired by the NumPyro library.

    Copyright Contributors to the Pyro project. SPDX-License-Identifier:
    Apache-2.0.

    Args:
        samples: Array of samples generated by the Markov chain.
        names: Names of the variables sampled by the Markov chain.
        should_print: Indicator variable for whether or not to print a table of
            the summary statistics.

    Returns:
        metrics: A dataframe containing the metrics.

    """
    assert samples.shape[-1] == len(names)
    samples = {n: s for n, s in zip(names, samples.T)}
    summary_dict = {}
    for n, s in samples.items():
        mean = s.mean()
        std = s.std()
        median = np.median(s)
        x = np.vstack(np.split(s, 2))
        ess = arviz.ess(x)
        rhat = arviz.rhat(x)
        summary_dict[n] = OrderedDict([
            ('mean', mean),
            ('std', std),
            ('median', median),
            ('ess', ess),
            ('rhat', rhat)
        ])

    if should_print:
        row_names = {k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']'
                     for k, v in samples.items()}
        max_len = max(max(map(lambda x: len(x), row_names.values())), 10)
        name_format = '{:>' + str(max_len) + '}'
        num_metrics = 5
        header_format = name_format + ' {:>9}' * num_metrics
        columns = [''] + list(list(summary_dict.values())[0].keys())
        print()
        print(header_format.format(*columns))
        row_format = name_format + ' {:>9.3f}' * num_metrics
        for name, stats_dict in summary_dict.items():
            shape = stats_dict["mean"].shape
            if len(shape) == 0:
                print(row_format.format(name, *stats_dict.values()))
            else:
                for idx in product(*map(range, shape)):
                    idx_str = '[{}]'.format(','.join(map(str, idx)))
                    print(row_format.format(name + idx_str, *[v[idx] for v in stats_dict.values()]))
        print()

    metrics = pd.DataFrame(summary_dict).T
    return metrics
