utils.io

File and user I/O.

class utils.io.DataAccumulatorHook(component: Module, log_dir: str, attributes: list[str], entries: int)

Wraps any Module object, accumulating selected attribute data on every forward call.

If used with a Component, selectable attributes are those included in the dictionary returned by the Module’s forward method, and are a subset of the object’s _loggable_props_.

Parameters:
  • component (torch.nn.Module) – Reference to the network component object whose loggable buffered attributes should be dumped to disk.

  • log_dir (str) – Path to the desired data loggable directory.

  • attributes (list of str) – Instance attribute names to be logged. Defaults to the full list component.loggable_props.

  • entries (int) – Number of log entries, equal to the number of simulation steps. Used to ensure the last batch in the run is written to disk despite its size being smaller than CHUNK_SIZE.

forward(data: Tensor) None

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

initialize_hdf() None

Initializes HDF5 file on the first iteration of this data accumulator instance.

utils.io.ensure_dir(path: str = None) str

Ensures that a path is created if parts thereof do not exist yet and returns it.

Parameters:

path (str) – File path to be verified or created.

Returns:

path – The same file path, so it can be used in the calling context (e.g., if ensure_dir() wraps an assignment).

Return type:

str

utils.io.flatten(lst: list[list]) list

Flattens a list.

utils.io.load_yaml(path: str) dict

Parses the YAML at path, returning a dictionary.

utils.io.log_settings(level: int = 20, stream: ~_io.TextIOWrapper = <_io.TextIOWrapper name='<stdout>' mode='w' encoding='utf-8'>, file: str = None, formatting: str = '%(asctime)s [%(levelname)s] %(message)s')

Temporary basic logger configuration.

Parameters:
  • level (int) – Logger level, INFO by default.

  • stream – Destination stream, stdout by default.

  • file – Destination file, None by default.

  • formatting (str) – Message formatting.

utils.io.plot_tensorboard(configuration: dict, data_dir: str, tb_dir: str)

Loads simulation output data from disk and writes it to tensorboard for visual inspection.

utils.io.save_yaml(data: dict, path: str)

Saves a dictionary data to the YAML file specified by path.