Skip to content

Commit 9a6838d

Browse files
Removed dependency on pandas, instead use generic csv (#736)
* removed dependency on pandas, instead use generic csv * remove mnist files, pushed by accident * added docstring and small fixes * Update memory.py * fixed path Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent deffbab commit 9a6838d

File tree

4 files changed

+70
-30
lines changed

4 files changed

+70
-30
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import warnings
55
from abc import ABC, abstractmethod
66
from argparse import Namespace
7+
import csv
78

8-
9-
import pandas as pd
109
import torch
1110
import torch.distributed as dist
12-
#
11+
1312
from pytorch_lightning.core.decorators import data_loader
1413
from pytorch_lightning.core.grads import GradInformation
1514
from pytorch_lightning.core.hooks import ModelHooks
@@ -1217,10 +1216,12 @@ def load_hparams_from_tags_csv(tags_csv):
12171216
logging.warning(f'Missing Tags: {tags_csv}.')
12181217
return Namespace()
12191218

1220-
tags_df = pd.read_csv(tags_csv)
1221-
dic = tags_df.to_dict(orient='records')
1222-
ns_dict = {row['key']: convert(row['value']) for row in dic}
1223-
ns = Namespace(**ns_dict)
1219+
tags = {}
1220+
with open(tags_csv) as f:
1221+
csv_reader = csv.reader(f, delimiter=',')
1222+
for row in list(csv_reader)[1:]:
1223+
tags[row[0]] = convert(row[1])
1224+
ns = Namespace(**tags)
12241225
return ns
12251226

12261227

pytorch_lightning/core/memory.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from subprocess import PIPE
1010

1111
import numpy as np
12-
import pandas as pd
1312
import torch
1413

1514

@@ -146,24 +145,14 @@ def make_summary(self):
146145
147146
Layer Name, Layer Type, Input Size, Output Size, Number of Parameters
148147
'''
149-
150-
cols = ['Name', 'Type', 'Params']
151-
if self.model.example_input_array is not None:
152-
cols.extend(['In_sizes', 'Out_sizes'])
153-
154-
df = pd.DataFrame(np.zeros((len(self.layer_names), len(cols))))
155-
df.columns = cols
156-
157-
df['Name'] = self.layer_names
158-
df['Type'] = self.layer_types
159-
df['Params'] = self.param_nums
160-
df['Params'] = df['Params'].map(get_human_readable_count)
161-
148+
arrays = [['Name', self.layer_names],
149+
['Type', self.layer_types],
150+
['Params', list(map(get_human_readable_count, self.param_nums))]]
162151
if self.model.example_input_array is not None:
163-
df['In_sizes'] = self.in_sizes
164-
df['Out_sizes'] = self.out_sizes
152+
arrays.append(['In sizes', self.in_sizes])
153+
arrays.append(['Out sizes', self.out_sizes])
165154

166-
self.summary = df
155+
self.summary = _format_summary_table(*arrays)
167156
return
168157

169158
def summarize(self):
@@ -176,6 +165,51 @@ def summarize(self):
176165
self.make_summary()
177166

178167

168+
def _format_summary_table(*cols):
169+
'''
170+
Takes in a number of arrays, each specifying a column in
171+
the summary table, and combines them all into one big
172+
string defining the summary table that are nicely formatted.
173+
'''
174+
n_rows = len(cols[0][1])
175+
n_cols = 1 + len(cols)
176+
177+
# Layer counter
178+
counter = list(map(str, list(range(n_rows))))
179+
counter_len = max([len(c) for c in counter])
180+
181+
# Get formatting length of each column
182+
length = []
183+
for c in cols:
184+
str_l = len(c[0]) # default length is header length
185+
for a in c[1]:
186+
if isinstance(a, np.ndarray):
187+
array_string = '[' + ', '.join([str(j) for j in a]) + ']'
188+
str_l = max(len(array_string), str_l)
189+
else:
190+
str_l = max(len(a), str_l)
191+
length.append(str_l)
192+
193+
# Formatting
194+
s = '{:<{}}'
195+
full_length = sum(length) + 3 * n_cols
196+
header = [s.format(' ', counter_len)] + [s.format(c[0], l) for c, l in zip(cols, length)]
197+
198+
# Summary = header + divider + Rest of table
199+
summary = ' | '.join(header) + '\n' + '-' * full_length
200+
for i in range(n_rows):
201+
line = s.format(counter[i], counter_len)
202+
for c, l in zip(cols, length):
203+
if isinstance(c[1][i], np.ndarray):
204+
array_string = '[' + ', '.join([str(j) for j in c[1][i]]) + ']'
205+
line += ' | ' + array_string + ' ' * (l - len(array_string))
206+
else:
207+
line += ' | ' + s.format(c[1][i], l)
208+
summary += '\n' + line
209+
210+
return summary
211+
212+
179213
def print_mem_stack(): # pragma: no cover
180214
for obj in gc.get_objects():
181215
try:

pytorch_lightning/logging/tensorboard.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pkg_resources import parse_version
55

66
import torch
7-
import pandas as pd
7+
import csv
88
from torch.utils.tensorboard import SummaryWriter
99

1010
from .base import LightningLoggerBase, rank_zero_only
@@ -108,12 +108,17 @@ def save(self):
108108
dir_path = os.path.join(self.save_dir, self.name, 'version_%s' % self.version)
109109
if not os.path.isdir(dir_path):
110110
dir_path = self.save_dir
111+
111112
# prepare the file path
112113
meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS)
114+
113115
# save the metatags file
114-
df = pd.DataFrame({'key': list(self.tags.keys()),
115-
'value': list(self.tags.values())})
116-
df.to_csv(meta_tags_path, index=False)
116+
with open(meta_tags_path, 'w', newline='') as csvfile:
117+
fieldnames = ['key', 'value']
118+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
119+
writer.writerow({'key': 'key', 'value': 'value'})
120+
for k, v in self.tags.items():
121+
writer.writerow({'key': k, 'value': v})
117122

118123
@rank_zero_only
119124
def finalize(self, status):

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ tqdm>=4.35.0
33
numpy>=1.16.4
44
torch>=1.1
55
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
6-
pandas>=0.24 # lower version do not support py3.7
76
tensorboard>=1.14
8-
future>=0.17.1 # required for builtins in setup.py
7+
future>=0.17.1 # required for builtins in setup.py
8+

0 commit comments

Comments
 (0)