Skip to content

WGANGP

Note

  1. Before running, download Cifar10 and update data_path in wgangp_cifar10.yaml
  2. Before running, download MINST and update data_path in wgangp_mnist.yaml
# CIFAR10 Experiment
python wgangp_cifar10.py
# MNIST Experiment
python wgangp_mnist.py
# Toy Dataset Experiment
python wgangp_toy.py
# CIFAR10 Experiment
python wgangp_cifar10.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_cifar10.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_cifar10.pdparams
# MNIST Experiment
python wgangp_mnist.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_mnist.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_mnist.pdparams
# Toy Dataset Experiment
python wgangp_toy.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_toy_8gaussians.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_toy_8gaussians.pdparams
Pretrained Model Metric
wgangp_cifar10_gen_pretrained.pdparams
wgangp_cifar10_dis_pretrained.pdparams
IS: 5.2

1. Background Introduction

In the fields of digital image processing and machine learning, Generative Adversarial Networks (GANs) have attracted widespread attention due to their excellent image generation capabilities. However, traditional GAN architectures may encounter instability problems during training, especially when generating high-resolution or complex scene images. To solve these problems, researchers proposed Wasserstein Generative Adversarial Networks with Gradient Penalty (WGAN-GP), which not only enhances the stability of the training process, but also significantly improves the quality of generated images.

WGAN-GP minimizes the difference between the real data distribution and the generated data distribution by improving the loss function, and introduces a gradient penalty mechanism to ensure smoothness and stability during the training process. This optimization method overcomes the common mode collapse problem in traditional GANs, while promoting more efficient training and more realistic image generation.

2. Model Principle

WGAN-GP proposes an alternative to weight clipping: penalize the norm of the gradient of the critic's input. Stabilize the training of multiple GAN architectures with almost no hyperparameter tuning.

2.1 Model Structure

WGAN-GP is a conditional adversarial network containing a noise-to-image generator and a CNN discriminator. The overall structure of the model is shown below.

    noise===>generator===>fake_image==
                                      ==>discriminator===>Wasserstein Loss+Gradient Penalty
                               image==
  • Generator is a convolutional neural network.

  • Discriminator is a model composed of convolutional blocks. Input image, output image authenticity score.

2.2 Loss Function

The discriminator's loss function uses Wasserstein loss and gradient penalty. Its expression is:

\[ L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right] \]

Where \(\mathbb{P}_g\) is the generator distribution, \(\mathbb{P}_r\) is the real data distribution, and \(\mathbb{P}_{\hat{x}}\) is a mixed interpolation sample from \(\mathbb{P}_g\) and \(\mathbb{P}_r\).

The generator's loss function is adversarial loss [\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)]. Its expression is:

\[ L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x}) \]

Where \(\mathbb{P}_g\) is the generator distribution

3. Model Construction

Next, we will explain how to use the PaddleScience framework to implement WGAN-GP. The following content only elaborates on key steps. For other details, please refer to API Documentation.

3.1 Dataset Introduction

The dataset uses Cifar10 dataset, MNIST and toy datasets (swissroll/8gaussians/25gaussians).

Cifar10 dataset contains 60000 32x32 color images, divided into 10 categories, with 6000 images per category.

Cifar10 dataset has 3 versions

Version Size md5sum
CIFAR-100 python 161 MB eb9058c3a382ffc7106e4002c42a8d85
CIFAR-100 Matlab 175 MB 6a4bfa1dcd5c9453dda6bb54194911f4
CIFAR-100 binary 161 MB 03b5dce01913d631647c71ecec9e9cb8

This implementation uses CIFAR-100 python version

MNIST dataset contains 60000 28x28 grayscale images, divided into 10 categories, with 6000 images per category.

Toy datasets

Swissroll: Three-dimensional nonlinear manifold dataset, presenting a continuous curled spiral structure,

8gaussians: Two-dimensional synthetic dataset containing eight symmetrically distributed Gaussian clusters, with centers uniformly distributed on a circle,

25gaussians: High-density Gaussian mixture dataset consisting of 25 regularly arranged two-dimensional Gaussian distributions with compact cluster spacing.

3.2 Build dataset API

