Skip to content

VelocityGAN

Note

  1. Before running, it is recommended to quickly understand Dataset and Data Reading Method.
  2. Download OpenFWI Dataset to the corresponding subdirectory in FWIOpenData directory (e.g. Flatvel_A).
  3. Correspond the anno parameter in the yaml configuration file to the dataset.
python velocityGAN.py
python velocityGAN.py model=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/velocitygan/velocitygan_pretrained.pdparams
Pretrained Model Metrics
velocitygan_pretrained.pdparams MAE: 0.0669
RMSE: 0.0947
SSIM: 0.8511

1. Background Introduction

Underground velocity images play an important role in the field of earth sciences. They reflect the propagation speed of seismic waves in various underground areas and provide key information for detecting the internal structure of the earth. Seismic waveform inversion methods are widely used to reconstruct underground velocity imaging. Traditional physics-driven solution methods are numerical optimization processes that require multiple iterations and solving wave equations. This is not only computationally expensive, but usually only achieves local optimal solutions, resulting in limited image accuracy. Data-driven deep learning methods can alleviate these problems and generate higher-precision velocity images in a shorter time.

VelocityGAN is a specific example. It is an end-to-end framework that can generate high-quality velocity images directly from raw seismic waveform data. The paper shows that VelocityGAN outperforms traditional physics-driven waveform inversion methods and achieves SOTA performance in data-driven benchmarks.

2. Model Principle

As a data-driven deep learning method, VelocityGAN can directly learn the mapping relationship from waveform data to velocity images without solving wave equations. This paragraph only briefly introduces the model principle. For specific details, please read VelocityGAN: Data-Driven Full-Waveform Inversion Using Conditional Adversarial Networks.

2.1 Model Structure

VelocityGAN is a conditional adversarial network containing an image-to-image generator and a CNN discriminator. The figure below shows the overall structure of the model.

velocityGAN

  • Generator is a convolutional neural network with Encoder-Decoder structure. The Encoder extracts features from seismic waveform data and gradually compresses them into latent vectors; the Decoder infers the corresponding velocity map based on this latent vector.

  • Discriminator is a model composed of 9 convolutional blocks. Input velocity image, output image authenticity score.

2.2 Loss Function

The discriminator's loss function uses Wasserstein loss and gradient penalty. Its expression is:

\[ L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right] \]

Where \(\mathbb{P}_g\) is the generator distribution, \(\mathbb{P}_r\) is the real data distribution, and \(\mathbb{P}_{\hat{x}}\) is a mixed interpolation sample from \(\mathbb{P}_g\) and \(\mathbb{P}_r\).

The generator's loss function is a combination of adversarial loss [\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)] and content loss (MAE, MSE). Its expression is:

\[ L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x}) + \frac{\lambda_1}{w\cdot h} \sum_{i=1}^{w} \sum_{j=1}^{h} \left| \tilde{v}(i,j) - v(i,j) \right| + \frac{\lambda_2}{w\cdot h}\sum_{i=1}^{w} \sum_{j=1}^{h} \left( \tilde{v}(i,j) - v(i,j) \right)^2 \]

Where \(w\) and \(h\) are the width and height of the velocity map respectively, \(v(\cdot)\) and \(\tilde{v}(\cdot)\) represent the true pixel value and predicted pixel value of the velocity map respectively. \(\lambda_1\) and \(\lambda_2\) are hyperparameters used to adjust the relative importance of the two losses.

3. Model Construction

Next, we will explain how to use the PaddleScience framework to implement VelocityGAN. The following content only elaborates on key steps. For other details, please refer to API Documentation.

3.1 Dataset Introduction

The dataset uses the OpenFWI dataset open sourced by SMILE Team.

OpenFWI has a total of 12 datasets, divided into four categories: Vel Family, Fault Family, Style Family and Kimberlina Family. This case mainly uses the first two categories, and their configuration information is as follows:

image-20240830153600238

image-20240830153613634

Among them, each dataset contains waveform data and corresponding velocity images. The figure below shows an example of velocity images in each dataset.

image-20240830154311787

It can be seen that Vel Family includes two cases of straight and curved geological interfaces, while Fault Family adds some geological faults on this basis.

Each sample contains a velocity image and five waveform data, as shown in the figure below.

image-20240830154807670

