"""Meant for the SAE baseline of the paper."""
import collections
from typing import Any, FrozenSet, List

from absl import app
from absl import flags

import numpy as np
import torch

from scom.examination.coefficients import sparse_fixed_k_top_example_infos
from scom.examination.coefficients import top_example_infos

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('top_examples_info_filepath', None, '')

flags.DEFINE_integer('n_top_examples', None, '')
flags.DEFINE_float('tuning_fraction', None, '')

flags.DEFINE_integer('min_top_examples', 1, '')

###############################################################################


# sparse_fixed_k_top_example_infos.SparseFixedKTopExamplesInfo
# examples['labels']

def _get_tokens(top_example: 'top_example_infos.TopExampleInfo') -> FrozenSet[int]:
    tokens = top_example.example['input_ids'][top_example.coefficients != 0.0]
    return frozenset(int(t) for t in tokens)


def _is_tuned(x: List[Any]) -> bool:
    counter = collections.defaultdict(lambda: 0)
    for y in x:
        counter[y] += 1

    most_seen_count = max(counter.values())

    return bool(most_seen_count >= FLAGS.tuning_fraction * len(x))


def _determine_tunings(
    # top_example_infos: 'sparse_fixed_k_top_example_infos.SparseFixedKTopExamplesInfo',
    # component_index: int,
    top_examples: List['top_example_infos.TopExampleInfo']
):
    has_logits = top_examples[0].logits is not None

    labels = [int(e.example['labels']) for e in top_examples]

    predictions = None
    if has_logits:
        predictions = [int(np.argmax(e.logits)) for e in top_examples]

    tokens = [_get_tokens(e) for e in top_examples]

    ret = {
        'label': _is_tuned(labels),
        'token': _is_tuned(tokens),
    }

    if predictions is not None:
        ret['predictions'] = _is_tuned(predictions)

    return ret


###############################################################################


@torch.no_grad()
def main(_):
    top_example_infos = sparse_fixed_k_top_example_infos.SparseFixedKTopExamplesInfo.load(FLAGS.top_examples_info_filepath)

    n_top_examples = FLAGS.n_top_examples
    if n_top_examples is None:
        n_top_examples = top_example_infos.n_top_examples
    else:
        assert n_top_examples <= top_example_infos.n_top_examples

    tuning_indicators = collections.defaultdict(list)

    for component_index in top_example_infos.unique_component_indices:
        top_examples = top_example_infos.get_top_examples_for_component(component_index, n_top_examples)
        if len(top_examples) < FLAGS.min_top_examples:
            continue
        
        tunings = _determine_tunings(top_examples)
        for k, v in tunings.items():
            tuning_indicators[k].append(v)

    for key in sorted(tuning_indicators.keys()):
        indicator = np.array(tuning_indicators[key], dtype=np.int64)
        print(f'{key}: {indicator.sum()} / {len(indicator)}')


if __name__ == "__main__":
    app.run(main)
