"""Collateral damage assessment for gradient ascent."""
import collections
import gc
import itertools
import json
import os
import random
from typing import List, Optional, Sequence, Tuple

from absl import app
from absl import flags

import h5py
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.perturbations import evaluation_contexts
from npeff_torch.perturbations import perturbation_contexts
from npeff_torch.unlearning.gradient_ascent import cda_results
from npeff_torch.unlearning.gradient_ascent import gradient_ascent

from npeff_torch.util import hdf5_utils
from npeff_torch.util import hf_utils
from npeff_torch.util import tokenizer_utils


R"""

Script stuff:
    - Get the examples from PEIs, must have NPEFF coefficients.
    - Maybe want separate retain set?
    - This script is currently only for unlearning of a single example.
    - Allow different ways to specify the examples to unlearn:
        - Example indices within the PESs.
        - Top examples of components.
            - Specify component and indices of top examples for each component
    - Examples to compute KLs of the unlearned model:
        - Compute the KL for the unlearned example.
        - Maybe first k examples of the retain set.
        - Top examples of components?
            - How to specify the components?
            - Probably have a fixed number of top examples for each example.
        => Unlearned example, the top non-zero components for the unlearned example, a decent number of the random PEIs examples.
    - What to save:
        - KLs for each example evaluated on.
            - Can analyze later.
        - Some information about the unlearning process.

General stuff:



"""

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the examples saved.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string("tokenizer", None, "If left None, assumed to be equal to --model.")


# Ways to specify the examples to run unlearning on.
flags.DEFINE_list('unlearned_example_indices', [],
                  'The indices of examples within the ones present in --pef_filepaths to run unlearning on.')
# TODO: Add method based on component top examples?


# Specification of the unlearning procedure.

flags.DEFINE_float('learning_rate', None, '')
flags.DEFINE_integer('retain_set_batch_size', None, '')
flags.DEFINE_integer('n_unlearning_batches', None, '')

flags.DEFINE_float('epsilon_forget', None, '')
flags.DEFINE_float('epsilon_retain', None, '')

# TODO: Maybe some KL threshold on the unlearned example where we don't evaluate if its kl is lower than that.
flags.DEFINE_integer('n_trials_per_example', 1, 'Runs this many trials per example.')


# Specification of evaluation of the unlearned model.

flags.DEFINE_integer('n_random_evaluation_examples', None,
                     'Number of random examples from --pef_filepaths to select to evaluate on.')
flags.DEFINE_integer('evaluation_examples_selection_seed', 53245,
                     'Used to select the random subset of examples to evaluate on.')

flags.DEFINE_integer('evaluation_batch_size', None, '')


# TODO: Some parameter filtering of parameters to update?
# TODO: Specification of retain set other than from the pefs.

# flags.DEFINE_bool('include_embeddings', True, 'Whether to include the embeddings in the parameters we use when unlearning.')
# flags.DEFINE_bool('include_layer_norms', True, 'Whether to include the layer norms in the parameters we use when unlearning.')


###############################################################################


def _make_metadata():
    return {
        # Metadata that just mostly mirror the flags.
        'pef_filepaths': FLAGS.pef_filepaths,
        'n_examples_per_pef': _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef),
        'model': FLAGS.model,
        'model_cls': FLAGS.model_cls,
        'tokenizer': FLAGS.tokenizer,
        'unlearned_example_indices': _get_unlearned_example_indices(),
        'learning_rate': FLAGS.learning_rate,
        'retain_set_batch_size': FLAGS.retain_set_batch_size,
        'n_unlearning_batches': FLAGS.n_unlearning_batches,
        'epsilon_forget': FLAGS.epsilon_forget,
        'epsilon_retain': FLAGS.epsilon_retain,
        'n_trials_per_example': FLAGS.n_trials_per_example,
        'n_random_evaluation_examples': FLAGS.n_random_evaluation_examples,
        'evaluation_examples_selection_seed': FLAGS.evaluation_examples_selection_seed,
        'evaluation_batch_size': FLAGS.evaluation_batch_size,
    }


###############################################################################


def _check_valid_output_filepath():
    output_filepath = FLAGS.output_filepath
    assert FLAGS.output_filepath is not None, 'The --output_filepath flag must be provided.'

    output_filepath = os.path.expanduser(output_filepath)
    assert os.path.isdir(os.path.dirname(output_filepath)), 'Invalid --output_filepath.'

    return output_filepath


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _get_unlearned_example_indices() -> List[int]:
    return [int(i) for i in FLAGS.unlearned_example_indices]


###############################################################################


def _read_in_model(tokenizer, device: torch.device) -> 'transformers.PreTrainedModel':
    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)
    # TODO: See if we want this for during unlearning?
    model.eval()

    # Prevent the following error:
    #   ValueError: Cannot handle batch sizes > 1 if no padding token is defined.
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    return model


def _clean_up_memory():
    gc.collect()
    torch.cuda.empty_cache()


###############################################################################


