"""Uses per example information (such as PEFs or gradients) to do CDA."""
import dataclasses
from typing import List

from absl import app
from absl import flags

import torch

from npeff_torch.peis.fishers.formats import frdn_lrm_pefs
from npeff_torch.peis.gradients.formats import dn_gradients
from npeff_torch.unlearning.gradient_ascent import cda_results

###############################################################################
_PEI_TYPES = ['pef', 'gradient']
###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_list('cda_results_filepath', None, '')

# TODO: Support something like --n_values_per_pei.
flags.DEFINE_list('pei_filepaths', None, '')
flags.DEFINE_enum('pei_type', None, _PEI_TYPES, '')

flags.DEFINE_integer('n_top_cos_sim', None, '')
flags.DEFINE_integer('n_top_kl', None, '')


###############################################################################


def _load_results(device: torch.device):
    ret = []
    for filepath in FLAGS.cda_results_filepath:
        results = cda_results.CdaRunResults.load_results_from_file(filepath)
        for r in results:
            # Make sure this is a single example forgetting.
            assert r.forget_set_size == 1
            ret.append(r.to(device))
    return ret


@torch.no_grad()
def _load_normalized_peis() -> torch.Tensor:
    peis = []

    if FLAGS.pei_type == 'pef':
        for filepath in FLAGS.pei_filepaths:
            pefs = frdn_lrm_pefs.load_pefs(filepath)
            norms = frdn_lrm_pefs.load_pef_frobenius_norms(filepath)

            norms = norms.sqrt_()
            pefs /= norms[:, None, None]
            peis.append(pefs)

    elif FLAGS.pei_type == 'gradient':
        for filepath in FLAGS.pei_filepaths:
            gradients = dn_gradients.load_gradients(filepath)
            norms = dn_gradients.load_norms(filepath)

            gradients /= norms[:, None]
            peis.append(gradients)

    else:
        raise ValueError(FLAGS.pei_type)

    return torch.cat(peis, dim=0)


###############################################################################

def _compute_sims(
    peis: torch.Tensor,
    forget_example_pei: torch.Tensor,
) -> torch.Tensor:
    if FLAGS.pei_type == 'pef':
        # shapes = [examples?, rank, d_proj]
        temp = torch.einsum('eij,kj->eik', peis, forget_example_pei)
        return torch.einsum('eik,eik->e', temp, temp)

    elif FLAGS.pei_type == 'gradient':
        return torch.einsum('ei,i', peis, forget_example_pei)

    else:
        raise ValueError(FLAGS.pei_type)

###############################################################################


def _print_summary(label: str, x: torch.Tensor):
    median = float(torch.median(x).detach().cpu().numpy())
    arith_mean = float(torch.mean(x).detach().cpu().numpy())
    print(f'{label}: {arith_mean} [mean], {median} [median]')


@dataclasses.dataclass
class AnalysisResult1:
    forget_example_kl: torch.Tensor

    # The top coefficient cosine similarities with the forgetten example.
    top_cos_sims: torch.Tensor
    # The kl-divergences for the examples with the highest cosine similarities with the forgetten example.
    top_cos_sim_example_kls: torch.Tensor

    # The examples (other than the forgotten example) with the largest kl-divergences.
    top_kls: torch.Tensor
    # The coefficient cosine similarities of the examples with the largest kl-divergence.
    top_kl_cos_sims: torch.Tensor

    # The averages across the non-forget example.
    avg_cos_sim: torch.Tensor
    median_cos_sim: torch.Tensor
    avg_kl: torch.Tensor
    median_kl: torch.Tensor

    def print_summary(self):
        print(f'forget_example_kl: {float(self.forget_example_kl.detach().cpu().numpy())}')
        # print(f'avg_cos_sim: {float(self.avg_cos_sim.detach().cpu().numpy())}')
        # print(f'avg_kl: {float(self.avg_kl.detach().cpu().numpy())}')

        print(f'avg_cos_sim: {float(self.avg_cos_sim.detach().cpu().numpy())} [mean] {float(self.median_cos_sim.detach().cpu().numpy())} [median]')
        print(f'avg_kl: {float(self.avg_kl.detach().cpu().numpy())} [mean] {float(self.median_kl.detach().cpu().numpy())} [median]')
        _print_summary('top_cos_sims', self.top_cos_sims)
        _print_summary('top_cos_sim_example_kls', self.top_cos_sim_example_kls)
        _print_summary('top_kls', self.top_kls)
        _print_summary('top_kl_cos_sims', self.top_kl_cos_sims)


