跳转至

2D-Biharmonic

python biharmonic2d.py
python biharmonic2d.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/biharmonic2d/biharmonic2d_pretrained.pdparams
python biharmonic2d.py mode=export
python biharmonic2d.py mode=infer
预训练模型 指标
biharmonic2d_pretrained.pdparams l2_error: 0.02774

1. 背景简介

双调和方程(Biharmonic Equation)是一种表征应力、应变和载荷之间关系的方程,它是一种四阶偏微分方程,因此在传统数值方法中难以解决。本案例尝试使用 PINNs(Physics Informed Neural Networks) 方法解决 Biharmonic 方程在 2D 矩形平板上的应用问题,并使用深度学习方法根据线弹性等方程进行求解。

2. 问题定义

本案例结构为一个长、宽和厚分别为 2 m、3 m 和 0.01 m 的矩形平板,平板四周固定,表面则被施加一个正弦分布载荷 \(q=q_0sin(\dfrac{\pi x}{a})sin(\dfrac{\pi x}{b})\),其中 \(q_0=980 Pa\)。PDE 方程为 2D 下的 Biharmonic 方程,公式为:

\[\nabla^4w=(\dfrac{\partial^2}{\partial x^2}+\dfrac{\partial^2}{\partial y^2})(\dfrac{\partial^2}{\partial x^2}+\dfrac{\partial^2}{\partial y^2})w=\dfrac{q}{D}\]

其中 \(w\) 为平板挠度,\(D\) 为抗弯刚度,可计算如下:

\[D=\dfrac{Et^3}{12(1-\nu^2)}\]

其中 \(E=201880.0e+6 Pa\) 为弹性杨氏模量,\(\nu=0.25\) 为泊松比。

根据平板挠度\(w\),可计算扭矩和剪切力如下:

\[ \begin{cases} M_x=-D(\dfrac{\partial^2w}{\partial x^2}+\nu\dfrac{\partial^2w}{\partial y^2}) \\ M_y=-D(\dfrac{\partial^2w}{\partial y^2}+\nu\dfrac{\partial^2w}{\partial x^2}) \\ M_{xy}=D(1-\nu\dfrac{\partial^2w}{\partial x y}) \\ Q_x=-D\dfrac{\partial}{\partial x}(\dfrac{\partial^2w}{\partial x^2}+\dfrac{\partial^2w}{\partial y^2}) \\ Q_y=-D\dfrac{\partial}{\partial y}(\dfrac{\partial^2w}{\partial x^2}+\dfrac{\partial^2w}{\partial y^2}) \\ \end{cases} \]

由于平板四周固定,在 \(x=0\)\(x=x_{max}\) 上,挠度 \(w\)\(y\) 方向的力矩 \(M_y\) 为 0,在 \(y=0\)\(y=y_{max}\) 上, 挠度 \(w\)\(x\) 方向的力矩 \(M_x\) 为 0,即:

\[ \begin{cases} w|_{x=0\ |\ x=\ a}=0 \\ M_y|_{x=0\ |\ x=\ a}=0 \\ w|_{y=0\ |\ y=\ b}=0 \\ M_x|_{y=0\ |\ y=\ b}=0 \\ \end{cases} \]

目标求解该平板表面每个点的挠度 \(w\),并以此计算出力矩和剪切力 \(M_x\)\(M_y\)\(M_{xy}\)\(Q_x\)\(Q_y\) 共 6 个物理量。常量定义代码如下:

log_freq: 20

# set working condition
E: 201880.0e+6  # Pa = N/m2
NU: 0.25
Q_0: 980     # Pa = N/m2
LENGTH: 2        # m

3. 问题求解

接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。 为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 API文档

3.1 模型构建

在 biharmonic2d 问题中,每一个已知的坐标点 \((x, y)\) 都有对应的待求解的未知量:受力方向(即 z 方向)的挠度 \(w\) 和力矩 \((M_x, M_y, M_{xy})\) 、剪切力 $(Q_x, Q_y),但由于力矩和剪切力为挠度计算得到,实际需要求出的未知量只有挠度 \(w\),因此仅需构建一个模型:

\[w = f(x,y)\]

上式中 \(f\) 即为挠度模型 disp_net,用 PaddleScience 代码表示如下:

# set models
disp_net = ppsci.arch.MLP(**cfg.MODEL)

为了在计算时,准确快速地访问具体变量的值,在这里指定应变模型的输入变量名是 ("x", "y"),为了与 PaddleScience 内置方程 API ppsci.equation.Biharmonic 匹配,输出变量名是 ("u") 而不是 ("w") ,这些命名与后续代码保持一致。

接着通过指定 MLP 的层数、神经元个数,就实例化出了一个拥有 5 层隐藏神经元,每层神经元数为 20 的神经网络模型 disp_net,使用 tanh 作为激活函数,并使用 WeightNorm 权重归一化。

3.2 方程构建

本案例涉及到双调和方程,使用 PaddleScience 内置的 ppsci.equation.Biharmonic 即可,由于载荷 \(q\) 为非均匀载荷,需要自定义载荷分布函数,并传入 API。

# set equation
x, y = sp.symbols("x y")
Q = cfg.Q_0 * sp.sin(np.pi * x / cfg.LENGTH) * sp.sin(np.pi * y / cfg.WIDTH)
equation = {
    "Biharmonic": ppsci.equation.Biharmonic(
        dim=2, q=Q, D=cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
    ),
}

3.3 计算域构建

由于平板的高很小,本问题的几何区域认为是长为 2 宽为 3 的 2D 矩形,通过 PaddleScience 内置的 ppsci.geometry.Rectangle API 构建:

# set geometry
plate = ppsci.geometry.Rectangle((0, 0), (cfg.LENGTH, cfg.WIDTH))
geom = {"geo": plate}

3.4 约束构建

本案例共涉及到 9 个约束,在具体约束构建之前,可以先构建数据读取配置,以便后续构建多个约束时复用该配置。

# set dataloader config
train_dataloader_cfg = {
    "dataset": "NamedArrayDataset",
    "iters_per_epoch": cfg.TRAIN.iters_per_epoch,
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
}

3.4.1 内部约束

以作用在背板内部点的 InteriorConstraint 为例,代码如下:

interior = ppsci.constraint.InteriorConstraint(
    equation["Biharmonic"].equations,
    {"biharmonic": 0},
    geom["geo"],
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.interior},
    ppsci.loss.MSELoss(),
    criteria=lambda x, y: ((0 < x) & (x < cfg.LENGTH) & (0 < y) & (y < cfg.WIDTH)),
    weight_dict={"biharmonic": cfg.TRAIN.weight.interior},
    name="INTERIOR",
)

InteriorConstraint 的第一个参数是方程(组)表达式,用于描述如何计算约束目标,此处填入在 3.2 方程构建 章节中实例化好的 equation["Biharmonic"].equations

第二个参数是约束变量的目标值,在本问题中希望与 Biharmonic 方程相关的 1 个值 biharmonic 被优化至 0;

第三个参数是约束方程作用的计算域,此处填入在 3.3 计算域构建 章节实例化好的 geom["geo"] 即可;

第四个参数是在计算域上的采样配置,此处设置 batch_size 为:

    tolerance_grad: 1.0e-8
    tolerance_change: 0
batch_size:

第五个参数是损失函数,此处选用常用的 MSE 函数,且 reduction 设置为 "mean",即会将参与计算的所有数据点产生的损失项求和;

第六个参数是几何点筛选,由于这个约束只施加在背板区域,因此需要对 geo 上采样出的点进行筛选,此处传入一个 lambda 筛选函数即可,其接受点集构成的张量 x, y,返回布尔值张亮,表示每个点是否符合筛选条件,不符合为 False,符合为 True

第七个参数是每个点参与损失计算时的权重,此处设置为:

  bc: 125
  interior: 8000
weight:

第八个参数是约束条件的名字,需要给每一个约束条件命名,方便后续对其索引。此处命名为 "INTERIOR" 即可。