Since the Cifar10 dataset consists of 5 data files, due to the dataset organization method, we cannot directly use the built-in dataset API of PaddleScience, so read all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.

The code for reading Cifar10 dataset is given below:

def load_cifar10(input_keys, label_keys, data_path):
    datas, labels = unpickle(data_path)
    datas = datas.astype("float32")
    datas_ = ((datas / 256.0) - 0.5) * 2
    random_uniform = np.random.uniform(size=[50000, 3072], low=0.0, high=1.0 / 128)
    datas_ = (datas_ + random_uniform).astype("float32")
    labels_ = np.array(labels, dtype="int32")
    labels = {label_keys[0]: datas_}
    datas = {input_keys[0]: labels_}
    return datas, labels
Where data_path passes in the path of CIFAR-10.

The configuration code of dataloader is given below:

inputs, labels = load_cifar10(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
        "label": labels,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

Since the MNIST dataset cannot directly use the built-in dataset API of PaddleScience, read all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.

The code for reading MNIST dataset is given below:

def load_mnist(
    data_path,
    input_keys,
):
    with gzip.open(data_path, "rb") as f:
        train_data, _, _ = pickle.load(f, encoding="latin1")
    data, _ = train_data
    data = {input_keys[0]: data}
    return data

The configuration code of dataloader is given below:

inputs = load_mnist(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

Since the toy dataset cannot directly use the built-in dataset API of PaddleScience, generate all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.

The generation code of toy dataset is given below

def load_toy_data(input_keys, mode):
    data = []
    if mode == "25gaussians":
        for i in range(100000 // 25):
            for x in range(-2, 3):
                for y in range(-2, 3):
                    point = np.random.randn(2) * 0.05
                    point[0] += 2 * x
                    point[1] += 2 * y
                    data.append(point)
        data = np.array(data, dtype="float32")
        np.random.shuffle(data)
        data /= 2.828  # stdev
    elif mode == "swissroll":
        data = make_swiss_roll(n_samples=100000, noise=0.25)[0]
        data = data.astype("float32")[:, [0, 2]]
        data /= 7.5  # stdev plus a little

    elif mode == "8gaussians":
        scale = 2.0
        centers = [
            (1, 0),
            (-1, 0),
            (0, 1),
            (0, -1),
            (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
            (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
            (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
            (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        ]
        centers = [(scale * x, scale * y) for x, y in centers]
        data = []
        for i in range(100000 // 8):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            data.append(point)
        data = np.array(data, dtype="float32")
        data /= 1.414  # stdev
    data = {input_keys[0]: data}
    return data

The configuration code of dataloader is given below:

inputs = load_toy_data(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

3.3 Model Construction

WGAN-GP in this case is not built into PaddleScience and needs to be implemented additionally, so we customized WganGpCifar10Generator and WganGpCifar10Discriminator, WganGpMnistGenerator and WganGpMnistDiscriminator, WganGpToyGenerator and WganGpToyDiscriminator.

The model construction code is as follows:

WganGpCifar10Generator and WganGpCifar10Discriminator

generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])

WganGpMnistGenerator and WganGpMnistDiscriminator

generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])

WganGpToyGenerator and WganGpToyDiscriminator

generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])

Parameter configuration is as follows:

WganGpCifar10Generator and WganGpCifar10Discriminator

MODEL:
  gen_net:
    input_keys: [ "labels" ]
    output_keys: [ "fake_data" ]
    dim: 128
    output_dim: 3072
    label_num: 10
    use_label: true
  dis_net:
    input_keys: [ "data", "labels" ]
    output_keys: [ "disc_fake", "disc_acgan" ]
    dim: 128
    label_num: 10
    use_label: true

WganGpMnistGenerator and WganGpMnistDiscriminator

MODEL:
  gen_net:
    output_keys: [ "fake_data" ]
    dim: 64
    output_dim: 784
  dis_net:
    input_keys: [ "data" ]
    output_keys: [ "score" ]
    dim: 64

WganGpToyGenerator and WganGpToyDiscriminator

MODEL:
  gen_net:
    output_keys: [ "fake_data" ]
    dim: 512
  dis_net:
    input_keys: [ "data" ]
    output_keys: [ "score" ]
    dim: 512

3.4 Custom loss

The loss function of WGAN-GP is relatively complex and needs to be implemented by ourselves. PaddleScience provides an API for customizing loss functions - ppsci.loss.FunctionalLoss. The method is to define the loss function first, and then pass the function name as a parameter to FunctionalLoss. Note that the input and output of the custom loss function need to be in the format of a dictionary.

3.4.1 Loss of Generator

The loss of Cifar10_Generator contains adversarial loss and classification loss. Both losses have corresponding weights. If the weight of a certain loss is 0, it means that the loss item is not added during training.

class Cifar10GenFuncs:
    """
    Loss function for cifar10 generator
    Args
        discriminator_model: discriminator model
        acgan_scale_g: scale of acgan loss for generator

    """

    def __init__(
        self,
        discriminator_model,
        acgan_scale_g=0.1,
    ):
        self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
        self.acgan_scale_g = acgan_scale_g
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_image = output_dict["fake_data"]
        labels = output_dict["labels"]
        outputs = self.discriminator_model({"data": fake_image, "labels": labels})
        disc_fake, disc_fake_acgan = outputs["disc_fake"], outputs["disc_acgan"]
        gen_cost = -paddle.mean(disc_fake)
        if disc_fake_acgan is not None:
            gen_acgan_cost = self.crossEntropyLoss(disc_fake_acgan, labels)
            gen_cost += self.acgan_scale_g * gen_acgan_cost
        return {"loss_g": gen_cost}

The loss of MNIST_Generator only contains adversarial loss.

class MnistGenFuncs:
    """
    Loss function for mnist generator
    Args
        discriminator_model: discriminator model
    """

    def __init__(self, discriminator_model):
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_data = output_dict["fake_data"]
        score = self.discriminator_model({"data": fake_data})["score"]
        gen_cost = -paddle.mean(score)
        return {"loss_g": gen_cost}

The loss of Toy_Generator only contains adversarial loss.

class ToyGenFuncs:
    """
    Loss function for toy generator
    Args
        discriminator_model: discriminator model
    """

    def __init__(self, discriminator_model):
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_data = output_dict["fake_data"]
        outputs = self.discriminator_model({"data": fake_data})
        disc_fake = outputs["score"]
        gen_cost = -paddle.mean(disc_fake)
        return {"loss_g": gen_cost}

3.4.2 Loss of Discriminator

The loss of Cifar10_Discriminator contains Wasserstein loss, gradient penalty and classification loss. Among them, only the classification loss item has weight parameters.

class Cifar10DisFuncs:
    """
    Loss function for cifar10 discriminator
    Args
        discriminator_model: discriminator model
        acgan_scale: scale of acgan loss for discriminator

    """

    def __init__(self, discriminator_model, acgan_scale):
        self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
        self.acgan_scale = acgan_scale
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, label_dict: Dict, *args):
        fake_image = output_dict["fake_data"]
        real_image = label_dict["real_data"]
        labels = output_dict["labels"]
        disc_fake = self.discriminator_model({"data": fake_image, "labels": labels})[
            "disc_fake"
        ]
        out = self.discriminator_model({"data": real_image, "labels": labels})
        disc_real, disc_real_acgan = out["disc_fake"], out["disc_acgan"]
        gradient_penalty = self.compute_gradient_penalty(real_image, fake_image, labels)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_wgan = disc_cost + gradient_penalty
        if disc_real_acgan is not None:
            disc_acgan_cost = self.crossEntropyLoss(disc_real_acgan, labels)
            disc_acgan = disc_acgan_cost.sum()
            disc_cost = disc_wgan + (self.acgan_scale * disc_acgan)
        else:
            disc_cost = disc_wgan
        return {"loss_d": disc_cost}

    def compute_gradient_penalty(self, real_data, fake_data, labels):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
                "disc_fake"
            ],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

The loss of MNIST_Discriminator contains Wasserstein loss and gradient penalty.

class MnistDisFuncs:
    """
    Loss function for mnist discriminator
    Args
        discriminator_model: discriminator model
        lamda: gradient penalty coefficient
    """

    def __init__(self, discriminator_model, lamda):
        self.discriminator_model = discriminator_model
        self.lamda = lamda

    def loss(self, output_dict: Dict, *args):
        real_data = output_dict["real_data"]
        fake_data = output_dict["fake_data"]
        disc_fake = self.discriminator_model({"data": fake_data})["score"]
        disc_real = self.discriminator_model({"data": real_data})["score"]
        gradient_penalty = self.compute_gradient_penalty(real_data, fake_data)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_cost = disc_cost + gradient_penalty
        loss = disc_cost
        return {"loss_d": loss}

    def compute_gradient_penalty(self, real_data, fake_data):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates})["score"],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = self.lamda * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

