#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import collections
import functools
import timeit
import traceback
import weakref
from json import dumps as json_dumps
from typing import List
import numpy as np
from .ascii_table import ascii_table
[docs]class SnapshotList(list):
"""lists are not weakref-able by default."""
[docs]class Timings:
sum: float
count: int
max: float
times: List
def __init__(
self,
sum: float = 0.0,
count: int = 0,
max: float = -float("inf"),
times: List = None,
):
self.sum = sum
self.count = count
self.max = max
self.times = [] if times is None else times
@property
def average(self):
return self.sum / (self.count or 1)
@property
def p50(self):
return np.percentile(self.times, 50)
@property
def p90(self):
return np.percentile(self.times, 90)
@property
def p99(self):
return np.percentile(self.times, 99)
[docs] def add(self, time):
self.times.append(time)
self.sum += time
self.count += 1
self.max = max(self.max, time)
SECONDS_IN_MINUTE = 60
SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
[docs]class Snapshot:
def __init__(self):
self.times = collections.defaultdict(Timings)
self.start = timeit.default_timer()
[docs] def report(self, report_pep=False):
snapshot_total = timeit.default_timer() - self.start
def path(key):
return " -> ".join(label for label, _ in key)
def print_pep(results, snapshot_total):
for key, times in sorted(self.times.items()):
if path(key) == "evaluate -> pytorch eval once":
info = {
"type": path(key),
"metric": "latency",
"unit": "ms",
"value": f"{times.average * 1000:.1f}",
}
print("PyTorchObserver " + json_dumps(info))
if len(self.times) == 0:
print(
"Note: Nothing was reported. "
'Please use timing.time("foo") to measure time.'
)
return
results = [
{
"name": path(key),
"total": format_time(times.sum),
"avg": format_time(times.average),
"max": format_time(times.max),
"p50": format_time(times.p50),
"p90": format_time(times.p90),
"p99": format_time(times.p99),
"count": times.count,
}
for key, times in sorted(self.times.items())
]
print(
ascii_table(
results,
human_column_names={
"name": "Stage",
"total": "Total",
"avg": "Average",
"max": "Max",
"p50": "P50",
"p90": "P90",
"p99": "P99",
"count": "Count",
},
footer={"name": "Total time", "total": format_time(snapshot_total)},
alignments={"name": "<"},
)
)
if report_pep:
print_pep(results, snapshot_total)
[docs]class HierarchicalTimer:
def __init__(self):
self.current_stack = []
self.all_snapshots = SnapshotList()
[docs] def snapshot(self):
snapshot = Snapshot()
self.all_snapshots.append(weakref.ref(snapshot))
return snapshot
def _clean_snapshots(self):
self.all_snapshots = [ref for ref in self.all_snapshots if ref() is not None]
[docs] def push(self, label, caller_id):
self.current_stack.append((label, caller_id, timeit.default_timer()))
[docs] def pop(self):
label, _, start_time = self.current_stack[-1]
key = tuple((label, caller) for label, caller, _ in self.current_stack)
delta = timeit.default_timer() - start_time
for ref in self.all_snapshots:
snapshot = ref()
if snapshot is not None:
snapshot.times[key].add(delta)
self.current_stack.pop()
# Need to put this somewhere
self._clean_snapshots()
[docs] def time(self, label):
return _TimerContextManager(label, self)
class _TimerContextManager:
def __init__(self, label, timer, caller_id=None):
self.label = label
self.timer = timer
self.caller_id = caller_id
def __enter__(self):
if self.caller_id:
caller_id = self.caller_id
else:
stack = traceback.extract_stack()
caller = stack[-2]
caller_id = (caller.filename, caller.line)
self.timer.push(self.label, caller_id)
def __exit__(self, *exception_info):
self.timer.pop()
def __call__(self, fn):
"""Decorator syntax"""
caller_id = (fn.__code__.co_filename, fn.__code__.co_firstlineno)
timer_context = _TimerContextManager(self.label, self.timer, caller_id)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with timer_context:
return fn(*args, **kwargs)
return wrapper
TIMER = HierarchicalTimer()
time = TIMER.time
snapshot = TIMER.snapshot
SNAPSHOT = TIMER.snapshot()
report = SNAPSHOT.report
[docs]def report_snapshot(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
snapshot = TIMER.snapshot()
result = fn(*args, **kwargs)
snapshot.report()
return result
return wrapper