Skip to content

tempoGAN(temporally Generative Adversarial Networks)

AI Studio Quick Experience

# linux
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat -P datasets/tempoGAN/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat --create-dirs -o ./datasets/tempoGAN/2d_train.mat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py
# linux
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat -P datasets/tempoGAN/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_train.mat --create-dirs -o ./datasets/tempoGAN/2d_train.mat
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/tempoGAN/tempogan_pretrained.pdparams
python tempoGAN.py mode=export
# linux
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat -P datasets/tempoGAN/
# windows
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/tempoGAN/2d_valid.mat --create-dirs -o ./datasets/tempoGAN/2d_valid.mat
python tempoGAN.py mode=infer
Pretrained Model Metrics
tempogan_pretrained.pdparams MSE: 4.21e-5
PSNR: 47.19
SSIM: 0.9974

1. Background Introduction

In fluid simulation problems, capturing the complex details of turbulence has always been a long-standing challenge for numerical simulation. Solving these details with discrete models will incur huge computational costs, which will soon become infeasible for flows on human spatial and temporal scales. Therefore, the demand for fluid super-resolution has emerged, which aims to recover high-resolution fluid simulation results from low-resolution results through fluid dynamics simulation and deep learning technology, so as to reduce the huge computational cost in the process of generating high-resolution fluids. This technology can be applied to various fluid simulations, such as water flow, air flow, flame simulation, etc.

Generative Adversarial Networks (GAN) is a deep learning network using unsupervised learning methods. GAN networks (at least) contain two models: Generator and Discriminator. The generator is used to generate the output of the problem, and the discriminator is used to judge whether the output is true or false. Both optimize together in mutual game, and finally make the output of the generator close to the true value.

Based on the GAN network, tempoGAN adds a time-related discriminator Discriminator_tempo. The network structure of this discriminator is the same as the basic discriminator, but the input is several consecutive frames of data in time, rather than a single frame of data, thereby taking timing into consideration.

This problem mainly uses this network to obtain corresponding high-density fluid data through input low-density fluid data, greatly saving time and computational costs.

2. Problem Definition

This problem includes three models: Generator, Discriminator and time-related discriminator (Discriminator_tempo). According to the training process of the GAN network, these three models are trained alternately, and the training order is: Discriminator, Discriminator_tempo, Generator. GAN network is unsupervised learning. In the network design of this problem, the target value is used as an input value and input into the network for training.

3. Problem Solving

Next, we will explain how to convert the problem into PaddleScience code step by step and solve the problem using deep learning methods. In order to quickly understand PaddleScience, only key steps such as model construction and constraint construction are described below, while other details please refer to API Documentation.

3.1 Dataset Introduction

The dataset is a 2d fluid dataset generated using the open source code package mantaflow. The dataset includes numerical values converted from low and high-density fluid images of a certain number of continuous frames, stored in dictionary form in .mat files.

Before running the code for this problem, please download training dataset and validation dataset, and store them separately in the path after downloading:

log_freq: 20
DATASET_PATH: ./datasets/tempoGAN/2d_train.mat

3.2 Model Construction

tempoGAN-arch

tempoGAN network model

The figure above is the complete model structure diagram of tempoGAN, but this problem only deals with relatively simple cases, and does not involve parts including velocity and vorticity input, 3d, data augmentation, advection operator, etc. If you are interested in contents not included in these documents, you can modify the code yourself and conduct further experiments.

As shown in the figure above, the input of Generator is the interpolation of low-density fluid data, and the output is the generated high-density fluid simulation data. The input of Discriminator is the concatenation of interpolation of low-density fluid data with high-density fluid simulation data generated by Generator and target high-density fluid data respectively. The input of Discriminator_tempo is multi-frame continuous high-density fluid simulation data generated by Generator and target high-density fluid data.

Although the composition of input and output looks complicated, they are essentially fluid density data, so the mapping functions of the 3 networks are all \(f: \mathbb{R}^1 \to \mathbb{R}^1\).

Different from simple MLP networks, depending on different problems to be solved, GAN generators and discriminators have a variety of network structures to choose from, which will not be repeated here. Due to this uniqueness, the tempoGAN network in this problem is not built into PaddleScience and needs to be implemented additionally.

The Generator in this problem is a model with 4 layers of improved Res Block, Discriminator and Discriminator_tempo are the same model with 4 layers of convolution results, both have the same network structure but different inputs. The network parameters of Generator, Discriminator and Discriminator_tempo also need to be defined additionally.

For specific code, please refer to the gan.py file in Complete Code.

Since the intermediate results of the generator and discriminator in the GAN network need to be called mutually and participate in each other's loss calculation, Model List is used for implementation, expressed in PaddleScience code as follows:

# define Generator model
model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
model_gen.register_input_transform(gen_funcs.transform_in)
disc_funcs.model_gen = model_gen

model_tuple = (model_gen,)
# define Discriminators
if cfg.USE_SPATIALDISC:
    model_disc = ppsci.arch.Discriminator(**cfg.MODEL.disc_net)
    model_disc.register_input_transform(disc_funcs.transform_in)
    model_tuple += (model_disc,)

# define temporal Discriminators
if cfg.USE_TEMPODISC:
    model_disc_tempo = ppsci.arch.Discriminator(**cfg.MODEL.tempo_net)
    model_disc_tempo.register_input_transform(disc_funcs.transform_in_tempo)
    model_tuple += (model_disc_tempo,)

# define model_list
model_list = ppsci.arch.ModelList(model_tuple)

Note that the network input defined in the above code is not exactly the same as the actual network input, so transform needs to be performed on the input.

3.3 transform Construction

The input of Generator is the interpolation of low-density fluid data, while the dataset stores the original low-density fluid data, so an interpolation transform is required.

def transform_in(self, _in):
    ratio = 2
    input_dict = reshape_input(_in)
    density_low = input_dict["density_low"]
    density_low_inp = interpolate(density_low, ratio, "nearest")
    return {"input_gen": density_low_inp}

Discriminator and Discriminator_tempo have more complex transforms on input, respectively:

def transform_in(self, _in):
    ratio = 2
    input_dict = reshape_input(_in)
    density_low = input_dict["density_low"]
    density_high_from_target = input_dict["density_high"]

    density_low_inp = interpolate(density_low, ratio, "nearest")

    density_high_from_gen = self.model_gen(input_dict)["output_gen"]
    density_high_from_gen.stop_gradient = True

    density_input_from_target = paddle.concat(
        [density_low_inp, density_high_from_target], axis=1
    )
    density_input_from_gen = paddle.concat(
        [density_low_inp, density_high_from_gen], axis=1
    )
    return {
        "input_disc_from_target": density_input_from_target,
        "input_disc_from_gen": density_input_from_gen,
    }

def transform_in_tempo(self, _in):
    density_high_from_target = _in["density_high"]

    input_dict = reshape_input(_in)
    density_high_from_gen = self.model_gen(input_dict)["output_gen"]
    density_high_from_gen.stop_gradient = True

    input_trans = {
        "input_tempo_disc_from_target": density_high_from_target,
        "input_tempo_disc_from_gen": density_high_from_gen,
    }

    return dereshape_input(input_trans, 3)

Where:

density_high_from_gen.stop_gradient = True

Indicates stopping the calculation gradient of parameters. This is set because this variable is only used as input for Discriminator and Discriminator_tempo here, and should not participate in gradient backpropagation during reverse calculation. If such setting is not made, since this variable comes from the output of Generator, the gradient will be transmitted to Generator along this variable during backpropagation, thereby changing the parameters in Generator, which is obviously not what we want.

In this way, we instantiate a neural network model model list possessing Generator, Discriminator and Discriminator_tempo and containing input transform.

3.4 Parameter and Hyperparameter Setting

We need to specify problem-related parameters, such as dataset path, weight parameters for various losses, etc.

log_freq: 20
DATASET_PATH: ./datasets/tempoGAN/2d_train.mat
DATASET_PATH_VALID: ./datasets/tempoGAN/2d_valid.mat

# set working condition
USE_AMP: true
USE_SPATIALDISC: true
USE_TEMPODISC: true
WEIGHT_GEN: [5.0, 0.0, 1.0]  # lambda_l1, lambda_l2, lambda_t
WEIGHT_GEN_LAYER: [-1.0e-5, -1.0e-5, -1.0e-5, -1.0e-5, -1.0e-5]
WEIGHT_DISC: 1.0