The loss of Toy_Discriminator contains Wasserstein loss and gradient penalty.

class ToyDisFuncs:
    """
    Loss function for toy discriminator
    Args
        discriminator_model: discriminator model
        lamda: gradient penalty coefficient
    """

    def __init__(self, discriminator_model, lamda):
        self.discriminator_model = discriminator_model
        self.lamda = lamda

    def loss(self, output_dict: Dict, *args):
        real_data = output_dict["real_data"]
        fake_data = output_dict["fake_data"]
        disc_fake = self.discriminator_model({"data": fake_data})["score"]
        disc_real = self.discriminator_model({"data": real_data})["score"]
        gradient_penalty = self.compute_gradient_penalty(real_data, fake_data)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_cost = disc_cost + gradient_penalty
        loss = disc_cost
        return {"loss_d": loss}

    def compute_gradient_penalty(self, real_data, fake_data):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates})["score"],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = self.lamda * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

3.5 Constraint Construction

All cases use ppsci.constraint.SupervisedConstraint to construct constraints.

The construction code is as follows:

For Cifar10 experiment

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

For MNIST experiment

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"real_data": lambda out: out["real_data"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

For toy dataset experiment

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"real_data": lambda out: out["real_data"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

3.6 Optimizer Construction

WGANGP uses Adam optimizer, which can be directly constructed by calling ppsci.optimizer.Adam, code as follows:

For Cifar10 experiment

lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()

optimizer_generator = ppsci.optimizer.Adam(
    learning_rate=lr_scheduler_generator,
    beta1=cfg["TRAIN"]["optimizer"]["beta1"],
    beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_discriminator = ppsci.optimizer.Adam(
    learning_rate=lr_scheduler_discriminator,
    beta1=cfg["TRAIN"]["optimizer"]["beta1"],
    beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_generator = optimizer_generator(generator_model)
optimizer_discriminator = optimizer_discriminator(discriminator_model)

For MNIST experiment

optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)

For toy dataset experiment

optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])

optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)

