""""""
import os
from typing import List, Optional

from absl import app
from absl import flags

import h5py
import torch

from npeff_torch.decomps.npeff import lrm_npeff_decomps
from npeff_torch.examination.top_examples import component_filtering
from npeff_torch.examination.top_examples import top_examples_from_coeffs
from npeff_torch.peis.fishers.formats import pef_format_common

from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('npeff_filepath', None, '')


flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_float('tuning_fraction', None, '')


###############################################################################


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret

###############################################################################


def _get_tuned_component_indices(
    top_examples_reader: 'top_examples_from_coeffs.TopExamplesReaderFromCoeffs',
) -> List[List[int]]:
    n_components = top_examples_reader.n_components
    n_classes = top_examples_reader.logits.shape[-1]

    ret = []

    for label in range(n_classes):
        label_filter = component_filtering.SpecificLabelFilter(label=label, fraction_threshold=FLAGS.tuning_fraction)
        label_components = [
            component_index for component_index in range(n_components)
            if label_filter.does_component_pass(
                # NOTE: I don't the tokenizer will generally be used, so don't pass it here.
                tokenizer=None,
                component_index=component_index,
                top_examples=top_examples_reader.get_top_examples_for_component(component_index, n_top_examples=FLAGS.n_top_examples),
            )
        ]
        ret.append(label_components)

    return ret


def _make_weight(
    tuned_component_indices_by_label: List[List[int]],
    n_components: int,
    device: torch.device,
    *,
    value: float = 1.0,
) -> torch.Tensor:
    n_classes = len(tuned_component_indices_by_label)

    weight = torch.zeros([n_classes, n_components], dtype=torch.float32)
    for label, component_indices in enumerate(tuned_component_indices_by_label):
        for component_index in component_indices:
            weight[label, component_index] = value

    return weight.to(device)


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.output_filepath is not None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    decomp = lrm_npeff_decomps.LrmNpeffDecomposition.load(FLAGS.npeff_filepath)
    decomp = decomp.to(device)
    decomp.normalize_reduced_components_to_unit_norm_()

    n_components = decomp.W.shape[1]

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert decomp.n_examples == pef_extra_infos.n_examples

    top_examples_reader = top_examples_from_coeffs.TopExamplesReaderFromCoeffs.create(
        coefficients=decomp.W.detach().cpu().numpy(),
        examples=pef_extra_infos.examples,
        labels=pef_extra_infos.labels,
        logits=pef_extra_infos.logits,
        top_log_probs_class_indices=pef_extra_infos.top_log_probs_class_indices,
        top_log_probs_values=pef_extra_infos.top_log_probs_values,
        token_positions=pef_extra_infos.token_positions,
    )

    #

    tuned_component_indices_by_label = _get_tuned_component_indices(top_examples_reader)

    weight = _make_weight(
        tuned_component_indices_by_label=tuned_component_indices_by_label,
        n_components=n_components,
        device=device,
    )

    # argmax(pseudo_logits, dim=-1) gives the predictions
    pseudo_logits = torch.einsum('ec,lc->el', decomp.W, weight)

    predictions = torch.argmax(pseudo_logits, dim=-1).detach().cpu().numpy()
    acc = (predictions == pef_extra_infos.labels).mean()
    print(f'acc: {acc}')

    with h5py.File(os.path.expanduser(FLAGS.output_filepath), "w") as f:
        hdf5_utils.save_h5_ds(f, 'data/logits', pseudo_logits.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(f, 'data/labels', pef_extra_infos.labels)


if __name__ == "__main__":
    app.run(main)
