"""Provides a direct comparison to the control set."""

from typing import List, Optional

from absl import app
from absl import flags

import torch

from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('mitigation_special_retain_results_filepath', None, '')
flags.DEFINE_list('control_special_retain_results_filepath', None, '')

flags.DEFINE_float('min_forget_set_kl', None, '')

flags.DEFINE_bool('print_kept_forget_example_indices', False, '')

flags.DEFINE_list('forget_example_indices', [], 'If non-empty, then only keep results for these forget example indices.')


###############################################################################

def _get_forget_example_indices():
    if not FLAGS.forget_example_indices:
        return None
    return [int(i) for i in FLAGS.forget_example_indices]


def _load_results(results_filepaths: List[str], device: torch.device):
    ret = []
    for filepath in results_filepaths:
        results = cda_results.SpecialRetainCdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


def _make_pairs(mitigation_results, control_results, forget_example_indices):
    mitigation_results_by_example_index = {
        int(torch.squeeze(r.forget_set_kls.example_indices).detach().cpu().numpy()): r
        for r in mitigation_results
    }
    control_results_by_example_index = {
        int(torch.squeeze(r.forget_set_kls.example_indices).detach().cpu().numpy()): r
        for r in control_results
    }

    assert set(mitigation_results_by_example_index.keys()) == set(control_results_by_example_index.keys())
    assert len(mitigation_results_by_example_index) == len(mitigation_results)
    assert len(control_results_by_example_index) == len(control_results)

    keep_pairs = []
    discard_pairs = []
    for example_index in control_results_by_example_index.keys():
        mitigation_result = mitigation_results_by_example_index[example_index]
        control_result = control_results_by_example_index[example_index]

        mitigation_forget_example_kl = mitigation_result.forget_set_kls.kls.squeeze()
        control_forget_example_kl = control_result.forget_set_kls.kls.squeeze()

        if forget_example_indices is not None and example_index not in forget_example_indices:
            discard_pairs.append((mitigation_result, control_result))
        elif mitigation_forget_example_kl < FLAGS.min_forget_set_kl or control_forget_example_kl < FLAGS.min_forget_set_kl:
            discard_pairs.append((mitigation_result, control_result))
        else:
            keep_pairs.append((mitigation_result, control_result))

    return keep_pairs, discard_pairs


def _analyze_pairs(pairs):
    mitigation_forget_example_kls = []
    mitigation_special_evaluation_set_kls = []
    mitigation_evaluation_set_kls = []

    control_forget_example_kls = []
    control_special_evaluation_set_kls = []
    control_evaluation_set_kls = []

    for mitigation_result, control_result in pairs:
        mitigation_forget_example_kl = mitigation_result.forget_set_kls.kls.squeeze()
        mitigation_special_evaluation_set_kl = torch.median(mitigation_result.special_evaluation_set_kls.kls)
        mitigation_evaluation_set_kl = torch.median(mitigation_result.evaluation_set_kls.kls)

        control_forget_example_kl = control_result.forget_set_kls.kls.squeeze()
        control_special_evaluation_set_kl = torch.median(control_result.special_evaluation_set_kls.kls)
        control_evaluation_set_kl = torch.median(control_result.evaluation_set_kls.kls)

        mitigation_forget_example_kls.append(mitigation_forget_example_kl)
        mitigation_special_evaluation_set_kls.append(mitigation_special_evaluation_set_kl)
        mitigation_evaluation_set_kls.append(mitigation_evaluation_set_kl)

        control_forget_example_kls.append(control_forget_example_kl)
        control_special_evaluation_set_kls.append(control_special_evaluation_set_kl)
        control_evaluation_set_kls.append(control_evaluation_set_kl)

    mitigation_forget_example_kls = torch.stack(mitigation_forget_example_kls, dim=0)
    mitigation_special_evaluation_set_kls = torch.stack(mitigation_special_evaluation_set_kls, dim=0)
    mitigation_evaluation_set_kls = torch.stack(mitigation_evaluation_set_kls, dim=0)

    control_forget_example_kls = torch.stack(control_forget_example_kls, dim=0)
    control_special_evaluation_set_kls = torch.stack(control_special_evaluation_set_kls, dim=0)
    control_evaluation_set_kls = torch.stack(control_evaluation_set_kls, dim=0)

    print(f'mitigation_forget_example_kls: {torch.mean(mitigation_forget_example_kls).detach().cpu().numpy()}')
    print(f'mitigation_special_evaluation_set_kls: {torch.mean(mitigation_special_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'mitigation_evaluation_set_kls: {torch.mean(mitigation_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'control_forget_example_kls: {torch.mean(control_forget_example_kls).detach().cpu().numpy()}')
    print(f'control_special_evaluation_set_kls: {torch.mean(control_special_evaluation_set_kls).detach().cpu().numpy()}')
    print(f'control_evaluation_set_kls: {torch.mean(control_evaluation_set_kls).detach().cpu().numpy()}')


def _print_kept_forget_example_indices(keep_pairs):
    forget_example_indices = set()
    for mitigation_result, control_result in keep_pairs:
        # Should be the same for mitigation_result and control_results.
        forget_example_index = int(torch.squeeze(mitigation_result.forget_set_kls.example_indices).detach().cpu().numpy())
        forget_example_indices.add(forget_example_index)

    print('Kept forget example indices:')
    print(forget_example_indices)


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    forget_example_indices = _get_forget_example_indices()

    mitigation_results = _load_results(FLAGS.mitigation_special_retain_results_filepath, device)
    control_results = _load_results(FLAGS.control_special_retain_results_filepath, device)

    keep_pairs, discard_pairs = _make_pairs(mitigation_results, control_results, forget_example_indices)

    if FLAGS.print_kept_forget_example_indices:
        _print_kept_forget_example_indices(keep_pairs)

    print(f'{len(keep_pairs)} / {len(mitigation_results)} results with forget set kl above {FLAGS.min_forget_set_kl}')

    _analyze_pairs(keep_pairs)


if __name__ == "__main__":
    app.run(main)

