跳转至

Solver(求解器) 模块

ppsci.solver

Solver

Class for solver.

Parameters:

Name Type Description Default
model Layer

Model.

required
constraint Optional[Dict[str, Constraint]]

Constraint(s) applied on model. Defaults to None.

None
output_dir Optional[str]

Output directory. Defaults to "./output/".

'./output/'
optimizer Optional[Optimizer]

Optimizer object. Defaults to None.

None
lr_scheduler Optional[LRScheduler]

Learning rate scheduler. Defaults to None.

None
epochs int

Training epoch(s). Defaults to 5.

5
iters_per_epoch int

Number of iterations within an epoch. Defaults to 20.

20
update_freq int

Update frequency of parameters. Defaults to 1.

1
save_freq int

Saving frequency for checkpoint. Defaults to 0.

0
log_freq int

Logging frequency. Defaults to 10.

10
eval_during_train bool

Whether evaluate model during training. Defaults to False.

False
start_eval_epoch int

Epoch number evaluation applied begin after. Defaults to 1.

1
eval_freq int

Evaluation frequency. Defaults to 1.

1
seed int

Random seed. Defaults to 42.

42
use_vdl Optional[bool]

Whether use VisualDL to log scalars. Defaults to False.

False
use_wandb Optional[bool]

Whether use wandb to log data. Defaults to False.

False
wandb_config Optional[Dict[str, str]]

Config dict of WandB. Defaults to None.

None
device Literal['cpu', 'gpu', 'xpu']

Runtime device. Defaults to "gpu".

'gpu'
equation Optional[Dict[str, PDE]]

Equation dict. Defaults to None.

None
geom Optional[Dict[str, Geometry]]

Geometry dict. Defaults to None.

None
validator Optional[Dict[str, Validator]]

Validator dict. Defaults to None.

None
visualizer Optional[Dict[str, Visualizer]]

Visualizer dict. Defaults to None.

None
use_amp bool

Whether use AMP. Defaults to False.

False
amp_level Literal['O1', 'O2', 'O0']

AMP level. Defaults to "O0".

'O0'
pretrained_model_path Optional[str]

Pretrained model path. Defaults to None.

None
checkpoint_path Optional[str]

Checkpoint path. Defaults to None.

None
compute_metric_by_batch bool

Whether calculate metrics after each batch during evaluation. Defaults to False.

False
eval_with_no_grad bool

Whether set stop_gradient=True for every Tensor if no differentiation involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.

False
to_static bool

Whether enable to_static for forward pass. Defaults to False.

False
loss_aggregator Optional[LossAggregator]

Loss aggregator, such as a multi-task learning loss aggregator. Defaults to None.

None

Examples:

