#include <cmath>
#include <cstdint>
#include <immintrin.h>
#include <iostream>
#include <vector>

#include "cnpy.h"

#include "evd/Ciphertext.hpp"
#include "evd/Client.hpp"
#include "evd/Const.hpp"
#include "evd/MLWECiphertext.hpp"
#include "evd/Message.hpp"
#include "evd/SecretKey.hpp"
#include "evd/Server.hpp"
#include "evd/SwitchingKey.hpp"

using namespace evd;

constexpr u64 LOG_RANK = 9;
constexpr u64 RANK = 1ULL << LOG_RANK;
constexpr u64 STACK = DEGREE / RANK;

constexpr u64 N = 100000;
const u64 ITER = (N + DEGREE - 1) / DEGREE;

int main(int argc, char *argv[]) {
  Client client(LOG_RANK);

  // Gen HE Keys
  SecretKey secKey;
  SwitchingKey relinKey;
  std::vector<std::vector<SwitchingKey>> autedModPackKeys(RANK);
  std::vector<std::vector<MLWESwitchingKey>> autedModPackMLWEKeys(RANK);

  client.genSecKey(secKey);
  client.genRelinKey(relinKey, secKey);
  for (u64 i = 0; i < RANK; ++i) {
    client.genAutedModPackKeys(autedModPackKeys[i], secKey, 2 * i + 1);
    client.genInvAutedModPackKeys(autedModPackMLWEKeys[i], secKey, 2 * i + 1);
  }
  const double scale = std::pow(2.0, LOG_SCALE);

  cnpy::NpyArray B = cnpy::npy_load(argv[1]);

  Server server(LOG_RANK, relinKey, autedModPackKeys, autedModPackMLWEKeys);

  std::vector<std::vector<Ciphertext>> keyCaches(ITER);
#pragma omp parallel for
  for (u64 i = 0; i < ITER; ++i) {
    std::vector<MLWECiphertext> keys;
    keys.reserve(DEGREE);

    for (u64 j = 0; j < DEGREE; ++j) {
      Message msg(RANK);
      if (i * DEGREE + j < N) {
        for (u64 k = 0; k < RANK; ++k)
          msg[k] = _cvtsh_ss(B.data<uint16_t>()[(i * DEGREE + j) * RANK + k]);
      }
      keys.emplace_back(RANK);
      client.encryptKey(keys[j], msg, secKey, scale);
    }
    keyCaches[i].resize(RANK);
    server.cacheKeys(keyCaches[i], keys);
  }

  // Evaluate Inner Product

  cnpy::NpyArray Q = cnpy::npy_load(argv[2]);

  double max_error = 0.0;
  double mean_error = 0.0;
  double std_error = 0.0;

  double recall1 = 0.0, recall5 = 0.0, mrr = 0.0;

  u64 n = 10000; // 10K
  std::vector<std::vector<int>> topK(n);
  for (u64 i = 0; i < n; ++i) {
    std::cout << i << std::endl;
    Message msg(RANK);
    MLWECiphertext query(RANK);
    std::vector<Ciphertext> queryCache(RANK);
    std::vector<Ciphertext> res(ITER);

    for (u64 j = 0; j < RANK; ++j)
      msg[j] = _cvtsh_ss(Q.data<uint16_t>()[i * RANK + j]);
    client.encryptQuery(query, msg, secKey, scale);

    server.cacheQuery(queryCache, query);
    for (u64 j = 0; j < ITER; ++j)
      server.innerProduct(res[j], queryCache, keyCaches[j]);

    const double doubleScale = std::pow(2.0, 2 * LOG_SCALE);

    // Decrypt
    std::vector<Message> dmsg;
    dmsg.reserve(ITER);

    int max_idx = 0;
    double max_score = -1;
    for (u64 j = 0; j < ITER; ++j)
      dmsg.emplace_back(DEGREE);
    client.decryptScore(dmsg, res, secKey, doubleScale);
    topK[i].resize(10);
    client.topKScore(topK[i], dmsg, 10);
    for (u64 j = 0; j < ITER; ++j) {
      for (u64 k = 0; k < DEGREE; ++k) {
        if (j * DEGREE + k >= N)
          break;
        double score = 0.0;
        for (u64 l = 0; l < RANK; ++l)
          score += _cvtsh_ss(Q.data<uint16_t>()[i * RANK + l]) *
                   _cvtsh_ss(B.data<uint16_t>()[(j * DEGREE + k) * RANK + l]);
        if (score > max_score) {
          max_score = score;
          max_idx = j * DEGREE + k;
        }
        double error = std::abs(score - dmsg[j][k]);
        max_error = std::max(max_error, error);
        mean_error += error / N;
        std_error += error * error / N;
      }
    }
    for (u64 j = 0; j < 10; ++j) {
      if (max_idx == topK[i][j]) {
        if (j == 0)
          recall1 += 1;
        if (j < 5)
          recall5 += 1;
        mrr += 1.0 / (j + 1);
      }
    }
  }

  std::cout << "max error: " << max_error << std::endl;
  std::cout << "mean error: " << mean_error / n << std::endl;
  std::cout << "std error: " << std_error / n << std::endl;

  std::cout << "mrr: " << mrr / n << std::endl;
  std::cout << "recall1: " << recall1 / n << std::endl;
  std::cout << "recall5: " << recall5 / n << std::endl;
}