WGANGP¶
Note
# CIFAR10 Experiment
python wgangp_cifar10.py
# MNIST Experiment
python wgangp_mnist.py
# Toy Dataset Experiment
python wgangp_toy.py
# CIFAR10 Experiment
python wgangp_cifar10.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_cifar10.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_cifar10.pdparams
# MNIST Experiment
python wgangp_mnist.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_mnist.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_mnist.pdparams
# Toy Dataset Experiment
python wgangp_toy.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_toy_8gaussians.pdparams #EVAL.pretrained_dis_model_path is the model address after downloading from https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_toy_8gaussians.pdparams
| Pretrained Model | Metric |
|---|---|
| wgangp_cifar10_gen_pretrained.pdparams wgangp_cifar10_dis_pretrained.pdparams |
IS: 5.2 |
1. Background Introduction¶
In the fields of digital image processing and machine learning, Generative Adversarial Networks (GANs) have attracted widespread attention due to their excellent image generation capabilities. However, traditional GAN architectures may encounter instability problems during training, especially when generating high-resolution or complex scene images. To solve these problems, researchers proposed Wasserstein Generative Adversarial Networks with Gradient Penalty (WGAN-GP), which not only enhances the stability of the training process, but also significantly improves the quality of generated images.
WGAN-GP minimizes the difference between the real data distribution and the generated data distribution by improving the loss function, and introduces a gradient penalty mechanism to ensure smoothness and stability during the training process. This optimization method overcomes the common mode collapse problem in traditional GANs, while promoting more efficient training and more realistic image generation.
2. Model Principle¶
WGAN-GP proposes an alternative to weight clipping: penalize the norm of the gradient of the critic's input. Stabilize the training of multiple GAN architectures with almost no hyperparameter tuning.
2.1 Model Structure¶
WGAN-GP is a conditional adversarial network containing a noise-to-image generator and a CNN discriminator. The overall structure of the model is shown below.
-
Generatoris a convolutional neural network. -
Discriminatoris a model composed of convolutional blocks. Input image, output image authenticity score.
2.2 Loss Function¶
The discriminator's loss function uses Wasserstein loss and gradient penalty. Its expression is:
Where \(\mathbb{P}_g\) is the generator distribution, \(\mathbb{P}_r\) is the real data distribution, and \(\mathbb{P}_{\hat{x}}\) is a mixed interpolation sample from \(\mathbb{P}_g\) and \(\mathbb{P}_r\).
The generator's loss function is adversarial loss [\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)]. Its expression is:
Where \(\mathbb{P}_g\) is the generator distribution
3. Model Construction¶
Next, we will explain how to use the PaddleScience framework to implement WGAN-GP. The following content only elaborates on key steps. For other details, please refer to API Documentation.
3.1 Dataset Introduction¶
The dataset uses Cifar10 dataset, MNIST and toy datasets (swissroll/8gaussians/25gaussians).
Cifar10 dataset contains 60000 32x32 color images, divided into 10 categories, with 6000 images per category.
Cifar10 dataset has 3 versions
| Version | Size | md5sum |
|---|---|---|
| CIFAR-100 python | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 |
| CIFAR-100 Matlab | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 |
| CIFAR-100 binary | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 |
This implementation uses CIFAR-100 python version
MNIST dataset contains 60000 28x28 grayscale images, divided into 10 categories, with 6000 images per category.
Toy datasets
Swissroll: Three-dimensional nonlinear manifold dataset, presenting a continuous curled spiral structure,
8gaussians: Two-dimensional synthetic dataset containing eight symmetrically distributed Gaussian clusters, with centers uniformly distributed on a circle,
25gaussians: High-density Gaussian mixture dataset consisting of 25 regularly arranged two-dimensional Gaussian distributions with compact cluster spacing.
3.2 Build dataset API¶
Since the Cifar10 dataset consists of 5 data files, due to the dataset organization method, we cannot directly use the built-in dataset API of PaddleScience, so read all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.
The code for reading Cifar10 dataset is given below:
data_path passes in the path of CIFAR-10.
The configuration code of dataloader is given below:
Since the MNIST dataset cannot directly use the built-in dataset API of PaddleScience, read all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.
The code for reading MNIST dataset is given below:
The configuration code of dataloader is given below:
Since the toy dataset cannot directly use the built-in dataset API of PaddleScience, generate all data first, and then use ppsci.data.dataset.array_dataset.NamedArrayDataset.
The generation code of toy dataset is given below
The configuration code of dataloader is given below:
3.3 Model Construction¶
WGAN-GP in this case is not built into PaddleScience and needs to be implemented additionally, so we customized WganGpCifar10Generator and WganGpCifar10Discriminator, WganGpMnistGenerator and WganGpMnistDiscriminator, WganGpToyGenerator and WganGpToyDiscriminator.
The model construction code is as follows:
WganGpCifar10Generator and WganGpCifar10Discriminator
WganGpMnistGenerator and WganGpMnistDiscriminator
WganGpToyGenerator and WganGpToyDiscriminator
Parameter configuration is as follows:
WganGpCifar10Generator and WganGpCifar10Discriminator
WganGpMnistGenerator and WganGpMnistDiscriminator
WganGpToyGenerator and WganGpToyDiscriminator
3.4 Custom loss¶
The loss function of WGAN-GP is relatively complex and needs to be implemented by ourselves. PaddleScience provides an API for customizing loss functions - ppsci.loss.FunctionalLoss. The method is to define the loss function first, and then pass the function name as a parameter to FunctionalLoss. Note that the input and output of the custom loss function need to be in the format of a dictionary.
3.4.1 Loss of Generator¶
The loss of Cifar10_Generator contains adversarial loss and classification loss. Both losses have corresponding weights. If the weight of a certain loss is 0, it means that the loss item is not added during training.
The loss of MNIST_Generator only contains adversarial loss.
The loss of Toy_Generator only contains adversarial loss.
3.4.2 Loss of Discriminator¶
The loss of Cifar10_Discriminator contains Wasserstein loss, gradient penalty and classification loss. Among them, only the classification loss item has weight parameters.
The loss of MNIST_Discriminator contains Wasserstein loss and gradient penalty.
The loss of Toy_Discriminator contains Wasserstein loss and gradient penalty.
3.5 Constraint Construction¶
All cases use ppsci.constraint.SupervisedConstraint to construct constraints.
The construction code is as follows:
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.6 Optimizer Construction¶
WGANGP uses Adam optimizer, which can be directly constructed by calling ppsci.optimizer.Adam, code as follows:
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.7 Solver Construction¶
Pass the constructed model, constraints, optimizer and other parameters to ppsci.solver.Solver.
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.8 Model Training¶
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.9 Custom metric¶
In the cases, only the case for Cifar10 has an evaluation metric as Inception Score, while MNIST and Toy cases do not have evaluation metrics. Since an error will be reported if metric is empty, an invalid metric is customized.
So we implemented two additional metrics
PaddleScience provides an API for customizing metric functions - ppsci.metric.FunctionalMetric. The method is to define the metric function first, and then pass the function name as a parameter to FunctionalMetric. Note that the input and output of the custom metric function need to be in the format of a dictionary.
The implementation code of Inception Score is as follows:
The code of invalid_metric is as follows
3.10 Validator Construction¶
This case uses ppsci.validate.SupervisedValidator to construct the validator.
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.11 Model Evaluation¶
After passing the model, validator and weight path to ppsci.solver.Solver, start evaluation through solver.eval().
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
3.12 Visualization¶
After evaluation, we visualize the results in the form of images, code as follows:
For Cifar10 experiment
For MNIST experiment
For toy dataset experiment
4. Complete Code¶
For Cifar10 experiment
import os
import platform
import hydra
import paddle
from functions import Cifar10DisFuncs
from functions import Cifar10GenFuncs
from functions import InceptionScore
from functions import load_cifar10
from functions import show_save_image
from omegaconf import DictConfig
from wgangp_cifar10_model import WganGpCifar10Discriminator
from wgangp_cifar10_model import WganGpCifar10Generator
import ppsci
from ppsci.optimizer.lr_scheduler import Linear
from ppsci.utils import logger
os.environ["FLAGS_cudnn_deterministic"] = "1"
def evaluate(cfg: DictConfig):
# set model
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
# set Loss
generator_funcs = Cifar10GenFuncs(
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
)
eval_inception_score = InceptionScore(**cfg["EVAL"]["inceptionscore"])
# set data
inputs, labels = load_cifar10(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
"label": labels,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
metric={
"IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
output_dir=cfg.output_dir,
)
# evaluation
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
generator_model.eval()
for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
if batch_idx + 1 > cfg.VIS.batch:
break
fake_image = generator_model(input_)["fake_data"]
show_save_image(
fake_image[0],
f"{cfg.output_dir}/image{batch_idx}.png",
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = Cifar10GenFuncs(
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
)
discriminator_funcs = Cifar10DisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs, labels = load_cifar10(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
"label": labels,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()
optimizer_generator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_generator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_discriminator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_discriminator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_generator = optimizer_generator(generator_model)
optimizer_discriminator = optimizer_discriminator(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_cifar10.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()
For MNIST experiment
import os
import platform
import hydra
import paddle
from functions import MnistDisFuncs
from functions import MnistGenFuncs
from functions import invalid_metric
from functions import load_mnist
from functions import show_mnist
from omegaconf import DictConfig
from wgangp_mnist_model import WganGpMnistDiscriminator
from wgangp_mnist_model import WganGpMnistGenerator
import ppsci
from ppsci.utils import logger
def evaluate(cfg: DictConfig):
# set model
generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
# set Loss
generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
# set dataloader
inputs = load_mnist(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
metric={
"MAE": ppsci.metric.FunctionalMetric(invalid_metric),
},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
output_dir=cfg.output_dir,
)
# evaluation
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
if batch_idx + 1 > cfg.VIS.batch:
break
fake_data = generator_model(input_)["fake_data"]
show_mnist(
fake_data[0],
f"{cfg.output_dir}/image{batch_idx}.png",
)
show_mnist(
input_["real_data"][0],
f"{cfg.output_dir}/image_real_{batch_idx}.png",
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
discriminator_funcs = MnistDisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs = load_mnist(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"real_data": lambda out: out["real_data"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_mnist.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()
For toy dataset experiment
import os
import platform
import hydra
import paddle
from functions import ToyDisFuncs
from functions import ToyGenFuncs
from functions import generate_toy_image
from functions import invalid_metric
from functions import load_toy_data
from omegaconf import DictConfig
from wgangp_toy_model import WganGpToyDiscriminator
from wgangp_toy_model import WganGpToyGenerator
import ppsci
from ppsci.utils import logger
def evaluate(cfg: DictConfig):
# set model
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
# set Loss
generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
# set dataloader
inputs = load_toy_data(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
output_dir=cfg.output_dir,
)
# eval
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
input_, _, _ = next(iter(validator.data_loader))
real_data = input_["real_data"]
generate_toy_image(
true_dist=real_data,
discriminator=discriminator_model,
path=os.path.join(cfg.output_dir, "image.png"),
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
discriminator_funcs = ToyDisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs = load_toy_data(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"real_data": lambda out: out["real_data"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_toy.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()