#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
#include <cmath>

using namespace Rcpp;

// [[Rcpp::export]]
arma::sp_mat sample_graphon_airoldi(const arma::rowvec& x, double rho = 1) {
  int n = x.n_elem;
  arma::sp_mat A(n, n);
  arma::mat f(n, n);
  for (int i = 0; i < n; ++i) {
    for (int j = i + 1; j < n; ++j) {
      double f_ij = rho * (x(i)*x(i) + x(j)*x(j) + sqrt(x(i)) + sqrt(x(j)) )/4;
      //f(i, j) = f_ij;
      //f(j, i) = f_ij;
      //double f_ij = 1/(1 + exp(-10*(x(i)*x(i) + x(j) *x(j)  ))      );
      double prob = f_ij;
      if (R::runif(0, 1) < prob) {
        A(i, j) = 1;
        A(j, i) = 1; // Assuming undirected graph
      }
    }
  }
  
  return A; // List::create(Named("A") = A, Named("f") = f);
}

/*
arma::sp_mat sample_graphon_airoldi_alt(const arma::rowvec& x, const arma::rowvec& S, double m,
                                        double c = 1,
                                        double rho = 1) {
  int n = x.n_elem;
  arma::sp_mat A(n, n);
  double eps_2 = c/(5*log(m));
  for (int i = 0; i < n; ++i) {
    for (int j = i + 1; j < n; ++j) {
      double prob = rho * (x(i)*x(i) + x(j)*x(j) + sqrt(x(i)) + sqrt(x(j)) )/4;
      if (R::runif(0, 1) < prob) {
        A(i, j) = 1;
        A(j, i) = 1; // Assuming undirected graph
      }
    }
  }
  // second pass, through S x S
  int S_len = S.n_elem;
  for (int i = 0; i < S_len; ++i) {
    for (int j = i + 1; j < S_len; ++j) {
      double prob = rho * (x(i)*x(i) + x(j)*x(j) + sqrt(x(i)) + sqrt(x(j)) )/4;
      if (R::runif(0, 1) < prob - eps_2) {
        A(i, j) = 1;
        A(j, i) = 1; // Assuming undirected graph
      } else {
        A(i, j) = 0;
        A(j, i) = 0;
      }
    }
  }
  return A;
}
*/

// [[Rcpp::export]]
arma::sp_mat sample_graphon_airoldi_alt(const arma::rowvec& x,
                                        double rho = 1,
                                        double eps = 0.01,
                                        double delta = 0.1) {
  int n = x.n_elem;
  arma::sp_mat A(n, n);

  for (int i = 0; i < n; ++i) {
    for (int j = i + 1; j < n; ++j) {
      double prob = rho * (x(i)*x(i) + x(j)*x(j) + sqrt(x(i)) + sqrt(x(j)) )/4;
      if ((x(i) < 0.5 + delta) && (x(i) > 0.5 - delta) && (x(j) < 0.5 + delta) && (x(j) > 0.5 - delta)){
        prob = prob + eps * rho;
      }
      if (R::runif(0, 1) < prob) {
        A(i, j) = 1;
        A(j, i) = 1; // Assuming undirected graph
      }
    }
  }
  return A;
}