"""
"""
import dataclasses
from typing import List

from absl import app
from absl import flags

import torch

from npeff_torch.decomps.npeff import lrm_npeff_decomps
from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################
SIMILAR_EXAMPLE_SELECTION_STRATEGIES = ['cosine_similarity', 'top_coefficient_top_examples']
###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('cda_results_filepath', None, '')
flags.DEFINE_string('npeff_filepath', None, '')

flags.DEFINE_integer('n_top_cos_sim', None, '')
flags.DEFINE_integer('n_top_kl', None, '')

flags.DEFINE_enum('similar_example_selection_strategy', 'cosine_similarity', SIMILAR_EXAMPLE_SELECTION_STRATEGIES, '')


R"""
Possible todos:
    - take into account the pef norms as an option?

"""

###############################################################################


def _load_results(device: torch.device):
    ret = []
    for filepath in FLAGS.cda_results_filepath:
        results = cda_results.CdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


def _load_normalized_coeffs(device: torch.device) -> torch.Tensor:
    # Each example's coefficients will be normalized to unit l2 norm.
    # ret.shape = [n_examples_total, n_components]
    decomp = lrm_npeff_decomps.LrmNpeffDecomposition.load(FLAGS.npeff_filepath, load_W=True, load_G=False)
    W = decomp.W.to(device)
    return torch.nn.functional.normalize(W, dim=-1)


###############################################################################

def _print_summary(label: str, x: torch.Tensor):
    median = float(torch.median(x).detach().cpu().numpy())
    arith_mean = float(torch.mean(x).detach().cpu().numpy())
    print(f'{label}: {arith_mean} [mean], {median} [median]')


@dataclasses.dataclass
class AnalysisResult1:
    forget_example_kl: torch.Tensor

    # The top coefficient cosine similarities with the forgetten example.
    top_cos_sims: torch.Tensor
    # The kl-divergences for the examples with the highest cosine similarities with the forgetten example.
    top_cos_sim_example_kls: torch.Tensor

    # The examples (other than the forgotten example) with the largest kl-divergences.
    top_kls: torch.Tensor
    # The coefficient cosine similarities of the examples with the largest kl-divergence.
    top_kl_cos_sims: torch.Tensor

    # The averages across the non-forget example.
    avg_cos_sim: torch.Tensor
    median_cos_sim: torch.Tensor
    avg_kl: torch.Tensor
    median_kl: torch.Tensor

    def print_summary(self):
        print(f'forget_example_kl: {float(self.forget_example_kl.detach().cpu().numpy())}')
        # print(f'avg_cos_sim: {float(self.avg_cos_sim.detach().cpu().numpy())}')
        # print(f'avg_kl: {float(self.avg_kl.detach().cpu().numpy())}')

        print(f'avg_cos_sim: {float(self.avg_cos_sim.detach().cpu().numpy())} [mean] {float(self.median_cos_sim.detach().cpu().numpy())} [median]')
        print(f'avg_kl: {float(self.avg_kl.detach().cpu().numpy())} [mean] {float(self.median_kl.detach().cpu().numpy())} [median]')
        _print_summary('top_cos_sims', self.top_cos_sims)
        _print_summary('top_cos_sim_example_kls', self.top_cos_sim_example_kls)
        _print_summary('top_kls', self.top_kls)
        _print_summary('top_kl_cos_sims', self.top_kl_cos_sims)


def _analyze1(
    result: 'cda_results.CdaRunResults',
    W: torch.Tensor,
    *,
    n_top_cos_sim: int,
    n_top_kl: int,
):
    forget_example_index = torch.squeeze(result.forget_set_kls.example_indices)
    forget_example_w = W[forget_example_index]

    evaluation_examples_indices = result.evaluation_set_kls.example_indices

    evaluation_examples_W = W[evaluation_examples_indices]
    # Make sure that the most similar example indices do NOT belong to the forget set.
    evaluation_examples_W = evaluation_examples_W[evaluation_examples_indices != forget_example_index]
    evaluation_examples_kls = result.evaluation_set_kls.kls[evaluation_examples_indices != forget_example_index]

    # Since coeffs are non-negative, the cosine similarities will always be non-negative.
    coeffs_cs = torch.einsum('ei,i', evaluation_examples_W, forget_example_w)

    top_cs, top_cs_indices = torch.topk(coeffs_cs, k=n_top_cos_sim)
    top_kls, top_kl_indices = torch.topk(evaluation_examples_kls, k=n_top_kl)

    top_cs_kls = evaluation_examples_kls[top_cs_indices]
    top_kl_cs = coeffs_cs[top_kl_indices]

    return AnalysisResult1(
        forget_example_kl=torch.squeeze(result.forget_set_kls.kls),
        top_cos_sims=top_cs,
        top_cos_sim_example_kls=top_cs_kls,
        top_kls=top_kls,
        top_kl_cos_sims=top_kl_cs,
        avg_cos_sim=coeffs_cs.mean(),
        median_cos_sim=torch.median(coeffs_cs),
        avg_kl=evaluation_examples_kls.mean(),
        median_kl=torch.median(evaluation_examples_kls),
    )


def _print_aggregate_summary(analysis_results: List['AnalysisResult1'], factor: float = 1.0):
    # Count if the top examples by cosine sim have kl mean/median greater than the mean/median kl.
    greater_mean_kl_count = 0
    greater_median_kl_count = 0
    # Count if the top examples by kl have cosine sim  greater than the mean/median cosine sim.
    greater_mean_cos_sim_count = 0
    greater_median_cos_sim_count = 0

    for r in analysis_results:
        if torch.mean(r.top_cos_sim_example_kls) > factor * r.avg_kl:
            greater_mean_kl_count += 1
        if torch.median(r.top_cos_sim_example_kls) > factor * r.median_kl:
            greater_median_kl_count += 1
        if torch.mean(r.top_kl_cos_sims) > factor * r.avg_cos_sim:
            greater_mean_cos_sim_count += 1
        if torch.median(r.top_kl_cos_sims) > factor * r.median_cos_sim:
            greater_median_cos_sim_count += 1
    
    print(f'factor: {factor}')
    print(f'larger kl [mean]: {greater_mean_kl_count / len(analysis_results)}')
    print(f'larger kl [median]: {greater_median_kl_count / len(analysis_results)}')
    print(f'larger cos sim [mean]: {greater_mean_cos_sim_count / len(analysis_results)}')
    print(f'larger cos sim [median]: {greater_median_cos_sim_count / len(analysis_results)}')


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    results = _load_results(device)
    W = _load_normalized_coeffs(device)

    analysis_results = []
    for result in results:
        analysis_result = _analyze1(result, W, n_top_cos_sim=FLAGS.n_top_cos_sim, n_top_kl=FLAGS.n_top_kl)
        analysis_results.append(analysis_result)

        analysis_result.print_summary()
        print()

    print()
    _print_aggregate_summary(analysis_results, factor=1.0)
    print()
    _print_aggregate_summary(analysis_results, factor=2.0)
    print()
    _print_aggregate_summary(analysis_results, factor=3.0)
    print()
    _print_aggregate_summary(analysis_results, factor=5.0)
    print()
    _print_aggregate_summary(analysis_results, factor=10.0)
    print()


if __name__ == "__main__":
    app.run(main)
