1616import cProfile
1717import io
1818import logging
19+ import os
1920import pstats
21+ import tempfile
2022from pathlib import Path
2123from typing import Dict , Optional , Tuple , Union
2224
2325from typing_extensions import override
2426
27+ from lightning .fabric .utilities .cloud_io import get_filesystem
2528from lightning .pytorch .profilers .profiler import Profiler
2629
2730log = logging .getLogger (__name__ )
@@ -40,6 +43,7 @@ def __init__(
4043 dirpath : Optional [Union [str , Path ]] = None ,
4144 filename : Optional [str ] = None ,
4245 line_count_restriction : float = 1.0 ,
46+ dump_stats : bool = False ,
4347 ) -> None :
4448 """
4549 Args:
@@ -54,13 +58,17 @@ def __init__(
5458 reported for each action. either an integer (to select a count of lines),
5559 or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
5660
61+ dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
62+
5763 Raises:
5864 ValueError:
5965 If you attempt to stop recording an action which was never started.
6066 """
6167 super ().__init__ (dirpath = dirpath , filename = filename )
6268 self .profiled_actions : Dict [str , cProfile .Profile ] = {}
6369 self .line_count_restriction = line_count_restriction
70+ self .dump_stats = dump_stats
71+ assert not self .dump_stats or self .dirpath is not None , "dirname must be provided for dump_states to work"
6472
6573 @override
6674 def start (self , action_name : str ) -> None :
@@ -75,10 +83,28 @@ def stop(self, action_name: str) -> None:
7583 raise ValueError (f"Attempting to stop recording an action ({ action_name } ) which was never started." )
7684 pr .disable ()
7785
86+ def _maybe_dump_stats (self , action_name : str , pr : cProfile .Profile ) -> None :
87+ if not self .dump_stats :
88+ return
89+ assert self .dirpath # redundant, but needed for mypy
90+ dst_filepath = os .path .join (self .dirpath , self ._prepare_filename (action_name = action_name , extension = ".prof" ))
91+ dst_fs = get_filesystem (dst_filepath )
92+ dst_fs .mkdirs (self .dirpath , exist_ok = True )
93+ # temporarily save to local since pstats can only dump into a local file
94+ with tempfile .TemporaryDirectory (prefix = "test" , suffix = "test" , dir = os .getcwd ()) as tmp_dir , dst_fs .open (
95+ dst_filepath , "wb"
96+ ) as dst_file :
97+ src_filepath = os .path .join (tmp_dir , "tmp.prof" )
98+ pr .dump_stats (src_filepath )
99+ src_fs = get_filesystem (src_filepath )
100+ with src_fs .open (src_filepath , "rb" ) as src_file :
101+ dst_file .write (src_file .read ())
102+
78103 @override
79104 def summary (self ) -> str :
80105 recorded_stats = {}
81106 for action_name , pr in self .profiled_actions .items ():
107+ self ._maybe_dump_stats (action_name , pr )
82108 s = io .StringIO ()
83109 ps = pstats .Stats (pr , stream = s ).strip_dirs ().sort_stats ("cumulative" )
84110 ps .print_stats (self .line_count_restriction )
0 commit comments