>>> import ppsci
>>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
>>> opt = ppsci.optimizer.AdamW(1e-3)((model,))
>>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1))
>>> pde_constraint = ppsci.constraint.InteriorConstraint(
...     {"u": lambda out: out["u"]},
...     {"u": 0},
...     geom,
...     {
...         "dataset": "IterableNamedArrayDataset",
...         "iters_per_epoch": 1,
...         "batch_size": 16,
...     },
...     ppsci.loss.MSELoss("mean"),
...     name="EQ",
... )
>>> solver = ppsci.solver.Solver(
...     model,
...     {"EQ": pde_constraint},
...     "./output",
...     opt,
...     None,
... )
Source code in ppsci/solver/solver.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
class Solver:
    """Class for solver.

    Args:
        model (nn.Layer): Model.
        constraint (Optional[Dict[str, ppsci.constraint.Constraint]]): Constraint(s) applied on model. Defaults to None.
        output_dir (Optional[str]): Output directory. Defaults to "./output/".
        optimizer (Optional[optimizer.Optimizer]): Optimizer object. Defaults to None.
        lr_scheduler (Optional[optimizer.lr.LRScheduler]): Learning rate scheduler. Defaults to None.
        epochs (int, optional): Training epoch(s). Defaults to 5.
        iters_per_epoch (int, optional): Number of iterations within an epoch. Defaults to 20.
        update_freq (int, optional): Update frequency of parameters. Defaults to 1.
        save_freq (int, optional): Saving frequency for checkpoint. Defaults to 0.
        log_freq (int, optional): Logging frequency. Defaults to 10.
        eval_during_train (bool, optional): Whether evaluate model during training. Defaults to False.
        start_eval_epoch (int, optional): Epoch number evaluation applied begin after. Defaults to 1.
        eval_freq (int, optional): Evaluation frequency. Defaults to 1.
        seed (int, optional): Random seed. Defaults to 42.
        use_vdl (Optional[bool]): Whether use VisualDL to log scalars. Defaults to False.
        use_wandb (Optional[bool]): Whether use wandb to log data. Defaults to False.
        wandb_config (Optional[Dict[str, str]]): Config dict of WandB. Defaults to None.
        device (Literal["cpu", "gpu", "xpu"], optional): Runtime device. Defaults to "gpu".
        equation (Optional[Dict[str, ppsci.equation.PDE]]): Equation dict. Defaults to None.
        geom (Optional[Dict[str, ppsci.geometry.Geometry]]): Geometry dict. Defaults to None.
        validator (Optional[Dict[str, ppsci.validate.Validator]]): Validator dict. Defaults to None.
        visualizer (Optional[Dict[str, ppsci.visualize.Visualizer]]): Visualizer dict. Defaults to None.
        use_amp (bool, optional): Whether use AMP. Defaults to False.
        amp_level (Literal["O1", "O2", "O0"], optional): AMP level. Defaults to "O0".
        pretrained_model_path (Optional[str]): Pretrained model path. Defaults to None.
        checkpoint_path (Optional[str]): Checkpoint path. Defaults to None.
        compute_metric_by_batch (bool, optional): Whether calculate metrics after each batch during evaluation. Defaults to False.
        eval_with_no_grad (bool, optional): Whether set `stop_gradient=True` for every Tensor if no differentiation
            involved during computation, generally for save GPU memory and accelerate computing. Defaults to False.
        to_static (bool, optional): Whether enable to_static for forward pass. Defaults to False.
        loss_aggregator (Optional[mtl.LossAggregator]): Loss aggregator, such as a multi-task learning loss aggregator. Defaults to None.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
        >>> opt = ppsci.optimizer.AdamW(1e-3)((model,))
        >>> geom = ppsci.geometry.Rectangle((0, 0), (1, 1))
        >>> pde_constraint = ppsci.constraint.InteriorConstraint(
        ...     {"u": lambda out: out["u"]},
        ...     {"u": 0},
        ...     geom,
        ...     {
        ...         "dataset": "IterableNamedArrayDataset",
        ...         "iters_per_epoch": 1,
        ...         "batch_size": 16,
        ...     },
        ...     ppsci.loss.MSELoss("mean"),
        ...     name="EQ",
        ... )
        >>> solver = ppsci.solver.Solver(
        ...     model,
        ...     {"EQ": pde_constraint},
        ...     "./output",
        ...     opt,
        ...     None,
        ... )  # doctest: +SKIP
    """

    def __init__(
        self,
        model: nn.Layer,
        constraint: Optional[Dict[str, ppsci.constraint.Constraint]] = None,
        output_dir: Optional[str] = "./output/",
        optimizer: Optional[optim.Optimizer] = None,
        lr_scheduler: Optional[optim.lr.LRScheduler] = None,
        epochs: int = 5,
        iters_per_epoch: int = 20,
        update_freq: int = 1,
        save_freq: int = 0,
        log_freq: int = 10,
        eval_during_train: bool = False,
        start_eval_epoch: int = 1,
        eval_freq: int = 1,
        seed: int = 42,
        use_vdl: bool = False,
        use_wandb: bool = False,
        wandb_config: Optional[Mapping] = None,
        device: Literal["cpu", "gpu", "xpu"] = "gpu",
        equation: Optional[Dict[str, ppsci.equation.PDE]] = None,
        geom: Optional[Dict[str, ppsci.geometry.Geometry]] = None,
        validator: Optional[Dict[str, ppsci.validate.Validator]] = None,
        visualizer: Optional[Dict[str, ppsci.visualize.Visualizer]] = None,
        use_amp: bool = False,
        amp_level: Literal["O1", "O2", "O0"] = "O0",
        pretrained_model_path: Optional[str] = None,
        checkpoint_path: Optional[str] = None,
        compute_metric_by_batch: bool = False,
        eval_with_no_grad: bool = False,
        to_static: bool = False,
        loss_aggregator: Optional[mtl.LossAggregator] = None,
    ):
        # set model
        self.model = model
        # set constraint
        self.constraint = constraint
        # set output directory
        self.output_dir = output_dir

        # set optimizer
        self.optimizer = optimizer
        # set learning rate scheduler
        self.lr_scheduler = lr_scheduler

        # set training hyper-parameter
        self.epochs = epochs
        self.iters_per_epoch = iters_per_epoch
        # set update_freq for gradient accumulation
        self.update_freq = update_freq
        # set checkpoint saving frequency
        self.save_freq = save_freq
        # set logging frequency
        self.log_freq = log_freq

        # set evaluation hyper-parameter
        self.eval_during_train = eval_during_train
        self.start_eval_epoch = start_eval_epoch
        self.eval_freq = eval_freq

        # initialize training log recorder for loss, time cost, metric, etc.
        self.train_output_info: Dict[str, misc.AverageMeter] = {}
        self.train_time_info = {
            "batch_cost": misc.AverageMeter("batch_cost", ".5f", postfix="s"),
            "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"),
        }
        self.train_loss_info: Dict[str, misc.AverageMeter] = {}

        # initialize evaluation log recorder for loss, time cost, metric, etc.
        self.eval_output_info: Dict[str, misc.AverageMeter] = {}
        self.eval_time_info = {
            "batch_cost": misc.AverageMeter("batch_cost", ".5f", postfix="s"),
            "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"),
        }

        # fix seed for reproducibility
        self.seed = seed

        # set running device
        if device != "cpu" and paddle.device.get_device() == "cpu":
            logger.warning(f"Set device({device}) to 'cpu' for only cpu available.")
            device = "cpu"
        self.device = paddle.set_device(device)

        # set equations for physics-driven or data-physics hybrid driven task, such as PINN
        self.equation = equation

        # set geometry for generating data
        self.geom = {} if geom is None else geom

        # set validator
        self.validator = validator

        # set visualizer
        self.visualizer = visualizer

        # set automatic mixed precision(AMP) configuration
        self.use_amp = use_amp
        self.amp_level = amp_level
        self.scaler = amp.GradScaler(True) if self.use_amp else None

        # whether calculate metrics by each batch during evaluation, mainly for memory efficiency
        self.compute_metric_by_batch = compute_metric_by_batch
        if validator is not None:
            for metric in itertools.chain(
                *[_v.metric.values() for _v in self.validator.values()]
            ):
                if metric.keep_batch ^ compute_metric_by_batch:
                    raise ValueError(
                        f"{misc.typename(metric)}.keep_batch should be "
                        f"{compute_metric_by_batch} when compute_metric_by_batch="
                        f"{compute_metric_by_batch}."
                    )
        # whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation
        self.eval_with_no_grad = eval_with_no_grad

        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        # initialize distributed environment
        if self.world_size > 1:
            # TODO(sensen): support different kind of DistributedStrategy
            fleet.init(is_collective=True)
            logger.warning(
                f"Detected 'world_size'({self.world_size}) > 1, it is recommended to "
                "scale up the learning rate and reduce the 'epochs' or "
                "'iters_per_epoch' according to the 'world_size' both linearly if you "
                "are training model."
            )

        # load pretrained model, usually used for transfer learning
        if pretrained_model_path is not None:
            save_load.load_pretrain(self.model, pretrained_model_path, self.equation)

        # initialize an dict for tracking best metric during training
        self.best_metric = {
            "metric": float("inf"),
            "epoch": 0,
        }
        # load model checkpoint, usually used for resume training
        if checkpoint_path is not None:
            if pretrained_model_path is not None:
                logger.warning(
                    "Detected 'pretrained_model_path' is given, weights in which might be"
                    "overridden by weights loaded from given 'checkpoint_path'."
                )
            loaded_metric = save_load.load_checkpoint(
                checkpoint_path, self.model, self.optimizer, self.scaler, self.equation
            )
            if isinstance(loaded_metric, dict):
                self.best_metric.update(loaded_metric)

        # decorate model(s) and optimizer(s) for AMP
        if self.use_amp:
            self.model, self.optimizer = amp.decorate(
                self.model,
                self.optimizer,
                self.amp_level,
                save_dtype="float32",
            )

        # choosing an appropriate training function for different optimizers
        if isinstance(self.optimizer, optim.LBFGS):
            self.train_epoch_func = ppsci.solver.train.train_LBFGS_epoch_func
            if self.update_freq != 1:
                self.update_freq = 1
                logger.warning("Set 'update_freq' to to 1 when using L-BFGS optimizer.")
        else:
            self.train_epoch_func = ppsci.solver.train.train_epoch_func

        # wrap model and optimizer to parallel object
        if self.world_size > 1:
            if isinstance(self.model, paddle.DataParallel):
                raise ValueError(
                    "Given model is already wrapped by paddle.DataParallel."
                    "Please do not wrap your model with DataParallel "
                    "before 'Solver.__init__' and keep it's type as 'nn.Layer'."
                )
            self.model = fleet.distributed_model(self.model)
            if hasattr(self.model, "input_keys"):
                self.model.input_keys = self.model._layers.input_keys
            if hasattr(self.model, "output_keys"):
                self.model.output_keys = self.model._layers.output_keys
            if self.optimizer is not None:
                self.optimizer = fleet.distributed_optimizer(self.optimizer)

        # set VisualDL tool
        self.vdl_writer = None
        if use_vdl:
            with misc.RankZeroOnly(self.rank) as is_master:
                if is_master:
                    self.vdl_writer = vdl.LogWriter(osp.join(output_dir, "vdl"))
            logger.info(
                "VisualDL tool is enabled for logging, you can view it by "
                f"running: 'visualdl --logdir {self.vdl_writer._logdir} --port 8080'."
            )

        # set WandB tool
        self.wandb_writer = None
        if use_wandb:
            try:
                import wandb
            except ModuleNotFoundError:
                raise ModuleNotFoundError(
                    "Please install 'wandb' with `pip install wandb` first."
                )
            with misc.RankZeroOnly(self.rank) as is_master:
                if is_master:
                    self.wandb_writer = wandb.init(**wandb_config)

        self.global_step = 0

        # log paddlepaddle's version
        if version.Version(paddle.__version__) != version.Version("0.0.0"):
            paddle_version = paddle.__version__
            logger.warning(
                f"Detected paddlepaddle version is '{paddle_version}', "
                "currently it is recommended to use develop version until the "
                "release of version 2.6."
            )
        else:
            paddle_version = f"develop({paddle.version.commit[:7]})"

        logger.info(f"Using paddlepaddle {paddle_version} on device {self.device}")

        self.forward_helper = expression.ExpressionSolver()

        # whether enable static for forward pass, defaults to False
        jit.enable_to_static(to_static)
        logger.info(f"Set to_static={to_static} for computational optimization.")

        # use loss aggregator, use summation if None
        self.loss_aggregator = loss_aggregator

        # convert sympy to callable object if exist
        extra_parameters = []
        if self.equation:
            for equation in self.equation.values():
                extra_parameters += list(equation.learnable_parameters)

        def convert_expr(
            container_dict: Union[
                Dict[str, ppsci.constraint.Constraint],
                Dict[str, ppsci.validate.Validator],
                Dict[str, ppsci.visualize.Visualizer],
            ]
        ) -> None:
            for container in container_dict.values():
                for name, expr in container.output_expr.items():
                    if isinstance(expr, sp.Basic):
                        container.output_expr[name] = ppsci.lambdify(
                            expr,
                            self.model,
                            extra_parameters,
                            # osp.join(self.output_dir, "symbolic_graph_visual", container.name, name), # HACK: Activate it for DEBUG.
                        )

        if self.constraint:
            convert_expr(self.constraint)

        if self.validator:
            convert_expr(self.validator)

        if self.visualizer:
            convert_expr(self.visualizer)

        # set up benchmark flag, will print memory stat if enabled
        self.benchmark_flag: bool = os.getenv("BENCHMARK_ROOT", None) is not None

    def train(self):
        """Training."""
        self.global_step = self.best_metric["epoch"] * self.iters_per_epoch

        for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
            self.train_epoch_func(self, epoch_id, self.log_freq)

            # log training summation at end of a epoch
            metric_msg = ", ".join(
                [self.train_output_info[key].avg_info for key in self.train_output_info]
            )
            logger.info(f"[Train][Epoch {epoch_id}/{self.epochs}][Avg] {metric_msg}")
            self.train_output_info.clear()

            cur_metric = float("inf")
            # evaluate during training
            if (
                self.eval_during_train
                and epoch_id % self.eval_freq == 0
                and epoch_id >= self.start_eval_epoch
            ):
                cur_metric, metric_dict_group = self.eval(epoch_id)
                if cur_metric < self.best_metric["metric"]:
                    self.best_metric["metric"] = cur_metric
                    self.best_metric["epoch"] = epoch_id
                    save_load.save_checkpoint(
                        self.model,
                        self.optimizer,
                        self.best_metric,
                        self.scaler,
                        self.output_dir,
                        "best_model",
                        self.equation,
                    )
                logger.info(
                    f"[Eval][Epoch {epoch_id}]"
                    f"[best metric: {self.best_metric['metric']}]"
                )
                for metric_dict in metric_dict_group.values():
                    logger.scaler(
                        {f"eval/{k}": v for k, v in metric_dict.items()},
                        epoch_id,
                        self.vdl_writer,
                        self.wandb_writer,
                    )

                # visualize after evaluation
                if self.visualizer is not None:
                    self.visualize(epoch_id)

            # update learning rate by epoch
            if self.lr_scheduler is not None and self.lr_scheduler.by_epoch:
                self.lr_scheduler.step()

            # save epoch model every save_freq epochs
            if self.save_freq > 0 and epoch_id % self.save_freq == 0:
                save_load.save_checkpoint(
                    self.model,
                    self.optimizer,
                    {"metric": cur_metric, "epoch": epoch_id},
                    self.scaler,
                    self.output_dir,
                    f"epoch_{epoch_id}",
                    self.equation,
                )

            # save the latest model for convenient resume training
            save_load.save_checkpoint(
                self.model,
                self.optimizer,
                {"metric": cur_metric, "epoch": epoch_id},
                self.scaler,
                self.output_dir,
                "latest",
                self.equation,
            )

    @misc.run_on_eval_mode
    def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
        """Evaluation.

        Args:
            epoch_id (int, optional): Epoch id. Defaults to 0.

        Returns:
            Tuple[float, Dict[str, Dict[str, float]]]: A targe metric value(float) and
                all metric(s)(dict) of evaluation, used to judge the quality of the model.
        """
        # set eval func
        self.eval_func = ppsci.solver.eval.eval_func

        result = self.eval_func(self, epoch_id, self.log_freq)
        metric_msg = ", ".join(
            [self.eval_output_info[key].avg_info for key in self.eval_output_info]
        )
        logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
        self.eval_output_info.clear()

        return result

    @misc.run_on_eval_mode
    def visualize(self, epoch_id: int = 0):
        """Visualization.

        Args:
            epoch_id (int, optional): Epoch id. Defaults to 0.
        """
        # set visualize func
        self.visu_func = ppsci.solver.visu.visualize_func

        self.visu_func(self, epoch_id)
        logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")

    @misc.run_on_eval_mode
    def predict(
        self,
        input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
        expr_dict: Optional[Dict[str, Callable]] = None,
        batch_size: int = 64,
        no_grad: bool = True,
        return_numpy: bool = False,
    ) -> Dict[str, Union[paddle.Tensor, np.ndarray]]:
        """Pure prediction using model.forward(...) and expression(optional, if given).

        Args:
            input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
            expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to
                compute equation variable with callable function. Defaults to None.
            batch_size (int, optional): Predicting by batch size. Defaults to 64.
            no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly
                for memory-efficiency. Defaults to True.
            return_numpy (bool): Whether convert result from Tensor to numpy ndarray.
                Defaults to False.

        Returns:
            Dict[str, Union[paddle.Tensor, np.ndarray]]: Prediction in dict.
        """
        num_samples = len(next(iter(input_dict.values())))
        num_pad = (self.world_size - num_samples % self.world_size) % self.world_size
        # pad with last element if `num_samples` is not divisible by `world_size`
        # ensuring every device get same number of data.
        if num_pad > 0:
            for k, v in input_dict.items():
                repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1)))
                if isinstance(v, np.ndarray):
                    input_dict[k] = np.concatenate(
                        (
                            v,
                            np.tile(v[num_samples - 1 : num_samples], repeat_times),
                        ),
                    )
                elif isinstance(v, paddle.Tensor):
                    input_dict[k] = paddle.concat(
                        (
                            v,
                            paddle.tile(v[num_samples - 1 : num_samples], repeat_times),
                        ),
                    )
                else:
                    raise ValueError(f"Unsupported data type {type(v)}.")

        num_samples_pad = num_samples + num_pad
        local_num_samples_pad = num_samples_pad // self.world_size
        local_input_dict = (
            {k: v[self.rank :: self.world_size] for k, v in input_dict.items()}
            if self.world_size > 1
            else input_dict
        )
        local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size

        pred_dict = misc.Prettydefaultdict(list)
        with self.no_grad_context_manager(no_grad), self.no_sync_context_manager(
            self.world_size > 1, self.model
        ):
            for batch_id in range(local_batch_num):
                batch_input_dict = {}
                st = batch_id * batch_size
                ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)

                # prepare batch input dict
                for key in local_input_dict:
                    if not paddle.is_tensor(local_input_dict[key]):
                        batch_input_dict[key] = paddle.to_tensor(
                            local_input_dict[key][st:ed], paddle.get_default_dtype()
                        )
                    else:
                        batch_input_dict[key] = local_input_dict[key][st:ed]
                    batch_input_dict[key].stop_gradient = no_grad

                # forward
                with self.autocast_context_manager(self.use_amp, self.amp_level):
                    batch_output_dict = self.forward_helper.visu_forward(
                        expr_dict, batch_input_dict, self.model
                    )

                # collect batch data
                for key, batch_output in batch_output_dict.items():
                    pred_dict[key].append(
                        batch_output.detach() if no_grad else batch_output
                    )

            # concatenate local predictions
            pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}

            if self.world_size > 1:
                # gather global predictions from all devices if world_size > 1
                pred_dict = {
                    key: misc.all_gather(value) for key, value in pred_dict.items()
                }
                # rearrange predictions as the same order of input_dict according
                # to inverse permutation
                perm = np.arange(num_samples_pad, dtype="int64")
                perm = np.concatenate(
                    [perm[rank :: self.world_size] for rank in range(self.world_size)],
                    axis=0,
                )
                perm_inv = np.empty_like(perm)
                perm_inv[perm] = np.arange(num_samples_pad, dtype="int64")
                perm_inv = paddle.to_tensor(perm_inv)
                pred_dict = {key: value[perm_inv] for key, value in pred_dict.items()}
                # then discard predictions of padding data at the end if num_pad > 0
                if num_pad > 0:
                    pred_dict = {
                        key: value[:num_samples] for key, value in pred_dict.items()
                    }
                    # NOTE: Discard padding data in input_dict for consistency
                    for k in input_dict:
                        input_dict[k] = input_dict[k][:num_samples]

        # convert to numpy ndarray if specified
        if return_numpy:
            pred_dict = {
                k: (v.numpy() if paddle.is_tensor(v) else v)
                for k, v in pred_dict.items()
            }

        return pred_dict

    @misc.run_on_eval_mode
    def export(self):
        """Export to inference model."""
        raise NotImplementedError("model export is not supported yet.")

    def autocast_context_manager(
        self, enable: bool, level: Literal["O0", "O1", "O2"] = "O1"
    ) -> contextlib.AbstractContextManager:
        """Smart autocast context manager for Auto Mix Precision.

        Args:
            enable (bool): Enable autocast.
            level (Literal["O0", "O1", "O2"]): Autocast level.

        Returns:
            contextlib.AbstractContextManager: Smart autocast context manager.
        """
        if enable:
            ctx_manager = amp.auto_cast(level=level)
        else:
            ctx_manager = (
                contextlib.nullcontext()
                if sys.version_info >= (3, 7)
                else contextlib.suppress()
            )
        return ctx_manager

    def no_grad_context_manager(
        self, enable: bool
    ) -> contextlib.AbstractContextManager:
        """Smart no_grad context manager.

        Args:
            enable (bool): Enable no_grad.

        Returns:
            contextlib.AbstractContextManager: Smart no_grad context manager.
        """
        if enable:
            ctx_manager = paddle.no_grad()
        else:
            ctx_manager = (
                contextlib.nullcontext()
                if sys.version_info >= (3, 7)
                else contextlib.suppress()
            )
        return ctx_manager

    def no_sync_context_manager(
        self,
        enable: bool,
        ddp_model: paddle.DataParallel,
    ) -> contextlib.AbstractContextManager:
        """Smart no_sync context manager for given model.
        NOTE: Only `paddle.DataParallel` object has `no_sync` interface.

        Args:
            enable (bool): Enable no_sync.

        Returns:
            contextlib.AbstractContextManager: Smart no_sync context manager.
        """
        if enable:
            if not isinstance(ddp_model, paddle.DataParallel):
                raise TypeError(
                    "no_sync interface is only for model with type paddle.DataParallel, "
                    f"but got type {misc.typename(ddp_model)}"
                )
            ctx_manager = ddp_model.no_sync()
        else:
            ctx_manager = (
                contextlib.nullcontext()
                if sys.version_info >= (3, 7)
                else contextlib.suppress()
            )
        return ctx_manager

    def plot_loss_history(
        self,
        by_epoch: bool = False,
        smooth_step: int = 1,
        use_semilogy: bool = True,
    ) -> None:
        """Plotting iteration/epoch-loss curve.

        Args:
            by_epoch (bool, optional): Whether the abscissa axis of the curve is epoch or iteration. Defaults to False.
            smooth_step (int, optional): How many steps of loss are squeezed to one point to smooth the curve. Defaults to 1.
            use_semilogy (bool, optional): Whether to set non-uniform coordinates for the y-axis. Defaults to True.
        """
        loss_dict = {}
        for key in self.train_loss_info:
            loss_arr = np.asarray(self.train_loss_info[key].history)
            if by_epoch:
                loss_arr = np.mean(
                    np.reshape(loss_arr, (-1, self.iters_per_epoch)),
                    axis=1,
                )
            loss_dict[key] = list(loss_arr)

        misc.plot_curve(
            data=loss_dict,
            xlabel="Epoch" if by_epoch else "Iteration",
            ylabel="Loss",
            output_dir=self.output_dir,
            smooth_step=smooth_step,
            use_semilogy=use_semilogy,
        )