3.4.2 边界约束

2. 问题定义 中所述,\(x=0\) 处的挠度 \(w\) 为 0,有如下边界条件,其他 7 个边界条件也与之类似:

# set constraint
bc_left = ppsci.constraint.BoundaryConstraint(
    {"w": lambda d: d["u"]},
    {"w": 0},
    geom["geo"],
    {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
    ppsci.loss.MSELoss(),
    criteria=lambda x, y: x == 0,
    weight_dict={"w": cfg.TRAIN.weight.bc},
    name="BC_LEFT",
)

在方程约束、边界约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。

# wrap constraints together
constraint = {
    bc_left.name: bc_left,
    bc_right.name: bc_right,
    bc_up.name: bc_up,
    bc_bottom.name: bc_bottom,
    bc_left_My.name: bc_left_My,
    bc_right_My.name: bc_right_My,
    bc_up_Mx.name: bc_up_Mx,
    bc_bottom_Mx.name: bc_bottom_Mx,
    interior.name: interior,
}

3.5 优化器构建

训练过程会调用优化器来更新模型参数,此处选择使用 Adam 先进行少量训练后,再使用 LBFGS 优化器精调。

optimizer_adam = ppsci.optimizer.Adam(**cfg.TRAIN.optimizer.adam)(disp_net)
optimizer_lbfgs = ppsci.optimizer.LBFGS(**cfg.TRAIN.optimizer.lbfgs)(disp_net)

3.6 超参数设定

接下来需要在配置文件中指定训练轮数和学习率等优化器参数。

# training settings
TRAIN:
  epochs: 1000
  iters_per_epoch: 1
  optimizer:
    adam:
      learning_rate: 1.0e-3
    lbfgs:
      learning_rate: 1.0
      max_iter: 50000

3.7 模型训练

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练,注意两个优化过程需要分别构建 Solver

# initialize adam solver
solver_adam = ppsci.solver.Solver(
    disp_net,
    constraint,
    cfg.output_dir,
    optimizer_adam,
    None,
    cfg.TRAIN.epochs,
    cfg.TRAIN.iters_per_epoch,
    save_freq=cfg.TRAIN.save_freq,
    log_freq=cfg.log_freq,
    seed=cfg.seed,
    equation=equation,
    geom=geom,
    checkpoint_path=cfg.TRAIN.checkpoint_path,
    pretrained_model_path=cfg.TRAIN.pretrained_model_path,
)
# train model
solver_adam.train()
# plot loss
solver_adam.plot_loss_history(by_epoch=True)
# initialize lbfgs solver
solver_lbfgs = ppsci.solver.Solver(
    disp_net,
    constraint,
    cfg.output_dir,
    optimizer_lbfgs,
    None,
    1,
    1,
    save_freq=cfg.TRAIN.save_freq,
    log_freq=cfg.log_freq,
    seed=cfg.seed,
    equation=equation,
    geom=geom,
    checkpoint_path=cfg.TRAIN.checkpoint_path,
    pretrained_model_path=cfg.TRAIN.pretrained_model_path,
)
# evaluate after finished training
solver_lbfgs.train()

3.8 模型评估和可视化

训练完成后,可以在 eval 模式中对训练好的模型进行评估和可视化。由于案例的特殊性,不需构建评估器和可视化器,而是使用自定义代码。

def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

    # set models
    disp_net = ppsci.arch.MLP(**cfg.MODEL)

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=disp_net, pretrained_model_path=cfg.EVAL.pretrained_model_path
    )

    # generate samples
    num_x = 201
    num_y = 301
    num_cords = num_x * num_y
    logger.info(f"num_cords: {num_cords}")
    x_grad, y_grad = np.meshgrid(
        np.linspace(start=0, stop=cfg.LENGTH, num=num_x, endpoint=True),
        np.linspace(start=0, stop=cfg.WIDTH, num=num_y, endpoint=True),
    )
    x_faltten = paddle.to_tensor(
        x_grad.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False
    )
    y_faltten = paddle.to_tensor(
        y_grad.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False
    )
    outs_pred = solver.predict(
        {"x": x_faltten, "y": y_faltten}, batch_size=num_cords, no_grad=False
    )

    # generate label
    D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
    Q = cfg.Q_0 / (
        (np.pi**4) * D * ((1 / (cfg.LENGTH**2) + 1 / (cfg.WIDTH**2)) ** 2)
    )
    outs_label = (
        paddle.to_tensor(Q, dtype=paddle.get_default_dtype())
        * paddle.sin(
            paddle.to_tensor(np.pi / cfg.LENGTH, dtype=paddle.get_default_dtype())
            * x_faltten,
        )
        * paddle.sin(
            paddle.to_tensor(np.pi / cfg.WIDTH, dtype=paddle.get_default_dtype())
            * y_faltten,
        )
    )

    # eval
    l2_error = ppsci.metric.L2Rel()(outs_pred, {"u": outs_label})["u"]
    logger.info(f"l2_error: {float(l2_error)}")

    # compute other pred outs
    def compute_outs(w, x, y):
        D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
        w_x2 = hessian(w, x)
        w_y2 = hessian(w, y)
        w_x_y = jacobian(jacobian(w, x), y)
        M_x = -(w_x2 + cfg.NU * w_y2) * D
        M_y = -(cfg.NU * w_x2 + w_y2) * D
        M_xy = (1 - cfg.NU) * w_x_y * D
        Q_x = -jacobian((w_x2 + w_y2), x) * D
        Q_y = -jacobian((w_x2 + w_y2), y) * D
        return {"Mx": M_x, "Mxy": M_xy, "My": M_y, "Qx": Q_x, "Qy": Q_y, "w": w}

    outs = compute_outs(outs_pred["u"], x_faltten, y_faltten)

    # plotting
    griddata_points = paddle.concat([x_faltten, y_faltten], axis=-1).numpy()
    griddata_xi = (x_grad, y_grad)
    boundary = [0, cfg.LENGTH, 0, cfg.WIDTH]
    plotting(
        "eval_Mx_Mxy_My_Qx_Qy_w",
        cfg.output_dir,
        {k: v.numpy() for k, v in outs.items()},
        griddata_points,
        griddata_xi,
        boundary,
    )

