""""""
import collections
import gc
import itertools
import json
import os
import random
from typing import Dict, List, Optional

from absl import app
from absl import flags

import h5py
import torch
from torch.utils.data import DataLoader
import transformers
from tqdm import tqdm
from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.decomps.npeff import lrm_npeff_decomps
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.perturbations import evaluation_contexts
from npeff_torch.perturbations import perturbation_contexts
from npeff_torch.unlearning.gradient_ascent import cda_results
from npeff_torch.unlearning.gradient_ascent import gradient_ascent

from npeff_torch.util import hdf5_utils
from npeff_torch.util import hf_utils
from npeff_torch.util import tokenizer_utils


R"""
- For special retain:
    - Have option to select based on either the coefficient similarity or based on top component.
    - Have option to select randomly for control.
- Have the special evaluation examples, basically the complement of those that would be selected for the special
  retain set from the most similar examples.
- Can maybe use a smaller baseline evaluation set as the goal is different here.
"""

###############################################################################
_DECOMPOSITION_TYPES = ['npeff', 'kmeans']
SIMILAR_EXAMPLE_SELECTION_STRATEGIES = ['cosine_similarity', 'top_coefficient_top_examples']
###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_string('npeff_filepath', None, '')
flags.DEFINE_enum('decomposition_type', 'npeff', _DECOMPOSITION_TYPES, 'Hack to let us use this for kmeans.')


flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the examples saved.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string("tokenizer", None, "If left None, assumed to be equal to --model.")


# Ways to specify the examples to run unlearning on.
flags.DEFINE_list('unlearned_example_indices', [],
                  'The indices of examples within the ones present in --pef_filepaths to run unlearning on.')


# Configuration of the special retain set and similar example evaluation set.

flags.DEFINE_enum('similar_example_selection_strategy', None, SIMILAR_EXAMPLE_SELECTION_STRATEGIES, '')

flags.DEFINE_integer('overall_similar_example_set_size', None, '')
flags.DEFINE_integer('special_retain_set_size', None, 
                     'Ideally, should be about half of the --overall_similar_example_set_size to get similar examples to evaluate on.')

flags.DEFINE_bool('random_special_retain_set', False, 'This acts as a control if true.')


# Specification of the unlearning procedure.

flags.DEFINE_float('learning_rate', None, '')
flags.DEFINE_integer('retain_set_batch_size', None, '')
flags.DEFINE_integer('n_unlearning_batches', None, '')

flags.DEFINE_float('epsilon_forget', None, '')
flags.DEFINE_float('epsilon_retain', None, '')

# TODO: Maybe some KL threshold on the unlearned example where we don't evaluate if its kl is lower than that.
flags.DEFINE_integer('n_trials_per_example', 1, 'Runs this many trials per example.')


# Specification of evaluation of the unlearned model.

flags.DEFINE_integer('n_random_evaluation_examples', None,
                     'Number of random examples from --pef_filepaths to select to evaluate on.')
flags.DEFINE_integer('rng_seed', 53245, '')

flags.DEFINE_integer('evaluation_batch_size', None, '')


###############################################################################


def _check_valid_output_filepath():
    output_filepath = FLAGS.output_filepath
    assert FLAGS.output_filepath is not None, 'The --output_filepath flag must be provided.'

    output_filepath = os.path.expanduser(output_filepath)
    assert os.path.isdir(os.path.dirname(output_filepath)), 'Invalid --output_filepath.'

    return output_filepath


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _get_unlearned_example_indices() -> List[int]:
    return [int(i) for i in FLAGS.unlearned_example_indices]


###############################################################################


@torch.no_grad()
def _load_maybe_normalized_coeffs(device: torch.device) -> torch.Tensor:
    # Each example's coefficients will be normalized to unit l2 norm.
    # ret.shape = [n_examples_total, n_components]

    if FLAGS.decomposition_type == 'npeff':
        decomp = lrm_npeff_decomps.LrmNpeffDecomposition.load(
            FLAGS.npeff_filepath,
            load_W=True,
            load_G=FLAGS.similar_example_selection_strategy == 'top_coefficient_top_examples')
        decomp = decomp.to(device)

        if FLAGS.similar_example_selection_strategy == 'cosine_similarity':
            return torch.nn.functional.normalize(decomp.W, dim=-1)

        elif FLAGS.similar_example_selection_strategy == 'top_coefficient_top_examples':
            decomp.normalize_reduced_components_to_unit_norm_()
            return decomp.W

        else:
            raise ValueError(FLAGS.similar_example_selection_strategy)

    elif FLAGS.decomposition_type == 'kmeans':
        if FLAGS.similar_example_selection_strategy != 'cosine_similarity':
            raise ValueError('Only supporting --similar_example_selection_strategy=cosine_similarity for kmeans.')
        
        n_clusters = kmeans.KmeansClusteringTorch.load_n_clusters(FLAGS.npeff_filepath)
        cluster_assignments = kmeans.KmeansClusteringTorch.load_cluster_assignments(FLAGS.npeff_filepath)
        cluster_assignments = cluster_assignments.to(device)
        W = torch.nn.functional.one_hot(cluster_assignments, num_classes=n_clusters).type(torch.float32)

        # Add a little bit of noise to make selections within a cluster deterministic.
        gen = torch.Generator(device)
        gen.manual_seed(346747 + FLAGS.rng_seed)
        W += 1e-6 * torch.randn(*W.shape, generator=gen, dtype=torch.float32, device=device)

        return torch.nn.functional.normalize(W, dim=-1)

    else:

        raise ValueError(FLAGS.decomposition_type)


