""""""
from typing import List, Optional

from absl import app
from absl import flags

import torch

from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('special_retain_results_filepath', None, '')

flags.DEFINE_float('min_forget_set_kl', None, 'Only applies for reporting of aggregate statistics.')


###############################################################################


def _load_results(device: torch.device):
    ret = []
    for filepath in FLAGS.special_retain_results_filepath:
        results = cda_results.SpecialRetainCdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


def _print_summary(label: str, x: torch.Tensor):
    median = float(torch.median(x).detach().cpu().numpy())
    arith_mean = float(torch.mean(x).detach().cpu().numpy())
    print(f'{label}: {arith_mean} [mean], {median} [median]')


def _print_result_summary(result: 'cda_results.SpecialRetainCdaRunResults'):
    print(f'forget_example_kl: {float(result.forget_set_kls.kls.squeeze().detach().cpu().numpy())}')
    _print_summary('special_retain_set_kls', result.special_retain_set_kls.kls)
    _print_summary('special_evaluation_set_kls', result.special_evaluation_set_kls.kls)
    _print_summary('evaluation_set_kls', result.evaluation_set_kls.kls)
    print()


def _print_aggregate_summary(results: List['cda_results.SpecialRetainCdaRunResults'], min_forget_set_kl: Optional[float]):
    forget_example_kls = []

    special_evaluation_set_kls_means = []
    special_evaluation_set_kls_medians = []

    evaluation_set_kls_means = []
    evaluation_set_kls_medians = []

    for r in results:
        forget_example_kl = r.forget_set_kls.kls.squeeze()
        if min_forget_set_kl is not None and forget_example_kl < min_forget_set_kl:
            continue

        forget_example_kls.append(forget_example_kl)

        special_evaluation_set_kls_means.append(torch.mean(r.special_evaluation_set_kls.kls))
        special_evaluation_set_kls_medians.append(torch.median(r.special_evaluation_set_kls.kls))

        evaluation_set_kls_means.append(torch.mean(r.evaluation_set_kls.kls))
        evaluation_set_kls_medians.append(torch.median(r.evaluation_set_kls.kls))

    forget_example_kls = torch.stack(forget_example_kls, dim=0)
    special_evaluation_set_kls_means = torch.stack(special_evaluation_set_kls_means, dim=0)
    special_evaluation_set_kls_medians = torch.stack(special_evaluation_set_kls_medians, dim=0)
    evaluation_set_kls_means = torch.stack(evaluation_set_kls_means, dim=0)
    evaluation_set_kls_medians = torch.stack(evaluation_set_kls_medians, dim=0)

    if min_forget_set_kl is not None:
        print(f'{forget_example_kls.numel()} / {len(results)} results with forget set kl above {min_forget_set_kl}')

    _print_summary('forget_example_kls', forget_example_kls)
    _print_summary('special_evaluation_set_kls_means', special_evaluation_set_kls_means)
    _print_summary('special_evaluation_set_kls_medians', special_evaluation_set_kls_medians)
    _print_summary('evaluation_set_kls_means', evaluation_set_kls_means)
    _print_summary('evaluation_set_kls_medians', evaluation_set_kls_medians)
    print()


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    results = _load_results(device)

    for result in results:
        _print_result_summary(result)

    # TODO: Need to count the number of ones where the unlearning worked.
    print()
    _print_aggregate_summary(results, FLAGS.min_forget_set_kl)


if __name__ == "__main__":
    app.run(main)
