detectron2.checkpoint

class detectron2.checkpoint.Checkpointer(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any)[source]

Bases: object

一个可以保存/加载模型以及其他可检查点对象的检查点器。

__init__(model: torch.nn.Module, save_dir: str = '', *, save_to_disk: bool = True, **checkpointables: Any)None[source]
参数
  • model (nn.Module) – 模型。

  • save_dir (str) – 保存和查找检查点的目录。

  • save_to_disk (bool) – 如果为 True,将检查点保存到磁盘,否则将为该检查点器禁用保存。

  • checkpointables (object) – 任何可检查点对象,即具有 state_dict()load_state_dict() 方法的对象。例如,它可以像 Checkpointer(model, “dir”, optimizer=optimizer) 一样使用。

add_checkpointable(key: str, checkpointable: Any)None[source]

为该检查点器添加要跟踪的可检查点对象。

参数
  • key (str) – 用于保存对象的键

  • checkpointable – 任何具有 state_dict()load_state_dict() 方法的对象

save(name: str, **kwargs: Any)None[source]

将模型和可检查点对象转储到文件。

参数
  • name (str) – 文件名。

  • kwargs (dict) – 要保存的其他任意数据。

load(path: str, checkpointables: Optional[List[str]] = None) → Dict[str, Any][source]

从给定的检查点加载。

参数
  • path (str) – 检查点的路径或 URL。如果为空,将不会加载任何内容。

  • checkpointables (list) – 要加载的可检查点名称列表。如果未指定(None),将加载所有可能的可检查点对象。

返回值

dict – 从检查点加载但尚未处理的额外数据。例如,那些使用 save(**extra_data)() 保存的数据。

has_checkpoint()bool[source]
返回值

bool – 目标目录中是否存在检查点。

get_checkpoint_file()str[source]
返回值

str – 目标目录中的最新检查点文件。

get_all_checkpoint_files() → List[str][source]
返回值

list

目标目录中所有可用的检查点文件(.pth 文件)。

目录。

resume_or_load(path: str, *, resume: bool = True) → Dict[str, Any][source]

如果 resume 为 True,此方法将尝试从最后一个检查点(如果存在)恢复。否则,从给定路径加载检查点。这在重新启动中断的训练作业时很有用。

参数
  • path (str) – 检查点的路径。

  • resume (bool) – 如果为 True,则从最后一个检查点(如果存在)恢复,并加载模型以及所有可检查点对象。否则,仅加载模型而不加载任何可检查点对象。

返回值

load() 相同。

tag_last_checkpoint(last_filename_basename: str)None[source]

标记最后一个检查点。

参数

last_filename_basename (str) – 最后一个文件名的基本名称。

class detectron2.checkpoint.PeriodicCheckpointer(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')[source]

Bases: object

定期保存检查点。当调用 .step(iteration) 时,它将根据给定的 checkpointer 执行 checkpointer.save,如果迭代是周期的倍数,或者如果 max_iter 达到。

checkpointer

底层检查点对象

类型

Checkpointer

__init__(checkpointer: fvcore.common.checkpoint.Checkpointer, period: int, max_iter: Optional[int] = None, max_to_keep: Optional[int] = None, file_prefix: str = 'model')None[source]
参数
  • checkpointer – 用于保存检查点的检查点对象。

  • period (int) – 保存检查点的周期。

  • max_iter (int) – 最大迭代次数。当它到达时,将保存一个名为“{file_prefix}_final”的检查点。

  • max_to_keep (int) – 要保留的最新的最大检查点数,以前的检查点将被删除

  • file_prefix (str) – 检查点文件名名前缀

step(iteration: int, **kwargs: Any)None[source]

在给定的迭代中执行适当的操作。

参数
  • iteration (int) – 当前迭代,范围在 [0, max_iter-1] 内。

  • kwargs (Any) – 要保存的额外数据,与 Checkpointer.save() 中的相同。

save(name: str, **kwargs: Any)None[source]

Checkpointer.save() 相同的参数。使用此方法可以在计划之外手动保存检查点。

参数
  • name (str) – 文件名。

  • kwargs (Any) – 要保存的额外数据,与 Checkpointer.save() 中的相同。

class detectron2.checkpoint.DetectionCheckpointer(model, save_dir='', *, save_to_disk=None, **checkpointables)[source]

基础类: fvcore.common.checkpoint.Checkpointer

Checkpointer 相同,但能够:1. 处理 detectron 和 detectron2 模型库中的模型,并应用对旧模型的转换。 2. 正确加载仅在主工作器上可用的检查点。

load(path, *args, **kwargs)[source]