#include <asio.hpp>
#include <asio/write.hpp>
#include <fstream>
#include <iostream>

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Message.hpp"
#include "evd/SecretKey.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;
using asio::ip::tcp;

constexpr u64 LOG_RANK = 7;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;
constexpr u64 N = 1000000;
constexpr u64 ITER = (N + DEGREE - 1) / DEGREE;

struct FVecs {
  uint32_t n = 0;
  uint32_t d = 0;
  std::vector<float> data;
};

FVecs load_fvecs(const std::string &path) {
  std::ifstream f(path, std::ios::binary | std::ios::ate);
  if (!f)
    throw std::runtime_error("File open fail: " + path);

  const size_t fsize = f.tellg();
  f.seekg(0, std::ios::beg);

  uint32_t d;
  f.read(reinterpret_cast<char *>(&d), 4);
  if (!f)
    throw std::runtime_error("File read fail");

  const size_t bytes_per_vec = 4 + d * 4UL;
  if (fsize % bytes_per_vec != 0)
    throw std::runtime_error("Wrong file size");

  const uint32_t n = static_cast<uint32_t>(fsize / bytes_per_vec);
  std::vector<char> raw(fsize);
  f.seekg(0, std::ios::beg);
  f.read(raw.data(), fsize);

  std::vector<float> dst(size_t(n) * d);
  const char *p = raw.data();
  for (uint32_t i = 0; i < n; ++i) {
    p += 4; // skip header
    std::memcpy(dst.data() + size_t(i) * d, p, size_t(d) * 4);
    p += size_t(d) * 4;
  }
  return {n, d, std::move(dst)};
}