Among them, 5 red stars lined up represent five seismic sources on the ground, and 70 receivers are also arranged on the ground. Seismic waves propagate downwards and bounce back, and the receivers record data every 0.001 seconds, totaling 1000. Therefore, a seismic waveform dataset with a shape of (5, 1000, 70) is generated.

Note: All data are not real collected data, but simulated. For specific details, please read OpenFWI: Large-Scale Multi-Structural Benchmark Datasets for Seismic Full Waveform Inversion.

3.2 Build dataset API

Since a dataset consists of 120 data files, passing in all file paths is cumbersome. In order to facilitate data reading, all paths can be packaged into a text file. By parsing the paths in turn, all data can be read. Due to this special reading method, we cannot use the built-in dataset API of PaddleScience, so we customized ppsci.data.dataset.FWIDataset.

The configuration code of dataloader is given below:

# set dataloader config
dataloader_cfg = {
    "dataset": {
        "name": "FWIDataset",
        "input_keys": ("data",),
        "label_keys": ("real_image",),
        "anno": cfg.TRAIN.dataset.anno,
        "preload": cfg.TRAIN.dataset.preload,
        "sample_ratio": cfg.TRAIN.dataset.sample_ratio,
        "file_size": ctx["file_size"],
        "transform_data": transform_data,
        "transform_label": transform_label,
    },
    "sampler": {
        "name": "BatchSampler",
        "shuffle": cfg.TRAIN.sampler.shuffle,
        "drop_last": cfg.TRAIN.sampler.drop_last,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "use_shared_memory": cfg.TRAIN.use_shared_memory,
    "num_workers": cfg.TRAIN.num_workers,
}
Among them, dataset uses our customized FWIDataset, and anno passes in the path of the text file, which contains the paths of all data files.

3.3 Model Construction

VelocityGAN in this case is not built into PaddleScience and needs to be implemented additionally, so we customized ppsci.arch.VelocityGenerator and ppsci.arch.VelocityDiscriminator.

The model construction code is as follows:

# set model
model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)
model_dis = ppsci.arch.VelocityDiscriminator(**cfg.MODEL.dis_net)

The parameter configuration is as follows:

# model settings
MODEL:
  gen_net:
    input_keys: ["data"]
    output_keys: ["fake_image"]
    dim1: 32
    dim2: 64
    dim3: 128
    dim4: 256
    dim5: 512
    sample_spatial: 1.0
  dis_net:
    input_keys: ["image"]
    output_keys: ["score"]
    dim1: 32
    dim2: 64
    dim3: 128
    dim4: 256

3.4 Custom loss

The loss function of VelocityGAN is a bit complicated and needs to be implemented by ourselves. PaddleScience provides an API for customizing loss functions - ppsci.loss.FunctionalLoss. The method is to define the loss function first, and then pass the function name as a parameter to FunctionalLoss. Note that the input and output of the custom loss function need to be in the format of a dictionary.

3.4.1 Loss of Generator

The loss of Generator includes L1 loss, L2 loss and adversarial loss. These three losses all have corresponding weights. If the weight of a certain loss is 0, it means that the loss item is not added during training.

def loss_func_gen(self, output_dict, label_dict, *args):
    """Calculate loss of generator.
        The loss includes L1 loss, L2 loss, and adversarial loss. Each of these losses has a corresponding weight,
        and if the weight of any loss is zero, it means that this loss component is not added during training.

    Args:
        output_dict: Output dict of model.
        label_dict: Label dict.

    Returns:
        Loss of generator.
    """
    l1loss = paddle.nn.L1Loss()
    l2loss = paddle.nn.MSELoss()

    pred = output_dict["fake_image"]
    label = label_dict["real_image"]

    loss_g1v = l1loss(pred, label)
    loss_g2v = l2loss(pred, label)

    loss = (
        self.weight["lambda_g1v"] * loss_g1v + self.weight["lambda_g2v"] * loss_g2v
    )

    loss_adv = -paddle.mean(self.model_dis({"image": pred})["score"])

    loss += self.weight["lambda_adv"] * loss_adv

    return {"loss_g": loss}

3.4.2 Loss of Discriminator

The loss of Discriminator includes Wasserstein loss and gradient penalty. Among them, only the gradient penalty term has weight parameters.

