utils.io
File and user I/O.
- class utils.io.DataAccumulatorHook(component: Module, log_dir: str, attributes: list[str], entries: int)
Wraps any
Module
object, accumulating selected attribute data on every forward call.If used with a
Component
, selectable attributes are those included in the dictionary returned by the Module’s forward method, and are a subset of the object’s _loggable_props_.- Parameters:
component (torch.nn.Module) – Reference to the network component object whose loggable buffered attributes should be dumped to disk.
log_dir (str) – Path to the desired data loggable directory.
attributes (list of str) – Instance attribute names to be logged. Defaults to the full list component.loggable_props.
entries (int) – Number of log entries, equal to the number of simulation steps. Used to ensure the last batch in the run is written to disk despite its size being smaller than
CHUNK_SIZE
.
- forward(data: Tensor) None
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- initialize_hdf() None
Initializes HDF5 file on the first iteration of this data accumulator instance.
- utils.io.ensure_dir(path: str = None) str
Ensures that a path is created if parts thereof do not exist yet and returns it.
- Parameters:
path (str) – File path to be verified or created.
- Returns:
path – The same file path, so it can be used in the calling context (e.g., if
ensure_dir()
wraps an assignment).- Return type:
str
- utils.io.flatten(lst: list[list]) list
Flattens a list.
- utils.io.load_yaml(path: str) dict
Parses the YAML at path, returning a dictionary.
- utils.io.log_settings(level: int = 20, stream: ~_io.TextIOWrapper = <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, file: str = None, formatting: str = '%(asctime)s [%(levelname)s] %(message)s')
Temporary basic logger configuration.
- Parameters:
level (int) – Logger level, INFO by default.
stream – Destination stream, stdout by default.
file – Destination file, None by default.
formatting (str) – Message formatting.
- utils.io.plot_tensorboard(configuration: dict, data_dir: str, tb_dir: str)
Loads simulation output data from disk and writes it to tensorboard for visual inspection.
- utils.io.save_yaml(data: dict, path: str)
Saves a dictionary data to the YAML file specified by path.