#include <cmath>
#include <cstdint>
#include <cstring>
#include <fstream>
#include <iostream>
#include <vector>

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Message.hpp"
#include "evd/SecretKey.hpp"
#include "evd/Server.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

constexpr u64 LOG_RANK = 7;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;
constexpr u64 N = 1000000;
constexpr u64 ITER = (N + DEGREE - 1) / DEGREE;

struct FVecs {
  uint32_t n = 0;
  uint32_t d = 0;
  std::vector<float> data;
};

struct IVecs {
  uint32_t n = 0;
  uint32_t d = 0;
  std::vector<int> data;
};

FVecs load_fvecs(const std::string &path) {
  std::ifstream f(path, std::ios::binary | std::ios::ate);
  if (!f)
    throw std::runtime_error("File open fail: " + path);

  const size_t fsize = f.tellg();
  f.seekg(0, std::ios::beg);

  uint32_t d;
  f.read(reinterpret_cast<char *>(&d), 4);
  if (!f)
    throw std::runtime_error("File read fail");

  const size_t bytes_per_vec = 4 + d * 4UL;
  if (fsize % bytes_per_vec != 0)
    throw std::runtime_error("Wrong file size");

  const uint32_t n = static_cast<uint32_t>(fsize / bytes_per_vec);
  std::vector<char> raw(fsize);
  f.seekg(0, std::ios::beg);
  f.read(raw.data(), fsize);

  std::vector<float> dst(size_t(n) * d);
  const char *p = raw.data();
  for (uint32_t i = 0; i < n; ++i) {
    p += 4; // skip header
    std::memcpy(dst.data() + size_t(i) * d, p, size_t(d) * 4);
    p += size_t(d) * 4;
  }
  return {n, d, std::move(dst)};
}

void save_ivecs(const std::string &path, uint32_t n, uint32_t d,
                const std::vector<std::vector<int>> &data) {
  std::ofstream f(path, std::ios::binary | std::ios::trunc);
  if (!f)
    throw std::runtime_error("File open fail: " + path);

  for (uint32_t i = 0; i < n; ++i) {
    f.write(reinterpret_cast<const char *>(&d), 4); // header
    f.write(reinterpret_cast<const char *>(data[i].data()),
            size_t(d) * 4); // payload
  }
  f.close(); // flush
  if (!f)
    throw std::runtime_error("File write fail: " + path);
}

int main(int argc, char *argv[]) {
  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys(RANK);
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  for (u64 i = 0; i < RANK; ++i) {
    client.genAutedModPackKeys(autedModPackKeys[i], secKey, 2 * i + 1);
    client.genInvAutedModPackKeys(autedModPackMLWEKeys[i], secKey, 2 * i + 1);
  }
  const double scale = std::pow(2.0, LOG_SCALE);

  FVecs B = load_fvecs(argv[1]);

  Server server(LOG_RANK, relinKey, autedModPackKeys, autedModPackMLWEKeys);

  std::vector<std::vector<Ciphertext>> keyCaches(ITER);
#pragma omp parallel for
  for (u64 i = 0; i < ITER; ++i) {
    std::vector<MLWECiphertext> keys;
    keys.reserve(DEGREE);

    for (u64 j = 0; j < DEGREE; ++j) {
      Message msg(RANK);
      if (i * DEGREE + j < N) {
        for (u64 k = 0; k < B.d; ++k)
          msg[k] = B.data[(i * DEGREE + j) * B.d + k];
      }
      keys.emplace_back(RANK);
      client.encryptKey(keys[j], msg, secKey, scale);
    }
    keyCaches[i].resize(RANK);
    server.cacheKeys(keyCaches[i], keys);
  }

  // Evaluate Inner Product

  FVecs Q = load_fvecs(argv[2]);

  FVecs scores = load_fvecs(argv[3]);

  double max_error = 0.0;
  double mean_error = 0.0;
  double std_error = 0.0;

  u64 n = 10000; // 10K
  std::vector<std::vector<int>> topK(n);
  for (u64 i = 0; i < n; ++i) {
    std::cout << i << std::endl;
    Message msg(RANK);
    MLWECiphertext query(RANK);
    std::vector<Ciphertext> queryCache(RANK);
    std::vector<Ciphertext> res(ITER);

    for (u64 j = 0; j < Q.d; ++j)
      msg[j] = Q.data[i * Q.d + j];
    client.encryptQuery(query, msg, secKey, scale);

    server.cacheQuery(queryCache, query);
    for (u64 j = 0; j < ITER; ++j)
      server.innerProduct(res[j], queryCache, keyCaches[j]);

    const double doubleScale = std::pow(2.0, 2 * LOG_SCALE);

    // Decrypt
    std::vector<Message> dmsg;
    dmsg.reserve(ITER);
    for (u64 j = 0; j < ITER; ++j)
      dmsg.emplace_back(DEGREE);
    client.decryptScore(dmsg, res, secKey, doubleScale);
    topK[i].resize(10);
    client.topKScore(topK[i], dmsg, 10);
    for (u64 j = 0; j < ITER; ++j) {
      for (u64 k = 0; k < DEGREE; ++k) {
        if (j * DEGREE + k >= N)
          break;
        double error =
            std::abs(scores.data[i * scores.d + j * DEGREE + k] - dmsg[j][k]);
        max_error = std::max(max_error, error);
        mean_error += error / N;
        std_error += error * error / N;
      }
    }
  }
  save_ivecs(argv[4], n, 10, topK);
  std::cout << "max error: " << max_error << std::endl;
  std::cout << "mean error: " << mean_error / Q.n << std::endl;
  std::cout << "std error: " << std_error / Q.n << std::endl;
}