def loss_func_dis(self, output_dict, label_dict, *args):
    """Calculate loss of discriminator.
        The discriminator's loss includes Wasserstein loss and gradient penalty, and only the gradient penalty has a weight parameter.

    Args:
        output_dict: Output dict of model.
        label_dict: Label dict.

    Returns:
        Loss of discriminator.
    """
    pred = output_dict["fake_image"]
    pred.stop_gradient = True
    label = label_dict["real_image"]

    gradient_penalty = self.compute_gradient_penalty(label, pred)

    loss_real = paddle.mean(self.model_dis({"image": label})["score"])
    loss_fake = paddle.mean(self.model_dis({"image": pred})["score"])

    loss = -loss_real + loss_fake + gradient_penalty * self.weight["lambda_gp"]

    return {"loss_d": loss}

def compute_gradient_penalty(self, real_samples, fake_samples):
    """Calculate the gradient penalty.
        Generate a random interpolation factor, create mixed samples, process through the discriminator,
        compute the gradient of the output, apply L2 norm and constrain it to 1, and finally obtain the gradient penalty.

    Args:
        real_samples: Ground truth data from dataset.
        fake_samples: Generated data from generator.

    Returns:
        Gradient penalty.
    """
    alpha = paddle.rand([real_samples.shape[0], 1, 1, 1], dtype=real_samples.dtype)
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.stop_gradient = False  # Allow gradients to be calculated
    d_interpolates = self.model_dis({"image": interpolates})["score"]

    gradients = paddle.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.reshape([gradients.shape[0], -1])
    gradient_penalty = paddle.mean((paddle.norm(gradients, p=2, axis=1) - 1) ** 2)
    return gradient_penalty

Note:

pred.stop_gradient = True

Indicates that the pred variable does not participate in gradient calculation. This is because pred is only used as the input of Discriminator and its gradient does not need to be considered. Moreover, pred is the output of Generator. If gradient calculation is not stopped, the parameter gradient of Generator will accumulate during discriminator training and eventually affect the training of the first batch of the generator.

3.5 Constraint Construction

This case uses ppsci.constraint.SupervisedConstraint to construct constraints.

The construction code is as follows:

# set constraint
constraint_gen = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
    output_expr={"fake_image": lambda out: out["fake_image"]},
    name="cst_gen",
)
constraint_gen_dict = {constraint_gen.name: constraint_gen}

constraint_dis = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(dis_funcs.loss_func_dis),
    output_expr={"fake_image": lambda out: out["fake_image"]},
    name="cst_dis",
)
constraint_dis_dict = {constraint_dis.name: constraint_dis}

Among them, output_expr specifies how to construct output_dict, and name is the name of the constraint, which is convenient for subsequent indexing.

After the constraint construction is completed, it needs to be created in the form of a dictionary for easy passing to ppsci.solver.Solver later.

3.6 Optimizer Construction

VelocityGAN uses AdamW optimizer, which can be directly constructed by calling ppsci.optimizer.AdamW, code as follows:

# set optimizer
optimizer = ppsci.optimizer.AdamW(
    learning_rate=cfg.TRAIN.learning_rate, weight_decay=cfg.TRAIN.weight_decay
)
optimizer_g = optimizer(model_gen)
optimizer_d = optimizer(model_dis)

3.7 Solver Construction

Pass the constructed model, constraints, optimizer and other parameters to ppsci.solver.Solver.

# initialize solver
solver_gen = ppsci.solver.Solver(
    model=model_gen,
    output_dir=cfg.output_dir,
    constraint=constraint_gen_dict,
    optimizer=optimizer_g,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
)

solver_dis = ppsci.solver.Solver(
    model=model_gen,
    output_dir=cfg.output_dir,
    constraint=constraint_dis_dict,
    optimizer=optimizer_d,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
)

3.8 Model Training

# training
for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    solver_dis.train()
    solver_gen.train()

3.9 Custom metric

The evaluation indicators of this case are: MAE (Mean Absolute Error), RMSE (Root Mean Squared Error) and SSIM (Structural SIMilarity). Among them, PaddleScience provides APIs for MAE and RMSE, while SSIM requires us to implement it additionally.

PaddleScience provides an API for customizing metric functions - ppsci.metric.FunctionalMetric. The method is to define the metric function first, and then pass the function name as a parameter to FunctionalMetric. Note that the input and output of the custom metric function need to be in the format of a dictionary.

The implementation code of SSIM is as follows:

