smc_estimator.c (12788B)
1 /* Copyright (C) 2015-2018, 2021-2023 |Méso|Star> (contact@meso-star.com) 2 * 3 * This program is free software: you can redistribute it and/or modify 4 * it under the terms of the GNU General Public License as published by 5 * the Free Software Foundation, either version 3 of the License, or 6 * (at your option) any later version. 7 * 8 * This program is distributed in the hope that it will be useful, 9 * but WITHOUT ANY WARRANTY; without even the implied warranty of 10 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 * GNU General Public License for more details. 12 * 13 * You should have received a copy of the GNU General Public License 14 * along with this program. If not, see <http://www.gnu.org/licenses/>. */ 15 16 #include "smc.h" 17 #include "smc_device_c.h" 18 #include "smc_type_c.h" 19 20 #include <rsys/mem_allocator.h> 21 22 #include <limits.h> 23 #include <omp.h> 24 #include <string.h> 25 26 struct smc_estimator { 27 struct smc_type type; 28 void* value; 29 void* square_value; 30 size_t nsamples; 31 size_t nfailed; 32 33 struct smc_estimator_status status; 34 35 struct smc_device* dev; 36 ref_T ref; 37 }; 38 39 /******************************************************************************* 40 * Helper functions 41 ******************************************************************************/ 42 static res_T 43 estimator_create 44 (struct smc_device* dev, 45 const struct smc_type* type, 46 void* ctx, 47 struct smc_estimator** out_estimator) 48 { 49 struct smc_estimator* estimator = NULL; 50 res_T res = RES_OK; 51 ASSERT(out_estimator && dev && type); 52 53 estimator = MEM_CALLOC(dev->allocator, 1, sizeof(struct smc_estimator)); 54 if(!estimator) { 55 res = RES_MEM_ERR; 56 goto error; 57 } 58 SMC(device_ref_get(dev)); 59 estimator->dev = dev; 60 ref_init(&estimator->ref); 61 62 #define TYPE_CREATE(Dst) { \ 63 (Dst) = type->create(dev->allocator, ctx); \ 64 if(!(Dst)) { \ 65 res = RES_MEM_ERR; \ 66 goto error; \ 67 } \ 68 type->zero((Dst)); \ 69 } (void)0 70 TYPE_CREATE(estimator->value); 71 TYPE_CREATE(estimator->square_value); 72 TYPE_CREATE(estimator->status.E); 73 TYPE_CREATE(estimator->status.V); 74 TYPE_CREATE(estimator->status.SE); 75 #undef TYPE_CREATE 76 estimator->nsamples = 0; 77 estimator->nfailed = 0; 78 estimator->status.N = 0; 79 estimator->status.NF = 0; 80 estimator->type = *type; 81 82 exit: 83 *out_estimator = estimator; 84 return res; 85 error: 86 if(estimator) { 87 SMC(estimator_ref_put(estimator)); 88 estimator = NULL; 89 } 90 goto exit; 91 } 92 93 static char 94 check_integrator(struct smc_integrator* integrator) 95 { 96 ASSERT(integrator); 97 return integrator->integrand 98 && integrator->type 99 && integrator->max_realisations 100 && check_type(integrator->type); 101 } 102 103 static void 104 estimator_release(ref_T* ref) 105 { 106 struct smc_estimator* estimator; 107 struct smc_device* dev; 108 ASSERT(ref); 109 110 estimator = CONTAINER_OF(ref, struct smc_estimator, ref); 111 dev = estimator->dev; 112 if(estimator->value) 113 estimator->type.destroy(dev->allocator, estimator->value); 114 if(estimator->square_value) 115 estimator->type.destroy(dev->allocator, estimator->square_value); 116 if(estimator->status.E) 117 estimator->type.destroy(dev->allocator, estimator->status.E); 118 if(estimator->status.V) 119 estimator->type.destroy(dev->allocator, estimator->status.V); 120 if(estimator->status.SE) 121 estimator->type.destroy(dev->allocator, estimator->status.SE); 122 MEM_RM(dev->allocator, estimator); 123 SMC(device_ref_put(dev)); 124 } 125 126 /******************************************************************************* 127 * Exported functions 128 ******************************************************************************/ 129 res_T 130 smc_solve 131 (struct smc_device* dev, 132 struct smc_integrator* integrator, 133 void* ctx, 134 struct smc_estimator** out_estimator) 135 { 136 struct smc_estimator* estimator = NULL; 137 int64_t i; 138 unsigned nthreads = 0; 139 char* raw = NULL; 140 void** vals = NULL; 141 void** accums = NULL; 142 void** accums_sqr = NULL; 143 size_t* nsamples = NULL; 144 size_t nfailed = 0; 145 int progress = 0; 146 ATOMIC cancel = 0; 147 ATOMIC nsolved_realisations = 0; 148 res_T res = RES_OK; 149 150 if(!dev || !integrator || !out_estimator || !check_integrator(integrator)) { 151 res = RES_BAD_ARG; 152 goto error; 153 } 154 SMC(device_get_threads_count(dev, &nthreads)); 155 156 /* Create per thread temporary variables */ 157 raw = MEM_CALLOC(dev->allocator, nthreads, sizeof(void*)*3 + sizeof(size_t)); 158 if(!raw) { 159 res = RES_MEM_ERR; 160 goto error; 161 } 162 vals = (void**) (raw + 0 * sizeof(void*) * nthreads); 163 accums = (void**) (raw + 1 * sizeof(void*) * nthreads); 164 accums_sqr = (void**) (raw + 2 * sizeof(void*) * nthreads); 165 nsamples = (size_t*)(raw + 3 * sizeof(void*) * nthreads); 166 #define TYPE_CREATE(Dst) { \ 167 (Dst) = integrator->type->create(dev->allocator, ctx); \ 168 if(!(Dst)) { \ 169 res = RES_MEM_ERR; \ 170 goto error; \ 171 } \ 172 integrator->type->zero((Dst)); \ 173 } (void)0 174 FOR_EACH(i, 0, (int64_t)nthreads) { 175 TYPE_CREATE(vals[i]); 176 TYPE_CREATE(accums[i]); 177 TYPE_CREATE(accums_sqr[i]); 178 nsamples[i] = 0; 179 } 180 #undef TYPE_CREATE 181 182 /* Create the estimator */ 183 res = estimator_create(dev, integrator->type, ctx, &estimator); 184 if(res != RES_OK) goto error; 185 186 /* Parallel evaluation of the simulation */ 187 log_info(dev, "Solving: %3d%%\r", progress); 188 #pragma omp parallel for schedule(static) 189 for(i = 0; i < (int64_t)integrator->max_realisations; ++i) { 190 const int ithread = omp_get_thread_num(); 191 int64_t n = 0; 192 int pcent = 0; 193 res_T res_local = RES_OK; 194 195 if(ATOMIC_GET(&cancel)) continue; 196 197 res_local = integrator->integrand 198 (vals[ithread], dev->rngs[ithread], (unsigned)ithread, (uint64_t)i, ctx); 199 200 if(res_local != RES_OK) { 201 #pragma omp critical 202 { 203 nfailed += 1; 204 if(nfailed > integrator->max_failures) { 205 ATOMIC_SET(&cancel, 1); 206 } 207 } 208 continue; 209 } 210 211 /* call succeded */ 212 integrator->type->add(accums[ithread], accums[ithread], vals[ithread]); 213 integrator->type->mul(vals[ithread], vals[ithread], vals[ithread]); 214 integrator->type->add(accums_sqr[ithread], accums_sqr[ithread], vals[ithread]); 215 ++nsamples[ithread]; 216 217 n = ATOMIC_INCR(&nsolved_realisations); 218 pcent = (int) 219 ((double)n * 100.0 / (double)integrator->max_realisations + 0.5/*round*/); 220 #pragma omp critical 221 if(pcent > progress) { 222 progress = pcent; 223 log_info(dev, "Solving: %3d%%\r", progress); 224 } 225 } 226 log_info(dev, "Solving: %3d%%\n", progress); 227 228 /* Merge the parallel estimation into the final estimator */ 229 FOR_EACH(i, 0, (int64_t)nthreads) { 230 estimator->nsamples += nsamples[i]; 231 integrator->type->add(estimator->value, estimator->value, accums[i]); 232 integrator->type->add 233 (estimator->square_value, estimator->square_value, accums_sqr[i]); 234 } 235 estimator->nfailed = nfailed; 236 237 exit: 238 if(raw) { /* Release temporary variables */ 239 FOR_EACH(i, 0, (int64_t)nthreads) { 240 if(vals[i]) integrator->type->destroy(dev->allocator, vals[i]); 241 if(accums[i]) integrator->type->destroy(dev->allocator, accums[i]); 242 if(accums_sqr[i]) integrator->type->destroy(dev->allocator, accums_sqr[i]); 243 } 244 MEM_RM(dev->allocator, raw); 245 } 246 if(out_estimator) *out_estimator = estimator; 247 return res; 248 error: 249 if(estimator) { 250 SMC(estimator_ref_put(estimator)); 251 estimator = NULL; 252 } 253 goto exit; 254 } 255 256 res_T 257 smc_solve_N 258 (struct smc_device* dev, 259 struct smc_integrator* integrator, 260 const size_t count, 261 void* ctx, 262 const size_t sizeof_ctx, 263 struct smc_estimator* estimators[]) 264 { 265 void** vals = NULL; 266 int64_t i; 267 unsigned nthreads = 0; 268 int progress = 0; 269 ATOMIC cancel = 0; 270 ATOMIC nsolved = 0; 271 res_T res = RES_OK; 272 273 if(!estimators) { 274 res = RES_BAD_ARG; 275 goto error; 276 } 277 memset(estimators, 0, sizeof(struct smc_estimator*) * count); 278 279 if(!dev || !integrator || !count || !check_integrator(integrator)) { 280 res = RES_BAD_ARG; 281 goto error; 282 } 283 284 /* Create the estimators */ 285 FOR_EACH(i, 0, (int64_t)count) { 286 res = estimator_create 287 (dev, integrator->type, (char*)ctx + (size_t)i*sizeof_ctx, estimators+i); 288 if(res != RES_OK) goto error; 289 } 290 291 /* Create the per thread temporary variables */ 292 SMC(device_get_threads_count(dev, &nthreads)); 293 vals = MEM_CALLOC(dev->allocator, nthreads, sizeof(void*)); 294 FOR_EACH(i, 0, (int64_t)nthreads) { 295 vals[i] = integrator->type->create 296 (dev->allocator, (char*)ctx + (size_t)i*sizeof_ctx); 297 if(!vals[i]) { 298 res = RES_MEM_ERR; 299 goto error; 300 } 301 } 302 303 /* Parallel estimation of N simulations */ 304 log_info(dev, "Solving: %3d%%\r", progress); 305 #pragma omp parallel for schedule(static, 1) 306 for(i = 0; i < (int64_t)count; ++i) { 307 size_t istep; 308 int64_t n = 0; 309 int pcent = 0; 310 const int ithread = omp_get_thread_num(); 311 res_T res_local = RES_OK; 312 313 if(ATOMIC_GET(&cancel)) continue; 314 315 FOR_EACH(istep, 0, integrator->max_realisations) { 316 if(ATOMIC_GET(&cancel)) break; 317 318 res_local = integrator->integrand 319 (vals[ithread], dev->rngs[ithread], (unsigned)ithread, (uint64_t)i, 320 (char*)ctx + (size_t)i*sizeof_ctx); 321 322 if(res_local != RES_OK) { 323 ++estimators[i]->nfailed; 324 if(estimators[i]->nfailed > integrator->max_failures) { 325 ATOMIC_SET(&cancel, 1); 326 } 327 break; 328 } 329 330 /* call succeded */ 331 integrator->type->add 332 (estimators[i]->value, estimators[i]->value, vals[ithread]); 333 integrator->type->mul 334 (vals[ithread], vals[ithread], vals[ithread]); 335 integrator->type->add 336 (estimators[i]->square_value, estimators[i]->square_value, vals[ithread]); 337 ++estimators[i]->nsamples; 338 } 339 340 n = ATOMIC_INCR(&nsolved); 341 pcent = (int)((double)n * 100.0 / (double)count + 0.5/*round*/); 342 #pragma omp critical 343 if(pcent > progress) { 344 progress = pcent; 345 log_info(dev, "Solving: %3d%%\r", progress); 346 } 347 348 } 349 log_info(dev, "Solving: %3d%%\n", progress); 350 351 exit: 352 if(vals) { 353 FOR_EACH(i, 0, (int64_t)nthreads) { 354 if(vals[i]) integrator->type->destroy(dev->allocator, vals[i]); 355 } 356 MEM_RM(dev->allocator, vals); 357 } 358 return res; 359 error: 360 if(estimators) { 361 FOR_EACH(i, 0, (int64_t)count) { 362 if(estimators[i]) { 363 SMC(estimator_ref_put(estimators[i])); 364 estimators[i] = NULL; 365 } 366 } 367 } 368 goto exit; 369 } 370 371 res_T 372 smc_estimator_ref_get(struct smc_estimator* estimator) 373 { 374 if(!estimator) return RES_BAD_ARG; 375 ref_get(&estimator->ref); 376 return RES_OK; 377 } 378 379 res_T 380 smc_estimator_ref_put(struct smc_estimator* estimator) 381 { 382 if(!estimator) return RES_BAD_ARG; 383 ref_put(&estimator->ref, estimator_release); 384 return RES_OK; 385 } 386 387 res_T 388 smc_estimator_get_status 389 (struct smc_estimator* estimator, 390 struct smc_estimator_status* status) 391 { 392 if(!estimator || !status) 393 return RES_BAD_ARG; 394 395 if(estimator->nsamples != estimator->status.N 396 || estimator->nfailed != estimator->status.NF) 397 { 398 estimator->status.N = estimator->nsamples; 399 estimator->status.NF = estimator->nfailed; 400 401 if(estimator->nsamples > 0) { 402 /* Variance */ 403 estimator->type.divi 404 (estimator->status.E, estimator->square_value, estimator->nsamples); 405 estimator->type.mul 406 (estimator->status.V, estimator->value, estimator->value); 407 estimator->type.divi 408 (estimator->status.V, 409 estimator->status.V, 410 estimator->nsamples * estimator->nsamples); 411 estimator->type.sub 412 (estimator->status.V, estimator->status.E, estimator->status.V); 413 /* Standard error */ 414 estimator->type.divi 415 (estimator->status.SE, estimator->status.V, estimator->nsamples); 416 estimator->type.sqrt(estimator->status.SE, estimator->status.SE); 417 /* Expected value */ 418 estimator->type.divi 419 (estimator->status.E, estimator->value, estimator->nsamples); 420 } 421 else { 422 estimator->type.zero(estimator->status.E); 423 estimator->type.zero(estimator->status.SE); 424 estimator->type.zero(estimator->status.V); 425 } 426 } 427 *status = estimator->status; 428 return RES_OK; 429 }