99from subprocess import PIPE
1010
1111import numpy as np
12- import pandas as pd
1312import torch
1413
1514
@@ -146,24 +145,14 @@ def make_summary(self):
146145
147146 Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
148147 '''
149-
150- cols = ['Name' , 'Type' , 'Params' ]
151- if self .model .example_input_array is not None :
152- cols .extend (['In_sizes' , 'Out_sizes' ])
153-
154- df = pd .DataFrame (np .zeros ((len (self .layer_names ), len (cols ))))
155- df .columns = cols
156-
157- df ['Name' ] = self .layer_names
158- df ['Type' ] = self .layer_types
159- df ['Params' ] = self .param_nums
160- df ['Params' ] = df ['Params' ].map (get_human_readable_count )
161-
148+ arrays = [['Name' , self .layer_names ],
149+ ['Type' , self .layer_types ],
150+ ['Params' , list (map (get_human_readable_count , self .param_nums ))]]
162151 if self .model .example_input_array is not None :
163- df [ 'In_sizes' ] = self .in_sizes
164- df [ 'Out_sizes' ] = self .out_sizes
152+ arrays . append ([ 'In sizes' , self .in_sizes ])
153+ arrays . append ([ 'Out sizes' , self .out_sizes ])
165154
166- self .summary = df
155+ self .summary = _format_summary_table ( * arrays )
167156 return
168157
169158 def summarize (self ):
@@ -176,6 +165,51 @@ def summarize(self):
176165 self .make_summary ()
177166
178167
168+ def _format_summary_table (* cols ):
169+ '''
170+ Takes in a number of arrays, each specifying a column in
171+ the summary table, and combines them all into one big
172+ string defining the summary table that are nicely formatted.
173+ '''
174+ n_rows = len (cols [0 ][1 ])
175+ n_cols = 1 + len (cols )
176+
177+ # Layer counter
178+ counter = list (map (str , list (range (n_rows ))))
179+ counter_len = max ([len (c ) for c in counter ])
180+
181+ # Get formatting length of each column
182+ length = []
183+ for c in cols :
184+ str_l = len (c [0 ]) # default length is header length
185+ for a in c [1 ]:
186+ if isinstance (a , np .ndarray ):
187+ array_string = '[' + ', ' .join ([str (j ) for j in a ]) + ']'
188+ str_l = max (len (array_string ), str_l )
189+ else :
190+ str_l = max (len (a ), str_l )
191+ length .append (str_l )
192+
193+ # Formatting
194+ s = '{:<{}}'
195+ full_length = sum (length ) + 3 * n_cols
196+ header = [s .format (' ' , counter_len )] + [s .format (c [0 ], l ) for c , l in zip (cols , length )]
197+
198+ # Summary = header + divider + Rest of table
199+ summary = ' | ' .join (header ) + '\n ' + '-' * full_length
200+ for i in range (n_rows ):
201+ line = s .format (counter [i ], counter_len )
202+ for c , l in zip (cols , length ):
203+ if isinstance (c [1 ][i ], np .ndarray ):
204+ array_string = '[' + ', ' .join ([str (j ) for j in c [1 ][i ]]) + ']'
205+ line += ' | ' + array_string + ' ' * (l - len (array_string ))
206+ else :
207+ line += ' | ' + s .format (c [1 ][i ], l )
208+ summary += '\n ' + line
209+
210+ return summary
211+
212+
179213def print_mem_stack (): # pragma: no cover
180214 for obj in gc .get_objects ():
181215 try :
0 commit comments