Note that it contains 3 bool type variables use_amp, use_spatialdisc and use_tempodisc, which respectively represent whether to use mixed precision training (AMP), whether to use Discriminator and whether to use Discriminator_tempo. When both use_spatialdisc and use_tempodisc are set to False, the network structure of this problem will become a pure Generator model and is no longer a GAN network.

At the same time, hyperparameters such as training epochs and learning rate need to be specified. Note that since the GAN network training process is different from general single-model networks, the setting of EPOCHS is also different.

TRAIN:
  epochs: 40000
  epochs_gen: 1
  epochs_disc: 1

3.5 Optimizer Construction

The training uses the Adam optimizer, and the learning rate is reduced to \(1/20\) of the original when Epoch reaches half, so the Step method is used as the learning rate strategy. If by_epoch is set to True, the learning rate will change according to the training Epoch, otherwise it will change according to Iteration.

# initialize Adam optimizer
lr_scheduler_gen = ppsci.optimizer.lr_scheduler.Step(
    step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
)()
optimizer_gen = ppsci.optimizer.Adam(lr_scheduler_gen)(model_gen)
if cfg.USE_SPATIALDISC:
    lr_scheduler_disc = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_disc = ppsci.optimizer.Adam(lr_scheduler_disc)(model_disc)
if cfg.USE_TEMPODISC:
    lr_scheduler_disc_tempo = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_disc_tempo = ppsci.optimizer.Adam(lr_scheduler_disc_tempo)(
        (model_disc_tempo,)
    )

3.6 Constraint Construction

This problem adopts unsupervised learning method. Although it is not trained in a supervised learning manner, supervised constraint SupervisedConstraint can still be used here. Before defining constraints, data reading configurations such as file path need to be specified for supervised constraints. Since tempoGAN belongs to self-supervised learning, there is no label data in the dataset, but a part of input data is used as label, so output_expr of the constraint needs to be set.

{
    "output_gen": lambda out: out["output_gen"],
    "density_high": lambda out: out["density_high"],
},

3.6.1 Constraints of Generator

The following is the specific content of the constraint, note the output_expr mentioned above:

sup_constraint_gen = ppsci.constraint.SupervisedConstraint(
    {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_train["density_low"],
                "density_high": dataset_train["density_high"],
            },
            "transforms": (
                {
                    "FunctionalTransform": {
                        "transform_func": data_funcs.transform,
                    },
                },
            ),
        },
        "batch_size": cfg.TRAIN.batch_size.sup_constraint,
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
    },
    ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
    {
        "output_gen": lambda out: out["output_gen"],
        "density_high": lambda out: out["density_high"],
    },
    name="sup_constraint_gen",
)

The first parameter of SupervisedConstraint is the reading configuration of supervised constraint, where dataset field represents the training dataset information used, and each field respectively represents:

  1. name: Dataset type, here NamedArrayDataset represents dataset of .mat type read from Array;
  2. input: Input data of Array type;
  3. label: Label data of Array type;
  4. transforms: All data transform methods, here FunctionalTransform is a custom data transform class reserved by PaddleScience, which supports custom transform of input data when writing code. For specific code, please refer to Custom loss and data transform;

batch_size field represents the size of batch;

sampler field represents sampling method, where each field represents:

  1. name: Sampler type, here BatchSampler represents batch sampler;
  2. drop_last: Whether to discard the last samples that cannot make up a mini-batch, default is False;
  3. shuffle: Whether to shuffle the order when generating sample subscripts, default is False;

The second parameter is the loss function. Here FunctionalLoss is a custom loss function class reserved by PaddleScience, which supports custom loss calculation method when writing code, rather than using existing methods such as MSE. For specific code, please refer to Custom loss and data transform.

The third parameter is output_expr of the constraint condition. As mentioned above, it is to allow the program to use input data as label.

The fourth parameter is the name of the constraint condition. We need to name each constraint condition for subsequent indexing.

After the constraints are constructed, encapsulate them into a dictionary with the name we just named as the keyword for subsequent access. Since use_spatialdisc and use_tempodisc are set in this problem, some constraints of Generator may not exist, so first encapsulate the certainly existing constraints into the dictionary, and when other constraints exist, add constraint elements to the dictionary.

if cfg.USE_TEMPODISC:
    sup_constraint_gen_tempo = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low_tempo"],
                    "density_high": dataset_train["density_high_tempo"],
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen_tempo),
        {
            "output_gen": lambda out: out["output_gen"],
            "density_high": lambda out: out["density_high"],
        },
        name="sup_constraint_gen_tempo",
    )
    constraint_gen[sup_constraint_gen_tempo.name] = sup_constraint_gen_tempo

3.6.2 Constraints of Discriminator

