#include <cmath>
#include <cstdint>
#include <iostream>
#include <vector>

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Message.hpp"
#include "evd/Polynomial.hpp"
#include "evd/Random.hpp"
#include "evd/SecretKey.hpp"
#include "evd/Server.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

constexpr u64 LOG_RANK = 7;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;

int main() {
  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys(RANK);
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  for (u64 i = 0; i < RANK; ++i) {
    client.genAutedModPackKeys(autedModPackKeys[i], secKey, 2 * i + 1);
    client.genInvAutedModPackKeys(autedModPackMLWEKeys[i], secKey, 2 * i + 1);
  }

  // Gen Query & Keys
  Message queryMsg(RANK);
  std::vector<Message> keyMsgs;

  double sum = 0;
  for (u64 i = 0; i < RANK; ++i) {
    queryMsg[i] = static_cast<int8_t>(Random::getRandomU8()) / 128.0;
    sum += queryMsg[i] * queryMsg[i];
  }
  sum = std::sqrt(sum);
  for (u64 i = 0; i < RANK; ++i)
    queryMsg[i] /= sum;

  keyMsgs.reserve(DEGREE);
  for (u64 i = 0; i < DEGREE; ++i) {
    keyMsgs.emplace_back(RANK);
    sum = 0;
    for (u64 j = 0; j < RANK; ++j) {
      keyMsgs[i][j] = static_cast<int8_t>(Random::getRandomU8()) / 128.0;
      sum += keyMsgs[i][j] * keyMsgs[i][j];
    }
    sum = std::sqrt(sum);
    for (u64 j = 0; j < RANK; ++j)
      keyMsgs[i][j] /= sum;
  }

  const double ctxtScale = std::pow(2.0, 35);
  const double ptxtScale = std::pow(2.0, 16.5);

  Polynomial query(RANK, MOD_Q);
  std::vector<MLWECiphertext> keys;

  client.encodeQuery(query, queryMsg, ptxtScale);

  keys.reserve(DEGREE);
  for (u64 i = 0; i < DEGREE; ++i) {
    keys.emplace_back(RANK);
    client.encryptKey(keys[i], keyMsgs[i], secKey, ctxtScale);
  }

  // Evaluate Inner Product
  Server server(LOG_RANK, relinKey, autedModPackKeys, autedModPackMLWEKeys);
  std::vector<Polynomial> queryCache;
  std::vector<Ciphertext> keyCache(RANK);
  Ciphertext res;

  queryCache.reserve(RANK);
  for (u64 i = 0; i < RANK; ++i)
    queryCache.emplace_back(DEGREE, MOD_Q);

  server.cacheQuery(queryCache, query);

  server.cacheKeys(keyCache, keys);

  server.innerProduct(res, keyCache, queryCache);

  const double doubleScale = std::pow(2.0, 2 * LOG_SCALE);

  // Decrypt
  Message dmsg(DEGREE);
  client.decrypt(dmsg, res, secKey, doubleScale);

  Message answer(DEGREE);
  for (u64 i = 0; i < DEGREE; ++i) {
    for (u64 j = 0; j < RANK; ++j)
      answer[i] += queryMsg[j] * keyMsgs[i][j];
  }
  double max_error = 0.0;
  for (u64 i = 0; i < DEGREE; ++i)
    max_error = std::max(max_error, std::abs(answer[i] - dmsg[i]));
  std::cout << "max error: " << max_error << std::endl;
}