3.7 Solver Construction

Pass the constructed model, constraints, optimizer and other parameters to ppsci.solver.Solver.

For Cifar10 experiment

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

For MNIST experiment

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

For toy dataset experiment

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

3.8 Model Training

For Cifar10 experiment

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

For MNIST experiment

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

For toy dataset experiment

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

3.9 Custom metric

In the cases, only the case for Cifar10 has an evaluation metric as Inception Score, while MNIST and Toy cases do not have evaluation metrics. Since an error will be reported if metric is empty, an invalid metric is customized.

So we implemented two additional metrics

PaddleScience provides an API for customizing metric functions - ppsci.metric.FunctionalMetric. The method is to define the metric function first, and then pass the function name as a parameter to FunctionalMetric. Note that the input and output of the custom metric function need to be in the format of a dictionary.

The implementation code of Inception Score is as follows:

class InceptionScore:
    """
    Inception Score
    Args
        eps: epsilon to avoid log(0)
        splits: number of splits
    """

    def __init__(self, eps=1e-16, splits=10, batch_size=64):
        self.inception_v3 = paddle.vision.inception_v3(pretrained=True)
        self.inception_v3.fc.bias.set_value(
            paddle.to_tensor(np.zeros(self.inception_v3.fc.bias.shape, dtype="float32"))
        )
        self.inception_v3.eval()
        self.eps = eps
        self.splits = splits
        self.softmax = paddle.nn.Softmax(axis=1)
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def inception_score(self, output_dict: Dict, label_dict, *args):
        with paddle.no_grad():
            images = output_dict["fake_data"]
            images = images.reshape((-1, 3, 32, 32))
            images = (images + 1.0) * (255.99 / 2)
            predict = []
            for i in range(images.shape[0] // self.batch_size):
                image = images[i * self.batch_size : (i + 1) * self.batch_size]
                image = F.interpolate(image, size=(299, 299), mode="bilinear")
                image = image / 255
                image = self.transform(image)
                predict.append(self.inception_v3(image))
            else:
                image = images[(images.shape[0] // self.batch_size) * self.batch_size :]
                if image.shape[0] != 0:
                    image = F.interpolate(image, size=(299, 299), mode="bilinear")
                    image = image / 255
                    image = self.transform(image)
                    predict.append(self.inception_v3(image))
            predict = paddle.concat(predict, axis=0)
            predict = self.softmax(predict) + self.eps
            scores = []
            split_size = predict.shape[0] // self.splits
            for i in range(self.splits):
                part = predict[i * split_size : (i + 1) * split_size]
                kl = part * (paddle.log(part) - paddle.log(paddle.mean(part, 0)))
                kl = paddle.mean(paddle.sum(kl, 1))
                scores.append(paddle.exp(kl))
            scores = paddle.to_tensor(scores)
            return {"inception_score": paddle.mean(scores)}

The code of invalid_metric is as follows

def invalid_metric(*args, **kwargs):
    return {"invalid_metric": 0}

3.10 Validator Construction

This case uses ppsci.validate.SupervisedValidator to construct the validator.

For Cifar10 experiment

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    metric={
        "IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
    },
    name="val",
)
validator_dict = {validator.name: validator}

For MNIST experiment

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    metric={
        "MAE": ppsci.metric.FunctionalMetric(invalid_metric),
    },
    name="val",
)
validator_dict = {validator.name: validator}

For toy dataset experiment

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
    name="val",
)
validator_dict = {validator.name: validator}