def _analyze1(
    result: 'cda_results.CdaRunResults',
    peis: torch.Tensor,
    device: torch.device,
    *,
    n_top_cos_sim: int,
    n_top_kl: int,
):
    forget_example_index = torch.squeeze(result.forget_set_kls.example_indices).cpu()

    evaluation_examples_indices = result.evaluation_set_kls.example_indices.cpu()
    evaluation_examples_indices = evaluation_examples_indices[evaluation_examples_indices != forget_example_index]

    forget_example_pei = peis[forget_example_index].to(device)
    peis = peis[evaluation_examples_indices].to(device)

    sims = _compute_sims(peis, forget_example_pei)

    evaluation_examples_kls = result.evaluation_set_kls.kls[result.evaluation_set_kls.example_indices != forget_example_index.to(device)]

    top_sims, top_sim_indices = torch.topk(sims, k=n_top_cos_sim)
    top_kls, top_kl_indices = torch.topk(evaluation_examples_kls, k=n_top_kl)

    top_sim_kls = evaluation_examples_kls[top_sim_indices]
    top_kl_sims = sims[top_kl_indices]

    return AnalysisResult1(
        forget_example_kl=torch.squeeze(result.forget_set_kls.kls),
        top_cos_sims=top_sims,
        top_cos_sim_example_kls=top_sim_kls,
        top_kls=top_kls,
        top_kl_cos_sims=top_kl_sims,
        avg_cos_sim=sims.mean(),
        median_cos_sim=torch.median(sims),
        avg_kl=evaluation_examples_kls.mean(),
        median_kl=torch.median(evaluation_examples_kls),
    )


def _print_aggregate_summary(analysis_results: List['AnalysisResult1'], factor: float = 1.0):
    # Count if the top examples by cosine sim have kl mean/median greater than the mean/median kl.
    greater_mean_kl_count = 0
    greater_median_kl_count = 0
    # Count if the top examples by kl have cosine sim  greater than the mean/median cosine sim.
    greater_mean_cos_sim_count = 0
    greater_median_cos_sim_count = 0

    for r in analysis_results:
        if torch.mean(r.top_cos_sim_example_kls) > factor * r.avg_kl:
            greater_mean_kl_count += 1
        if torch.median(r.top_cos_sim_example_kls) > factor * r.median_kl:
            greater_median_kl_count += 1
        if torch.mean(r.top_kl_cos_sims) > factor * r.avg_cos_sim:
            greater_mean_cos_sim_count += 1
        if torch.median(r.top_kl_cos_sims) > factor * r.median_cos_sim:
            greater_median_cos_sim_count += 1
    
    print(f'factor: {factor}')
    print(f'larger kl [mean]: {greater_mean_kl_count / len(analysis_results)}')
    print(f'larger kl [median]: {greater_median_kl_count / len(analysis_results)}')
    print(f'larger cos sim [mean]: {greater_mean_cos_sim_count / len(analysis_results)}')
    print(f'larger cos sim [median]: {greater_median_cos_sim_count / len(analysis_results)}')


###############################################################################

# TODO: Figure out how to move stuff to device and whatnot without causing overflow.
@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    results = _load_results(device)
    peis = _load_normalized_peis()

    analysis_results = []
    for result in results:
        analysis_result = _analyze1(result, peis, device=device, n_top_cos_sim=FLAGS.n_top_cos_sim, n_top_kl=FLAGS.n_top_kl)
        analysis_results.append(analysis_result)

        analysis_result.print_summary()
        print()

    print()
    _print_aggregate_summary(analysis_results, factor=1.0)
    print()
    _print_aggregate_summary(analysis_results, factor=2.0)
    print()
    _print_aggregate_summary(analysis_results, factor=3.0)
    print()
    _print_aggregate_summary(analysis_results, factor=5.0)
    print()
    _print_aggregate_summary(analysis_results, factor=10.0)
    print()


if __name__ == "__main__":
    app.run(main)