class SSIM(paddle.nn.Layer):
    """
    SSIM is used to measure the similarity between two images.

    Attributes:
        window_size (int): The size of the gaussian window used for computing SSIM. Defaults to 11.
        size_average (bool): If True, the SSIM values across spatial dimensions are averaged. Defaults to True.

    Methods:
        forward(img1, img2): Computes the SSIM score between two images using a gaussian filter defined by `window`.
    """

    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        _, channel, _, _ = img1.shape

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            if img1.place.is_gpu_place():
                window = window.cuda(img1.place.gpu_device_id())
            window = window.astype(img1.dtype)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def gaussian(window_size, sigma):
    gauss = paddle.to_tensor(
        data=[
            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
            for x in range(window_size)
        ],
        dtype="float32",
    )
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = (
        paddle.mm(_1D_window, _1D_window.t())
        .astype("float32")
        .unsqueeze(0)
        .unsqueeze(0)
    )
    window = _2D_window.expand([channel, 1, window_size, window_size])
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = paddle.nn.functional.conv2d(
        x=img1, weight=window, padding=window_size // 2, groups=channel
    )
    mu2 = paddle.nn.functional.conv2d(
        x=img2, weight=window, padding=window_size // 2, groups=channel
    )

    mu1_sq = mu1.pow(y=2)
    mu2_sq = mu2.pow(y=2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
        paddle.nn.functional.conv2d(
            x=img1 * img1, weight=window, padding=window_size // 2, groups=channel
        )
        - mu1_sq
    )
    sigma2_sq = (
        paddle.nn.functional.conv2d(
            x=img2 * img2, weight=window, padding=window_size // 2, groups=channel
        )
        - mu2_sq
    )
    sigma12 = (
        paddle.nn.functional.conv2d(
            x=img1 * img2, weight=window, padding=window_size // 2, groups=channel
        )
        - mu1_mu2
    )

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = (
        (2 * mu1_mu2 + C1)
        * (2 * sigma12 + C2)
        / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    )

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(axis=1).mean(axis=1).mean(axis=1)


def ssim_metirc(output_dict, label_dict):
    ssim_loss = SSIM(window_size=11)
    metric_dict = {}

    for key in label_dict:
        ssim = ssim_loss(label_dict[key] / 2 + 0.5, output_dict[key] / 2 + 0.5)
        metric_dict[key] = ssim

    return metric_dict

3.10 Validator Construction

This case uses ppsci.validate.SupervisedValidator to construct the validator.

# set validator
validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.MAELoss("mean"),
    output_expr={"real_image": lambda out: out["fake_image"]},
    metric={
        "MAE": ppsci.metric.MAE(),
        "RMSE": ppsci.metric.RMSE(),
        "SSIM": ppsci.metric.FunctionalMetric(func_module.ssim_metirc),
    },
    name="val",
)
validator_dict = {validator.name: validator}

3.11 Model Evaluation

After passing the model, validator and weight path to ppsci.solver.Solver, start evaluation through solver.eval().

# initialize solver
solver = ppsci.solver.Solver(
    model=model_gen,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
)

# evaluation
solver.eval()

3.12 Visualization

