"""Provides a direct comparison to the control set.

This is adapted for kmeans, allows filtering out examples where the special retain and control sets
did not belong to the same cluster.

"""
from typing import List, Optional

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('kmeans_filepath', None, '')

flags.DEFINE_list('mitigation_special_retain_results_filepath', None, '')
flags.DEFINE_list('control_special_retain_results_filepath', None, '')

flags.DEFINE_float('min_forget_set_kl', None, '')

flags.DEFINE_bool('print_kept_forget_example_indices', False, '')

###############################################################################


def _load_results(results_filepaths: List[str], device: torch.device):
    ret = []
    for filepath in results_filepaths:
        results = cda_results.SpecialRetainCdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


def _is_result_bad_with_kmeans(cluster_assignments, result, check_special_retain_set: bool):
    forget_example_index = torch.squeeze(result.forget_set_kls.example_indices)
    forget_example_cluster = cluster_assignments[forget_example_index]

    if check_special_retain_set and torch.any(forget_example_cluster != cluster_assignments[result.special_retain_set_kls.example_indices]):
        return True
    elif torch.any(forget_example_cluster != cluster_assignments[result.special_evaluation_set_kls.example_indices]):
        return True
    else:
        return False


def _make_pairs(cluster_assignments, mitigation_results, control_results):
    mitigation_results_by_example_index = {
        int(torch.squeeze(r.forget_set_kls.example_indices).detach().cpu().numpy()): r
        for r in mitigation_results
    }
    control_results_by_example_index = {
        int(torch.squeeze(r.forget_set_kls.example_indices).detach().cpu().numpy()): r
        for r in control_results
    }

    assert set(mitigation_results_by_example_index.keys()) == set(control_results_by_example_index.keys())
    assert len(mitigation_results_by_example_index) == len(mitigation_results)
    assert len(control_results_by_example_index) == len(control_results)

    keep_pairs = []
    discard_pairs = []
    for example_index in control_results_by_example_index.keys():
        mitigation_result = mitigation_results_by_example_index[example_index]
        control_result = control_results_by_example_index[example_index]

        mitigation_forget_example_kl = mitigation_result.forget_set_kls.kls.squeeze()
        control_forget_example_kl = control_result.forget_set_kls.kls.squeeze()

        if mitigation_forget_example_kl < FLAGS.min_forget_set_kl or control_forget_example_kl < FLAGS.min_forget_set_kl:
            discard_pairs.append((mitigation_result, control_result))
        elif _is_result_bad_with_kmeans(cluster_assignments, mitigation_result, True) or _is_result_bad_with_kmeans(cluster_assignments, control_result, False):
            discard_pairs.append((mitigation_result, control_result))
        else:
            keep_pairs.append((mitigation_result, control_result))

    return keep_pairs, discard_pairs


def _analyze_pairs(pairs):
    mitigation_forget_example_kls = []
    mitigation_special_evaluation_set_kls = []
    mitigation_evaluation_set_kls = []

    control_forget_example_kls = []
    control_special_evaluation_set_kls = []
    control_evaluation_set_kls = []

    for mitigation_result, control_result in pairs:
        mitigation_forget_example_kl = mitigation_result.forget_set_kls.kls.squeeze()
        mitigation_special_evaluation_set_kl = torch.median(mitigation_result.special_evaluation_set_kls.kls)
        mitigation_evaluation_set_kl = torch.median(mitigation_result.evaluation_set_kls.kls)

        control_forget_example_kl = control_result.forget_set_kls.kls.squeeze()
        control_special_evaluation_set_kl = torch.median(control_result.special_evaluation_set_kls.kls)
        control_evaluation_set_kl = torch.median(control_result.evaluation_set_kls.kls)

        mitigation_forget_example_kls.append(mitigation_forget_example_kl)
        mitigation_special_evaluation_set_kls.append(mitigation_special_evaluation_set_kl)
        mitigation_evaluation_set_kls.append(mitigation_evaluation_set_kl)

        control_forget_example_kls.append(control_forget_example_kl)
        control_special_evaluation_set_kls.append(control_special_evaluation_set_kl)
        control_evaluation_set_kls.append(control_evaluation_set_kl)

    mitigation_forget_example_kls = torch.stack(mitigation_forget_example_kls, dim=0)
    mitigation_special_evaluation_set_kls = torch.stack(mitigation_special_evaluation_set_kls, dim=0)
    mitigation_evaluation_set_kls = torch.stack(mitigation_evaluation_set_kls, dim=0)

    control_forget_example_kls = torch.stack(control_forget_example_kls, dim=0)
    control_special_evaluation_set_kls = torch.stack(control_special_evaluation_set_kls, dim=0)
    control_evaluation_set_kls = torch.stack(control_evaluation_set_kls, dim=0)

    print(f'mitigation_forget_example_kls: {torch.mean(mitigation_forget_example_kls).detach().cpu().numpy()}')
    print(f'mitigation_special_evaluation_set_kls: {torch.mean(mitigation_special_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'mitigation_evaluation_set_kls: {torch.mean(mitigation_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'control_forget_example_kls: {torch.mean(control_forget_example_kls).detach().cpu().numpy()}')
    print(f'control_special_evaluation_set_kls: {torch.mean(control_special_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'control_evaluation_set_kls: {torch.mean(control_evaluation_set_kls).detach().cpu().numpy()}')


def _print_kept_forget_example_indices(keep_pairs):
    forget_example_indices = set()
    for mitigation_result, control_result in keep_pairs:
        # Should be the same for mitigation_result and control_results.
        forget_example_index = int(torch.squeeze(mitigation_result.forget_set_kls.example_indices).detach().cpu().numpy())
        forget_example_indices.add(forget_example_index)

    print('Kept forget example indices:')
    print(forget_example_indices)


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cluster_assignments = kmeans.KmeansClusteringTorch.load_cluster_assignments(FLAGS.kmeans_filepath)

    mitigation_results = _load_results(FLAGS.mitigation_special_retain_results_filepath, device)
    control_results = _load_results(FLAGS.control_special_retain_results_filepath, device)

    keep_pairs, discard_pairs = _make_pairs(cluster_assignments, mitigation_results, control_results)

    if FLAGS.print_kept_forget_example_indices:
        _print_kept_forget_example_indices(keep_pairs)

    print(f'{len(keep_pairs)} / {len(mitigation_results)} results with forget set kl above {FLAGS.min_forget_set_kl}')

    _analyze_pairs(keep_pairs)


if __name__ == "__main__":
    app.run(main)

