1919from lightning_utilities .test .warning import no_warning_call
2020
2121from lightning .fabric .utilities .warnings import PossibleUserWarning
22- from lightning .pytorch .callbacks import ModelCheckpoint
22+ from lightning .pytorch .callbacks import EarlyStopping , ModelCheckpoint
2323from lightning .pytorch .demos .boring_classes import BoringModel
2424from lightning .pytorch .trainer .trainer import Trainer
2525
@@ -92,7 +92,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
9292 (min_epochs/steps is satisfied).
9393
9494 """
95- model = BoringModel ()
95+
96+ class NewBoring (BoringModel ):
97+ def training_step (self , batch , batch_idx ):
98+ self .log ("loss" , self .step (batch ))
99+ return {"loss" : self .step (batch )}
100+
101+ model = NewBoring ()
102+ # create a stopping condition with a high threshold so it triggers immediately
103+ # check the condition before validation so the count is unaffected
104+ stopping = EarlyStopping (monitor = "loss" , check_on_train_epoch_end = True , stopping_threshold = 100 )
96105 trainer = Trainer (
97106 default_root_dir = tmp_path ,
98107 num_sanity_val_steps = 0 ,
@@ -103,8 +112,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
103112 min_steps = min_steps ,
104113 enable_model_summary = False ,
105114 enable_checkpointing = False ,
115+ callbacks = [stopping ],
106116 )
107- trainer .should_stop = True # Request to stop before min_epochs/min_steps are reached
108117 trainer .fit_loop .epoch_loop .val_loop .run = Mock ()
109118 trainer .fit (model )
110119 assert trainer .fit_loop .epoch_loop .val_loop .run .call_count == val_count
0 commit comments