#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/HEval.hpp"
#include "evd/Message.hpp"
#include "evd/PIRServer.hpp"
#include "evd/Polynomial.hpp"
#include "evd/Random.hpp"
#include "evd/SecretKey.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

constexpr u64 LOG_RANK = 10;
constexpr double FIRST_SCALE = 25.25;
constexpr double SECOND_SCALE = 25.25;

constexpr u64 IDX = 1024;

constexpr u64 RANK = 1ULL << LOG_RANK;

int main() {
  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<SwitchingKey> invAutKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  client.genInvAutKeys(invAutKeys, secKey, RANK);

  // Gen Query & Keys
  std::vector<Polynomial> db;

  db.reserve(RANK * RANK);

  HEval eval(LOG_RANK);

  for (u64 i = 0; i < RANK * RANK; ++i) {
    db.emplace_back(DEGREE, MOD_Q);
    Random::sampleUniform(db[i]);
    for (u64 j = 0; j < DEGREE; ++j) {
      db[i][j] &= 3;
      db[i][j] = (db[i][j] > 1) ? (MOD_Q - db[i][j] + 1) : db[i][j];
    }
    eval.ntt(db[i], db[i]);
  }

  Ciphertext firstDim, secondDim, res;

  client.encryptPIR(firstDim, IDX / RANK, secKey, std::pow(2.0, FIRST_SCALE));
  client.encryptPIR(secondDim, IDX % RANK, secKey, std::pow(2.0, SECOND_SCALE));

  // Evaluate Inner Product
  PIRServer server(LOG_RANK, relinKey, invAutKeys);

  server.pir(res, firstDim, secondDim, db);

  // Decrypt
  Message dmsg(DEGREE);
  client.decrypt(dmsg, res, secKey, std::pow(2.0, FIRST_SCALE + SECOND_SCALE));

  Polynomial answer(DEGREE, MOD_Q);
  eval.intt(answer, db[IDX]);
  double max_error = 0.0;
  for (u64 i = 0; i < DEGREE; ++i) {
    double ans = answer[i] > MOD_Q / 2 ? static_cast<double>(answer[i]) - MOD_Q
                                       : answer[i];
    max_error = std::max(max_error, std::abs(ans - dmsg[i]));
  }
  std::cout << "max error: " << max_error << std::endl;
}