int main(int argc, char *argv[]) try {
  asio::io_context io;
  tcp::resolver res(io);
  auto ep = res.resolve(argv[1], argv[2]);
  tcp::socket sock(io);
  asio::connect(sock, ep);

  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys(RANK);
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  for (u64 i = 0; i < RANK; ++i) {
    client.genAutedModPackKeys(autedModPackKeys[i], secKey, 2 * i + 1);
    client.genInvAutedModPackKeys(autedModPackMLWEKeys[i], secKey, 2 * i + 1);
  }

  asio::write(sock, asio::buffer(relinKey.getPolyAModQ().getData(),
                                 DEGREE * sizeof(u64)));
  asio::write(sock, asio::buffer(relinKey.getPolyAModP().getData(),
                                 DEGREE * sizeof(u64)));
  asio::write(sock, asio::buffer(relinKey.getPolyBModQ().getData(),
                                 DEGREE * sizeof(u64)));
  asio::write(sock, asio::buffer(relinKey.getPolyBModP().getData(),
                                 DEGREE * sizeof(u64)));
  for (u64 i = 0; i < RANK; ++i) {
    autedModPackKeys[i].resize(STACK);
    for (u64 j = 0; j < STACK; ++j) {
      asio::write(sock,
                  asio::buffer(autedModPackKeys[i][j].getPolyAModQ().getData(),
                               DEGREE * sizeof(u64)));
      asio::write(sock,
                  asio::buffer(autedModPackKeys[i][j].getPolyAModP().getData(),
                               DEGREE * sizeof(u64)));
      asio::write(sock,
                  asio::buffer(autedModPackKeys[i][j].getPolyBModQ().getData(),
                               DEGREE * sizeof(u64)));
      asio::write(sock,
                  asio::buffer(autedModPackKeys[i][j].getPolyBModP().getData(),
                               DEGREE * sizeof(u64)));
    }
  }
  for (u64 i = 0; i < RANK; ++i) {
    for (u64 j = 0; j < STACK; ++j) {
      for (u64 k = 0; k < STACK; ++k) {
        asio::write(
            sock,
            asio::buffer(autedModPackMLWEKeys[i][j].getPolyAModQ(k).getData(),
                         RANK * sizeof(u64)));
        asio::write(
            sock,
            asio::buffer(autedModPackMLWEKeys[i][j].getPolyAModP(k).getData(),
                         RANK * sizeof(u64)));
        asio::write(
            sock,
            asio::buffer(autedModPackMLWEKeys[i][j].getPolyBModQ(k).getData(),
                         RANK * sizeof(u64)));
        asio::write(
            sock,
            asio::buffer(autedModPackMLWEKeys[i][j].getPolyBModP(k).getData(),
                         RANK * sizeof(u64)));
      }
    }
  }

  const double scale = std::pow(2.0, LOG_SCALE);

  FVecs B = load_fvecs(argv[3]);

  std::vector<MLWECiphertext> keys;
  keys.reserve(DEGREE);
  for (u64 i = 0; i < DEGREE; ++i)
    keys.emplace_back(RANK);
  for (u64 i = 0; i < ITER; ++i) {
#pragma omp parallel for
    for (u64 j = 0; j < DEGREE; ++j) {
      Message msg(RANK);
      if (i * DEGREE + j < N) {
        for (u64 k = 0; k < B.d; ++k)
          msg[k] = B.data[(i * DEGREE + j) * B.d + k];
      }
      client.encryptKey(keys[j], msg, secKey, scale);
    }
    for (u64 j = 0; j < DEGREE; ++j) {
      keys.emplace_back(RANK);
      for (u64 k = 0; k < STACK; ++k)
        asio::write(
            sock, asio::buffer(keys[j].getA(k).getData(), RANK * sizeof(u64)));
      asio::write(sock,
                  asio::buffer(keys[j].getB().getData(), RANK * sizeof(u64)));
    }
  }
  std::cout << "DB Key Send Done" << std::endl;

  FVecs Q = load_fvecs(argv[4]);
  FVecs scores = load_fvecs(argv[5]);

  double max_error = 0.0;
  double mean_error = 0.0;
  double std_error = 0.0;
  Message msg(RANK);
  MLWECiphertext query(RANK);
  std::vector<Ciphertext> queryCache(RANK);
  std::vector<Ciphertext> ret(ITER);
  std::vector<Message> dmsg;
  dmsg.reserve(ITER);
  for (u64 j = 0; j < ITER; ++j)
    dmsg.emplace_back(DEGREE);
  u64 n = 100; // 10K
  std::vector<std::vector<int>> topK(n);
  for (u64 i = 0; i < n; ++i) {
    for (u64 j = 0; j < Q.d; ++j)
      msg[j] = Q.data[i * Q.d + j];

    auto wholeStart = std::chrono::high_resolution_clock::now();

    auto start = std::chrono::high_resolution_clock::now();
    client.encryptQuery(query, msg, secKey, scale);
    auto end = std::chrono::high_resolution_clock::now();
    auto duration =
        std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
    std::cout << "encrypt query: " << duration.count() << "ms" << std::endl;

    for (u64 i = 0; i < STACK; ++i)
      asio::write(sock,
                  asio::buffer(query.getA(i).getData(), RANK * sizeof(u64)));
    asio::write(sock, asio::buffer(query.getB().getData(), RANK * sizeof(u64)));
    for (u64 i = 0; i < ITER; ++i) {
      asio::read(sock,
                 asio::buffer(ret[i].getA().getData(), DEGREE * sizeof(u64)));
      ret[i].getA().setIsNTT(true);
      asio::read(sock,
                 asio::buffer(ret[i].getB().getData(), DEGREE * sizeof(u64)));
      ret[i].getB().setIsNTT(true);
    }

    const double doubleScale = std::pow(2.0, 2 * LOG_SCALE);

    // Decrypt
    start = std::chrono::high_resolution_clock::now();
    client.decryptScore(dmsg, ret, secKey, doubleScale);
    end = std::chrono::high_resolution_clock::now();
    duration =
        std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
    std::cout << "decrypt query: " << duration.count() << "ms" << std::endl;

    start = std::chrono::high_resolution_clock::now();
    topK[i].clear();
    topK[i].resize(10);
    client.topKScore(topK[i], dmsg, 10);
    end = std::chrono::high_resolution_clock::now();
    duration =
        std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
    std::cout << "top 10: " << duration.count() << "ms" << std::endl;
    for (u64 j = 0; j < ITER; ++j) {
      for (u64 k = 0; k < DEGREE; ++k) {
        if (j * DEGREE + k >= N)
          break;
        double error =
            std::abs(scores.data[i * scores.d + j * DEGREE + k] - dmsg[j][k]);
        max_error = std::max(max_error, error);
        mean_error += error / N;
        std_error += error * error / N;
      }
    }
    auto wholeEnd = std::chrono::high_resolution_clock::now();
    duration = std::chrono::duration_cast<std::chrono::milliseconds>(
        wholeEnd - wholeStart);
    std::cout << "E2E: " << duration.count() << "ms" << std::endl;
  }
  std::cout << "Max error: " << max_error << std::endl;
  std::cout << "Mean error: " << mean_error / Q.n << std::endl;
  std::cout << "Std error: " << std_error / Q.n << std::endl;

} catch (std::exception &e) {
  std::cerr << e.what() << '\n';
}
