#include <chrono>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <vector>

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Message.hpp"
#include "evd/Random.hpp"
#include "evd/SecretKey.hpp"
#include "evd/Server.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

constexpr u64 LOG_RANK = 10;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;
constexpr u64 N = DEGREE;

int main() {
  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys(RANK);
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  for (u64 i = 0; i < RANK; ++i) {
    client.genAutedModPackKeys(autedModPackKeys[i], secKey, 2 * i + 1);
    client.genInvAutedModPackKeys(autedModPackMLWEKeys[i], secKey, 2 * i + 1);
  }

  // Gen Query & Keys
  Message queryMsg(RANK);
  std::vector<Message> keyMsgs;

  double sum = 0;
  for (u64 i = 0; i < RANK; ++i) {
    queryMsg[i] = static_cast<int8_t>(Random::getRandomU8()) / 128.0;
    sum += queryMsg[i] * queryMsg[i];
  }
  sum = std::sqrt(sum);
  for (u64 i = 0; i < RANK; ++i)
    queryMsg[i] /= sum;

  keyMsgs.reserve(DEGREE);
  for (u64 i = 0; i < DEGREE; ++i) {
    keyMsgs.emplace_back(RANK);
    // sum = 0;
    // for (u64 j = 0; j < RANK; ++j) {
    //   keyMsgs[i][j] = static_cast<int8_t>(Random::getRandomU8()) / 128.0;
    //   sum += keyMsgs[i][j] * keyMsgs[i][j];
    // }
    // sum = std::sqrt(sum);
    // for (u64 j = 0; j < RANK; ++j)
    //   keyMsgs[i][j] /= sum;
  }

  const double scale = std::pow(2.0, LOG_SCALE);

  MLWECiphertext query(RANK);
  std::vector<MLWECiphertext> keys;

  client.encryptQuery(query, queryMsg, secKey, scale);

  keys.reserve(N);
  for (u64 i = 0; i < N; ++i) {
    keys.emplace_back(RANK);
    client.encryptKey(keys[i], keyMsgs[i], secKey, scale);
  }

  // Client send query and key to Server

  // Evaluate Inner Product
  Server server(LOG_RANK, relinKey, autedModPackKeys, autedModPackMLWEKeys);

  std::vector<Ciphertext> queryCache(RANK), keyCache(RANK * N / DEGREE);
  Ciphertext res;

  server.cacheQuery(queryCache, query);
  auto start = std::chrono::high_resolution_clock::now();
  for (u64 i = 0; i < 100; ++i)
    server.cacheKeys(keyCache, keys);
  auto end = std::chrono::high_resolution_clock::now();
  auto duration =
      std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
  std::cout << duration.count() / 100 << "ms" << std::endl;

  server.innerProduct(res, queryCache, keyCache);

  const double doubleScale = std::pow(2.0, 2 * LOG_SCALE);

  // Decrypt
  Message dmsg(DEGREE);
  client.decrypt(dmsg, res, secKey, doubleScale);

  Message answer(DEGREE);
  for (u64 i = 0; i < N; ++i) {
    for (u64 j = 0; j < RANK; ++j)
      answer[i] += queryMsg[j] * keyMsgs[i][j];
  }
  double max_error = 0.0;
  for (u64 i = 0; i < N; ++i) {
    max_error = std::max(max_error, std::abs(answer[i] - dmsg[i]));
  }
  std::cout << "max error: " << max_error << std::endl;
}