autocast_context_manager(enable, level='O1')

Smart autocast context manager for Auto Mix Precision.

Parameters:

Name Type Description Default
enable bool

Enable autocast.

required
level Literal['O0', 'O1', 'O2']

Autocast level.

'O1'

Returns:

Type Description
AbstractContextManager

contextlib.AbstractContextManager: Smart autocast context manager.

Source code in ppsci/solver/solver.py
def autocast_context_manager(
    self, enable: bool, level: Literal["O0", "O1", "O2"] = "O1"
) -> contextlib.AbstractContextManager:
    """Smart autocast context manager for Auto Mix Precision.

    Args:
        enable (bool): Enable autocast.
        level (Literal["O0", "O1", "O2"]): Autocast level.

    Returns:
        contextlib.AbstractContextManager: Smart autocast context manager.
    """
    if enable:
        ctx_manager = amp.auto_cast(level=level)
    else:
        ctx_manager = (
            contextlib.nullcontext()
            if sys.version_info >= (3, 7)
            else contextlib.suppress()
        )
    return ctx_manager
eval(epoch_id=0)

Evaluation.

Parameters:

Name Type Description Default
epoch_id int

Epoch id. Defaults to 0.

0

Returns:

Type Description
Tuple[float, Dict[str, Dict[str, float]]]

