@@ -864,12 +864,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864864 raise RuntimeError ("Trouble!" )
865865
866866 model = BoringModel ()
867- epoch_length = 64
867+ epoch_length = 2
868868 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
869869 trainer = Trainer (
870870 default_root_dir = tmp_path ,
871871 callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872872 max_epochs = 5 ,
873+ limit_train_batches = epoch_length ,
873874 logger = False ,
874875 enable_progress_bar = False ,
875876 )
@@ -887,12 +888,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887888 raise RuntimeError ("Trouble!" )
888889
889890 model = BoringModel ()
890- epoch_length = 64
891+ epoch_length = 2
891892 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
892893 trainer = Trainer (
893894 default_root_dir = tmp_path ,
894895 callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895896 max_epochs = 5 ,
897+ limit_train_batches = epoch_length ,
896898 logger = False ,
897899 enable_progress_bar = False ,
898900 )
@@ -956,12 +958,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956958 raise RuntimeError ("Trouble!" )
957959
958960 model = BoringModel ()
959- epoch_length = 64
961+ epoch_length = 2
960962 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
961963 trainer = Trainer (
962964 default_root_dir = tmp_path ,
963965 callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964966 max_epochs = 5 ,
967+ limit_train_batches = epoch_length ,
965968 logger = False ,
966969 enable_progress_bar = False ,
967970 )
@@ -979,12 +982,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979982 raise RuntimeError ("Trouble!" )
980983
981984 model = BoringModel ()
982- epoch_length = 64
985+ epoch_length = 2
983986 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
984987 trainer = Trainer (
985988 default_root_dir = tmp_path ,
986989 callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987990 max_epochs = 5 ,
991+ limit_train_batches = epoch_length ,
988992 logger = False ,
989993 enable_progress_bar = False ,
990994 )
@@ -1002,12 +1006,13 @@ def on_validation_start(self, trainer, pl_module):
10021006 raise RuntimeError ("Trouble!" )
10031007
10041008 model = BoringModel ()
1005- epoch_length = 64
1009+ epoch_length = 2
10061010 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
10071011 trainer = Trainer (
10081012 default_root_dir = tmp_path ,
10091013 callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
10101014 max_epochs = 5 ,
1015+ limit_train_batches = epoch_length ,
10111016 logger = False ,
10121017 enable_progress_bar = False ,
10131018 )
@@ -1025,12 +1030,13 @@ def on_validation_end(self, trainer, pl_module):
10251030 raise RuntimeError ("Trouble!" )
10261031
10271032 model = BoringModel ()
1028- epoch_length = 64
1033+ epoch_length = 2
10291034 checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
10301035 trainer = Trainer (
10311036 default_root_dir = tmp_path ,
10321037 callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
10331038 max_epochs = 5 ,
1039+ limit_train_batches = epoch_length ,
10341040 logger = False ,
10351041 enable_progress_bar = False ,
10361042 )
0 commit comments