if cfg.USE_SPATIALDISC:
    sup_constraint_disc = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low"],
                    "density_high": dataset_train["density_high"],
                },
                "label": {
                    "out_disc_from_target": np.ones(
                        (np.shape(dataset_train["density_high"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                    "out_disc_from_gen": np.ones(
                        (np.shape(dataset_train["density_high"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": cfg.TRAIN.batch_size.sup_constraint,
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(disc_funcs.loss_func),
        name="sup_constraint_disc",
    )
    constraint_disc = {sup_constraint_disc.name: sup_constraint_disc}

The meaning of each parameter is the same as Constraints of Generator.

3.6.3 Constraints of Discriminator_tempo

if cfg.USE_TEMPODISC:
    sup_constraint_disc_tempo = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low_tempo"],
                    "density_high": dataset_train["density_high_tempo"],
                },
                "label": {
                    "out_disc_tempo_from_target": np.ones(
                        (np.shape(dataset_train["density_high_tempo"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                    "out_disc_tempo_from_gen": np.ones(
                        (np.shape(dataset_train["density_high_tempo"])[0], 1),
                        dtype=paddle.get_default_dtype(),
                    ),
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(disc_funcs.loss_func_tempo),
        name="sup_constraint_disc_tempo",
    )
    constraint_disc_tempo = {
        sup_constraint_disc_tempo.name: sup_constraint_disc_tempo
    }

The meaning of each parameter is the same as Constraints of Generator.

3.7 Visualizer Construction

Because of the characteristics of GAN network training, this problem does not use the built-in visualizer in PaddleScience, but customizes a function for implementing inference. This function reads validation set data, obtains inference results and saves the results in image form. Calling this function at regular intervals during the training process can monitor the training effect during the training process.

def predict_and_save_plot(
    output_dir: str,
    epoch_id: int,
    solver_gen: ppsci.solver.Solver,
    dataset_valid: np.ndarray,
    tile_ratio: int = 1,
):
    """Predicting and plotting.

    Args:
        output_dir (str): Output dir path.
        epoch_id (int): Which epoch it is.
        solver_gen (ppsci.solver.Solver): Solver for predicting.
        dataset_valid (np.ndarray): Valid dataset.
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
    """
    dir_pred = "predict/"
    os.makedirs(os.path.join(output_dir, dir_pred), exist_ok=True)

    start_idx = 190
    density_low = dataset_valid["density_low"][start_idx : start_idx + 3]
    density_high = dataset_valid["density_high"][start_idx : start_idx + 3]

    # tile
    density_low = (
        split_data(density_low, tile_ratio) if tile_ratio != 1 else density_low
    )
    density_high = (
        split_data(density_high, tile_ratio) if tile_ratio != 1 else density_high
    )

    pred_dict = solver_gen.predict(
        {
            "density_low": density_low,
            "density_high": density_high,
        },
        {"density_high": lambda out: out["output_gen"]},
        batch_size=tile_ratio * tile_ratio if tile_ratio != 1 else 3,
        no_grad=False,
    )
    if epoch_id == 1:
        # plot interpolated input image
        input_img = np.expand_dims(dataset_valid["density_low"][start_idx], axis=0)
        input_img = paddle.to_tensor(input_img, dtype=paddle.get_default_dtype())
        input_img = F.interpolate(
            input_img,
            [input_img.shape[-2] * 4, input_img.shape[-1] * 4],
            mode="nearest",
        ).numpy()
        Img.imsave(
            os.path.join(output_dir, dir_pred, "input.png"),
            np.squeeze(input_img),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
        # plot target image
        Img.imsave(
            os.path.join(output_dir, dir_pred, "target.png"),
            np.squeeze(dataset_valid["density_high"][start_idx]),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
    # plot pred image
    pred_img = (
        concat_data(pred_dict["density_high"].numpy(), tile_ratio)
        if tile_ratio != 1
        else np.squeeze(pred_dict["density_high"][0].numpy())
    )
    Img.imsave(
        os.path.join(output_dir, dir_pred, f"pred_epoch_{str(epoch_id)}.png"),
        pred_img,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )

3.8 Custom loss and data transform

Since this problem adopts unsupervised learning and there is no label data in the data, loss is calculated, so loss needs to be customized. The method is to define relevant functions first, and then pass the function name as a parameter to FunctionalLoss. It should be noted that the input and output parameters of the custom loss function need to be consistent with other functions such as MSE in PaddleScience, that is, the input is dictionary variables such as model output output_dict, and the output is loss value paddle.Tensor.

3.8.1 Loss of Generator

The loss of Generator provides l1 loss, l2 loss, loss judged by Discriminator on output and loss judged by Discriminator_tempo on output. Whether these losses exist is controlled according to weight parameters. If the weight parameter of a certain loss item is 0, it means that the loss item is not added during training.

def loss_func_gen(self, output_dict: Dict, *args) -> paddle.Tensor:
    """Calculate loss of generator when use spatial discriminator.
        The loss consists of l1 loss, l2 loss and layer loss when use spatial discriminator.
        Notice that all item of loss is optional because weight of them might be 0.

    Args:
        output_dict (Dict): output dict of model.

    Returns:
        paddle.Tensor: Loss of generator.
    """
    # l1 loss
    loss_l1 = F.l1_loss(
        output_dict["output_gen"], output_dict["density_high"], "mean"
    )
    losses = loss_l1 * self.weight_gen[0]

    # l2 loss
    loss_l2 = F.mse_loss(
        output_dict["output_gen"], output_dict["density_high"], "mean"
    )
    losses += loss_l2 * self.weight_gen[1]

    if self.weight_gen_layer is not None:
        # disc(generator_out) loss
        out_disc_from_gen = output_dict["out_disc_from_gen"][-1]
        label_ones = paddle.ones_like(out_disc_from_gen)
        loss_gen = F.binary_cross_entropy_with_logits(
            out_disc_from_gen, label_ones, reduction="mean"
        )
        losses += loss_gen

        # layer loss
        key_list = list(output_dict.keys())
        # ["out0_layer0","out0_layer1","out0_layer2","out0_layer3","out_disc_from_target",
        # "out1_layer0","out1_layer1","out1_layer2","out1_layer3","out_disc_from_gen"]
        loss_layer = 0
        for i in range(1, len(self.weight_gen_layer)):
            # i = 0,1,2,3
            loss_layer += (
                self.weight_gen_layer[i]
                * F.mse_loss(
                    output_dict[key_list[i]],
                    output_dict[key_list[5 + i]],
                    reduction="sum",
                )
                / 2
            )
        losses += loss_layer * self.weight_gen_layer[0]

    return {"output_gen": losses}

def loss_func_gen_tempo(self, output_dict: Dict, *args) -> paddle.Tensor:
    """Calculate loss of generator when use temporal discriminator.
        The loss is cross entropy loss when use temporal discriminator.

    Args:
        output_dict (Dict): output dict of model.

    Returns:
        paddle.Tensor: Loss of generator.
    """
    out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"][-1]
    label_t_ones = paddle.ones_like(out_disc_tempo_from_gen)

    loss_gen_t = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_gen, label_t_ones, reduction="mean"
    )
    losses = loss_gen_t * self.weight_gen[2]
    return {"out_disc_tempo_from_gen": losses}

3.8.2 Loss of Discriminator

Discriminator is a discriminator, its function is to judge whether data is true data or false data, so its loss is the loss generated by judging data generated by Generator as false and the loss generated by judging target value data as true.

def loss_func(self, output_dict, *args):
    out_disc_from_target = output_dict["out_disc_from_target"]
    out_disc_from_gen = output_dict["out_disc_from_gen"]

    label_ones = paddle.ones_like(out_disc_from_target)
    label_zeros = paddle.zeros_like(out_disc_from_gen)

    loss_disc_from_target = F.binary_cross_entropy_with_logits(
        out_disc_from_target, label_ones, reduction="mean"
    )
    loss_disc_from_gen = F.binary_cross_entropy_with_logits(
        out_disc_from_gen, label_zeros, reduction="mean"
    )
    losses = loss_disc_from_target * self.weight_disc + loss_disc_from_gen
    return {"CE_loss": losses}

3.8.3 Loss of Discriminator_tempo

The loss composition of Discriminator_tempo is the same as Discriminator, only the required data is different.

def loss_func_tempo(self, output_dict, *args):
    out_disc_tempo_from_target = output_dict["out_disc_tempo_from_target"]
    out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"]

    label_ones = paddle.ones_like(out_disc_tempo_from_target)
    label_zeros = paddle.zeros_like(out_disc_tempo_from_gen)

    loss_disc_tempo_from_target = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_target, label_ones, reduction="mean"
    )
    loss_disc_tempo_from_gen = F.binary_cross_entropy_with_logits(
        out_disc_tempo_from_gen, label_zeros, reduction="mean"
    )
    losses = (
        loss_disc_tempo_from_target * self.weight_disc + loss_disc_tempo_from_gen
    )
    return {"CE_tempo_loss": losses}

3.8.4 Custom data transform

This problem provides an input data processing method, randomly cropping a piece of input fluid density data, then judging the density value. If the density value of the cropped block is lower than the threshold, it is re-cropped until the density meets the condition or the number of cropping times reaches the threshold. This is mainly done to reduce the video memory required for training, and the judgment of the block density value ensures the richness of information in the block. In Parameter and Hyperparameter Setting, tile_ratio indicates how many times the original size is compared to the block size, that is, if tile_ratio is 2, the size of the cropped block is one-fourth of the entire original image.

class DataFuncs:
    """All functions used for data transform.

    Args:
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
        density_min (float, optional): Minimize density of one tile. Defaults to 0.02.
        max_turn (int, optional): Maximize turn of taking a tile from one image. Defaults to 20.
    """

    def __init__(
        self, tile_ratio: int = 1, density_min: float = 0.02, max_turn: int = 20
    ) -> None:
        self.tile_ratio = tile_ratio
        self.density_min = density_min
        self.max_turn = max_turn

    def transform(
        self,
        input_item: Dict[str, np.ndarray],
        label_item: Dict[str, np.ndarray],
        weight_item: Dict[str, np.ndarray],
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        if self.tile_ratio == 1:
            return input_item, label_item, weight_item
        for _ in range(self.max_turn):
            rand_ratio = np.random.rand()
            density_low = self.cut_data(input_item["density_low"], rand_ratio)
            density_high = self.cut_data(input_item["density_high"], rand_ratio)
            if self.is_valid_tile(density_low):
                break

        input_item["density_low"] = density_low
        input_item["density_high"] = density_high
        return input_item, label_item, weight_item

    def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor:
        # data: C,H,W
        _, H, W = data.shape
        if H % self.tile_ratio != 0 or W % self.tile_ratio != 0:
            exit(
                f"ERROR: input images cannot be divided into {self.tile_ratio} parts evenly!"
            )
        tile_shape = [H // self.tile_ratio, W // self.tile_ratio]
        rand_shape = np.floor(rand_ratio * (np.array([H, W]) - np.array(tile_shape)))
        start = [int(rand_shape[0]), int(rand_shape[1])]
        end = [int(rand_shape[0] + tile_shape[0]), int(rand_shape[1] + tile_shape[1])]
        data = paddle.slice(
            paddle.to_tensor(data), axes=[-2, -1], starts=start, ends=end
        )

        return data

    def is_valid_tile(self, tile: paddle.Tensor):
        img_density = tile[0].sum()
        return img_density >= (
            self.density_min * tile.shape[0] * tile.shape[1] * tile.shape[2]
        )

Note that the code here only provides the idea of data transform. The simple block method in the current code will obviously affect the training effect due to less information contained in the input. Therefore, in this problem, when the video memory is sufficient, tile_ratio should be set to 1. When the video memory is insufficient, it is also recommended to prioritize using mixed precision training to reduce current memory usage.

3.9 Model Training

After completing the above settings, first pass the above instantiated objects to ppsci.solver.Solver in order, and then start training.

solver_gen = ppsci.solver.Solver(
    model_list,
    constraint_gen,
    cfg.output_dir,
    optimizer_gen,
    lr_scheduler_gen,
    cfg.TRAIN.epochs_gen,
    cfg.TRAIN.iters_per_epoch,
    eval_during_train=cfg.TRAIN.eval_during_train,
    use_amp=cfg.USE_AMP,
    amp_level=cfg.TRAIN.amp_level,
)

Note that the GAN type network training method is alternating training of multiple models, which is different from single model or multi-model phased training, and solver.train API cannot be simply used. For specific code, please refer to tempoGAN.py file in Complete Code.

3.10 Model Evaluation

3.10.1 Evaluation during training

During training, only target results and model output results of specific images are saved at specific Epoch, and an evaluation is performed on the output result of the last Epoch after training ends, so as to intuitively evaluate the model optimization effect. Do not use the built-in evaluator in PaddleScience, nor evaluate during the training process:

for i in range(1, cfg.TRAIN.epochs + 1):
    logger.message(f"\nEpoch: {i}\n")
    # plotting during training
    if i == 1 or i % PRED_INTERVAL == 0 or i == cfg.TRAIN.epochs:
        func_module.predict_and_save_plot(
            cfg.output_dir, i, solver_gen, dataset_valid, cfg.TILE_RATIO
        )
############### evaluation for training ###############
img_target = (
    func_module.get_image_array(
        os.path.join(cfg.output_dir, "predict", "target.png")
    )
    / 255.0
)
img_pred = (
    func_module.get_image_array(
        os.path.join(
            cfg.output_dir, "predict", f"pred_epoch_{cfg.TRAIN.epochs}.png"
        )
    )
    / 255.0
)
eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(img_target, img_pred)
logger.message(f"MSE: {eval_mse}, PSNR: {eval_psnr}, SSIM: {eval_ssim}")

For specific code, please refer to tempoGAN.py file in Complete Code.

3.10.2 Evaluation in eval

The evaluation metric for this problem is to compare the super-resolution result output by the model with the actual high-resolution image, and use three indicators MSE (Mean-Square Error), PSNR (Peak Signal-to-Noise Ratio), and SSIM (Structural SIMilarity) to evaluate image similarity. Therefore, the built-in evaluator in PaddleScience is not used, nor is there a Solver.eval() process.

def evaluate(cfg: DictConfig):
    if cfg.EVAL.save_outs:
        from matplotlib import image as Img

        os.makedirs(osp.join(cfg.output_dir, "eval_outs"), exist_ok=True)

    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)

    # load dataset
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    save_load.load_pretrain(model_list, cfg.EVAL.pretrained_model_path)

    # set validator
    eval_dataloader_cfg = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_valid["density_low"],
            },
            "label": {"density_high": dataset_valid["density_high"]},
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": 1,
    }
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss("mean"),
        {"density_high": lambda out: out["output_gen"]},
        metric={"metric": ppsci.metric.L2Rel()},
        name="sup_validator_gen",
    )

    # customized evaluation
    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    eval_mse_list = []
    eval_psnr_list = []
    eval_ssim_list = []
    for i, (input, label, _) in enumerate(sup_validator.data_loader):
        output_dict = model_list({"density_low": input["density_low"]})
        output_arr = scale(np.squeeze(output_dict["output_gen"].numpy()))
        target_arr = scale(np.squeeze(label["density_high"].numpy()))

        eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(
            target_arr, output_arr
        )
        eval_mse_list.append(eval_mse)
        eval_psnr_list.append(eval_psnr)
        eval_ssim_list.append(eval_ssim)

        if cfg.EVAL.save_outs:
            Img.imsave(
                osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
                output_arr,
                vmin=0.0,
                vmax=1.0,
                cmap="gray",
            )
    logger.message(
        f"MSE: {np.mean(eval_mse_list)}, PSNR: {np.mean(eval_psnr_list)}, SSIM: {np.mean(eval_ssim_list)}"
    )

In addition, where:

if cfg.EVAL.save_outs:
    Img.imsave(
        osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
        output_arr,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )

Provides the option to save the model output result, so as to see the result after super-resolution more intuitively. Whether to open is specified by save_outs in configuration file EVAL:

# evaluation settings
EVAL:
  pretrained_model_path: null

4. Complete Code

The complete code contains PaddleScience specific training process code tempoGAN.py and all custom function code functions.py. In addition, network structure code gan.py is added to ppsci.arch, which is shown below together. If you need to customize the network structure, you can use it as a reference.

tempoGAN.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os import path as osp

import functions as func_module
import hydra
import numpy as np
import paddle
from omegaconf import DictConfig

import ppsci
from ppsci.utils import checker
from ppsci.utils import logger
from ppsci.utils import save_load

if not checker.dynamic_import_to_globals("hdf5storage"):
    raise ImportError(
        "Could not import hdf5storage python package. "
        "Please install it with `pip install hdf5storage`."
    )
import hdf5storage


def train(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    gen_funcs = func_module.GenFuncs(
        cfg.WEIGHT_GEN, (cfg.WEIGHT_GEN_LAYER if cfg.USE_SPATIALDISC else None)
    )
    disc_funcs = func_module.DiscFuncs(cfg.WEIGHT_DISC)
    data_funcs = func_module.DataFuncs(cfg.TILE_RATIO)

    # load dataset
    logger.message(
        "Attention! Start loading datasets, this will take tens of seconds to several minutes, please wait patiently."
    )
    dataset_train = hdf5storage.loadmat(cfg.DATASET_PATH)
    logger.message("Finish loading training dataset.")
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)
    logger.message("Finish loading validation dataset.")

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)
    disc_funcs.model_gen = model_gen

    model_tuple = (model_gen,)
    # define Discriminators
    if cfg.USE_SPATIALDISC:
        model_disc = ppsci.arch.Discriminator(**cfg.MODEL.disc_net)
        model_disc.register_input_transform(disc_funcs.transform_in)
        model_tuple += (model_disc,)

    # define temporal Discriminators
    if cfg.USE_TEMPODISC:
        model_disc_tempo = ppsci.arch.Discriminator(**cfg.MODEL.tempo_net)
        model_disc_tempo.register_input_transform(disc_funcs.transform_in_tempo)
        model_tuple += (model_disc_tempo,)

    # define model_list
    model_list = ppsci.arch.ModelList(model_tuple)

    # initialize Adam optimizer
    lr_scheduler_gen = ppsci.optimizer.lr_scheduler.Step(
        step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
    )()
    optimizer_gen = ppsci.optimizer.Adam(lr_scheduler_gen)(model_gen)
    if cfg.USE_SPATIALDISC:
        lr_scheduler_disc = ppsci.optimizer.lr_scheduler.Step(
            step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
        )()
        optimizer_disc = ppsci.optimizer.Adam(lr_scheduler_disc)(model_disc)
    if cfg.USE_TEMPODISC:
        lr_scheduler_disc_tempo = ppsci.optimizer.lr_scheduler.Step(
            step_size=cfg.TRAIN.epochs // 2, **cfg.TRAIN.lr_scheduler
        )()
        optimizer_disc_tempo = ppsci.optimizer.Adam(lr_scheduler_disc_tempo)(
            (model_disc_tempo,)
        )

    # Generator
    # manually build constraint(s)
    sup_constraint_gen = ppsci.constraint.SupervisedConstraint(
        {
            "dataset": {
                "name": "NamedArrayDataset",
                "input": {
                    "density_low": dataset_train["density_low"],
                    "density_high": dataset_train["density_high"],
                },
                "transforms": (
                    {
                        "FunctionalTransform": {
                            "transform_func": data_funcs.transform,
                        },
                    },
                ),
            },
            "batch_size": cfg.TRAIN.batch_size.sup_constraint,
            "sampler": {
                "name": "BatchSampler",
                "drop_last": False,
                "shuffle": False,
            },
        },
        ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
        {
            "output_gen": lambda out: out["output_gen"],
            "density_high": lambda out: out["density_high"],
        },
        name="sup_constraint_gen",
    )
    constraint_gen = {sup_constraint_gen.name: sup_constraint_gen}
    if cfg.USE_TEMPODISC:
        sup_constraint_gen_tempo = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low_tempo"],
                        "density_high": dataset_train["density_high_tempo"],
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen_tempo),
            {
                "output_gen": lambda out: out["output_gen"],
                "density_high": lambda out: out["density_high"],
            },
            name="sup_constraint_gen_tempo",
        )
        constraint_gen[sup_constraint_gen_tempo.name] = sup_constraint_gen_tempo

    # Discriminators
    # manually build constraint(s)
    if cfg.USE_SPATIALDISC:
        sup_constraint_disc = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low"],
                        "density_high": dataset_train["density_high"],
                    },
                    "label": {
                        "out_disc_from_target": np.ones(
                            (np.shape(dataset_train["density_high"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                        "out_disc_from_gen": np.ones(
                            (np.shape(dataset_train["density_high"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": cfg.TRAIN.batch_size.sup_constraint,
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(disc_funcs.loss_func),
            name="sup_constraint_disc",
        )
        constraint_disc = {sup_constraint_disc.name: sup_constraint_disc}

    # temporal Discriminators
    # manually build constraint(s)
    if cfg.USE_TEMPODISC:
        sup_constraint_disc_tempo = ppsci.constraint.SupervisedConstraint(
            {
                "dataset": {
                    "name": "NamedArrayDataset",
                    "input": {
                        "density_low": dataset_train["density_low_tempo"],
                        "density_high": dataset_train["density_high_tempo"],
                    },
                    "label": {
                        "out_disc_tempo_from_target": np.ones(
                            (np.shape(dataset_train["density_high_tempo"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                        "out_disc_tempo_from_gen": np.ones(
                            (np.shape(dataset_train["density_high_tempo"])[0], 1),
                            dtype=paddle.get_default_dtype(),
                        ),
                    },
                    "transforms": (
                        {
                            "FunctionalTransform": {
                                "transform_func": data_funcs.transform,
                            },
                        },
                    ),
                },
                "batch_size": int(cfg.TRAIN.batch_size.sup_constraint // 3),
                "sampler": {
                    "name": "BatchSampler",
                    "drop_last": False,
                    "shuffle": False,
                },
            },
            ppsci.loss.FunctionalLoss(disc_funcs.loss_func_tempo),
            name="sup_constraint_disc_tempo",
        )
        constraint_disc_tempo = {
            sup_constraint_disc_tempo.name: sup_constraint_disc_tempo
        }

    # initialize solver
    solver_gen = ppsci.solver.Solver(
        model_list,
        constraint_gen,
        cfg.output_dir,
        optimizer_gen,
        lr_scheduler_gen,
        cfg.TRAIN.epochs_gen,
        cfg.TRAIN.iters_per_epoch,
        eval_during_train=cfg.TRAIN.eval_during_train,
        use_amp=cfg.USE_AMP,
        amp_level=cfg.TRAIN.amp_level,
    )
    if cfg.USE_SPATIALDISC:
        solver_disc = ppsci.solver.Solver(
            model_list,
            constraint_disc,
            cfg.output_dir,
            optimizer_disc,
            lr_scheduler_disc,
            cfg.TRAIN.epochs_disc,
            cfg.TRAIN.iters_per_epoch,
            eval_during_train=cfg.TRAIN.eval_during_train,
            use_amp=cfg.USE_AMP,
            amp_level=cfg.TRAIN.amp_level,
        )
    if cfg.USE_TEMPODISC:
        solver_disc_tempo = ppsci.solver.Solver(
            model_list,
            constraint_disc_tempo,
            cfg.output_dir,
            optimizer_disc_tempo,
            lr_scheduler_disc_tempo,
            cfg.TRAIN.epochs_disc_tempo,
            cfg.TRAIN.iters_per_epoch,
            eval_during_train=cfg.TRAIN.eval_during_train,
            use_amp=cfg.USE_AMP,
            amp_level=cfg.TRAIN.amp_level,
        )

    PRED_INTERVAL = 200
    for i in range(1, cfg.TRAIN.epochs + 1):
        logger.message(f"\nEpoch: {i}\n")
        # plotting during training
        if i == 1 or i % PRED_INTERVAL == 0 or i == cfg.TRAIN.epochs:
            func_module.predict_and_save_plot(
                cfg.output_dir, i, solver_gen, dataset_valid, cfg.TILE_RATIO
            )

        disc_funcs.model_gen = model_gen
        # train disc, input: (x,y,G(x))
        if cfg.USE_SPATIALDISC:
            solver_disc.train()

        # train disc tempo, input: (y_3,G(x)_3)
        if cfg.USE_TEMPODISC:
            solver_disc_tempo.train()

        # train gen, input: (x,)
        solver_gen.train()

    ############### evaluation for training ###############
    img_target = (
        func_module.get_image_array(
            os.path.join(cfg.output_dir, "predict", "target.png")
        )
        / 255.0
    )
    img_pred = (
        func_module.get_image_array(
            os.path.join(
                cfg.output_dir, "predict", f"pred_epoch_{cfg.TRAIN.epochs}.png"
            )
        )
        / 255.0
    )
    eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(img_target, img_pred)
    logger.message(f"MSE: {eval_mse}, PSNR: {eval_psnr}, SSIM: {eval_ssim}")


def evaluate(cfg: DictConfig):
    if cfg.EVAL.save_outs:
        from matplotlib import image as Img

        os.makedirs(osp.join(cfg.output_dir, "eval_outs"), exist_ok=True)

    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)

    # load dataset
    dataset_valid = hdf5storage.loadmat(cfg.DATASET_PATH_VALID)

    # define Generator model
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    save_load.load_pretrain(model_list, cfg.EVAL.pretrained_model_path)

    # set validator
    eval_dataloader_cfg = {
        "dataset": {
            "name": "NamedArrayDataset",
            "input": {
                "density_low": dataset_valid["density_low"],
            },
            "label": {"density_high": dataset_valid["density_high"]},
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": 1,
    }
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss("mean"),
        {"density_high": lambda out: out["output_gen"]},
        metric={"metric": ppsci.metric.L2Rel()},
        name="sup_validator_gen",
    )

    # customized evaluation
    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    eval_mse_list = []
    eval_psnr_list = []
    eval_ssim_list = []
    for i, (input, label, _) in enumerate(sup_validator.data_loader):
        output_dict = model_list({"density_low": input["density_low"]})
        output_arr = scale(np.squeeze(output_dict["output_gen"].numpy()))
        target_arr = scale(np.squeeze(label["density_high"].numpy()))

        eval_mse, eval_psnr, eval_ssim = func_module.evaluate_img(
            target_arr, output_arr
        )
        eval_mse_list.append(eval_mse)
        eval_psnr_list.append(eval_psnr)
        eval_ssim_list.append(eval_ssim)

        if cfg.EVAL.save_outs:
            Img.imsave(
                osp.join(cfg.output_dir, "eval_outs", f"out_{i}.png"),
                output_arr,
                vmin=0.0,
                vmax=1.0,
                cmap="gray",
            )
    logger.message(
        f"MSE: {np.mean(eval_mse_list)}, PSNR: {np.mean(eval_psnr_list)}, SSIM: {np.mean(eval_ssim_list)}"
    )


def export(cfg: DictConfig):
    from paddle.static import InputSpec

    # set models
    gen_funcs = func_module.GenFuncs(cfg.WEIGHT_GEN, None)
    model_gen = ppsci.arch.Generator(**cfg.MODEL.gen_net)
    model_gen.register_input_transform(gen_funcs.transform_in)

    # define model_list
    model_list = ppsci.arch.ModelList((model_gen,))

    # load pretrained model
    solver = ppsci.solver.Solver(
        model=model_list, pretrained_model_path=cfg.INFER.pretrained_model_path
    )

    # export models
    input_spec = [
        {"density_low": InputSpec([None, 1, 128, 128], "float32", name="density_low")},
    ]
    solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True)


def inference(cfg: DictConfig):
    from matplotlib import image as Img

    from deploy.python_infer import pinn_predictor

    # set model predictor
    predictor = pinn_predictor.PINNPredictor(cfg)

    # load dataset
    dataset_infer = {
        "density_low": hdf5storage.loadmat(cfg.DATASET_PATH_VALID)["density_low"]
    }

    output_dict = predictor.predict(dataset_infer, cfg.INFER.batch_size)

    # mapping data to cfg.INFER.output_keys
    output = [output_dict[key] for key in output_dict]

    def scale(data):
        smax = np.max(data)
        smin = np.min(data)
        return (data - smin) / (smax - smin)

    for i, img in enumerate(output[0]):
        img = scale(np.squeeze(img))
        Img.imsave(
            osp.join(cfg.output_dir, f"out_{i}.png"),
            img,
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )


@hydra.main(version_base=None, config_path="./conf", config_name="tempogan.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()
functions.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict
from typing import List
from typing import Tuple

import numpy as np
import paddle
import paddle.nn.functional as F
from matplotlib import image as Img
from PIL import Image
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity

import ppsci
from ppsci.utils import logger


# train
def interpolate(
    data: paddle.Tensor, ratio: int, mode: str = "nearest"
) -> paddle.Tensor:
    """Interpolate twice.

    Args:
        data (paddle.Tensor): The data to be interpolated.
        ratio (int): Ratio of one interpolation.
        mode (str, optional): Interpolation method. Defaults to "nearest".

    Returns:
        paddle.Tensor: Data interpolated.
    """
    for _ in range(2):
        data = F.interpolate(
            data,
            [data.shape[-2] * ratio, data.shape[-1] * ratio],
            mode=mode,
        )
    return data


def reshape_input(input_dict: Dict[str, paddle.Tensor]) -> Dict[str, paddle.Tensor]:
    """Reshape input data for temporally Discriminator. Reshape data from N, C, W, H to N * C, 1, H, W.
        Which will merge N dimension and C dimension to 1 dimension but still keep 4 dimensions
        to ensure the data can be used for training.

    Args:
        input_dict (Dict[str, paddle.Tensor]): input data dict.

    Returns:
        Dict[str, paddle.Tensor]: reshaped data dict.
    """
    out_dict = {}
    for key in input_dict:
        input = input_dict[key]
        N, C, H, W = input.shape
        out_dict[key] = paddle.reshape(input, [N * C, 1, H, W])
    return out_dict


def dereshape_input(
    input_dict: Dict[str, paddle.Tensor], C: int
) -> Dict[str, paddle.Tensor]:
    """Dereshape input data for temporally Discriminator. Deeshape data from 1, N * C, H, W to N, C, W, H.

    Args:
        input_dict (Dict[str, paddle.Tensor]): input data dict.
        C (int): Channel of dereshape.

    Returns:
        Dict[str, paddle.Tensor]: dereshaped data dict.
    """
    for key in input_dict:
        input = input_dict[key]
        _, N, H, W = input.shape
        if N < C:
            logger.warning(
                f"batch_size is smaller than {C}! Tempo needs at least {C} frames, input will be copied."
            )
            input_dict[key] = paddle.concat([input[:1]] * C, axis=1)
        else:
            N_new = int(N // C)
            input_dict[key] = paddle.reshape(input[: N_new * C], [-1, C, H, W])
    return input_dict


# predict
def split_data(data: np.ndarray, tile_ratio: int) -> np.ndarray:
    """Split a numpy image to tiles equally.

    Args:
        data (np.ndarray): The image to be Split.
        tile_ratio (int): How many tiles of one dim.
            Number of result tiles is tile_ratio * tile_ratio for a 2d image.

    Returns:
        np.ndarray: Tiles in [N,C,H,W] shape.
    """
    _, _, h, w = data.shape
    tile_h, tile_w = h // tile_ratio, w // tile_ratio
    tiles = []
    for i in range(tile_ratio):
        for j in range(tile_ratio):
            tiles.append(
                data[
                    :1,
                    :,
                    i * tile_h : i * tile_h + tile_h,
                    j * tile_w : j * tile_w + tile_w,
                ],
            )
    return np.concatenate(tiles, axis=0)


def concat_data(data: np.ndarray, tile_ratio: int) -> np.ndarray:
    """Concat numpy tiles to a image equally.

    Args:
        data (np.ndarray): The tiles to be upsplited.
        tile_ratio (int): How many tiles of one dim.
            Number of input tiles is tile_ratio * tile_ratio for 2d result.

    Returns:
        np.ndarray: Image in [H,W] shape.
    """
    _, _, tile_h, tile_w = data.shape
    h, w = tile_h * tile_ratio, tile_w * tile_ratio
    data_whole = np.ones([h, w], dtype=paddle.get_default_dtype())
    tile_idx = 0
    for i in range(tile_ratio):
        for j in range(tile_ratio):
            data_whole[
                i * tile_h : i * tile_h + tile_h,
                j * tile_w : j * tile_w + tile_w,
            ] = data[tile_idx][0]
            tile_idx += 1
    return data_whole


def predict_and_save_plot(
    output_dir: str,
    epoch_id: int,
    solver_gen: ppsci.solver.Solver,
    dataset_valid: np.ndarray,
    tile_ratio: int = 1,
):
    """Predicting and plotting.

    Args:
        output_dir (str): Output dir path.
        epoch_id (int): Which epoch it is.
        solver_gen (ppsci.solver.Solver): Solver for predicting.
        dataset_valid (np.ndarray): Valid dataset.
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
    """
    dir_pred = "predict/"
    os.makedirs(os.path.join(output_dir, dir_pred), exist_ok=True)

    start_idx = 190
    density_low = dataset_valid["density_low"][start_idx : start_idx + 3]
    density_high = dataset_valid["density_high"][start_idx : start_idx + 3]

    # tile
    density_low = (
        split_data(density_low, tile_ratio) if tile_ratio != 1 else density_low
    )
    density_high = (
        split_data(density_high, tile_ratio) if tile_ratio != 1 else density_high
    )

    pred_dict = solver_gen.predict(
        {
            "density_low": density_low,
            "density_high": density_high,
        },
        {"density_high": lambda out: out["output_gen"]},
        batch_size=tile_ratio * tile_ratio if tile_ratio != 1 else 3,
        no_grad=False,
    )
    if epoch_id == 1:
        # plot interpolated input image
        input_img = np.expand_dims(dataset_valid["density_low"][start_idx], axis=0)
        input_img = paddle.to_tensor(input_img, dtype=paddle.get_default_dtype())
        input_img = F.interpolate(
            input_img,
            [input_img.shape[-2] * 4, input_img.shape[-1] * 4],
            mode="nearest",
        ).numpy()
        Img.imsave(
            os.path.join(output_dir, dir_pred, "input.png"),
            np.squeeze(input_img),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
        # plot target image
        Img.imsave(
            os.path.join(output_dir, dir_pred, "target.png"),
            np.squeeze(dataset_valid["density_high"][start_idx]),
            vmin=0.0,
            vmax=1.0,
            cmap="gray",
        )
    # plot pred image
    pred_img = (
        concat_data(pred_dict["density_high"].numpy(), tile_ratio)
        if tile_ratio != 1
        else np.squeeze(pred_dict["density_high"][0].numpy())
    )
    Img.imsave(
        os.path.join(output_dir, dir_pred, f"pred_epoch_{str(epoch_id)}.png"),
        pred_img,
        vmin=0.0,
        vmax=1.0,
        cmap="gray",
    )


# evaluation
def evaluate_img(
    img_target: np.ndarray, img_pred: np.ndarray
) -> Tuple[float, float, float]:
    """Evaluate two images.

    Args:
        img_target (np.ndarray): Target image.
        img_pred (np.ndarray): Image generated by prediction.

    Returns:
        Tuple[float, float, float]: MSE, PSNR, SSIM.
    """
    eval_mse = mean_squared_error(img_target, img_pred)
    eval_psnr = peak_signal_noise_ratio(img_target, img_pred)
    eval_ssim = structural_similarity(img_target, img_pred, data_range=1.0)
    return eval_mse, eval_psnr, eval_ssim


def get_image_array(img_path):
    return np.array(Image.open(img_path).convert("L"))


class GenFuncs:
    """All functions used for Generator, including functions of transform and loss.

    Args:
        weight_gen (List[float]): Weights of L1 loss.
        weight_gen_layer (List[float], optional): Weights of layers loss. Defaults to None.
    """

    def __init__(
        self, weight_gen: List[float], weight_gen_layer: List[float] = None
    ) -> None:
        self.weight_gen = weight_gen
        self.weight_gen_layer = weight_gen_layer

    def transform_in(self, _in):
        ratio = 2
        input_dict = reshape_input(_in)
        density_low = input_dict["density_low"]
        density_low_inp = interpolate(density_low, ratio, "nearest")
        return {"input_gen": density_low_inp}

    def loss_func_gen(self, output_dict: Dict, *args) -> paddle.Tensor:
        """Calculate loss of generator when use spatial discriminator.
            The loss consists of l1 loss, l2 loss and layer loss when use spatial discriminator.
            Notice that all item of loss is optional because weight of them might be 0.

        Args:
            output_dict (Dict): output dict of model.

        Returns:
            paddle.Tensor: Loss of generator.
        """
        # l1 loss
        loss_l1 = F.l1_loss(
            output_dict["output_gen"], output_dict["density_high"], "mean"
        )
        losses = loss_l1 * self.weight_gen[0]

        # l2 loss
        loss_l2 = F.mse_loss(
            output_dict["output_gen"], output_dict["density_high"], "mean"
        )
        losses += loss_l2 * self.weight_gen[1]

        if self.weight_gen_layer is not None:
            # disc(generator_out) loss
            out_disc_from_gen = output_dict["out_disc_from_gen"][-1]
            label_ones = paddle.ones_like(out_disc_from_gen)
            loss_gen = F.binary_cross_entropy_with_logits(
                out_disc_from_gen, label_ones, reduction="mean"
            )
            losses += loss_gen

            # layer loss
            key_list = list(output_dict.keys())
            # ["out0_layer0","out0_layer1","out0_layer2","out0_layer3","out_disc_from_target",
            # "out1_layer0","out1_layer1","out1_layer2","out1_layer3","out_disc_from_gen"]
            loss_layer = 0
            for i in range(1, len(self.weight_gen_layer)):
                # i = 0,1,2,3
                loss_layer += (
                    self.weight_gen_layer[i]
                    * F.mse_loss(
                        output_dict[key_list[i]],
                        output_dict[key_list[5 + i]],
                        reduction="sum",
                    )
                    / 2
                )
            losses += loss_layer * self.weight_gen_layer[0]

        return {"output_gen": losses}

    def loss_func_gen_tempo(self, output_dict: Dict, *args) -> paddle.Tensor:
        """Calculate loss of generator when use temporal discriminator.
            The loss is cross entropy loss when use temporal discriminator.

        Args:
            output_dict (Dict): output dict of model.

        Returns:
            paddle.Tensor: Loss of generator.
        """
        out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"][-1]
        label_t_ones = paddle.ones_like(out_disc_tempo_from_gen)

        loss_gen_t = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_gen, label_t_ones, reduction="mean"
        )
        losses = loss_gen_t * self.weight_gen[2]
        return {"out_disc_tempo_from_gen": losses}


class DiscFuncs:
    """All functions used for Discriminator and temporally Discriminator, including functions of transform and loss.

    Args:
        weight_disc (float): Weight of loss generated by the discriminator to judge the true target.
    """

    def __init__(self, weight_disc: float) -> None:
        self.weight_disc = weight_disc
        self.model_gen = None

    def transform_in(self, _in):
        ratio = 2
        input_dict = reshape_input(_in)
        density_low = input_dict["density_low"]
        density_high_from_target = input_dict["density_high"]

        density_low_inp = interpolate(density_low, ratio, "nearest")

        density_high_from_gen = self.model_gen(input_dict)["output_gen"]
        density_high_from_gen.stop_gradient = True

        density_input_from_target = paddle.concat(
            [density_low_inp, density_high_from_target], axis=1
        )
        density_input_from_gen = paddle.concat(
            [density_low_inp, density_high_from_gen], axis=1
        )
        return {
            "input_disc_from_target": density_input_from_target,
            "input_disc_from_gen": density_input_from_gen,
        }

    def transform_in_tempo(self, _in):
        density_high_from_target = _in["density_high"]

        input_dict = reshape_input(_in)
        density_high_from_gen = self.model_gen(input_dict)["output_gen"]
        density_high_from_gen.stop_gradient = True

        input_trans = {
            "input_tempo_disc_from_target": density_high_from_target,
            "input_tempo_disc_from_gen": density_high_from_gen,
        }

        return dereshape_input(input_trans, 3)

    def loss_func(self, output_dict, *args):
        out_disc_from_target = output_dict["out_disc_from_target"]
        out_disc_from_gen = output_dict["out_disc_from_gen"]

        label_ones = paddle.ones_like(out_disc_from_target)
        label_zeros = paddle.zeros_like(out_disc_from_gen)

        loss_disc_from_target = F.binary_cross_entropy_with_logits(
            out_disc_from_target, label_ones, reduction="mean"
        )
        loss_disc_from_gen = F.binary_cross_entropy_with_logits(
            out_disc_from_gen, label_zeros, reduction="mean"
        )
        losses = loss_disc_from_target * self.weight_disc + loss_disc_from_gen
        return {"CE_loss": losses}

    def loss_func_tempo(self, output_dict, *args):
        out_disc_tempo_from_target = output_dict["out_disc_tempo_from_target"]
        out_disc_tempo_from_gen = output_dict["out_disc_tempo_from_gen"]

        label_ones = paddle.ones_like(out_disc_tempo_from_target)
        label_zeros = paddle.zeros_like(out_disc_tempo_from_gen)

        loss_disc_tempo_from_target = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_target, label_ones, reduction="mean"
        )
        loss_disc_tempo_from_gen = F.binary_cross_entropy_with_logits(
            out_disc_tempo_from_gen, label_zeros, reduction="mean"
        )
        losses = (
            loss_disc_tempo_from_target * self.weight_disc + loss_disc_tempo_from_gen
        )
        return {"CE_tempo_loss": losses}


class DataFuncs:
    """All functions used for data transform.

    Args:
        tile_ratio (int, optional): How many tiles of one dim. Defaults to 1.
        density_min (float, optional): Minimize density of one tile. Defaults to 0.02.
        max_turn (int, optional): Maximize turn of taking a tile from one image. Defaults to 20.
    """

    def __init__(
        self, tile_ratio: int = 1, density_min: float = 0.02, max_turn: int = 20
    ) -> None:
        self.tile_ratio = tile_ratio
        self.density_min = density_min
        self.max_turn = max_turn

    def transform(
        self,
        input_item: Dict[str, np.ndarray],
        label_item: Dict[str, np.ndarray],
        weight_item: Dict[str, np.ndarray],
    ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
        if self.tile_ratio == 1:
            return input_item, label_item, weight_item
        for _ in range(self.max_turn):
            rand_ratio = np.random.rand()
            density_low = self.cut_data(input_item["density_low"], rand_ratio)
            density_high = self.cut_data(input_item["density_high"], rand_ratio)
            if self.is_valid_tile(density_low):
                break

        input_item["density_low"] = density_low
        input_item["density_high"] = density_high
        return input_item, label_item, weight_item

    def cut_data(self, data: np.ndarray, rand_ratio: float) -> paddle.Tensor:
        # data: C,H,W
        _, H, W = data.shape
        if H % self.tile_ratio != 0 or W % self.tile_ratio != 0:
            exit(
                f"ERROR: input images cannot be divided into {self.tile_ratio} parts evenly!"
            )
        tile_shape = [H // self.tile_ratio, W // self.tile_ratio]
        rand_shape = np.floor(rand_ratio * (np.array([H, W]) - np.array(tile_shape)))
        start = [int(rand_shape[0]), int(rand_shape[1])]
        end = [int(rand_shape[0] + tile_shape[0]), int(rand_shape[1] + tile_shape[1])]
        data = paddle.slice(
            paddle.to_tensor(data), axes=[-2, -1], starts=start, ends=end
        )

        return data

    def is_valid_tile(self, tile: paddle.Tensor):
        img_density = tile[0].sum()
        return img_density >= (
            self.density_min * tile.shape[0] * tile.shape[1] * tile.shape[2]
        )
gan.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Dict
from typing import List
from typing import Tuple

import paddle
import paddle.nn as nn

from ppsci.arch import activation as act_mod
from ppsci.arch import base


class Conv2DBlock(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        stride,
        use_bn,
        act,
        mean,
        std,
        value,
    ):
        super().__init__()
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=mean, std=std)
        )
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(value=value))
        self.conv_2d = nn.Conv2D(
            in_channel,
            out_channel,
            kernel_size,
            stride,
            padding="SAME",
            weight_attr=weight_attr,
            bias_attr=bias_attr,
        )
        self.bn = nn.BatchNorm2D(out_channel) if use_bn else None
        self.act = act_mod.get_activation(act) if act else None

    def forward(self, x):
        y = x
        y = self.conv_2d(y)
        if self.bn:
            y = self.bn(y)
        if self.act:
            y = self.act(y)
        return y


class VariantResBlock(nn.Layer):
    def __init__(
        self,
        in_channel,
        out_channels,
        kernel_sizes,
        strides,
        use_bns,
        acts,
        mean,
        std,
        value,
    ):
        super().__init__()
        self.conv_2d_0 = Conv2DBlock(
            in_channel=in_channel,
            out_channel=out_channels[0],
            kernel_size=kernel_sizes[0],
            stride=strides[0],
            use_bn=use_bns[0],
            act=acts[0],
            mean=mean,
            std=std,
            value=value,
        )
        self.conv_2d_1 = Conv2DBlock(
            in_channel=out_channels[0],
            out_channel=out_channels[1],
            kernel_size=kernel_sizes[1],
            stride=strides[1],
            use_bn=use_bns[1],
            act=acts[1],
            mean=mean,
            std=std,
            value=value,
        )

        self.conv_2d_2 = Conv2DBlock(
            in_channel=in_channel,
            out_channel=out_channels[2],
            kernel_size=kernel_sizes[2],
            stride=strides[2],
            use_bn=use_bns[2],
            act=acts[2],
            mean=mean,
            std=std,
            value=value,
        )

        self.act = act_mod.get_activation("relu")

    def forward(self, x):
        y = x
        y = self.conv_2d_0(y)
        y = self.conv_2d_1(y)
        short = self.conv_2d_2(x)
        y = paddle.add(y, short)
        y = self.act(y)
        return y


class FCBlock(nn.Layer):
    def __init__(self, in_channel, act, mean, std, value):
        super().__init__()
        self.flatten = nn.Flatten()
        weight_attr = paddle.ParamAttr(
            initializer=nn.initializer.Normal(mean=mean, std=std)
        )
        bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(value=value))
        self.linear = nn.Linear(
            in_channel,
            1,
            weight_attr=weight_attr,
            bias_attr=bias_attr,
        )
        self.act = act_mod.get_activation(act) if act else None

    def forward(self, x):
        y = x
        y = self.flatten(y)
        y = self.linear(y)
        if self.act:
            y = self.act(y)
        return y


class Generator(base.Arch):
    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
        unique to "tempoGAN" example but not an open source network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int): Number of input channels of the first conv layer.
        out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
            such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
        kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers,
            such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]
        strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers,
            such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]
        use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer.
        acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

    Examples:
        >>> import ppsci
        >>> in_channel = 1
        >>> rb_channel0 = (2, 8, 8)
        >>> rb_channel1 = (128, 128, 128)
        >>> rb_channel2 = (32, 8, 8)
        >>> rb_channel3 = (2, 1, 1)
        >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
        >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
        >>> strides_tuple = ((1, 1, 1), ) * 4
        >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
        >>> acts_tuple = (("relu", None, None), ) * 4
        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
        >>> batch_size = 4
        >>> height = 64
        >>> width = 64
        >>> input_data = paddle.randn([batch_size, in_channel, height, width])
        >>> input_dict = {'in': input_data}
        >>> output_data = model(input_dict)
        >>> print(output_data['out'].shape)
        [4, 1, 64, 64]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels_tuple: Tuple[Tuple[int, ...], ...],
        kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
        strides_tuple: Tuple[Tuple[int, ...], ...],
        use_bns_tuple: Tuple[Tuple[bool, ...], ...],
        acts_tuple: Tuple[Tuple[str, ...], ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels_tuple = out_channels_tuple
        self.kernel_sizes_tuple = kernel_sizes_tuple
        self.strides_tuple = strides_tuple
        self.use_bns_tuple = use_bns_tuple
        self.acts_tuple = acts_tuple

        self.init_blocks()

    def init_blocks(self):
        blocks_list = []
        for i in range(len(self.out_channels_tuple)):
            in_channel = (
                self.in_channel if i == 0 else self.out_channels_tuple[i - 1][-1]
            )
            blocks_list.append(
                VariantResBlock(
                    in_channel=in_channel,
                    out_channels=self.out_channels_tuple[i],
                    kernel_sizes=self.kernel_sizes_tuple[i],
                    strides=self.strides_tuple[i],
                    use_bns=self.use_bns_tuple[i],
                    acts=self.acts_tuple[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )
        self.blocks = nn.LayerList(blocks_list)

    def forward_tensor(self, x):
        y = x
        for block in self.blocks:
            y = block(y)
        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y


class Discriminator(base.Arch):
    """Discriminator Net of GAN.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int):  Number of input channels of the first conv layer.
        out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
            such as (out_conv0, out_conv1, out_conv2).
        fc_channel (int):  Number of input features of linear layer. Number of output features of the layer
            is set to 1 in this Net to construct a fully_connected layer.
        kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers,
            such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).
        strides (Tuple[int, ...]): Number of stride of all conv layers,
            such as (stride_conv0, stride_conv1, stride_conv2).
        use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer.
        acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as (act_conv0, act_conv1, act_conv2).

    Examples:
        >>> import ppsci
        >>> in_channel = 2
        >>> in_channel_tempo = 3
        >>> out_channels = (32, 64, 128, 256)
        >>> fc_channel = 65536
        >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
        >>> strides = (2, 2, 2, 1)
        >>> use_bns = (False, True, True, True)
        >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
        >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
        >>> input_data = [paddle.to_tensor(paddle.randn([1, in_channel, 128, 128])),paddle.to_tensor(paddle.randn([1, in_channel, 128, 128]))]
        >>> input_dict = {"in_1": input_data[0],"in_2": input_data[1]}
        >>> out_dict = model(input_dict)
        >>> for k, v in out_dict.items():
        ...     print(k, v.shape)
        out_1 [1, 32, 64, 64]
        out_2 [1, 64, 32, 32]
        out_3 [1, 128, 16, 16]
        out_4 [1, 256, 16, 16]
        out_5 [1, 1]
        out_6 [1, 32, 64, 64]
        out_7 [1, 64, 32, 32]
        out_8 [1, 128, 16, 16]
        out_9 [1, 256, 16, 16]
        out_10 [1, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels: Tuple[int, ...],
        fc_channel: int,
        kernel_sizes: Tuple[int, ...],
        strides: Tuple[int, ...],
        use_bns: Tuple[bool, ...],
        acts: Tuple[str, ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.fc_channel = fc_channel
        self.kernel_sizes = kernel_sizes
        self.strides = strides
        self.use_bns = use_bns
        self.acts = acts

        self.init_layers()

    def init_layers(self):
        layers_list = []
        for i in range(len(self.out_channels)):
            in_channel = self.in_channel if i == 0 else self.out_channels[i - 1]
            layers_list.append(
                Conv2DBlock(
                    in_channel=in_channel,
                    out_channel=self.out_channels[i],
                    kernel_size=self.kernel_sizes[i],
                    stride=self.strides[i],
                    use_bn=self.use_bns[i],
                    act=self.acts[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )

        layers_list.append(
            FCBlock(self.fc_channel, self.acts[4], mean=0.0, std=0.04, value=0.1)
        )
        self.layers = nn.LayerList(layers_list)

    def forward_tensor(self, x):
        y = x
        y_list = []
        for layer in self.layers:
            y = layer(y)
            y_list.append(y)
        return y_list  # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out)

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y_list = []
        # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc
        for k in x:
            y_list.extend(self.forward_tensor(x[k]))

        y = self.split_to_dict(y_list, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)

        return y

    @staticmethod
    def split_to_dict(
        data_list: List[paddle.Tensor], keys: Tuple[str, ...]
    ) -> Dict[str, paddle.Tensor]:
        """Overwrite of split_to_dict() method belongs to Class base.Arch.

        Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
        That is because input in "tempoGAN" example is not in a regular format, but a format like:
        {
            "input1": paddle.concat([in1, in2], axis=1),
            "input2": paddle.concat([in1, in3], axis=1),
        }

        Args:
            data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
            keys (Tuple[str, ...]): Keys of outputs.

        Returns:
            Dict[str, paddle.Tensor]: Dict with split data.
        """
        if len(keys) == 1:
            return {keys[0]: data_list[0]}
        return {key: data_list[i] for i, key in enumerate(keys)}

5. Result Display

After using mixed precision training, evaluate MSE, PSNR, SSIM between target on test set. The values of evaluation indicators are:

MSE PSNR SSIM
4.21e-5 47.19 0.9974

The input of a fluid super-resolution example, model prediction result, and result directly generated by open source code package mantaflow in Dataset Introduction are as follows. The model prediction result is basically consistent with the generated target result.

input

Input low-density fluid

pred-amp02

High-density fluid obtained by inference after mixed precision training

target

Target high-density fluid

6. References