Tuple[float, Dict[str, Dict[str, float]]]: A targe metric value(float) and all metric(s)(dict) of evaluation, used to judge the quality of the model.

Source code in ppsci/solver/solver.py
@misc.run_on_eval_mode
def eval(self, epoch_id: int = 0) -> Tuple[float, Dict[str, Dict[str, float]]]:
    """Evaluation.

    Args:
        epoch_id (int, optional): Epoch id. Defaults to 0.

    Returns:
        Tuple[float, Dict[str, Dict[str, float]]]: A targe metric value(float) and
            all metric(s)(dict) of evaluation, used to judge the quality of the model.
    """
    # set eval func
    self.eval_func = ppsci.solver.eval.eval_func

    result = self.eval_func(self, epoch_id, self.log_freq)
    metric_msg = ", ".join(
        [self.eval_output_info[key].avg_info for key in self.eval_output_info]
    )
    logger.info(f"[Eval][Epoch {epoch_id}][Avg] {metric_msg}")
    self.eval_output_info.clear()

    return result
export()

Export to inference model.

Source code in ppsci/solver/solver.py
@misc.run_on_eval_mode
def export(self):
    """Export to inference model."""
    raise NotImplementedError("model export is not supported yet.")
no_grad_context_manager(enable)

Smart no_grad context manager.

