# coding=utf-8
# Copyright 2023 The Soar Neurips2023 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Computes KMR of various VQ setups."""
from absl import app
from absl import flags
import h5py
import matplotlib.pyplot as plt
import numpy as np
import utils

_HDF5 = flags.DEFINE_string("hdf5", None, "Path to hdf5 of dataset.")
_NUM_CENTERS = flags.DEFINE_integer("num_centers", 2000, "# k-means centers.")
_ETA = flags.DEFINE_float("eta", 2.5, "AVQ eta.")
_LAMBDA = flags.DEFINE_float("lambda", 1, "SOAR lambda.")


def calc_kmr():
  """Compute KMR and save to Numpy array."""
  hdf5 = h5py.File(_HDF5.value, "r")
  ds = utils.normalize(hdf5["train"][()])
  qs = utils.normalize(hdf5["test"][()])
  gt = hdf5["neighbors"][()]
  print("Dataset shape:", ds.shape)
  print("Query shape:", qs.shape)
  print("Ground truth shape:", gt.shape, flush=True)

  eta = _ETA.value
  soar_l = _LAMBDA.value
  num_centers = _NUM_CENTERS.value

  orig_centers, toke = utils.train_kmeans(ds, num_centers)
  print("Performing no-SOAR spilled assignment...", flush=True)
  toke, toke2 = utils.redo_assignment(orig_centers, ds)
  print("# empty partitions:", num_centers - len(set(toke)))
  print("Updating centers...", flush=True)
  centers = utils.compute_avq_centers(ds, orig_centers, toke, eta)
  print("SOAR...", flush=True)
  toke3 = utils.soar_assign(ds, centers, toke, soar_l, True)

  print("KMR...", flush=True)
  np.save("npys/kmr1.npy", utils.kmr(centers, toke, None, qs, gt))
  np.save("npys/kmr2.npy", utils.kmr(centers, toke, toke2, qs, gt))
  np.save("npys/kmr3.npy", utils.kmr(centers, toke, toke3, qs, gt))


def main(argv):
  del argv  # Unused.
  # calc_kmr()

  kmr1 = np.load("npys/kmr1.npy")
  kmr2 = np.load("npys/kmr2.npy")
  kmr3 = np.load("npys/kmr3.npy")
  for r in [0.80, 0.85, 0.90, 0.95]:
    amts = [(kmr > r).argmax() for kmr in [kmr1, kmr2, kmr3]]
    print(100 * r, amts, amts[0] / amts[2])

  plt.figure(figsize=(3.6, 2.5))
  plt.plot(100 * kmr1, label="No spilling")
  plt.plot(100 * kmr2, label="Spilling, no SOAR")
  plt.plot(100 * kmr3, label="SOAR")
  plt.legend()

  plt.gca().yaxis.set_ticks(np.linspace(85, 100, 4))
  plt.gca().grid(linestyle="--", axis="y")
  plt.gca().set_xlim([0, 335000])
  plt.gca().set_ylim([85, 100])
  plt.gca().set_xlabel("Datapoints Searched")
  plt.gca().set_ylabel("Recall@100 (Percent)")
  plt.savefig("out/kmr.pdf", bbox_inches="tight", pad_inches=0.02)
  plt.close()


if __name__ == "__main__":
  app.run(main)
