#include <asio.hpp>
#include <asio/bind_executor.hpp>
#include <asio/error_code.hpp>
#include <asio/write.hpp>
#include <iostream>
#include <memory>

#include "evd/Ciphertext.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Server.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

using asio::ip::tcp;

constexpr u64 LOG_RANK = 7;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;
constexpr u64 N = 1000000;
constexpr u64 ITER = (N + DEGREE - 1) / DEGREE;

class Session : public std::enable_shared_from_this<Session> {
  tcp::socket sock_;

  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys;
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys;

  Server *server;

  std::vector<std::vector<Ciphertext>> keyCaches;

  MLWECiphertext query;
  std::vector<Ciphertext> queryCache;

  std::vector<Ciphertext> res;

  asio::strand<asio::basic_socket<asio::ip::tcp>::executor_type> st;

public:
  explicit Session(tcp::socket s)
      : sock_(std::move(s)), autedModPackKeys(RANK), autedModPackMLWEKeys(RANK),
        keyCaches(ITER), query(RANK), queryCache(RANK), res(ITER),
        st(sock_.get_executor()) {}

  void start() {
    asio::read(sock_, asio::buffer(relinKey.getPolyAModQ().getData(),
                                   DEGREE * sizeof(u64)));
    relinKey.getPolyAModQ().setIsNTT(true);
    asio::read(sock_, asio::buffer(relinKey.getPolyAModP().getData(),
                                   DEGREE * sizeof(u64)));
    relinKey.getPolyAModP().setIsNTT(true);
    asio::read(sock_, asio::buffer(relinKey.getPolyBModQ().getData(),
                                   DEGREE * sizeof(u64)));
    relinKey.getPolyBModQ().setIsNTT(true);
    asio::read(sock_, asio::buffer(relinKey.getPolyBModP().getData(),
                                   DEGREE * sizeof(u64)));
    relinKey.getPolyBModP().setIsNTT(true);
    for (u64 i = 0; i < RANK; ++i) {
      autedModPackKeys[i].resize(STACK);
      for (u64 j = 0; j < STACK; ++j) {
        asio::read(sock_,
                   asio::buffer(autedModPackKeys[i][j].getPolyAModQ().getData(),
                                DEGREE * sizeof(u64)));
        autedModPackKeys[i][j].getPolyAModQ().setIsNTT(true);
        asio::read(sock_,
                   asio::buffer(autedModPackKeys[i][j].getPolyAModP().getData(),
                                DEGREE * sizeof(u64)));
        autedModPackKeys[i][j].getPolyAModP().setIsNTT(true);
        asio::read(sock_,
                   asio::buffer(autedModPackKeys[i][j].getPolyBModQ().getData(),
                                DEGREE * sizeof(u64)));
        autedModPackKeys[i][j].getPolyBModQ().setIsNTT(true);
        asio::read(sock_,
                   asio::buffer(autedModPackKeys[i][j].getPolyBModP().getData(),
                                DEGREE * sizeof(u64)));
        autedModPackKeys[i][j].getPolyBModP().setIsNTT(true);
      }
    }
    for (u64 i = 0; i < RANK; ++i) {
      autedModPackMLWEKeys[i].reserve(STACK);
      for (u64 j = 0; j < STACK; ++j) {
        autedModPackMLWEKeys[i].emplace_back(RANK);
        for (u64 k = 0; k < STACK; ++k) {
          asio::read(
              sock_,
              asio::buffer(autedModPackMLWEKeys[i][j].getPolyAModQ(k).getData(),
                           RANK * sizeof(u64)));
          autedModPackMLWEKeys[i][j].getPolyAModQ(k).setIsNTT(true);
          asio::read(
              sock_,
              asio::buffer(autedModPackMLWEKeys[i][j].getPolyAModP(k).getData(),
                           RANK * sizeof(u64)));
          autedModPackMLWEKeys[i][j].getPolyAModP(k).setIsNTT(true);

          asio::read(
              sock_,
              asio::buffer(autedModPackMLWEKeys[i][j].getPolyBModQ(k).getData(),
                           RANK * sizeof(u64)));
          autedModPackMLWEKeys[i][j].getPolyBModQ(k).setIsNTT(true);
          asio::read(
              sock_,
              asio::buffer(autedModPackMLWEKeys[i][j].getPolyBModP(k).getData(),
                           RANK * sizeof(u64)));
          autedModPackMLWEKeys[i][j].getPolyBModP(k).setIsNTT(true);
        }
      }
    }
    std::cout << "Eval Keys Received" << std::endl;
    server =
        new Server(LOG_RANK, relinKey, autedModPackKeys, autedModPackMLWEKeys);
    std::cout << "Server Initialized" << std::endl;
    for (u64 i = 0; i < ITER; ++i) {
      std::vector<MLWECiphertext> keys;
      keys.reserve(DEGREE);

      for (u64 j = 0; j < DEGREE; ++j) {
        keys.emplace_back(RANK);
        for (u64 k = 0; k < STACK; ++k)
          asio::read(sock_, asio::buffer(keys[j].getA(k).getData(),
                                         RANK * sizeof(u64)));
        asio::read(sock_,
                   asio::buffer(keys[j].getB().getData(), RANK * sizeof(u64)));
      }
      keyCaches[i].resize(RANK);
      server->cacheKeys(keyCaches[i], keys);
    }
    std::cout << "DB Keys Received" << std::endl;
    for (u64 i = 0;; ++i) {
      std::cout << i << std::endl;
      do_read();
    }
  }

private:
  void do_read() {
    for (u64 i = 0; i < STACK; ++i)
      asio::read(sock_,
                 asio::buffer(query.getA(i).getData(), RANK * sizeof(u64)));
    asio::read(sock_, asio::buffer(query.getB().getData(), RANK * sizeof(u64)));
    server->cacheQuery(queryCache, query);
    for (u64 i = 0; i < ITER; ++i) {
      server->innerProduct(res[i], queryCache, keyCaches[i]);
      asio::async_write(
          sock_, asio::buffer(res[i].getA().getData(), DEGREE * sizeof(u64)),
          asio::bind_executor(st, [](auto, size_t) {}));
      asio::async_write(
          sock_, asio::buffer(res[i].getB().getData(), DEGREE * sizeof(u64)),
          asio::bind_executor(st, [](auto, size_t) {}));
    }
  }
};

int main() {
  asio::io_context io;
  tcp::acceptor acc(io, {tcp::v4(), 9000});

  auto do_accept = [&](auto &&self) {
    acc.async_accept([&](auto ec, tcp::socket s) {
      if (!ec)
        std::make_shared<Session>(std::move(s))->start();
    });
  };
  do_accept(do_accept);
  io.run();
  while (true) {
  };
}