Parameters:

Name Type Description Default
enable bool

Enable no_grad.

required

Returns:

Type Description
AbstractContextManager

contextlib.AbstractContextManager: Smart no_grad context manager.

Source code in ppsci/solver/solver.py
def no_grad_context_manager(
    self, enable: bool
) -> contextlib.AbstractContextManager:
    """Smart no_grad context manager.

    Args:
        enable (bool): Enable no_grad.

    Returns:
        contextlib.AbstractContextManager: Smart no_grad context manager.
    """
    if enable:
        ctx_manager = paddle.no_grad()
    else:
        ctx_manager = (
            contextlib.nullcontext()
            if sys.version_info >= (3, 7)
            else contextlib.suppress()
        )
    return ctx_manager
no_sync_context_manager(enable, ddp_model)

Smart no_sync context manager for given model. NOTE: Only paddle.DataParallel object has no_sync interface.

Parameters:

Name Type Description Default
enable bool

Enable no_sync.

required

Returns:

Type Description
AbstractContextManager

contextlib.AbstractContextManager: Smart no_sync context manager.

Source code in ppsci/solver/solver.py
def no_sync_context_manager(
    self,
    enable: bool,
    ddp_model: paddle.DataParallel,
) -> contextlib.AbstractContextManager:
    """Smart no_sync context manager for given model.
    NOTE: Only `paddle.DataParallel` object has `no_sync` interface.

    Args:
        enable (bool): Enable no_sync.

    Returns:
        contextlib.AbstractContextManager: Smart no_sync context manager.
    """
    if enable:
        if not isinstance(ddp_model, paddle.DataParallel):
            raise TypeError(
                "no_sync interface is only for model with type paddle.DataParallel, "
                f"but got type {misc.typename(ddp_model)}"
            )
        ctx_manager = ddp_model.no_sync()
    else:
        ctx_manager = (
            contextlib.nullcontext()
            if sys.version_info >= (3, 7)
            else contextlib.suppress()
        )
    return ctx_manager
plot_loss_history(by_epoch=False, smooth_step=1, use_semilogy=True)

Plotting iteration/epoch-loss curve.

Parameters:

Name Type Description Default
by_epoch bool

Whether the abscissa axis of the curve is epoch or iteration. Defaults to False.

False
smooth_step int

How many steps of loss are squeezed to one point to smooth the curve. Defaults to 1.

1
use_semilogy bool

Whether to set non-uniform coordinates for the y-axis. Defaults to True.

True
Source code in ppsci/solver/solver.py
def plot_loss_history(
    self,
    by_epoch: bool = False,
    smooth_step: int = 1,
    use_semilogy: bool = True,
) -> None:
    """Plotting iteration/epoch-loss curve.

    Args:
        by_epoch (bool, optional): Whether the abscissa axis of the curve is epoch or iteration. Defaults to False.
        smooth_step (int, optional): How many steps of loss are squeezed to one point to smooth the curve. Defaults to 1.
        use_semilogy (bool, optional): Whether to set non-uniform coordinates for the y-axis. Defaults to True.
    """
    loss_dict = {}
    for key in self.train_loss_info:
        loss_arr = np.asarray(self.train_loss_info[key].history)
        if by_epoch:
            loss_arr = np.mean(
                np.reshape(loss_arr, (-1, self.iters_per_epoch)),
                axis=1,
            )
        loss_dict[key] = list(loss_arr)

    misc.plot_curve(
        data=loss_dict,
        xlabel="Epoch" if by_epoch else "Iteration",
        ylabel="Loss",
        output_dir=self.output_dir,
        smooth_step=smooth_step,
        use_semilogy=use_semilogy,
    )
predict(input_dict, expr_dict=None, batch_size=64, no_grad=True, return_numpy=False)

Pure prediction using model.forward(...) and expression(optional, if given).

Parameters:

Name Type Description Default
input_dict Dict[str, Union[ndarray, Tensor]]

Input data in dict.

required
expr_dict Optional[Dict[str, Callable]]

Expression dict, which guide to compute equation variable with callable function. Defaults to None.

None
batch_size int

Predicting by batch size. Defaults to 64.

64
no_grad bool

Whether set stop_gradient=True for entire prediction, mainly for memory-efficiency. Defaults to True.

True
return_numpy bool

Whether convert result from Tensor to numpy ndarray. Defaults to False.

False

Returns:

Type Description
Dict[str, Union[Tensor, ndarray]]

Dict[str, Union[paddle.Tensor, np.ndarray]]: Prediction in dict.