def _read_in_model(tokenizer, device: torch.device) -> 'transformers.PreTrainedModel':
    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)
    # TODO: See if we want this for during unlearning?
    model.eval()

    # Prevent the following error:
    #   ValueError: Cannot handle batch sizes > 1 if no padding token is defined.
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    return model


def _clean_up_memory():
    gc.collect()
    torch.cuda.empty_cache()


###############################################################################


@torch.no_grad()
def _cat_batches(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    return {
        k: torch.cat([b[k] for b in batches], dim=0)
        # Doesn't do any checking that the keys are the same here, mate.
        for k in batches[0].keys()
    }


@torch.no_grad()
def _get_special_retain_and_eval_sets(
    W: torch.Tensor,
    pef_extra_infos: 'pef_format_common.PefExtraInfos',
    example_index: int,
):
    overall_similar_example_set_size = FLAGS.overall_similar_example_set_size

    rng = random.Random(example_index + FLAGS.rng_seed)

    w = W[example_index]

    if FLAGS.similar_example_selection_strategy == 'cosine_similarity':
        coeffs_cs = torch.einsum('ei,i', W, w)
        _, similar_example_indices = torch.topk(coeffs_cs, k=1 + overall_similar_example_set_size)

    elif FLAGS.similar_example_selection_strategy == 'top_coefficient_top_examples':
        component_index = torch.argmax(w)
        _, similar_example_indices = torch.topk(W[:, component_index], k=1 + overall_similar_example_set_size)

    else:
        raise ValueError(FLAGS.similar_example_selection_strategy)

    # Remove the example_index if present from the similar_example_indices.
    similar_example_indices = similar_example_indices[similar_example_indices != example_index]
    similar_example_indices = similar_example_indices[:overall_similar_example_set_size]

    special_retain_indices = rng.sample(range(overall_similar_example_set_size), k=FLAGS.special_retain_set_size)
    special_eval_indices = list(sorted(set(range(overall_similar_example_set_size)) - set(special_retain_indices)))

    special_eval_example_indices = similar_example_indices[torch.tensor(special_eval_indices, dtype=torch.int64, device=W.device)]

    if FLAGS.random_special_retain_set:
        special_retain_example_indices = torch.tensor(rng.sample(range(W.shape[0]), k=FLAGS.special_retain_set_size), dtype=torch.int64, device=W.device)
    else:
        special_retain_example_indices = similar_example_indices[torch.tensor(special_retain_indices, dtype=torch.int64, device=W.device)]

    return special_retain_example_indices, special_eval_example_indices


def _unlearn_example(
    *,
    example_index: int,
    special_retain_example_indices: torch.Tensor,
    original_model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    all_examples: evaluation_contexts.Examples,
    device: torch.device,
):
    t_example_index = torch.tensor([example_index], dtype=torch.int64, device=device)
    forget_examples = all_examples.gather_examples(t_example_index)
    forget_batch = forget_examples.examples

    special_retain_examples = all_examples.gather_examples(special_retain_example_indices)
    special_retain_batch = special_retain_examples.examples

    # Include all examples in the retain dataset. This includes the forget dataset. Hopefully, this
    # shouldn't be an issue since they will be heavily oversampled in the forget dataset.
    retain_dataset = all_examples.as_torch_dataset()
    retain_dataloader = DataLoader(retain_dataset, batch_size=FLAGS.retain_set_batch_size, shuffle=True, drop_last=True)
    retain_dataloader = hf_utils.cycle_dataloader(retain_dataloader)

    unlearned_model = _read_in_model(tokenizer, device)

    unlearner = gradient_ascent.LastTokenKlOnlyGradientAscent(
        model=unlearned_model,
        original_model=original_model,
        epsilon_forget=FLAGS.epsilon_forget,
        epsilon_retain=FLAGS.epsilon_retain,
    )

    optimizer = torch.optim.Adam(unlearned_model.parameters(), lr=FLAGS.learning_rate)

    progress_bar = tqdm(total=FLAGS.n_unlearning_batches)
    loss_infos = collections.defaultdict(list)
    for retain_batch in itertools.islice(retain_dataloader, FLAGS.n_unlearning_batches):
        retain_batch = _cat_batches([retain_batch, special_retain_batch])
        loss_info = unlearner.compute_loss_info(forget_batch=forget_batch, retain_batch=retain_batch, device=device)
        #
        loss_info['loss'].backward()
        optimizer.step()
        optimizer.zero_grad()
        #
        loss_info = {k: float(v.detach().cpu().numpy()) for k, v in loss_info.items()}
        for k, v in loss_info.items():
            loss_infos[k].append(v)
        #
        progress_bar.update(1)
        progress_bar.set_postfix(loss_info)

    progress_bar.close()
    loss_infos = {k: torch.tensor(v, dtype=torch.float32) for k, v in loss_infos.items()}

    # TODO: Maybe also return some info related to the unlearning procedure?
    return unlearned_model, forget_examples, loss_infos


@torch.no_grad()
def _evaluate_unlearned_model(
    *,
    original_model: transformers.PreTrainedModel,
    unlearned_model: transformers.PreTrainedModel,
    evaluation_examples: evaluation_contexts.Examples,
    device: torch.device,
):
    unlearned_model.eval()
    original_model.eval()

    unlearned_model_evaluator = evaluation_contexts.ModelEvaluator(model=unlearned_model, device=device)
    original_model_evaluator = evaluation_contexts.ModelEvaluator(model=original_model, device=device)

    evaluation_example_kls = []
    for batch in tqdm(evaluation_examples.get_batches(FLAGS.evaluation_batch_size)):
        # NOTE: The order is swapped so that the KLs represent D_KL(altered||original).
        batch_perturb_info = perturbation_contexts.ExamplesBatchPerturbationInfo(
            original_batch_info=unlearned_model_evaluator.compute_batch_info(batch),
            perturbed_batch_info=original_model_evaluator.compute_batch_info(batch),
        )
        evaluation_example_kls.append(batch_perturb_info.kls)

    return torch.cat(evaluation_example_kls, dim=0).detach()


###############################################################################


def main(_):
    assert FLAGS.overall_similar_example_set_size >= FLAGS.special_retain_set_size

    # Check this now so that we don't error out after doing all the work.
    output_filepath = _check_valid_output_filepath()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    W = _load_maybe_normalized_coeffs(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    original_model = _read_in_model(tokenizer, device)

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert pef_extra_infos.parameter_infos is not None
    assert pef_extra_infos.examples is not None
    assert pef_extra_infos.n_examples is not None

    all_examples = evaluation_contexts.Examples(
        examples={k: torch.from_numpy(v) for k, v in pef_extra_infos.examples.items()},
        token_positions=torch.from_numpy(pef_extra_infos.token_positions) if pef_extra_infos.token_positions is not None else None,
    )
    all_examples = all_examples.to(device)

    unlearned_example_indices = _get_unlearned_example_indices()

    rng = random.Random(534534634 + FLAGS.rng_seed)
    evaluation_example_indices = rng.sample(range(pef_extra_infos.n_examples), k=FLAGS.n_random_evaluation_examples)
    evaluation_example_indices = torch.tensor(evaluation_example_indices, dtype=torch.int64, device=device)
    evaluation_examples = all_examples.gather_examples(evaluation_example_indices)

    with h5py.File(os.path.expanduser(output_filepath), "w") as f:
        data_group = f.create_group('data')
        data_group.attrs['metadata'] = json.dumps({})

        for example_index in unlearned_example_indices:
            for trial_index in range(FLAGS.n_trials_per_example):
                special_retain_example_indices, special_eval_example_indices = _get_special_retain_and_eval_sets(
                    W, pef_extra_infos, example_index)

                unlearned_model, forget_examples, loss_infos = _unlearn_example(
                    example_index=example_index,
                    special_retain_example_indices=special_retain_example_indices,
                    original_model=original_model,
                    tokenizer=tokenizer,
                    all_examples=all_examples,
                    device=device,
                )
                _clean_up_memory()

                forget_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=forget_examples,
                    device=device,
                )
                special_retain_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=all_examples.gather_examples(special_retain_example_indices),
                    device=device,
                )
                special_evaluation_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=all_examples.gather_examples(special_eval_example_indices),
                    device=device,
                )
                evaluation_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=evaluation_examples,
                    device=device,
                )

                forget_set_kls = cda_results.EvaluationKls(
                    example_indices=torch.tensor([example_index], dtype=torch.int64, device=device),
                    kls=forget_example_kls,
                )
                special_retain_set_kls = cda_results.EvaluationKls(
                    example_indices=special_retain_example_indices,
                    kls=special_retain_example_kls,
                )
                special_evaluation_set_kls = cda_results.EvaluationKls(
                    example_indices=special_eval_example_indices,
                    kls=special_evaluation_example_kls,
                )
                evaluation_set_kls = cda_results.EvaluationKls(
                    example_indices=evaluation_example_indices,
                    kls=evaluation_example_kls,
                )

                run_results = cda_results.SpecialRetainCdaRunResults(
                    forget_set_kls=forget_set_kls,
                    special_retain_set_kls=special_retain_set_kls,
                    special_evaluation_set_kls=special_evaluation_set_kls,
                    evaluation_set_kls=evaluation_set_kls,
                    unlearning_training_infos=loss_infos,
                )

                run_group = f.create_group(f'data/examples/{example_index}/trial/{trial_index}')
                run_results.save_to_group(run_group)
                
                _clean_up_memory()


if __name__ == "__main__":
    app.run(main)
