"""Given saved JSONs of perturbation results, creates a summary."""
import json
import os
from typing import Optional

from absl import app
from absl import flags

import torch

from npeff_torch.perturbations import perturbation_results

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('perturbation_result_filepaths', [],
                  'Should have been generated by the run_lrm_npeff_perturbations.py script. If multiple are provided, '
                  'then these are concatenated.')

flags.DEFINE_list('fisher_norm_ratio_filepaths', [],
                  'Should have been generated by the compute_fisher_norm_ratios.py script. If multiple are provided, '
                  'then these are concatenated.')

###############################################################################


def _load_results() -> Optional['perturbation_results.ExperimentPerturbationResults']:
    if not FLAGS.perturbation_result_filepaths:
        return None

    component_results = []
    for filepath in FLAGS.perturbation_result_filepaths:
        with open(os.path.expanduser(filepath), 'rt') as f:
            jason = json.load(f)
        res = perturbation_results.ExperimentPerturbationResults.from_json(jason)
        component_results.extend(res.component_results)

    return perturbation_results.ExperimentPerturbationResults(component_results=component_results)


def _load_fisher_norm_ratios() -> Optional[torch.Tensor]:
    if not FLAGS.fisher_norm_ratio_filepaths:
        return None

    ret = []

    for filepath in FLAGS.fisher_norm_ratio_filepaths:
        with open(os.path.expanduser(filepath), 'rt') as f:
            jason = json.load(f)
        for item in jason:
            ret.append(item['top_examples_mean_norm'] / item['baseline_examples_mean_norm'])

    return torch.tensor(ret, dtype=torch.float32)


def _geo_mean(x: torch.Tensor) -> torch.Tensor:
    return torch.exp(torch.sum(torch.log(x)) / x.numel())


def _print_summary(label: str, x: torch.Tensor):
    median = float(torch.median(x).detach().cpu().numpy())
    arith_mean = float(torch.mean(x).detach().cpu().numpy())
    geo_mean = float(_geo_mean(x).detach().cpu().numpy())
    print(f'{label}: {median} [median], {arith_mean} [arith_mean], {geo_mean} [geo_mean]')


###############################################################################


def main(_):
    results = _load_results()
    fisher_norm_ratios = _load_fisher_norm_ratios()

    if results is not None:
        kl_ratios = results.get_greatest_kl_ratios()
        top_examples_kls = results.get_greatest_kl_ratio_top_examples_kls()
        baseline_examples_kls = results.get_greatest_kl_ratio_baseline_examples_kls()

        print('n_components:', results.get_n_components())
        _print_summary('kl_ratios', kl_ratios)
        _print_summary('top_examples_kls', top_examples_kls)
        _print_summary('baseline_examples_kls', baseline_examples_kls)

    if fisher_norm_ratios is not None:
        print('n_components:', fisher_norm_ratios.numel())
        _print_summary('fisher_norm_ratios', fisher_norm_ratios)


if __name__ == "__main__":
    app.run(main)