Source code in ppsci/solver/solver.py
@misc.run_on_eval_mode
def predict(
    self,
    input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
    expr_dict: Optional[Dict[str, Callable]] = None,
    batch_size: int = 64,
    no_grad: bool = True,
    return_numpy: bool = False,
) -> Dict[str, Union[paddle.Tensor, np.ndarray]]:
    """Pure prediction using model.forward(...) and expression(optional, if given).

    Args:
        input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Input data in dict.
        expr_dict (Optional[Dict[str, Callable]]): Expression dict, which guide to
            compute equation variable with callable function. Defaults to None.
        batch_size (int, optional): Predicting by batch size. Defaults to 64.
        no_grad (bool): Whether set stop_gradient=True for entire prediction, mainly
            for memory-efficiency. Defaults to True.
        return_numpy (bool): Whether convert result from Tensor to numpy ndarray.
            Defaults to False.

    Returns:
        Dict[str, Union[paddle.Tensor, np.ndarray]]: Prediction in dict.
    """
    num_samples = len(next(iter(input_dict.values())))
    num_pad = (self.world_size - num_samples % self.world_size) % self.world_size
    # pad with last element if `num_samples` is not divisible by `world_size`
    # ensuring every device get same number of data.
    if num_pad > 0:
        for k, v in input_dict.items():
            repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1)))
            if isinstance(v, np.ndarray):
                input_dict[k] = np.concatenate(
                    (
                        v,
                        np.tile(v[num_samples - 1 : num_samples], repeat_times),
                    ),
                )
            elif isinstance(v, paddle.Tensor):
                input_dict[k] = paddle.concat(
                    (
                        v,
                        paddle.tile(v[num_samples - 1 : num_samples], repeat_times),
                    ),
                )
            else:
                raise ValueError(f"Unsupported data type {type(v)}.")

    num_samples_pad = num_samples + num_pad
    local_num_samples_pad = num_samples_pad // self.world_size
    local_input_dict = (
        {k: v[self.rank :: self.world_size] for k, v in input_dict.items()}
        if self.world_size > 1
        else input_dict
    )
    local_batch_num = (local_num_samples_pad + (batch_size - 1)) // batch_size

    pred_dict = misc.Prettydefaultdict(list)
    with self.no_grad_context_manager(no_grad), self.no_sync_context_manager(
        self.world_size > 1, self.model
    ):
        for batch_id in range(local_batch_num):
            batch_input_dict = {}
            st = batch_id * batch_size
            ed = min(local_num_samples_pad, (batch_id + 1) * batch_size)

            # prepare batch input dict
            for key in local_input_dict:
                if not paddle.is_tensor(local_input_dict[key]):
                    batch_input_dict[key] = paddle.to_tensor(
                        local_input_dict[key][st:ed], paddle.get_default_dtype()
                    )
                else:
                    batch_input_dict[key] = local_input_dict[key][st:ed]
                batch_input_dict[key].stop_gradient = no_grad

            # forward
            with self.autocast_context_manager(self.use_amp, self.amp_level):
                batch_output_dict = self.forward_helper.visu_forward(
                    expr_dict, batch_input_dict, self.model
                )

            # collect batch data
            for key, batch_output in batch_output_dict.items():
                pred_dict[key].append(
                    batch_output.detach() if no_grad else batch_output
                )

        # concatenate local predictions
        pred_dict = {key: paddle.concat(value) for key, value in pred_dict.items()}

        if self.world_size > 1:
            # gather global predictions from all devices if world_size > 1
            pred_dict = {
                key: misc.all_gather(value) for key, value in pred_dict.items()
            }
            # rearrange predictions as the same order of input_dict according
            # to inverse permutation
            perm = np.arange(num_samples_pad, dtype="int64")
            perm = np.concatenate(
                [perm[rank :: self.world_size] for rank in range(self.world_size)],
                axis=0,
            )
            perm_inv = np.empty_like(perm)
            perm_inv[perm] = np.arange(num_samples_pad, dtype="int64")
            perm_inv = paddle.to_tensor(perm_inv)
            pred_dict = {key: value[perm_inv] for key, value in pred_dict.items()}
            # then discard predictions of padding data at the end if num_pad > 0
            if num_pad > 0:
                pred_dict = {
                    key: value[:num_samples] for key, value in pred_dict.items()
                }
                # NOTE: Discard padding data in input_dict for consistency
                for k in input_dict:
                    input_dict[k] = input_dict[k][:num_samples]

    # convert to numpy ndarray if specified
    if return_numpy:
        pred_dict = {
            k: (v.numpy() if paddle.is_tensor(v) else v)
            for k, v in pred_dict.items()
        }

    return pred_dict
train()

Training.

Source code in ppsci/solver/solver.py
def train(self):
    """Training."""
    self.global_step = self.best_metric["epoch"] * self.iters_per_epoch

    for epoch_id in range(self.best_metric["epoch"] + 1, self.epochs + 1):
        self.train_epoch_func(self, epoch_id, self.log_freq)

        # log training summation at end of a epoch
        metric_msg = ", ".join(
            [self.train_output_info[key].avg_info for key in self.train_output_info]
        )
        logger.info(f"[Train][Epoch {epoch_id}/{self.epochs}][Avg] {metric_msg}")
        self.train_output_info.clear()

        cur_metric = float("inf")
        # evaluate during training
        if (
            self.eval_during_train
            and epoch_id % self.eval_freq == 0
            and epoch_id >= self.start_eval_epoch
        ):
            cur_metric, metric_dict_group = self.eval(epoch_id)
            if cur_metric < self.best_metric["metric"]:
                self.best_metric["metric"] = cur_metric
                self.best_metric["epoch"] = epoch_id
                save_load.save_checkpoint(
                    self.model,
                    self.optimizer,
                    self.best_metric,
                    self.scaler,
                    self.output_dir,
                    "best_model",
                    self.equation,
                )
            logger.info(
                f"[Eval][Epoch {epoch_id}]"
                f"[best metric: {self.best_metric['metric']}]"
            )
            for metric_dict in metric_dict_group.values():
                logger.scaler(
                    {f"eval/{k}": v for k, v in metric_dict.items()},
                    epoch_id,
                    self.vdl_writer,
                    self.wandb_writer,
                )

            # visualize after evaluation
            if self.visualizer is not None:
                self.visualize(epoch_id)

        # update learning rate by epoch
        if self.lr_scheduler is not None and self.lr_scheduler.by_epoch:
            self.lr_scheduler.step()

        # save epoch model every save_freq epochs
        if self.save_freq > 0 and epoch_id % self.save_freq == 0:
            save_load.save_checkpoint(
                self.model,
                self.optimizer,
                {"metric": cur_metric, "epoch": epoch_id},
                self.scaler,
                self.output_dir,
                f"epoch_{epoch_id}",
                self.equation,
            )

        # save the latest model for convenient resume training
        save_load.save_checkpoint(
            self.model,
            self.optimizer,
            {"metric": cur_metric, "epoch": epoch_id},
            self.scaler,
            self.output_dir,
            "latest",
            self.equation,
        )
visualize(epoch_id=0)

Visualization.

Parameters:

Name Type Description Default
epoch_id int

Epoch id. Defaults to 0.

0
Source code in ppsci/solver/solver.py
@misc.run_on_eval_mode
def visualize(self, epoch_id: int = 0):
    """Visualization.

    Args:
        epoch_id (int, optional): Epoch id. Defaults to 0.
    """
    # set visualize func
    self.visu_func = ppsci.solver.visu.visualize_func

    self.visu_func(self, epoch_id)
    logger.info(f"[Visualize][Epoch {epoch_id}] Finish visualization")

ppsci.solver.train

train_epoch_func(solver, epoch_id, log_freq)

Train program for one epoch.

Parameters:

Name Type Description Default
solver Solver

Main solver.

required
epoch_id int

Epoch id.

required
log_freq int

Log training information every log_freq steps.

