commit 2ad6e27171a42a7aa5845f68e7547e6b7a2e69c6
parent 0613381152e7772d8b5d350a7870a405642da296
Author: Vincent Forest <vincent.forest@meso-star.com>
Date: Fri, 3 Dec 2021 16:49:31 +0100
Fix the ssp_rng_proxy_read function
Cached RNG states were not cleaned when the proxy state was updated. As
a result, proxy-managed RNGs generated random numbers from the previous
proxy states saved into the cache.
Diffstat:
2 files changed, 154 insertions(+), 13 deletions(-)
diff --git a/src/ssp_rng_proxy.c b/src/ssp_rng_proxy.c
@@ -57,7 +57,7 @@ struct rng_state_cache {
CLBK(rng_proxy_cb_T, ARG1(const struct ssp_rng_proxy*));
enum rng_proxy_sig {
- RNG_PROXY_SIG_READ,
+ RNG_PROXY_SIG_SET_STATE,
RNG_PROXY_SIGS_COUNT__
};
@@ -106,7 +106,7 @@ struct ssp_rng_proxy {
* |######| Bucket 0 | ... | Bucket N-1 |####| Bucket 0 | ...
* |######| 1st pool | | 1st pool |####| 2nd pool |
* +------+------------+- -+------------+----+------------+-
- * \ / \_________sequence_size_______/ /
+ * \ / \_________sequence_size_______/ /
* sequence \________sequence_pitch__________/
* offset
*/
@@ -179,6 +179,17 @@ rng_state_cache_release(struct rng_state_cache* cache)
darray_char_release(&cache->state_scratch);
}
+static void
+rng_state_cache_clear(struct rng_state_cache* cache)
+{
+ ASSERT(cache->stream);
+ rewind(cache->stream);
+ cache->read = cache->write = ftell(cache->stream);
+ cache->nstates = 0;
+ cache->no_wstream = 0;
+ cache->no_rstream = 0;
+}
+
static char
rng_state_cache_is_empty(struct rng_state_cache* cache)
{
@@ -305,11 +316,11 @@ struct rng_bucket {
struct ssp_rng_proxy* proxy; /* The RNG proxy */
size_t name; /* Unique bucket identifier in [0, #buckets) */
size_t count; /* Remaining unique random numbers in `pool' */
- rng_proxy_cb_T cb_on_proxy_read;
+ rng_proxy_cb_T cb_on_proxy_set_state;
};
static void
-rng_bucket_on_proxy_read(const struct ssp_rng_proxy* proxy, void* ctx)
+rng_bucket_on_proxy_set_state(const struct ssp_rng_proxy* proxy, void* ctx)
{
struct rng_bucket* rng = (struct rng_bucket*)ctx;
ASSERT(proxy && ctx && rng->proxy == proxy);
@@ -395,7 +406,7 @@ rng_bucket_release(void* data)
struct rng_bucket* rng = (struct rng_bucket*)data;
ASSERT(data && rng->proxy);
ATOMIC_SET(&rng->proxy->buckets[rng->name], 0);
- CLBK_DISCONNECT(&rng->cb_on_proxy_read);
+ CLBK_DISCONNECT(&rng->cb_on_proxy_set_state);
SSP(rng_proxy_ref_put(rng->proxy));
}
@@ -520,6 +531,19 @@ error:
goto exit;
}
+static void
+rng_proxy_clear_caches(struct ssp_rng_proxy* proxy)
+{
+ size_t ibucket;
+ ASSERT(proxy);
+
+ mutex_lock(proxy->mutex);
+ FOR_EACH(ibucket, 0, sa_size(proxy->pools)) {
+ rng_state_cache_clear(proxy->states+ibucket);
+ }
+ mutex_unlock(proxy->mutex);
+}
+
void
rng_proxy_release(ref_T* ref)
{
@@ -683,7 +707,12 @@ ssp_rng_proxy_read(struct ssp_rng_proxy* proxy, FILE* stream)
mutex_unlock(proxy->mutex);
if(res != RES_OK) return res;
- SIG_BROADCAST(proxy->signals+RNG_PROXY_SIG_READ, rng_proxy_cb_T, ARG1(proxy));
+ /* Discard the cached RNG states */
+ rng_proxy_clear_caches(proxy);
+
+ /* Notify to bucket RNGs that the proxy RNG state was updated */
+ SIG_BROADCAST
+ (proxy->signals+RNG_PROXY_SIG_SET_STATE, rng_proxy_cb_T, ARG1(proxy));
return RES_OK;
}
@@ -752,9 +781,11 @@ ssp_rng_proxy_create_rng
/* The bucket RNG listens the "write" signal of the proxy to reset its
* internal RNs counter on "write" invocation. */
- CLBK_INIT(&bucket->cb_on_proxy_read);
- CLBK_SETUP(&bucket->cb_on_proxy_read, rng_bucket_on_proxy_read, bucket);
- SIG_CONNECT_CLBK(proxy->signals+RNG_PROXY_SIG_READ, &bucket->cb_on_proxy_read);
+ CLBK_INIT(&bucket->cb_on_proxy_set_state);
+ CLBK_SETUP
+ (&bucket->cb_on_proxy_set_state, rng_bucket_on_proxy_set_state, bucket);
+ SIG_CONNECT_CLBK
+ (proxy->signals+RNG_PROXY_SIG_SET_STATE, &bucket->cb_on_proxy_set_state);
exit:
if(out_rng) *out_rng = rng;
diff --git a/src/test_ssp_rng_proxy.c b/src/test_ssp_rng_proxy.c
@@ -395,25 +395,28 @@ test_read(void)
CHK(ssp_rng_create(NULL, SSP_RNG_MT19937_64, &rng) == RES_OK);
CHK(ssp_rng_discard(rng, COUNT) == RES_OK);
+ /* Create a RNG state */
stream = tmpfile();
CHK(stream != NULL);
CHK(ssp_rng_write(rng, stream) == RES_OK);
rewind(stream);
+ /* Create a proxy from the RNG state */
CHK(ssp_rng_proxy_create(NULL, SSP_RNG_MT19937_64, 4, &proxy) == RES_OK);
CHK(ssp_rng_proxy_read(NULL, NULL) == RES_BAD_ARG);
CHK(ssp_rng_proxy_read(proxy, NULL) == RES_BAD_ARG);
CHK(ssp_rng_proxy_read(NULL, stream) == RES_BAD_ARG);
CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK);
+ /* Create the list of RNG managed by the proxy */
FOR_EACH(i, 0, 4) CHK(ssp_rng_proxy_create_rng(proxy, i, &rng1[i]) == RES_OK);
+ /* Check the random number sequences */
r[0] = ssp_rng_get(rng);
CHK(r[0] == ssp_rng_get(rng1[0]));
CHK(r[0] != ssp_rng_get(rng1[1]));
CHK(r[0] != ssp_rng_get(rng1[2]));
CHK(r[0] != ssp_rng_get(rng1[3]));
-
FOR_EACH(i, 0, NRANDS) {
FOR_EACH(j, 0, 4) {
r[j] = ssp_rng_get(rng1[j]);
@@ -424,7 +427,113 @@ test_read(void)
CHK(ssp_rng_proxy_ref_put(proxy) == RES_OK);
CHK(ssp_rng_ref_put(rng) == RES_OK);
FOR_EACH(i, 0, 4) CHK(ssp_rng_ref_put(rng1[i]) == RES_OK);
- fclose(stream);
+
+ CHK(fclose(stream) == 0);
+}
+
+static void
+test_read_with_cached_states(void)
+{
+ struct ssp_rng_proxy_create2_args args = SSP_RNG_PROXY_CREATE2_ARGS_NULL;
+ struct ssp_rng_proxy* proxy;
+ struct ssp_rng* rng;
+ struct ssp_rng* rng0;
+ struct ssp_rng* rng1;
+ FILE* stream;
+ size_t iseq;
+ size_t i;
+
+ CHK(ssp_rng_create(NULL, SSP_RNG_MT19937_64, &rng) == RES_OK);
+
+ /* Create a proxy with a very small sequence size of 2 RNs per bucket in
+ * order to maximize the number of cached states. Furthermore, use the
+ * Mersenne Twister RNG type since its internal state is the greatest one of
+ * the proposed builtin type and is thus the one that will fill quickly the
+ * cache stream. */
+ args.type = SSP_RNG_MT19937_64;
+ args.sequence_size = 4;
+ args.sequence_pitch = 4;
+ args.nbuckets = 2;
+ CHK(ssp_rng_proxy_create2(NULL, &args, &proxy) == RES_OK);
+ CHK(ssp_rng_proxy_create_rng(proxy, 0, &rng0) == RES_OK);
+ CHK(ssp_rng_proxy_create_rng(proxy, 1, &rng1) == RES_OK);
+
+ /* Check RNG states */
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng0));
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng0));
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng1));
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng1));
+
+ /* Discard several RNs for the 1st RNG to cache several states for the rng1 */
+ CHK(ssp_rng_discard(rng0, 100) == RES_OK);
+
+ /* Generate a RNG state */
+ stream = tmpfile();
+ CHK(stream != NULL);
+ CHK(ssp_rng_discard(rng, 100000) == RES_OK);
+ CHK(ssp_rng_write(rng, stream) == RES_OK);
+ rewind(stream);
+
+ /* Setup the proxy state and check that the RNGs managed by the proxy now use
+ * the new state properly, even though RNG states were previously cached */
+ CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK);
+ FOR_EACH(iseq, 0, 100) {
+ FOR_EACH(i, 0, args.sequence_size) {
+ if(i < args.sequence_size/2) {
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng0));
+ } else {
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng1));
+ }
+ }
+ }
+
+ /* Discard several RNs for the first RNG only to make under pressure the
+ * cache stream of 'rng1'. The cache stream limit is set to 32 MB and the
+ * size of a Mersenne Twister RNG state is greater than 6 KB. Consquently,
+ * ~5500 RNG states will exceed the cache stream, i.e. 5500*2 = 11000
+ * random generations (since there is 2 RNs per bucket). Above this limit,
+ * 'rng1' will not rely anymore on the proxy RNG to manage its state. */
+ CHK(ssp_rng_discard(rng0, 20000) == RES_OK);
+
+ /* Generate a new RNG state */
+ rewind(stream);
+ CHK(ssp_rng_discard(rng, 100000) == RES_OK);
+ CHK(ssp_rng_write(rng, stream) == RES_OK);
+ rewind(stream);
+
+ /* Now setup a new proxy state and check that the RNGs managed by the proxy
+ * correctly use the new state */
+ CHK(ssp_rng_proxy_read(proxy, stream) == RES_OK);
+ FOR_EACH(iseq, 0, 100) {
+ FOR_EACH(i, 0, args.sequence_size) {
+ if(i < args.sequence_size/2) {
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng0));
+ } else {
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng1));
+ }
+ }
+ }
+
+ /* Discard several RNs for the 1st RNG to cache several states for the rng1 */
+ CHK(ssp_rng_discard(rng0, 100) == RES_OK);
+
+ /* Finally verify that the cache was reset on proxy read */
+ FOR_EACH(iseq, 0, 100) {
+ FOR_EACH(i, 0, args.sequence_size) {
+ if(i < args.sequence_size/2) {
+ CHK(ssp_rng_discard(rng, 1) == RES_OK);
+ } else {
+ CHK(ssp_rng_get(rng) == ssp_rng_get(rng1));
+ }
+ }
+ }
+
+ CHK(ssp_rng_proxy_ref_put(proxy) == RES_OK);
+ CHK(ssp_rng_ref_put(rng0) == RES_OK);
+ CHK(ssp_rng_ref_put(rng1) == RES_OK);
+ CHK(ssp_rng_ref_put(rng) == RES_OK);
+
+ CHK(fclose(stream) == 0);
}
static void
@@ -460,7 +569,7 @@ test_write(void)
CHK(ssp_rng_ref_put(rng0) == RES_OK);
CHK(ssp_rng_ref_put(rng1) == RES_OK);
- fclose(stream);
+ CHK(fclose(stream) == 0);
}
static void
@@ -484,7 +593,7 @@ test_cache(void)
args.sequence_pitch = 4;
args.nbuckets = 2;
- /* Simply test that the RNs generated by the proxy are the same thant the
+ /* Simply test that the RNs generated by the proxy are the same than the
* ones generated by a regular RNG. Since each RNG invocation are
* interleaved, the cache pressure is very low, at most 1 RNG state is
* cached. */
@@ -550,6 +659,7 @@ main(int argc, char** argv)
test_multi_proxies();
test_proxy_from_rng();
test_read();
+ test_read_with_cached_states();
test_write();
test_cache();
CHK(mem_allocated_size() == 0);