#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;

double sigm(double x){
  double res = 1/(1 + exp(-x));
  return res;
}


double Ker(const arma::rowvec& u, const arma::rowvec& v, 
         const arma::rowvec& x, const arma::rowvec& y, double bw){
  double r = exp( - (arma::norm(u - x) + arma::norm(v - y))/(2 * bw * bw) );
  return r;
}

// [[Rcpp::export]]
double f(const arma::rowvec& u, const arma::rowvec& v, const arma::mat& x, 
         const arma::mat& y, const arma::vec& alpha, double bw){
  double res = 0;
  int dim = x.n_rows;
  for (int i = 0; i < dim; i++){
    res = res + alpha(i) * Ker(u, v, x.row(i), y.row(i), bw);
  }
  return sigm(res);
}

// [[Rcpp::export]]
arma::sp_mat sample_graphon(int n, const arma::mat& x, const arma::mat& y,
                            const arma::vec& alpha,
                            double bw, double rho) {
  int d = x.n_cols;
  arma::sp_mat A(n, n);
  arma::mat V(n, d, arma::fill::randn);


  for (int i = 0; i < n; ++i) {
    for (int j = i + 1; j < n; ++j) {
      double prob = rho * f(V.row(i), V.row(j), x, y, alpha, bw);
      if (R::runif(0, 1) < prob) {
        A(i, j) = 1;
        A(j, i) = 1; // Assuming undirected graph
      }
    }
  }
  
  return A;
}
