Unverified 提交 d0fa0042 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

New `@try_export` decorator (#9096)

* New export decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * New export decorator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename fcn to func * rename to @try_export Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 eab35f66
差异被折叠。
......@@ -148,6 +148,7 @@ class Profile(contextlib.ContextDecorator):
def __enter__(self):
self.start = self.time()
return self
def __exit__(self, type, value, traceback):
self.dt = self.time() - self.start # delta-time
......@@ -220,10 +221,10 @@ def methods(instance):
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, fcn, _, _ = inspect.getframeinfo(x)
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
......@@ -231,7 +232,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
except ValueError:
file = Path(file).stem
s = (f'{file}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
......@@ -255,7 +256,13 @@ def init_seeds(seed=0, deterministic=False):
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def get_latest_run(search_dir='.'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论