4. 完整代码

biharmonic2d.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import path as osp

import hydra
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import paddle
import sympy as sp
from mpl_toolkits.axes_grid1 import make_axes_locatable
from omegaconf import DictConfig
from scipy.interpolate import griddata

import ppsci
from ppsci.autodiff import hessian
from ppsci.autodiff import jacobian
from ppsci.utils import logger


def plotting(figname, output_dir, data, griddata_points, griddata_xi, boundary):
    plt.clf()
    fig = plt.figure(figname, figsize=(15, 12))
    gs = gridspec.GridSpec(2, 3)
    gs.update(top=0.8, bottom=0.2, left=0.1, right=0.9, wspace=0.5)

    for i, key in enumerate(data):
        plot_data = griddata(
            griddata_points,
            data[key].flatten(),
            griddata_xi,
            method="cubic",
        )

        ax = plt.subplot(gs[i // 3, i % 3])
        h = ax.imshow(
            plot_data,
            interpolation="nearest",
            cmap="jet",
            extent=boundary,
            origin="lower",
            aspect="auto",
        )
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(h, cax=cax)
        ax.axis("equal")
        ax.set_xlim(0, boundary[1])
        ax.set_ylim(0, boundary[3])
        ax.set_xlabel("$x$")
        ax.set_ylabel("$y$")
        plt.tick_params(labelsize=12)
        ax.set_title(key, fontsize=10)

    plt.savefig(osp.join(output_dir, figname))
    plt.close()


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

    # set models
    disp_net = ppsci.arch.MLP(**cfg.MODEL)

    # set optimizer
    optimizer_adam = ppsci.optimizer.Adam(**cfg.TRAIN.optimizer.adam)(disp_net)
    optimizer_lbfgs = ppsci.optimizer.LBFGS(**cfg.TRAIN.optimizer.lbfgs)(disp_net)

    # set equation
    x, y = sp.symbols("x y")
    Q = cfg.Q_0 * sp.sin(np.pi * x / cfg.LENGTH) * sp.sin(np.pi * y / cfg.WIDTH)
    equation = {
        "Biharmonic": ppsci.equation.Biharmonic(
            dim=2, q=Q, D=cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
        ),
    }

    # set geometry
    plate = ppsci.geometry.Rectangle((0, 0), (cfg.LENGTH, cfg.WIDTH))
    geom = {"geo": plate}

    # set dataloader config
    train_dataloader_cfg = {
        "dataset": "NamedArrayDataset",
        "iters_per_epoch": cfg.TRAIN.iters_per_epoch,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
    }

    # set constraint
    bc_left = ppsci.constraint.BoundaryConstraint(
        {"w": lambda d: d["u"]},
        {"w": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: x == 0,
        weight_dict={"w": cfg.TRAIN.weight.bc},
        name="BC_LEFT",
    )
    bc_right = ppsci.constraint.BoundaryConstraint(
        {"w": lambda d: d["u"]},
        {"w": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: x == cfg.LENGTH,
        weight_dict={"w": cfg.TRAIN.weight.bc},
        name="BC_RIGHT",
    )
    bc_up = ppsci.constraint.BoundaryConstraint(
        {"w": lambda d: d["u"]},
        {"w": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: y == 0,
        weight_dict={"w": cfg.TRAIN.weight.bc},
        name="BC_UP",
    )
    bc_bottom = ppsci.constraint.BoundaryConstraint(
        {"w": lambda d: d["u"]},
        {"w": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: y == cfg.WIDTH,
        weight_dict={"w": cfg.TRAIN.weight.bc},
        name="BC_BOTTOM",
    )
    bc_left_My = ppsci.constraint.BoundaryConstraint(
        {
            "M_y": lambda d: -(
                cfg.NU * hessian(d["u"], d["x"]) + hessian(d["u"], d["y"])
            )
        },
        {"M_y": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: x == 0,
        weight_dict={"M_y": cfg.TRAIN.weight.bc},
        name="BC_LEFT_My",
    )
    bc_right_My = ppsci.constraint.BoundaryConstraint(
        {
            "M_y": lambda d: -(
                cfg.NU * hessian(d["u"], d["x"]) + hessian(d["u"], d["y"])
            )
        },
        {"M_y": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: x == cfg.LENGTH,
        weight_dict={"M_y": cfg.TRAIN.weight.bc},
        name="BC_RIGHT_My",
    )
    bc_up_Mx = ppsci.constraint.BoundaryConstraint(
        {
            "M_x": lambda d: -(
                hessian(d["u"], d["x"]) + cfg.NU * hessian(d["u"], d["y"])
            )
        },
        {"M_x": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: y == 0,
        weight_dict={"M_x": cfg.TRAIN.weight.bc},
        name="BC_UP_Mx",
    )
    bc_bottom_Mx = ppsci.constraint.BoundaryConstraint(
        {
            "M_x": lambda d: -(
                hessian(d["u"], d["x"]) + cfg.NU * hessian(d["u"], d["y"])
            )
        },
        {"M_x": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: y == cfg.WIDTH,
        weight_dict={"M_x": cfg.TRAIN.weight.bc},
        name="BC_BOTTOM_Mx",
    )
    interior = ppsci.constraint.InteriorConstraint(
        equation["Biharmonic"].equations,
        {"biharmonic": 0},
        geom["geo"],
        {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.interior},
        ppsci.loss.MSELoss(),
        criteria=lambda x, y: ((0 < x) & (x < cfg.LENGTH) & (0 < y) & (y < cfg.WIDTH)),
        weight_dict={"biharmonic": cfg.TRAIN.weight.interior},
        name="INTERIOR",
    )
    # wrap constraints together
    constraint = {
        bc_left.name: bc_left,
        bc_right.name: bc_right,
        bc_up.name: bc_up,
        bc_bottom.name: bc_bottom,
        bc_left_My.name: bc_left_My,
        bc_right_My.name: bc_right_My,
        bc_up_Mx.name: bc_up_Mx,
        bc_bottom_Mx.name: bc_bottom_Mx,
        interior.name: interior,
    }

    # initialize adam solver
    solver_adam = ppsci.solver.Solver(
        disp_net,
        constraint,
        cfg.output_dir,
        optimizer_adam,
        None,
        cfg.TRAIN.epochs,
        cfg.TRAIN.iters_per_epoch,
        save_freq=cfg.TRAIN.save_freq,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        equation=equation,
        geom=geom,
        checkpoint_path=cfg.TRAIN.checkpoint_path,
        pretrained_model_path=cfg.TRAIN.pretrained_model_path,
    )
    # train model
    solver_adam.train()
    # plot loss
    solver_adam.plot_loss_history(by_epoch=True)
    # initialize lbfgs solver
    solver_lbfgs = ppsci.solver.Solver(
        disp_net,
        constraint,
        cfg.output_dir,
        optimizer_lbfgs,
        None,
        1,
        1,
        save_freq=cfg.TRAIN.save_freq,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        equation=equation,
        geom=geom,
        checkpoint_path=cfg.TRAIN.checkpoint_path,
        pretrained_model_path=cfg.TRAIN.pretrained_model_path,
    )
    # evaluate after finished training
    solver_lbfgs.train()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

    # set models
    disp_net = ppsci.arch.MLP(**cfg.MODEL)

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=disp_net, pretrained_model_path=cfg.EVAL.pretrained_model_path
    )

    # generate samples
    num_x = 201
    num_y = 301
    num_cords = num_x * num_y
    logger.info(f"num_cords: {num_cords}")
    x_grad, y_grad = np.meshgrid(
        np.linspace(start=0, stop=cfg.LENGTH, num=num_x, endpoint=True),
        np.linspace(start=0, stop=cfg.WIDTH, num=num_y, endpoint=True),
    )
    x_faltten = paddle.to_tensor(
        x_grad.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False
    )
    y_faltten = paddle.to_tensor(
        y_grad.flatten()[:, None], dtype=paddle.get_default_dtype(), stop_gradient=False
    )
    outs_pred = solver.predict(
        {"x": x_faltten, "y": y_faltten}, batch_size=num_cords, no_grad=False
    )

    # generate label
    D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
    Q = cfg.Q_0 / (
        (np.pi**4) * D * ((1 / (cfg.LENGTH**2) + 1 / (cfg.WIDTH**2)) ** 2)
    )
    outs_label = (
        paddle.to_tensor(Q, dtype=paddle.get_default_dtype())
        * paddle.sin(
            paddle.to_tensor(np.pi / cfg.LENGTH, dtype=paddle.get_default_dtype())
            * x_faltten,
        )
        * paddle.sin(
            paddle.to_tensor(np.pi / cfg.WIDTH, dtype=paddle.get_default_dtype())
            * y_faltten,
        )
    )

    # eval
    l2_error = ppsci.metric.L2Rel()(outs_pred, {"u": outs_label})["u"]
    logger.info(f"l2_error: {float(l2_error)}")

    # compute other pred outs
    def compute_outs(w, x, y):
        D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
        w_x2 = hessian(w, x)
        w_y2 = hessian(w, y)
        w_x_y = jacobian(jacobian(w, x), y)
        M_x = -(w_x2 + cfg.NU * w_y2) * D
        M_y = -(cfg.NU * w_x2 + w_y2) * D
        M_xy = (1 - cfg.NU) * w_x_y * D
        Q_x = -jacobian((w_x2 + w_y2), x) * D
        Q_y = -jacobian((w_x2 + w_y2), y) * D
        return {"Mx": M_x, "Mxy": M_xy, "My": M_y, "Qx": Q_x, "Qy": Q_y, "w": w}

    outs = compute_outs(outs_pred["u"], x_faltten, y_faltten)

    # plotting
    griddata_points = paddle.concat([x_faltten, y_faltten], axis=-1).numpy()
    griddata_xi = (x_grad, y_grad)
    boundary = [0, cfg.LENGTH, 0, cfg.WIDTH]
    plotting(
        "eval_Mx_Mxy_My_Qx_Qy_w",
        cfg.output_dir,
        {k: v.numpy() for k, v in outs.items()},
        griddata_points,
        griddata_xi,
        boundary,
    )


def export(cfg: DictConfig):
    from paddle import nn
    from paddle.static import InputSpec

    # set models
    disp_net = ppsci.arch.MLP(**cfg.MODEL)

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=disp_net, pretrained_model_path=cfg.INFER.pretrained_model_path
    )

    class Wrapped_Model(nn.Layer):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, x):
            model_out = self.model(x)
            outs = self.compute_outs(model_out["u"], x["x"], x["y"])
            return outs

        def compute_outs(self, w, x, y):
            D = cfg.E * (cfg.HEIGHT**3) / (12.0 * (1.0 - cfg.NU**2))
            w_x2 = hessian(w, x)
            w_y2 = hessian(w, y)
            w_x_y = jacobian(jacobian(w, x), y)
            M_x = -(w_x2 + cfg.NU * w_y2) * D
            M_y = -(cfg.NU * w_x2 + w_y2) * D
            M_xy = (1 - cfg.NU) * w_x_y * D
            Q_x = -jacobian((w_x2 + w_y2), x) * D
            Q_y = -jacobian((w_x2 + w_y2), y) * D
            return {"Mx": M_x, "Mxy": M_xy, "My": M_y, "Qx": Q_x, "Qy": Q_y, "w": w}

    solver.model = Wrapped_Model(solver.model)

    # export models
    input_spec = [
        {key: InputSpec([None, 1], "float32", name=key) for key in disp_net.input_keys},
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    from deploy.python_infer import pinn_predictor

    # set model predictor
    predictor = pinn_predictor.PINNPredictor(cfg)

    # generate samples
    num_x = 201
    num_y = 301
    x_grad, y_grad = np.meshgrid(
        np.linspace(
            start=0, stop=cfg.LENGTH, num=num_x, endpoint=True, dtype=np.float32
        ),
        np.linspace(
            start=0, stop=cfg.WIDTH, num=num_y, endpoint=True, dtype=np.float32
        ),
    )
    x_faltten = x_grad.reshape(-1, 1)
    y_faltten = y_grad.reshape(-1, 1)

    output_dict = predictor.predict(
        {"x": x_faltten, "y": y_faltten}, cfg.INFER.batch_size
    )

    # mapping data to cfg.INFER.output_keys
    output_dict = {
        store_key: output_dict[infer_key]
        for store_key, infer_key in zip(cfg.INFER.output_keys, output_dict.keys())
    }

    # plotting
    griddata_points = np.concatenate([x_faltten, y_faltten], axis=-1)
    griddata_xi = (x_grad, y_grad)
    boundary = [0, cfg.LENGTH, 0, cfg.WIDTH]
    plotting(
        "eval_Mx_Mxy_My_Qx_Qy_w",
        cfg.output_dir,
        output_dict,
        griddata_points,
        griddata_xi,
        boundary,
    )


@hydra.main(version_base=None, config_path="./conf", config_name="biharmonic2d.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

5. 结果展示

下面展示了挠度 \(w\) 以及力矩 \(M_x, M_y, M_{xy}\) 和剪切力 \(Q_x, Q_y\) 的模型预测结果和理论解结果。

biharmonic2d_pred.jpg

力矩 Mx, My, Mxy、剪切力 Qx, Qy 和挠度 w 的模型预测结果

biharmonic2d_label_M.jpg

力矩 Mx, My, Mxy 的理论解结果

biharmonic2d_label_Q_w.jpg

剪切力 Qx, Qy 和挠度 w 的理论解结果

可以看到模型预测的结果与理论解结果基本一致。

6. 参考文献

参考文献:A Physics Informed Neural Network Approach to Solution and Identification of Biharmonic Equations of Elasticity