
rowsum_constrmat <- \(i, K, N) {
    n <- N^2
    ri <- rep(seq_len(N), each = N)
    ci <- do.call(
        "c", seq_len(N) |> lapply(\(j) n * (i - 1) + N * (j - 1) + seq_len(N))
    )
    slam::simple_triplet_matrix(ri, ci, rep(1, length(ri)), nrow = N, ncol = n * K)
}

colsum_constrmat <- \(i, K, N) {
    n <- N^2
    ri <- rep(seq_len(N), each = N)
    ci <- do.call(
        "c", seq_len(N) |> lapply(\(j) n * (i - 1) + j + (seq_len(N) - 1) * N)
    )
    slam::simple_triplet_matrix(ri, ci, rep(1, length(ri)), nrow = N, ncol = n * K)
}

ot_barycenter_constrmat <- \(K, N) {
    do.call("rbind", lapply(seq_len(K), rowsum_constrmat, K, N)) |>
        rbind(do.call("rbind", lapply(1 + seq_len(K - 1), \(i) colsum_constrmat(i - 1, K, N)[-N, ] - colsum_constrmat(i, K, N)[-N, ])))
}

ot_cost_constrmat <- \(N) {
    rbind(rowsum_constrmat(1, 1, N), colsum_constrmat(1, 1, N))
}

ot_cost_constrmat_eq <- \(N) {
    rowsum_constrmat(1, 1, N) - colsum_constrmat(1, 1, N)
}
