add matrix_X argument

This commit is contained in:
Niclas
2026-03-11 09:56:15 +01:00
parent 8517c5534d
commit 14b4425570
2 changed files with 20 additions and 8 deletions

View File

@@ -44,6 +44,8 @@ source(here::here("R", "graphon_distribution.R"))
#' @param Fv Cumulative distribution function of the latent variable
#' \eqn{v}. Also has to be vectorised. Typical examples are
#' `pnorm`, `pexp`, ….
#' @param matrix_X matrix with the covariates at each node. Each row corresponds
#' to a single node with p attributes.
#' @param guard Positive numeric guard value. Default is `sqrt(.Machine$double.eps)`,
#' which is about `1.5e8` on most platforms small enough to be negligible
#' for most computations. If it is null, then it is not used.
@@ -107,6 +109,7 @@ compute_matrix <- function(
sample_X_fn,
fv,
Fv,
matrix_X = NULL,
guard = sqrt(.Machine$double.eps),
scaled = FALSE
) {
@@ -118,14 +121,21 @@ compute_matrix <- function(
if (!is.function(sample_X_fn)) stop("'sample_X_fn' must be a function")
if (!is.function(fv)) stop("'f_v' must be a function")
if (!is.function(Fv)) stop("'F_v' must be a function")
if (!is.null(matrix_X) && !is.matrix(matrix_X)) stop("matrix_X must be either null or a matrix")
## 1.2 Generate the Matrix X of covariates ===================================
# The withr environment allows us to capsulate the global state like the seed
# and enables a better reproduction
X <- withr::with_seed(seed, {
as.matrix(sample_X_fn(n))
})
if (nrow(X) != n) stop("`sample_X_fn` must return exactly `n` rows")
# If the argument matrix_X is present, use this matrix, otherwise generate one
# with sample_X_fn.
if (!is.null(matrix_X)) {
X <- matrix_X
} else {
# The withr environment allows us to encapsulate the global state like the seed
# and enables a better reproduction
X <- withr::with_seed(seed, {
as.matrix(sample_X_fn(n))
})
}
if (nrow(X) != n) stop(" the covariate matrix `X` must have exactly `n` rows")
if (ncol(X) != length(a)) {
stop("Number of columns of X (", ncol(X), ") must equal length(a) (", length(a), ")")
}