3.11 Model Evaluation

After passing the model, validator and weight path to ppsci.solver.Solver, start evaluation through solver.eval().

For Cifar10 experiment

solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
    output_dir=cfg.output_dir,
)

# evaluation
solver.eval()

For MNIST experiment

# initialize solver
solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
    output_dir=cfg.output_dir,
)

# evaluation
solver.eval()

For toy dataset experiment

solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    output_dir=cfg.output_dir,
)

# eval
solver.eval()

3.12 Visualization

After evaluation, we visualize the results in the form of images, code as follows:

For Cifar10 experiment

if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        generator_model.eval()
        for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.batch:
                break
            fake_image = generator_model(input_)["fake_data"]
            show_save_image(
                fake_image[0],
                f"{cfg.output_dir}/image{batch_idx}.png",
            )
    print(f"The visualizations are saved to {cfg.output_dir}")

For MNIST experiment

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.batch:
                break
            fake_data = generator_model(input_)["fake_data"]
            show_mnist(
                fake_data[0],
                f"{cfg.output_dir}/image{batch_idx}.png",
            )
            show_mnist(
                input_["real_data"][0],
                f"{cfg.output_dir}/image_real_{batch_idx}.png",
            )
    print(f"The visualizations are saved to {cfg.output_dir}")

For toy dataset experiment

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        input_, _, _ = next(iter(validator.data_loader))
        real_data = input_["real_data"]
        generate_toy_image(
            true_dist=real_data,
            discriminator=discriminator_model,
            path=os.path.join(cfg.output_dir, "image.png"),
        )
    print(f"The visualizations are saved to {cfg.output_dir}")

4. Complete Code

For Cifar10 experiment

import os
import platform

import hydra
import paddle
from functions import Cifar10DisFuncs
from functions import Cifar10GenFuncs
from functions import InceptionScore
from functions import load_cifar10
from functions import show_save_image
from omegaconf import DictConfig
from wgangp_cifar10_model import WganGpCifar10Discriminator
from wgangp_cifar10_model import WganGpCifar10Generator

