"""

This is adapted for kmeans, allows filtering out examples where the special retain and control sets
did not belong to the same cluster.

"""
from typing import List, Optional

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('kmeans_filepath', None, '')

flags.DEFINE_list('special_retain_results_filepath', None, '')

flags.DEFINE_float('min_forget_set_kl', None, 'Only applies for reporting of aggregate statistics.')

flags.DEFINE_bool('check_special_retain_set', True, '')


###############################################################################

def _load_results(device: torch.device):
    ret = []
    for filepath in FLAGS.special_retain_results_filepath:
        results = cda_results.SpecialRetainCdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


def _print_summary(label: str, x: torch.Tensor):
    median = float(torch.median(x).detach().cpu().numpy())
    arith_mean = float(torch.mean(x).detach().cpu().numpy())
    print(f'{label}: {arith_mean} [mean], {median} [median]')


# def _print_result_summary(result: 'cda_results.SpecialRetainCdaRunResults'):
#     print(f'forget_example_kl: {float(result.forget_set_kls.kls.squeeze().detach().cpu().numpy())}')
#     _print_summary('special_retain_set_kls', result.special_retain_set_kls.kls)
#     _print_summary('special_evaluation_set_kls', result.special_evaluation_set_kls.kls)
#     _print_summary('evaluation_set_kls', result.evaluation_set_kls.kls)
#     print()


def _is_result_bad_with_kmeans(cluster_assignments, result):
    forget_example_index = torch.squeeze(result.forget_set_kls.example_indices)
    forget_example_cluster = cluster_assignments[forget_example_index]

    if FLAGS.check_special_retain_set and torch.any(forget_example_cluster != cluster_assignments[result.special_retain_set_kls.example_indices]):
        return True
    elif torch.any(forget_example_cluster != cluster_assignments[result.special_evaluation_set_kls.example_indices]):
        return True
    else:
        return False


def _print_aggregate_summary(cluster_assignments: torch.Tensor, results: List['cda_results.SpecialRetainCdaRunResults'], min_forget_set_kl: Optional[float]):
    forget_example_kls = []

    special_evaluation_set_kls_means = []
    special_evaluation_set_kls_medians = []

    evaluation_set_kls_means = []
    evaluation_set_kls_medians = []

    for r in results:
        forget_example_kl = r.forget_set_kls.kls.squeeze()
        if min_forget_set_kl is not None and forget_example_kl < min_forget_set_kl:
            continue
        elif _is_result_bad_with_kmeans(cluster_assignments, r):
            continue

        forget_example_kls.append(forget_example_kl)

        special_evaluation_set_kls_means.append(torch.mean(r.special_evaluation_set_kls.kls))
        special_evaluation_set_kls_medians.append(torch.median(r.special_evaluation_set_kls.kls))

        evaluation_set_kls_means.append(torch.mean(r.evaluation_set_kls.kls))
        evaluation_set_kls_medians.append(torch.median(r.evaluation_set_kls.kls))

    forget_example_kls = torch.stack(forget_example_kls, dim=0)
    special_evaluation_set_kls_means = torch.stack(special_evaluation_set_kls_means, dim=0)
    special_evaluation_set_kls_medians = torch.stack(special_evaluation_set_kls_medians, dim=0)
    evaluation_set_kls_means = torch.stack(evaluation_set_kls_means, dim=0)
    evaluation_set_kls_medians = torch.stack(evaluation_set_kls_medians, dim=0)

    if min_forget_set_kl is not None:
        print(f'{forget_example_kls.numel()} / {len(results)} results with forget set kl above {min_forget_set_kl}')

    _print_summary('forget_example_kls', forget_example_kls)
    _print_summary('special_evaluation_set_kls_means', special_evaluation_set_kls_means)
    _print_summary('special_evaluation_set_kls_medians', special_evaluation_set_kls_medians)
    _print_summary('evaluation_set_kls_means', evaluation_set_kls_means)
    _print_summary('evaluation_set_kls_medians', evaluation_set_kls_medians)
    print()


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cluster_assignments = kmeans.KmeansClusteringTorch.load_cluster_assignments(FLAGS.kmeans_filepath).to(device)

    results = _load_results(device)

    # for result in results:
    #     _print_result_summary(result)

    # TODO: Need to count the number of ones where the unlearning worked.
    print()
    _print_aggregate_summary(cluster_assignments, results, FLAGS.min_forget_set_kl)


if __name__ == "__main__":
    app.run(main)