required
Source code in ppsci/solver/train.py
def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
    """Train program for one epoch.

    Args:
        solver (solver.Solver): Main solver.
        epoch_id (int): Epoch id.
        log_freq (int): Log training information every `log_freq` steps.
    """
    batch_tic = time.perf_counter()

    for iter_id in range(1, solver.iters_per_epoch + 1):
        total_loss = 0
        loss_dict = misc.Prettydefaultdict(float)
        loss_dict["loss"] = 0.0
        total_batch_size = 0
        reader_cost = 0
        batch_cost = 0
        reader_tic = time.perf_counter()

        input_dicts = []
        label_dicts = []
        weight_dicts = []
        for _, _constraint in solver.constraint.items():
            try:
                input_dict, label_dict, weight_dict = next(_constraint.data_iter)
            except StopIteration:
                _constraint.data_iter = iter(_constraint.data_loader)
                input_dict, label_dict, weight_dict = next(_constraint.data_iter)
            # profile code below
            # profiler.add_profiler_step(solver.cfg["profiler_options"])
            if iter_id == 5:
                # 5 step for warmup
                for key in solver.train_time_info:
                    solver.train_time_info[key].reset()
            reader_cost += time.perf_counter() - reader_tic
            for v in input_dict.values():
                if hasattr(v, "stop_gradient"):
                    v.stop_gradient = False

            # gather each constraint's input, label, weight to a list
            input_dicts.append(input_dict)
            label_dicts.append(label_dict)
            weight_dicts.append(weight_dict)
            total_batch_size += next(iter(input_dict.values())).shape[0]
            reader_tic = time.perf_counter()

        with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
            # forward for every constraint, including model and equation expression
            with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
                constraint_losses = solver.forward_helper.train_forward(
                    tuple(
                        _constraint.output_expr
                        for _constraint in solver.constraint.values()
                    ),
                    input_dicts,
                    solver.model,
                    solver.constraint,
                    label_dicts,
                    weight_dicts,
                )
                # accumulate all losses
                for i, _constraint in enumerate(solver.constraint.values()):
                    total_loss += constraint_losses[i]
                    loss_dict[_constraint.name] += (
                        float(constraint_losses[i]) / solver.update_freq
                    )
                if solver.update_freq > 1:
                    total_loss = total_loss / solver.update_freq
                loss_dict["loss"] = float(total_loss)

            # backward
            if solver.loss_aggregator is None:
                if solver.use_amp:
                    total_loss_scaled = solver.scaler.scale(total_loss)
                    total_loss_scaled.backward()
                else:
                    total_loss.backward()
            else:
                solver.loss_aggregator(constraint_losses, solver.global_step).backward()

        # update parameters
        if iter_id % solver.update_freq == 0 or iter_id == solver.iters_per_epoch:
            if solver.world_size > 1:
                # fuse + allreduce manually before optimization if use DDP + no_sync
                # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
                hpu.fused_allreduce_gradients(list(solver.model.parameters()), None)
            if solver.use_amp:
                solver.scaler.minimize(solver.optimizer, total_loss_scaled)
            else:
                solver.optimizer.step()
            solver.optimizer.clear_grad()

        # update learning rate by step
        if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
            solver.lr_scheduler.step()

        batch_cost += time.perf_counter() - batch_tic

        # update and log training information
        solver.global_step += 1
        solver.train_time_info["reader_cost"].update(reader_cost)
        solver.train_time_info["batch_cost"].update(batch_cost)
        printer.update_train_loss(solver, loss_dict, total_batch_size)
        if iter_id == 1 or iter_id % log_freq == 0:
            printer.log_train_info(solver, total_batch_size, epoch_id, iter_id)

        batch_tic = time.perf_counter()

train_LBFGS_epoch_func(solver, epoch_id, log_freq)

Train function for one epoch with L-BFGS optimizer.

NOTE: L-BFGS training program do not support AMP now.

Parameters:

Name Type Description Default
solver Solver

Main solver.

required
epoch_id int

Epoch id.

required
log_freq int

Log training information every log_freq steps.

required
Source code in ppsci/solver/train.py
def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
    """Train function for one epoch with L-BFGS optimizer.

    NOTE: L-BFGS training program do not support AMP now.

    Args:
        solver (solver.Solver): Main solver.
        epoch_id (int): Epoch id.
        log_freq (int): Log training information every `log_freq` steps.
    """
    batch_tic = time.perf_counter()

    for iter_id in range(1, solver.iters_per_epoch + 1):
        loss_dict = misc.Prettydefaultdict(float)
        loss_dict["loss"] = 0.0
        total_batch_size = 0
        reader_cost = 0
        batch_cost = 0
        reader_tic = time.perf_counter()

        input_dicts = []
        label_dicts = []
        weight_dicts = []
        for _, _constraint in solver.constraint.items():
            try:
                input_dict, label_dict, weight_dict = next(_constraint.data_iter)
            except StopIteration:
                _constraint.data_iter = iter(_constraint.data_loader)
                input_dict, label_dict, weight_dict = next(_constraint.data_iter)
            reader_cost += time.perf_counter() - reader_tic
            for v in input_dict.values():
                if hasattr(v, "stop_gradient"):
                    v.stop_gradient = False

            # gather all constraint data into list
            input_dicts.append(input_dict)
            label_dicts.append(label_dict)
            weight_dicts.append(weight_dict)
            total_batch_size += next(iter(input_dict.values())).shape[0]
            reader_tic = time.perf_counter()

        def closure():
            """Forward-backward closure function for LBFGS optimizer.

            Returns:
                Tensor: Computed loss.
            """
            total_loss = 0
            with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
                with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
                    # forward for every constraint, including model and equation expression
                    constraint_losses = solver.forward_helper.train_forward(
                        tuple(
                            _constraint.output_expr
                            for _constraint in solver.constraint.values()
                        ),
                        input_dicts,
                        solver.model,
                        solver.constraint,
                        label_dicts,
                        weight_dicts,
                    )
                    # accumulate all losses
                    for i, _constraint in enumerate(solver.constraint.values()):
                        total_loss += constraint_losses[i]
                        loss_dict[_constraint.name] = float(constraint_losses[i])
                    loss_dict["loss"] = float(total_loss)

                # backward
                solver.optimizer.clear_grad()
                if solver.loss_aggregator is None:
                    total_loss.backward()
                else:
                    solver.loss_aggregator(
                        constraint_losses, solver.global_step
                    ).backward()

            if solver.world_size > 1:
                # fuse + allreduce manually before optimization if use DDP model
                # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
                hpu.fused_allreduce_gradients(list(solver.model.parameters()), None)

            return total_loss

        # update parameters
        solver.optimizer.step(closure)

        # update learning rate by step
        if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
            solver.lr_scheduler.step()

        batch_cost += time.perf_counter() - batch_tic

        # update and log training information
        solver.global_step += 1
        solver.train_time_info["reader_cost"].update(reader_cost)
        solver.train_time_info["batch_cost"].update(batch_cost)
        printer.update_train_loss(solver, loss_dict, total_batch_size)
        if iter_id == 1 or iter_id % log_freq == 0:
            printer.log_train_info(solver, total_batch_size, epoch_id, iter_id)

        batch_tic = time.perf_counter()

ppsci.solver.eval

eval_func(solver, epoch_id, log_freq)

Evaluation function.

Parameters:

Name Type Description Default
solver Solver

Main Solver.

required
epoch_id int

Epoch id.

required
log_freq int

Log evaluation information every log_freq steps.

required

Returns:

Type Description
Tuple[float, Dict[str, Dict[str, float]]]

Tuple[float, Dict[str, Dict[str, float]]]: Target metric and all metric dicts computed during evaluation.

Source code in ppsci/solver/eval.py
def eval_func(
    solver: "solver.Solver", epoch_id: int, log_freq: int
) -> Tuple[float, Dict[str, Dict[str, float]]]:
    """Evaluation function.

    Args:
        solver (solver.Solver): Main Solver.
        epoch_id (int): Epoch id.
        log_freq (int): Log evaluation information every `log_freq` steps.

    Returns:
        Tuple[float, Dict[str, Dict[str, float]]]: Target metric and all metric dicts
            computed during evaluation.
    """
    if solver.compute_metric_by_batch:
        return _eval_by_batch(solver, epoch_id, log_freq)
    return _eval_by_dataset(solver, epoch_id, log_freq)

ppsci.solver.visu

