Source code for wepy.reporter.walker_pkl
# Standard Library
import logging
logger = logging.getLogger(__name__)
# Standard Library
import os
import os.path as osp
import pickle
# First Party Library
from wepy.reporter.reporter import Reporter
[docs]
class WalkerPklReporter(Reporter):
def __init__(self, save_dir="./", freq=100, num_backups=2):
# the directory in which to save the pickles
self.save_dir = save_dir
# the frequency of cycles to backup the walkers as a pickle
self.backup_freq = freq
# the number of most recent walker pickles to keep, this will remove the rest
self.num_backups = num_backups
[docs]
def init(self, *args, **kwargs):
# make sure the save_dir exists
if not osp.exists(self.save_dir):
os.makedirs(self.save_dir)
[docs]
def report(self, cycle_idx=None, new_walkers=None, **kwargs):
# total number of cycles completed
n_cycles = cycle_idx + 1
# if the cycle is on the frequency backup walkers to a pickle
if n_cycles % self.backup_freq == 0:
pkl_name = "walkers_cycle_{}.pkl".format(cycle_idx)
pkl_path = osp.join(self.save_dir, pkl_name)
with open(pkl_path, "wb") as wf:
pickle.dump(new_walkers, wf)
# remove old pickles if we have more than the num_backups
if (cycle_idx // self.backup_freq) >= self.num_backups:
old_idx = cycle_idx - self.num_backups * self.backup_freq
old_pkl_fname = "walkers_cycle_{}.pkl".format(old_idx)
os.remove(osp.join(self.save_dir, old_pkl_fname))