star-mc

Parallel estimation of Monte Carlo integrators
git clone git://git.meso-star.fr/star-mc.git
Log | Files | Refs | README | LICENSE

smc_estimator.c (12788B)


      1 /* Copyright (C) 2015-2018, 2021-2023 |Méso|Star> (contact@meso-star.com)
      2  *
      3  * This program is free software: you can redistribute it and/or modify
      4  * it under the terms of the GNU General Public License as published by
      5  * the Free Software Foundation, either version 3 of the License, or
      6  * (at your option) any later version.
      7  *
      8  * This program is distributed in the hope that it will be useful,
      9  * but WITHOUT ANY WARRANTY; without even the implied warranty of
     10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
     11  * GNU General Public License for more details.
     12  *
     13  * You should have received a copy of the GNU General Public License
     14  * along with this program. If not, see <http://www.gnu.org/licenses/>. */
     15 
     16 #include "smc.h"
     17 #include "smc_device_c.h"
     18 #include "smc_type_c.h"
     19 
     20 #include <rsys/mem_allocator.h>
     21 
     22 #include <limits.h>
     23 #include <omp.h>
     24 #include <string.h>
     25 
     26 struct smc_estimator {
     27   struct smc_type type;
     28   void* value;
     29   void* square_value;
     30   size_t nsamples;
     31   size_t nfailed;
     32 
     33   struct smc_estimator_status status;
     34 
     35   struct smc_device* dev;
     36   ref_T ref;
     37 };
     38 
     39 /*******************************************************************************
     40  * Helper functions
     41  ******************************************************************************/
     42 static res_T
     43 estimator_create
     44   (struct smc_device* dev,
     45    const struct smc_type* type,
     46    void* ctx,
     47    struct smc_estimator** out_estimator)
     48 {
     49   struct smc_estimator* estimator = NULL;
     50   res_T res = RES_OK;
     51   ASSERT(out_estimator && dev && type);
     52 
     53   estimator = MEM_CALLOC(dev->allocator, 1, sizeof(struct smc_estimator));
     54   if(!estimator) {
     55     res = RES_MEM_ERR;
     56     goto error;
     57   }
     58   SMC(device_ref_get(dev));
     59   estimator->dev = dev;
     60   ref_init(&estimator->ref);
     61 
     62   #define TYPE_CREATE(Dst) {                                                   \
     63     (Dst) = type->create(dev->allocator, ctx);                                 \
     64     if(!(Dst)) {                                                               \
     65       res = RES_MEM_ERR;                                                       \
     66       goto error;                                                              \
     67     }                                                                          \
     68     type->zero((Dst));                                                         \
     69   } (void)0
     70   TYPE_CREATE(estimator->value);
     71   TYPE_CREATE(estimator->square_value);
     72   TYPE_CREATE(estimator->status.E);
     73   TYPE_CREATE(estimator->status.V);
     74   TYPE_CREATE(estimator->status.SE);
     75   #undef TYPE_CREATE
     76   estimator->nsamples = 0;
     77   estimator->nfailed = 0;
     78   estimator->status.N = 0;
     79   estimator->status.NF = 0;
     80   estimator->type = *type;
     81 
     82 exit:
     83   *out_estimator = estimator;
     84   return res;
     85 error:
     86   if(estimator) {
     87     SMC(estimator_ref_put(estimator));
     88     estimator = NULL;
     89   }
     90   goto exit;
     91 }
     92 
     93 static char
     94 check_integrator(struct smc_integrator* integrator)
     95 {
     96   ASSERT(integrator);
     97   return integrator->integrand
     98       && integrator->type
     99       && integrator->max_realisations
    100       && check_type(integrator->type);
    101 }
    102 
    103 static void
    104 estimator_release(ref_T* ref)
    105 {
    106   struct smc_estimator* estimator;
    107   struct smc_device* dev;
    108   ASSERT(ref);
    109 
    110   estimator = CONTAINER_OF(ref, struct smc_estimator, ref);
    111   dev = estimator->dev;
    112   if(estimator->value)
    113     estimator->type.destroy(dev->allocator, estimator->value);
    114   if(estimator->square_value)
    115     estimator->type.destroy(dev->allocator, estimator->square_value);
    116   if(estimator->status.E)
    117     estimator->type.destroy(dev->allocator, estimator->status.E);
    118   if(estimator->status.V)
    119     estimator->type.destroy(dev->allocator, estimator->status.V);
    120   if(estimator->status.SE)
    121     estimator->type.destroy(dev->allocator, estimator->status.SE);
    122   MEM_RM(dev->allocator, estimator);
    123   SMC(device_ref_put(dev));
    124 }
    125 
    126 /*******************************************************************************
    127  * Exported functions
    128  ******************************************************************************/
    129 res_T
    130 smc_solve
    131   (struct smc_device* dev,
    132    struct smc_integrator* integrator,
    133    void* ctx,
    134    struct smc_estimator** out_estimator)
    135 {
    136   struct smc_estimator* estimator = NULL;
    137   int64_t i;
    138   unsigned nthreads = 0;
    139   char* raw = NULL;
    140   void** vals = NULL;
    141   void** accums = NULL;
    142   void** accums_sqr = NULL;
    143   size_t* nsamples = NULL;
    144   size_t nfailed = 0;
    145   int progress = 0;
    146   ATOMIC cancel = 0;
    147   ATOMIC nsolved_realisations = 0;
    148   res_T res = RES_OK;
    149 
    150   if(!dev || !integrator || !out_estimator || !check_integrator(integrator)) {
    151     res = RES_BAD_ARG;
    152     goto error;
    153   }
    154   SMC(device_get_threads_count(dev, &nthreads));
    155 
    156   /* Create per thread temporary variables */
    157   raw = MEM_CALLOC(dev->allocator, nthreads, sizeof(void*)*3 + sizeof(size_t));
    158   if(!raw) {
    159     res = RES_MEM_ERR;
    160     goto error;
    161   }
    162   vals       = (void**) (raw + 0 * sizeof(void*) * nthreads);
    163   accums     = (void**) (raw + 1 * sizeof(void*) * nthreads);
    164   accums_sqr = (void**) (raw + 2 * sizeof(void*) * nthreads);
    165   nsamples   = (size_t*)(raw + 3 * sizeof(void*) * nthreads);
    166   #define TYPE_CREATE(Dst) {                                                   \
    167     (Dst) = integrator->type->create(dev->allocator, ctx);                     \
    168     if(!(Dst)) {                                                               \
    169       res = RES_MEM_ERR;                                                       \
    170       goto error;                                                              \
    171     }                                                                          \
    172     integrator->type->zero((Dst));                                             \
    173   } (void)0
    174   FOR_EACH(i, 0, (int64_t)nthreads) {
    175     TYPE_CREATE(vals[i]);
    176     TYPE_CREATE(accums[i]);
    177     TYPE_CREATE(accums_sqr[i]);
    178     nsamples[i] = 0;
    179   }
    180   #undef TYPE_CREATE
    181 
    182   /* Create the estimator */
    183   res = estimator_create(dev, integrator->type, ctx, &estimator);
    184   if(res != RES_OK) goto error;
    185 
    186   /* Parallel evaluation of the simulation */
    187   log_info(dev, "Solving: %3d%%\r", progress);
    188   #pragma omp parallel for schedule(static)
    189   for(i = 0; i < (int64_t)integrator->max_realisations; ++i) {
    190     const int ithread = omp_get_thread_num();
    191     int64_t n = 0;
    192     int pcent = 0;
    193     res_T res_local = RES_OK;
    194 
    195     if(ATOMIC_GET(&cancel)) continue;
    196 
    197     res_local = integrator->integrand
    198       (vals[ithread], dev->rngs[ithread], (unsigned)ithread, (uint64_t)i, ctx);
    199 
    200     if(res_local != RES_OK) {
    201       #pragma omp critical
    202       {
    203         nfailed += 1;
    204         if(nfailed > integrator->max_failures) {
    205           ATOMIC_SET(&cancel, 1);
    206         }
    207       }
    208       continue;
    209     }
    210 
    211     /* call succeded */
    212     integrator->type->add(accums[ithread], accums[ithread], vals[ithread]);
    213     integrator->type->mul(vals[ithread], vals[ithread], vals[ithread]);
    214     integrator->type->add(accums_sqr[ithread], accums_sqr[ithread], vals[ithread]);
    215     ++nsamples[ithread];
    216 
    217     n = ATOMIC_INCR(&nsolved_realisations);
    218     pcent = (int)
    219       ((double)n * 100.0 / (double)integrator->max_realisations + 0.5/*round*/);
    220     #pragma omp critical
    221     if(pcent > progress) {
    222       progress = pcent;
    223       log_info(dev, "Solving: %3d%%\r", progress);
    224     }
    225   }
    226   log_info(dev, "Solving: %3d%%\n", progress);
    227 
    228   /* Merge the parallel estimation into the final estimator */
    229   FOR_EACH(i, 0, (int64_t)nthreads) {
    230     estimator->nsamples += nsamples[i];
    231     integrator->type->add(estimator->value, estimator->value, accums[i]);
    232     integrator->type->add
    233       (estimator->square_value, estimator->square_value, accums_sqr[i]);
    234   }
    235   estimator->nfailed = nfailed;
    236 
    237 exit:
    238   if(raw) { /* Release temporary variables */
    239     FOR_EACH(i, 0, (int64_t)nthreads) {
    240       if(vals[i]) integrator->type->destroy(dev->allocator, vals[i]);
    241       if(accums[i]) integrator->type->destroy(dev->allocator, accums[i]);
    242       if(accums_sqr[i]) integrator->type->destroy(dev->allocator, accums_sqr[i]);
    243     }
    244     MEM_RM(dev->allocator, raw);
    245   }
    246   if(out_estimator) *out_estimator = estimator;
    247   return res;
    248 error:
    249   if(estimator) {
    250     SMC(estimator_ref_put(estimator));
    251     estimator = NULL;
    252   }
    253   goto exit;
    254 }
    255 
    256 res_T
    257 smc_solve_N
    258   (struct smc_device* dev,
    259    struct smc_integrator* integrator,
    260    const size_t count,
    261    void* ctx,
    262    const size_t sizeof_ctx,
    263    struct smc_estimator* estimators[])
    264 {
    265   void** vals = NULL;
    266   int64_t i;
    267   unsigned nthreads = 0;
    268   int progress = 0;
    269   ATOMIC cancel = 0;
    270   ATOMIC nsolved = 0;
    271   res_T res = RES_OK;
    272 
    273   if(!estimators) {
    274     res = RES_BAD_ARG;
    275     goto error;
    276   }
    277   memset(estimators, 0, sizeof(struct smc_estimator*) * count);
    278 
    279   if(!dev || !integrator || !count || !check_integrator(integrator)) {
    280     res = RES_BAD_ARG;
    281     goto error;
    282   }
    283 
    284   /* Create the estimators */
    285   FOR_EACH(i, 0, (int64_t)count) {
    286     res = estimator_create
    287       (dev, integrator->type, (char*)ctx + (size_t)i*sizeof_ctx, estimators+i);
    288     if(res != RES_OK) goto error;
    289   }
    290 
    291   /* Create the per thread temporary variables */
    292   SMC(device_get_threads_count(dev, &nthreads));
    293   vals = MEM_CALLOC(dev->allocator, nthreads, sizeof(void*));
    294   FOR_EACH(i, 0, (int64_t)nthreads) {
    295     vals[i] = integrator->type->create
    296       (dev->allocator, (char*)ctx + (size_t)i*sizeof_ctx);
    297     if(!vals[i]) {
    298       res = RES_MEM_ERR;
    299       goto error;
    300     }
    301   }
    302 
    303   /* Parallel estimation of N simulations */
    304   log_info(dev, "Solving: %3d%%\r", progress);
    305   #pragma omp parallel for schedule(static, 1)
    306   for(i = 0; i < (int64_t)count; ++i) {
    307     size_t istep;
    308     int64_t n = 0;
    309     int pcent = 0;
    310     const int ithread = omp_get_thread_num();
    311     res_T res_local = RES_OK;
    312 
    313     if(ATOMIC_GET(&cancel)) continue;
    314 
    315     FOR_EACH(istep, 0, integrator->max_realisations) {
    316       if(ATOMIC_GET(&cancel)) break;
    317 
    318       res_local = integrator->integrand
    319         (vals[ithread], dev->rngs[ithread], (unsigned)ithread, (uint64_t)i,
    320          (char*)ctx + (size_t)i*sizeof_ctx);
    321 
    322       if(res_local != RES_OK) {
    323         ++estimators[i]->nfailed;
    324         if(estimators[i]->nfailed > integrator->max_failures) {
    325           ATOMIC_SET(&cancel, 1);
    326         }
    327         break;
    328       }
    329 
    330       /* call succeded */
    331       integrator->type->add
    332         (estimators[i]->value, estimators[i]->value, vals[ithread]);
    333       integrator->type->mul
    334         (vals[ithread], vals[ithread], vals[ithread]);
    335       integrator->type->add
    336         (estimators[i]->square_value, estimators[i]->square_value, vals[ithread]);
    337       ++estimators[i]->nsamples;
    338     }
    339 
    340     n = ATOMIC_INCR(&nsolved);
    341     pcent = (int)((double)n * 100.0 / (double)count + 0.5/*round*/);
    342     #pragma omp critical
    343     if(pcent > progress) {
    344       progress = pcent;
    345       log_info(dev, "Solving: %3d%%\r", progress);
    346     }
    347     
    348   }
    349   log_info(dev, "Solving: %3d%%\n", progress);
    350 
    351 exit:
    352   if(vals) {
    353     FOR_EACH(i, 0, (int64_t)nthreads) {
    354       if(vals[i]) integrator->type->destroy(dev->allocator, vals[i]);
    355     }
    356     MEM_RM(dev->allocator, vals);
    357   }
    358   return res;
    359 error:
    360   if(estimators) {
    361     FOR_EACH(i, 0, (int64_t)count) {
    362       if(estimators[i]) {
    363         SMC(estimator_ref_put(estimators[i]));
    364         estimators[i] = NULL;
    365       }
    366     }
    367   }
    368   goto exit;
    369 }
    370 
    371 res_T
    372 smc_estimator_ref_get(struct smc_estimator* estimator)
    373 {
    374   if(!estimator) return RES_BAD_ARG;
    375   ref_get(&estimator->ref);
    376   return RES_OK;
    377 }
    378 
    379 res_T
    380 smc_estimator_ref_put(struct smc_estimator* estimator)
    381 {
    382   if(!estimator) return RES_BAD_ARG;
    383   ref_put(&estimator->ref, estimator_release);
    384   return RES_OK;
    385 }
    386 
    387 res_T
    388 smc_estimator_get_status
    389   (struct smc_estimator* estimator,
    390    struct smc_estimator_status* status)
    391 {
    392   if(!estimator || !status)
    393     return RES_BAD_ARG;
    394 
    395   if(estimator->nsamples != estimator->status.N
    396     || estimator->nfailed != estimator->status.NF)
    397   {
    398     estimator->status.N = estimator->nsamples;
    399     estimator->status.NF = estimator->nfailed;
    400 
    401     if(estimator->nsamples > 0) {
    402       /* Variance */
    403       estimator->type.divi
    404         (estimator->status.E, estimator->square_value, estimator->nsamples);
    405       estimator->type.mul
    406         (estimator->status.V, estimator->value, estimator->value);
    407       estimator->type.divi
    408         (estimator->status.V,
    409          estimator->status.V,
    410          estimator->nsamples * estimator->nsamples);
    411       estimator->type.sub
    412         (estimator->status.V, estimator->status.E, estimator->status.V);
    413       /* Standard error */
    414       estimator->type.divi
    415         (estimator->status.SE, estimator->status.V, estimator->nsamples);
    416       estimator->type.sqrt(estimator->status.SE, estimator->status.SE);
    417       /* Expected value */
    418       estimator->type.divi
    419         (estimator->status.E, estimator->value,  estimator->nsamples);
    420     }
    421     else {
    422       estimator->type.zero(estimator->status.E);
    423       estimator->type.zero(estimator->status.SE);
    424       estimator->type.zero(estimator->status.V);
    425     }
    426   }
    427   *status = estimator->status;
    428   return RES_OK;
    429 }