def _unlearn_example(
    *,
    example_index: int,
    original_model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    all_examples: evaluation_contexts.Examples,
    device: torch.device,
):
    t_example_index = torch.tensor([example_index], dtype=torch.int64, device=device)
    forget_examples = all_examples.gather_examples(t_example_index)
    forget_batch = forget_examples.examples

    # Include all examples in the retain dataset. This includes the forget dataset. Hopefully, this
    # shouldn't be an issue since they will be heavily oversampled in the forget dataset.
    retain_dataset = all_examples.as_torch_dataset()
    retain_dataloader = DataLoader(retain_dataset, batch_size=FLAGS.retain_set_batch_size, shuffle=True, drop_last=True)
    retain_dataloader = hf_utils.cycle_dataloader(retain_dataloader)

    unlearned_model = _read_in_model(tokenizer, device)

    unlearner = gradient_ascent.LastTokenKlOnlyGradientAscent(
        model=unlearned_model,
        original_model=original_model,
        epsilon_forget=FLAGS.epsilon_forget,
        epsilon_retain=FLAGS.epsilon_retain,
    )

    optimizer = torch.optim.Adam(unlearned_model.parameters(), lr=FLAGS.learning_rate)

    progress_bar = tqdm(total=FLAGS.n_unlearning_batches)
    loss_infos = collections.defaultdict(list)
    for retain_batch in itertools.islice(retain_dataloader, FLAGS.n_unlearning_batches):
        loss_info = unlearner.compute_loss_info(forget_batch=forget_batch, retain_batch=retain_batch, device=device)
        #
        loss_info['loss'].backward()
        optimizer.step()
        optimizer.zero_grad()
        #
        loss_info = {k: float(v.detach().cpu().numpy()) for k, v in loss_info.items()}
        for k, v in loss_info.items():
            loss_infos[k].append(v)
        #
        progress_bar.update(1)
        progress_bar.set_postfix(loss_info)

    progress_bar.close()
    loss_infos = {k: torch.tensor(v, dtype=torch.float32) for k, v in loss_infos.items()}

    # TODO: Maybe also return some info related to the unlearning procedure?
    return unlearned_model, forget_examples, loss_infos


@torch.no_grad()
def _evaluate_unlearned_model(
    *,
    original_model: transformers.PreTrainedModel,
    unlearned_model: transformers.PreTrainedModel,
    evaluation_examples: evaluation_contexts.Examples,
    device: torch.device,
):
    unlearned_model.eval()
    original_model.eval()

    unlearned_model_evaluator = evaluation_contexts.ModelEvaluator(model=unlearned_model, device=device)
    original_model_evaluator = evaluation_contexts.ModelEvaluator(model=original_model, device=device)

    evaluation_example_kls = []
    for batch in tqdm(evaluation_examples.get_batches(FLAGS.evaluation_batch_size)):
        batch_perturb_info = perturbation_contexts.ExamplesBatchPerturbationInfo(
            original_batch_info=original_model_evaluator.compute_batch_info(batch),
            perturbed_batch_info=unlearned_model_evaluator.compute_batch_info(batch),
        )
        evaluation_example_kls.append(batch_perturb_info.kls)

    return torch.cat(evaluation_example_kls, dim=0).detach()


###############################################################################


def main(_):
    # Check this now so that we don't error out after doing all the work.
    output_filepath = _check_valid_output_filepath()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    original_model = _read_in_model(tokenizer, device)

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert pef_extra_infos.parameter_infos is not None
    assert pef_extra_infos.examples is not None
    assert pef_extra_infos.n_examples is not None

    all_examples = evaluation_contexts.Examples(
        examples={k: torch.from_numpy(v) for k, v in pef_extra_infos.examples.items()},
        token_positions=torch.from_numpy(pef_extra_infos.token_positions) if pef_extra_infos.token_positions is not None else None,
    )
    all_examples = all_examples.to(device)

    unlearned_example_indices = _get_unlearned_example_indices()

    rng = random.Random(FLAGS.evaluation_examples_selection_seed)
    evaluation_example_indices = rng.sample(range(pef_extra_infos.n_examples), k=FLAGS.n_random_evaluation_examples)
    evaluation_example_indices = torch.tensor(evaluation_example_indices, dtype=torch.int64, device=device)
    evaluation_examples = all_examples.gather_examples(evaluation_example_indices)

    with h5py.File(os.path.expanduser(output_filepath), "w") as f:
        data_group = f.create_group('data')
        data_group.attrs['metadata'] = json.dumps(_make_metadata())

        for example_index in unlearned_example_indices:
            for trial_index in range(FLAGS.n_trials_per_example):
                unlearned_model, forget_examples, loss_infos = _unlearn_example(
                    example_index=example_index,
                    original_model=original_model,
                    tokenizer=tokenizer,
                    all_examples=all_examples,
                    device=device,
                )
                _clean_up_memory()

                forget_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=forget_examples,
                    device=device,
                )
                evaluation_example_kls = _evaluate_unlearned_model(
                    original_model=original_model,
                    unlearned_model=unlearned_model,
                    evaluation_examples=evaluation_examples,
                    device=device,
                )

                forget_set_kls = cda_results.EvaluationKls(
                    example_indices=torch.tensor([example_index], dtype=torch.int64, device=device),
                    kls=forget_example_kls,
                )
                evaluation_set_kls = cda_results.EvaluationKls(
                    example_indices=evaluation_example_indices,
                    kls=evaluation_example_kls,
                )

                run_results = cda_results.CdaRunResults(
                    forget_set_kls=forget_set_kls,
                    evaluation_set_kls=evaluation_set_kls,
                    unlearning_training_infos=loss_infos,
                )

                run_group = f.create_group(f'data/examples/{example_index}/trial/{trial_index}')
                run_results.save_to_group(run_group)
                
                _clean_up_memory()


if __name__ == "__main__":
    app.run(main)