visualize_func(solver, epoch_id)

Visualization program.

Parameters:

Name Type Description Default
solver Solver

Main Solver.

required
epoch_id int

Epoch id.

required
Source code in ppsci/solver/visu.py
def visualize_func(solver: "solver.Solver", epoch_id: int):
    """Visualization program.

    Args:
        solver (solver.Solver): Main Solver.
        epoch_id (int): Epoch id.
    """
    for _, _visualizer in solver.visualizer.items():
        all_input = misc.Prettydefaultdict(list)
        all_output = misc.Prettydefaultdict(list)

        # NOTE: 'visualize_func' now do not apply data sharding(different from 'Solver.predict'),
        # where every rank receive same input data and compute same output data
        # (which will cause computational redundancy),
        # but only the 0-rank(master) device save the visualization result into disk.
        # TODO(HydrogenSulfate): This will be optimized in the future.

        input_dict = _visualizer.input_dict
        batch_size = _visualizer.batch_size
        num_samples = len(next(iter(input_dict.values())))
        batch_num = (num_samples + (batch_size - 1)) // batch_size

        for batch_id in range(batch_num):
            batch_input_dict = {}
            st = batch_id * batch_size
            ed = min(num_samples, (batch_id + 1) * batch_size)

            # prepare batch input dict
            for key in input_dict:
                if not paddle.is_tensor(input_dict[key]):
                    batch_input_dict[key] = paddle.to_tensor(
                        input_dict[key][st:ed], paddle.get_default_dtype()
                    )
                else:
                    batch_input_dict[key] = input_dict[key][st:ed]
                batch_input_dict[key].stop_gradient = False

            # forward
            with solver.autocast_context_manager(
                solver.use_amp, solver.amp_level
            ), solver.no_grad_context_manager(solver.eval_with_no_grad):
                batch_output_dict = solver.forward_helper.visu_forward(
                    _visualizer.output_expr, batch_input_dict, solver.model
                )

            # collect batch data with dtype fixed to float32 regardless of the dtypes of
            # paddle runtime, which is most compatible with almost visualization tools.
            for key, batch_input in batch_input_dict.items():
                all_input[key].append(batch_input.detach().astype("float32"))
            for key, batch_output in batch_output_dict.items():
                all_output[key].append(batch_output.detach().astype("float32"))

        # concatenate all data
        for key in all_input:
            all_input[key] = paddle.concat(all_input[key])
        for key in all_output:
            all_output[key] = paddle.concat(all_output[key])

        # save visualization
        with misc.RankZeroOnly(solver.rank) as is_master:
            if is_master:
                visual_dir = osp.join(solver.output_dir, "visual", f"epoch_{epoch_id}")
                os.makedirs(visual_dir, exist_ok=True)
                _visualizer.save(
                    osp.join(visual_dir, _visualizer.prefix),
                    {**all_input, **all_output},
                )

ppsci.solver.printer

update_train_loss(trainer, loss_dict, batch_size)

Source code in ppsci/solver/printer.py
def update_train_loss(
    trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
):
    for key in loss_dict:
        if key not in trainer.train_output_info:
            trainer.train_output_info[key] = misc.AverageMeter(key, "7.5f")
        trainer.train_output_info[key].update(float(loss_dict[key]), batch_size)
        if key not in trainer.train_loss_info:
            trainer.train_loss_info[key] = misc.AverageMeter(key, ".5f")
        trainer.train_loss_info[key].update(float(loss_dict[key]))

update_eval_loss(trainer, loss_dict, batch_size)

Source code in ppsci/solver/printer.py
def update_eval_loss(
    trainer: "solver.Solver", loss_dict: Dict[str, float], batch_size: int
):
    for key in loss_dict:
        if key not in trainer.eval_output_info:
            trainer.eval_output_info[key] = misc.AverageMeter(key, "7.5f")
        trainer.eval_output_info[key].update(float(loss_dict[key]), batch_size)

log_train_info(trainer, batch_size, epoch_id, iter_id)

Source code in ppsci/solver/printer.py
def log_train_info(
    trainer: "solver.Solver", batch_size: int, epoch_id: int, iter_id: int
):
    lr_msg = f"lr: {trainer.optimizer.get_lr():.5f}"

    metric_msg = ", ".join(
        [
            f"{key}: {trainer.train_output_info[key].avg:.5f}"
            for key in trainer.train_output_info
        ]
    )

    time_msg = ", ".join(
        [trainer.train_time_info[key].mean for key in trainer.train_time_info]
    )

    ips_msg = (
        f"ips: {batch_size / trainer.train_time_info['batch_cost'].avg:.5f} samples/s"
    )

    eta_sec = (
        (trainer.epochs - epoch_id + 1) * trainer.iters_per_epoch - iter_id
    ) * trainer.train_time_info["batch_cost"].avg
    eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec))):s}"

    log_str = (
        f"[Train][Epoch {epoch_id}/{trainer.epochs}]"
        f"[Iter: {iter_id}/{trainer.iters_per_epoch}] {lr_msg}, "
        f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
    )
    if trainer.benchmark_flag:
        max_mem_reserved_msg = (
            f"max_mem_reserved: {device.cuda.max_memory_reserved()} B"
        )
        max_mem_allocated_msg = (
            f"max_mem_allocated: {device.cuda.max_memory_allocated()} B"
        )
        log_str += f", {max_mem_reserved_msg}, {max_mem_allocated_msg}"
    logger.info(log_str)

    logger.scaler(
        {
            "train/lr": trainer.optimizer.get_lr(),
            **{
                f"train/{key}": trainer.train_output_info[key].avg
                for key in trainer.train_output_info
            },
        },
        step=trainer.global_step,
        vdl_writer=trainer.vdl_writer,
        wandb_writer=trainer.wandb_writer,
    )

log_eval_info(trainer, batch_size, epoch_id, iters_per_epoch, iter_id)

Source code in ppsci/solver/printer.py
def log_eval_info(
    trainer: "solver.Solver",
    batch_size: int,
    epoch_id: int,
    iters_per_epoch: int,
    iter_id: int,
):
    metric_msg = ", ".join(
        [
            f"{key}: {trainer.eval_output_info[key].avg:.5f}"
            for key in trainer.eval_output_info
        ]
    )

    time_msg = ", ".join(
        [trainer.eval_time_info[key].mean for key in trainer.eval_time_info]
    )

    ips_msg = (
        f"ips: {batch_size / trainer.eval_time_info['batch_cost'].avg:.5f}" f"samples/s"
    )

    eta_sec = (iters_per_epoch - iter_id) * trainer.eval_time_info["batch_cost"].avg
    eta_msg = f"eta: {str(datetime.timedelta(seconds=int(eta_sec))):s}"
    logger.info(
        f"[Eval][Epoch {epoch_id}][Iter: {iter_id}/{iters_per_epoch}] "
        f"{metric_msg}, {time_msg}, {ips_msg}, {eta_msg}"
    )

    logger.scaler(
        {
            f"eval/{key}": trainer.eval_output_info[key].avg
            for key in trainer.eval_output_info
        },
        step=trainer.global_step,
        vdl_writer=trainer.vdl_writer,
        wandb_writer=trainer.wandb_writer,
    )

最后更新: November 17, 2023
创建日期: November 6, 2023