"""Compares saved predictions."""
import dataclasses
import functools
import os
from typing import List

from absl import app
from absl import flags

import h5py
import numpy as np
import torch

from npeff_torch.util import hdf5_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_list('predictions_filepaths_1', None, '')
flags.DEFINE_list('predictions_filepaths_2', None, '')


###############################################################################


@dataclasses.dataclass
class Info:
    logits: torch.Tensor
    labels: torch.Tensor

    @functools.cached_property
    def predictions(self) -> torch.Tensor:
        return torch.argmax(self.logits, dim=-1)

    @property
    def n_examples(self) -> int:
        return self.logits.shape[0]


def _read_predictions(filepaths: List[str]) -> 'Info':
    logits = []
    labels = []
    for filepath in filepaths:
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            logits.append(torch.from_numpy(hdf5_utils.load_h5_ds(f['data/logits'])))
            labels.append(torch.from_numpy(hdf5_utils.load_h5_ds(f['data/labels'])))

    return Info(
        logits=torch.cat(logits, dim=0),
        labels=torch.cat(labels, dim=0),
    )


###############################################################################


@torch.no_grad()
def main(_):
    info1 = _read_predictions(FLAGS.predictions_filepaths_1)
    info2 = _read_predictions(FLAGS.predictions_filepaths_2)

    assert (info1.labels == info2.labels).all()

    acc1 = (info1.predictions == info1.labels).to(torch.float32).mean().detach().cpu().numpy()
    acc2 = (info2.predictions == info2.labels).to(torch.float32).mean().detach().cpu().numpy()
    similarity = (info1.predictions == info2.predictions).to(torch.float32).mean().detach().cpu().numpy()

    print(f'acc1: {acc1}')
    print(f'acc2: {acc2}')
    print(f'similarity: {similarity}')


if __name__ == "__main__":
    app.run(main)
