"""Wrapper around discrete archives."""
import hydra
import torch
from scipy.spatial import KDTree
import numpy as np


class DiscreteArchiveSolutionModel:

    def __init__(self, cfg, seed, device, archive=None):
        self.cfg = cfg
        self.rng = np.random.default_rng(seed)
        self.device = device

        if archive is None:
            self.model = hydra.utils.instantiate(self.cfg.model, seed=seed)
        else:
            # `archive` can be passed in to override the archive creation.
            self.model = archive

    def num_params(self):
        return 0

    def make_kd_tree(self):
        """Creates a kdtree with data currently in the archive.

        Also returns the archive data so as to facilitate indexing it.
        """
        archive_data = self.model.data()

        if self.cfg.normalize_features_before_dist:
            kdtree = KDTree(archive_data["measures"] / self.model.interval_size)
        else:
            kdtree = KDTree(archive_data["measures"])

        return kdtree, archive_data

    def inference(self, inputs, samples=None):
        """Retrieves the solution with the closest measures to each input."""
        if samples is None:
            measures = inputs
        else:
            measures = torch.repeat_interleave(inputs, samples, dim=0)

        measures = measures.detach().cpu().numpy()

        kdtree, archive_data = self.make_kd_tree()

        if self.cfg.normalize_features_before_dist:
            _, indices = kdtree.query(measures / self.model.interval_size)
        else:
            _, indices = kdtree.query(measures)

        return torch.tensor(archive_data["solution"][indices],
                            dtype=torch.float32,
                            device=self.device)