After evaluation, we visualize the results in the form of images, code as follows:

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        for batch_idx, (input_, label_, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.vb:
                break
            fake_image = model_gen(input_)["fake_image"].numpy()
            real_image = label_["real_image"].numpy()
            for i in range(cfg.VIS.vsa):
                plot_velocity(
                    fake_image[i, 0],
                    real_image[i, 0],
                    f"{cfg.output_dir}/V_{batch_idx}_{i}.png",
                )
    print(f"The visualizations are saved to {cfg.output_dir}")

4. Complete Code

velocityGAN.py
import json
import os
import sys

import functions as func_module
import hydra
import paddle
from functions import plot_velocity
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger

os.environ["FLAGS_embedding_deterministic"] = "1"
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["NCCL_ALGO"] = "Tree"


def evaluate(cfg: DictConfig):
    # get dataset configuration information
    with open("dataset_config.json") as f:
        try:
            ctx = json.load(f)[cfg.DATASET]
        except KeyError:
            print("Unsupported dataset.")
            sys.exit()

    if cfg.file_size is not None:
        ctx["file_size"] = cfg.file_size

    # get data transformation
    transform_data, transform_label = func_module.create_transform(ctx, cfg.k)

    # set model
    model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)

    # set valid_dataloader_cfg
    valid_dataloader_cfg = {
        "dataset": {
            "name": "FWIDataset",
            "input_keys": ("data",),
            "label_keys": ("real_image",),
            "anno": cfg.EVAL.dataset.anno,
            "preload": cfg.EVAL.dataset.preload,
            "sample_ratio": cfg.EVAL.dataset.sample_ratio,
            "file_size": ctx["file_size"],
            "transform_data": transform_data,
            "transform_label": transform_label,
        },
        "batch_size": cfg.EVAL.batch_size,
        "use_shared_memory": cfg.EVAL.use_shared_memory,
        "num_workers": cfg.EVAL.num_workers,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.MAELoss("mean"),
        output_expr={"real_image": lambda out: out["fake_image"]},
        metric={
            "MAE": ppsci.metric.MAE(),
            "RMSE": ppsci.metric.RMSE(),
            "SSIM": ppsci.metric.FunctionalMetric(func_module.ssim_metirc),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=model_gen,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            for batch_idx, (input_, label_, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.vb:
                    break
                fake_image = model_gen(input_)["fake_image"].numpy()
                real_image = label_["real_image"].numpy()
                for i in range(cfg.VIS.vsa):
                    plot_velocity(
                        fake_image[i, 0],
                        real_image[i, 0],
                        f"{cfg.output_dir}/V_{batch_idx}_{i}.png",
                    )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # get dataset configuration information
    with open(cfg.DATASET_CONFIG) as f:
        try:
            ctx = json.load(f)[cfg.DATASET]
        except KeyError:
            print("Unsupported dataset.")
            sys.exit()

    if cfg.file_size is not None:
        ctx["file_size"] = cfg.file_size

    # get data transformation
    transform_data, transform_label = func_module.create_transform(ctx, cfg.k)

    # set model
    model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)
    model_dis = ppsci.arch.VelocityDiscriminator(**cfg.MODEL.dis_net)

    # set class for loss function
    gen_funcs = func_module.GenFuncs(model_dis, cfg.WEIGHT_DICT.gen)
    dis_funcs = func_module.DisFuncs(model_dis, cfg.WEIGHT_DICT.dis)

    # set dataloader config
    dataloader_cfg = {
        "dataset": {
            "name": "FWIDataset",
            "input_keys": ("data",),
            "label_keys": ("real_image",),
            "anno": cfg.TRAIN.dataset.anno,
            "preload": cfg.TRAIN.dataset.preload,
            "sample_ratio": cfg.TRAIN.dataset.sample_ratio,
            "file_size": ctx["file_size"],
            "transform_data": transform_data,
            "transform_label": transform_label,
        },
        "sampler": {
            "name": "BatchSampler",
            "shuffle": cfg.TRAIN.sampler.shuffle,
            "drop_last": cfg.TRAIN.sampler.drop_last,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "use_shared_memory": cfg.TRAIN.use_shared_memory,
        "num_workers": cfg.TRAIN.num_workers,
    }

    # set constraint
    constraint_gen = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
        output_expr={"fake_image": lambda out: out["fake_image"]},
        name="cst_gen",
    )
    constraint_gen_dict = {constraint_gen.name: constraint_gen}

    constraint_dis = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(dis_funcs.loss_func_dis),
        output_expr={"fake_image": lambda out: out["fake_image"]},
        name="cst_dis",
    )
    constraint_dis_dict = {constraint_dis.name: constraint_dis}

    # set optimizer
    optimizer = ppsci.optimizer.AdamW(
        learning_rate=cfg.TRAIN.learning_rate, weight_decay=cfg.TRAIN.weight_decay
    )
    optimizer_g = optimizer(model_gen)
    optimizer_d = optimizer(model_dis)

    # initialize solver
    solver_gen = ppsci.solver.Solver(
        model=model_gen,
        output_dir=cfg.output_dir,
        constraint=constraint_gen_dict,
        optimizer=optimizer_g,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    )

    solver_dis = ppsci.solver.Solver(
        model=model_gen,
        output_dir=cfg.output_dir,
        constraint=constraint_dis_dict,
        optimizer=optimizer_d,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    )

    # training
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        solver_dis.train()
        solver_gen.train()

    # save model weight
    paddle.save(
        model_gen.state_dict(), os.path.join(cfg.output_dir, "model_gen.pdparams")
    )


@hydra.main(version_base=None, config_path="./conf", config_name="velocityGAN.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

5. Result Display

Training results using FlatVel-A dataset.

MAE RMSE SSIM
0.0669 0.0947 0.8511

image-20240914192445180

image-20240914192456002

6. References