Skip to content

Commit 0c21547

Browse files
committed
feat: add dump_stats to advanced profiler (#19698)
1 parent 8b378f0 commit 0c21547

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))
1818

19+
- Added `dump_stats` flag to `AdvancedProfiler` ([#19698](https://github.com/Lightning-AI/pytorch-lightning/issues/19698))
20+
1921
-
2022

2123
### Changed

src/lightning/pytorch/profilers/advanced.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import cProfile
1717
import io
1818
import logging
19+
import os
1920
import pstats
21+
import tempfile
2022
from pathlib import Path
2123
from typing import Dict, Optional, Tuple, Union
2224

2325
from typing_extensions import override
2426

27+
from lightning.fabric.utilities.cloud_io import get_filesystem
2528
from lightning.pytorch.profilers.profiler import Profiler
2629

2730
log = logging.getLogger(__name__)
@@ -40,6 +43,7 @@ def __init__(
4043
dirpath: Optional[Union[str, Path]] = None,
4144
filename: Optional[str] = None,
4245
line_count_restriction: float = 1.0,
46+
dump_stats: bool = False,
4347
) -> None:
4448
"""
4549
Args:
@@ -54,13 +58,17 @@ def __init__(
5458
reported for each action. either an integer (to select a count of lines),
5559
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
5660
61+
dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
62+
5763
Raises:
5864
ValueError:
5965
If you attempt to stop recording an action which was never started.
6066
"""
6167
super().__init__(dirpath=dirpath, filename=filename)
6268
self.profiled_actions: Dict[str, cProfile.Profile] = {}
6369
self.line_count_restriction = line_count_restriction
70+
self.dump_stats = dump_stats
71+
assert not self.dump_stats or self.dirpath is not None, "dirname must be provided for dump_states to work"
6472

6573
@override
6674
def start(self, action_name: str) -> None:
@@ -75,10 +83,28 @@ def stop(self, action_name: str) -> None:
7583
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
7684
pr.disable()
7785

86+
def _maybe_dump_stats(self, action_name: str, pr: cProfile.Profile) -> None:
87+
if not self.dump_stats:
88+
return
89+
assert self.dirpath # redundant, but needed for mypy
90+
dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
91+
dst_fs = get_filesystem(dst_filepath)
92+
dst_fs.mkdirs(self.dirpath, exist_ok=True)
93+
# temporarily save to local since pstats can only dump into a local file
94+
with tempfile.TemporaryDirectory(prefix="test", suffix="test", dir=os.getcwd()) as tmp_dir, dst_fs.open(
95+
dst_filepath, "wb"
96+
) as dst_file:
97+
src_filepath = os.path.join(tmp_dir, "tmp.prof")
98+
pr.dump_stats(src_filepath)
99+
src_fs = get_filesystem(src_filepath)
100+
with src_fs.open(src_filepath, "rb") as src_file:
101+
dst_file.write(src_file.read())
102+
78103
@override
79104
def summary(self) -> str:
80105
recorded_stats = {}
81106
for action_name, pr in self.profiled_actions.items():
107+
self._maybe_dump_stats(action_name, pr)
82108
s = io.StringIO()
83109
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
84110
ps.print_stats(self.line_count_restriction)

tests/tests_pytorch/profilers/test_profiler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,25 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
299299
assert len(data) > 0
300300

301301

302+
def test_advanced_profiler_dump_states(tmp_path):
303+
advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
304+
"""Ensure the profiler dump stats during summary."""
305+
# record at least one event
306+
with advanced_profiler.profile(action_name := "test"):
307+
pass
308+
# dump_stats to file
309+
advanced_profiler.describe()
310+
path = advanced_profiler.dirpath / f"{action_name}.prof"
311+
data = path.read_bytes()
312+
assert len(data) > 0
313+
314+
315+
def test_advanced_profiler_dump_states_needs_dirpath():
316+
"""Ensure the profiler requires dirpath to dump stats."""
317+
with pytest.raises(AssertionError):
318+
AdvancedProfiler(dump_stats=True)
319+
320+
302321
def test_advanced_profiler_value_errors(advanced_profiler):
303322
"""Ensure errors are raised where expected."""
304323
action = "test"

0 commit comments

Comments
 (0)