import ppsci
from ppsci.optimizer.lr_scheduler import Linear
from ppsci.utils import logger

os.environ["FLAGS_cudnn_deterministic"] = "1"


def evaluate(cfg: DictConfig):
    # set model
    generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))

    # set Loss
    generator_funcs = Cifar10GenFuncs(
        **cfg["LOSS"]["gen"], discriminator_model=discriminator_model
    )
    eval_inception_score = InceptionScore(**cfg["EVAL"]["inceptionscore"])

    # set data
    inputs, labels = load_cifar10(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
            "label": labels,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        metric={
            "IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
        output_dir=cfg.output_dir,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            generator_model.eval()
            for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.batch:
                    break
                fake_image = generator_model(input_)["fake_data"]
                show_save_image(
                    fake_image[0],
                    f"{cfg.output_dir}/image{batch_idx}.png",
                )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = Cifar10GenFuncs(
        **cfg["LOSS"]["gen"], discriminator_model=discriminator_model
    )
    discriminator_funcs = Cifar10DisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs, labels = load_cifar10(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
            "label": labels,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
    lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()

    optimizer_generator = ppsci.optimizer.Adam(
        learning_rate=lr_scheduler_generator,
        beta1=cfg["TRAIN"]["optimizer"]["beta1"],
        beta2=cfg["TRAIN"]["optimizer"]["beta2"],
    )
    optimizer_discriminator = ppsci.optimizer.Adam(
        learning_rate=lr_scheduler_discriminator,
        beta1=cfg["TRAIN"]["optimizer"]["beta1"],
        beta2=cfg["TRAIN"]["optimizer"]["beta2"],
    )
    optimizer_generator = optimizer_generator(generator_model)
    optimizer_discriminator = optimizer_discriminator(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_cifar10.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

For MNIST experiment

import os
import platform

import hydra
import paddle
from functions import MnistDisFuncs
from functions import MnistGenFuncs
from functions import invalid_metric
from functions import load_mnist
from functions import show_mnist
from omegaconf import DictConfig
from wgangp_mnist_model import WganGpMnistDiscriminator
from wgangp_mnist_model import WganGpMnistGenerator

import ppsci
from ppsci.utils import logger


def evaluate(cfg: DictConfig):
    # set model
    generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))

    # set Loss
    generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)

    # set dataloader
    inputs = load_mnist(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        metric={
            "MAE": ppsci.metric.FunctionalMetric(invalid_metric),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
        output_dir=cfg.output_dir,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.batch:
                    break
                fake_data = generator_model(input_)["fake_data"]
                show_mnist(
                    fake_data[0],
                    f"{cfg.output_dir}/image{batch_idx}.png",
                )
                show_mnist(
                    input_["real_data"][0],
                    f"{cfg.output_dir}/image_real_{batch_idx}.png",
                )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
    discriminator_funcs = MnistDisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs = load_mnist(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"real_data": lambda out: out["real_data"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
    optimizer_generator = optimizer(generator_model)
    optimizer_discriminator = optimizer(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_mnist.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

For toy dataset experiment

import os
import platform

import hydra
import paddle
from functions import ToyDisFuncs
from functions import ToyGenFuncs
from functions import generate_toy_image
from functions import invalid_metric
from functions import load_toy_data
from omegaconf import DictConfig
from wgangp_toy_model import WganGpToyDiscriminator
from wgangp_toy_model import WganGpToyGenerator

import ppsci
from ppsci.utils import logger


def evaluate(cfg: DictConfig):
    # set model
    discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
    generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])

    # set Loss
    generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)

    # set dataloader
    inputs = load_toy_data(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        output_dir=cfg.output_dir,
    )

    # eval
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            input_, _, _ = next(iter(validator.data_loader))
            real_data = input_["real_data"]
            generate_toy_image(
                true_dist=real_data,
                discriminator=discriminator_model,
                path=os.path.join(cfg.output_dir, "image.png"),
            )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
    discriminator_funcs = ToyDisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs = load_toy_data(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"real_data": lambda out: out["real_data"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])

    optimizer_generator = optimizer(generator_model)
    optimizer_discriminator = optimizer(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_toy.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

6. References