Skip to content

MoFlow(Flow Model for Molecular)

Precautions

  1. Before starting training and evaluation, please download the QM9 Dataset and ZINC Dataset, and modify the FILE_PATH in the yaml configuration file to the path of the decompressed dataset. It is recommended to place it in example./datasets/moflow.
  2. Before starting training, testing, and optimization evaluation, please install additional chemical packages and data display conversion tools with the command pip install -r requirements.txt, install rdkit chemical tool and cairosvg data conversion and saving tool.
  3. The pre-trained model needs to be modified and placed in the specified folder, the corresponding yaml configuration file should be modified, and the prompts appearing when executing imperative molecular generation is unreasonable can be ignored.
# qm9 dataset model training
python moflow_train.py data_name=qm9

# zinc250k dataset model training
python moflow_train.py data_name=zinc250k
# qm9 dataset pre-trained model generation evaluation, where EVAL_mode=Reconstruct is reconstruction generation, EVAL_mode=Random is random generation, EVAL_mode=Inter2point is intermolecular interpolation generation, EVAL_mode=Intergrid is molecular grid interpolation generation, refer to 3.7 Model Generation Evaluation Construction for details
python test_generate.py data_name=qm9 EVAL_mode=Reconstruct EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/qm9_pretrained.pdparams

# zinc250k dataset pre-trained model generation evaluation, where EVAL_mode=Reconstruct is reconstruction generation, EVAL_mode=Random is random generation, EVAL_mode=Inter2point is intermolecular interpolation generation, EVAL_mode=Intergrid is molecular grid interpolation generation, refer to 3.7 Model Generation Evaluation Construction for details
python test_generate.py data_name=zinc250k EVAL_mode=Reconstruct EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/zinc250k_pretrained.pdparams
# Method 1: Do not use pre-trained model, the first run is model training, the second run is prediction generation result output
# qm9 dataset pre-trained model optimization, where OPTIMIZE.property_name=qed is latent space to QED property, OPTIMIZE.property_name=plogp is from latent space to plogp property, refer to 3.8 Model Optimization Construction for details
python optimize_moflow.py data_name=qm9  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/qm9_pretrained.pdparams  OPTIMIZE.property_name=qed

# zinc250k dataset pre-trained model optimization, where OPTIMIZE.property_name=qed is latent space to QED property, OPTIMIZE.property_name=plogp is from latent space to plogp property, refer to 3.8 Model Optimization Construction for details
python optimize_moflow.py data_name=zinc250k  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/zinc250k_pretrained.pdparams OPTIMIZE.property_name=qed

# Method 2: Use provided pre-trained model, download optimized model for prediction result generation output
# qm9 dataset pre-trained model optimization
mkdir -p ./outputs_moflow_optimize/qm9/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/qed_opt_pretrained.pdparams -O ./outputs_moflow_optimize/qm9/qed_model.pdparams
python optimize_moflow.py data_name=qm9  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/qm9_pretrained.pdparams  OPTIMIZE.property_name=qed

wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/plogp_opt_pretrained.pdparams -O ./outputs_moflow_optimize/qm9/plogp_model.pdparams
python optimize_moflow.py data_name=qm9  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/qm9/qm9_pretrained.pdparams  OPTIMIZE.property_name=plogp

# zinc250k dataset pre-trained model optimization
mkdir -p ./outputs_moflow_optimize/zinc250k/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/qed_opt_pretrained.pdparams -O ./outputs_moflow_optimize/zinc250k/qed_model.pdparams
python optimize_moflow.py data_name=zinc250k  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/zinc250k_pretrained.pdparams OPTIMIZE.property_name=qed

wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/plogp_opt_pretrained.pdparams -O ./outputs_moflow_optimize/zinc250k/plogp_model.pdparams
python optimize_moflow.py data_name=zinc250k  TRAIN.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/MoFlow/zinc250k/zinc250k_pretrained.pdparams OPTIMIZE.property_name=plogp
Pretrained Model Metrics
qm9 loss(Residual):
1.09976
zinc250k loss(Residual)::
1.12570

1. Background Introduction

MoFlow is a flow-based graph generation model designed to accelerate the drug discovery process by generating molecular graphs with desired chemical properties. Such graph generation models usually include two steps: learning latent representations and generating molecular graphs. Generating novel molecular graphs that conform to chemical rules from latent representations is very challenging because molecular graphs have chemical constraints and combinatorial complexity.

MoFlow is used to learn the invertible mapping between molecular graphs and their latent representations. First, bonds (edges) are generated through a Glow-based model, and then atoms (nodes) are generated given bonds through a novel graph conditional flow model. Finally, they are assembled into a molecular graph conforming to chemical rules through post-processing validity correction. It has advantages such as accurate and computable likelihood training, efficient one-time embedding and generation, chemical validity guarantee, 100% reconstruction of training data, and good generalization ability. A variant of the Glow model is used to generate bonds (multi-type edges, such as single bonds, double bonds, and triple bonds), and a novel graph conditional flow based on graph convolution is used to generate atoms (multi-type nodes, such as C, N, etc.) according to bonds, assembling atoms and bonds into effective molecular graphs conforming to bond valence constraints.

MoFlow is one of the first flow-based graph generation models capable of generating molecular graphs at once through invertible mapping with validity guarantee. In order to capture the combinatorial atom and bond structure of molecular graphs, the proposed Glow model is used to generate bonds (edges), and a novel method based on graph conditional flow is used to generate atoms (nodes) according to bonds, and then they are assembled into effective molecular graphs; many good results have been achieved in molecular graph generation, reconstruction, optimization, etc. One-time inference and generation are very efficient, which means it has the potential for efficiency and effectiveness in exploring chemical space for drug discovery.

2. Model Principle

This chapter only briefly introduces the model principle of MoFlow. For detailed model structure and derivation process, please read the paper MoFlow: An Invertible Flow Model for Generating Molecular Graphs.

2.1 Basic Framework of the Model

The Flow model learns a series of invertible transformations \(f_Θ = f_L ◦ ... ◦ f_1\) between complex high-dimensional data \(X \sim P_\mathcal{X}(X)\) and \(Z \sim P_\mathcal{Z}(Z)\) in a latent space with the same dimension, where the latent distribution \(P_\mathcal{Z}(Z)\) is easy to model (for example, a strong independence assumption holds in such a latent space). Potential complex data in the original space can be modeled by the variable transformation formula, where \(Z = f_Θ(X)\) and:

\[ \begin{aligned} P_\mathcal{X}(X) & = P_\mathcal{Z}(Z)|\det(\frac{\partial Z}{\partial X}) \end{aligned} \]

Sampling \(\widetilde{X} \sim P_\mathcal{X}(X)\) transforms \(f_Θ\) by sampling \(\widetilde{Z} \sim P_\mathcal{Z}(Z)\) and then reverse mapping through \(\widetilde{X} = f_Θ^{−1}\widetilde{Z}\). Let \(Z = f_Θ(X) = f_L ◦ ... ◦ f_1(X), H_l = f_l(H_{l−1})\), where \(f_l(l = 1, ...L ∈ \mathbb{N}^+)\) is an invertible map, \(H_0 = X, H_L = Z\), and \(P_\mathcal{Z}(Z)\) follows a standard isotropic Gaussian distribution with independent dimensions. Then, the log likelihood of \(X\) can be obtained through the variable transformation formula:

\[ \begin{aligned} \log P_\mathcal{X}(X) &=\log P_\mathcal{Z}(Z) + \log \left|\det\left(\frac{\partial Z}{\partial X}\right) \right| \\ &= \sum_{i} \log P_{\mathcal{Z}_i}(Z_i) + \sum_{l=1}^L \log \left|\det\left(\frac{\partial f_l}{\partial H_{l−1}}\right)\right| \end{aligned} \]

Where \(P_{\mathcal{Z}_i}(Z_i)\) is the probability of the \(i^{th}\) dimension of \(Z\), and \(fΘ = f_L ◦ ... ◦ f_1\) is the invertible deep neural network to be learned.

Coupling invertible affine coupling layer, designing an expressive structure with an invertible function f, capable of calculating the efficiency of the Jacobian determinant through an affine coupling transformation \(Z = f_Θ(X): \mathbb{R}^n \mapsto \mathbb{R}^n\):

\[ \begin{aligned} Z_{1:d} & = X_{1:d} \\ Z_{d+1:n} & = X_{d+1:n} ⊙ e^{S_Θ(X_{1:d})} + T_Θ(X_{1:d}) \end{aligned} \]

By dividing \(X\) into two partitions \(X = (X_{1:d}, X_{d+1:n})\), invertibility is guaranteed by:

\[ \begin{aligned} X_{1:d} & = Z_{1:d} \\ X_{d+1:n} & = (Z_{d+1:n} - T_Θ(X_{1:d}) )/ e^{S_Θ(Z_{1:d})} \end{aligned} \]

Expressiveness depends on the scale function \(S_Θ:\mathbb{R}^d \mapsto \mathbb{R}^{n-d}\) and transformation function \(T_Θ:\mathbb{R}^d \mapsto \mathbb{R}^{n-d}\) of any neural structure in the affine transformation of \(X_{d+1:n}\). The Jacobian determinant can be efficiently calculated by: \(\det(\frac{\partial Z}{\partial X}) = \exp(\sum_j S_Θ(X_{1:d}))\).

2.2 Principle of MoFlow Model

Consider the molecular graph \(\text{M}\) as an undirected graph composed of atoms as nodes and bonds as edges. Its mathematical notation can be denoted as \(\mathcal{M} = \mathcal{A} \times \mathcal{B} \subset \mathbb{R}^{n \times k} \times \mathbb{R}^{c \times n \times n}\), where the set has \(n\) atoms and \(k\) atom types, \(A(i,k)=1\) represents node \(i\) is a \(k\)-type atom, the set represents bonds (edges), bonds have \(c\) types, \(B(c,i,j)=1\) represents atom \(i\) and \(j\) are connected by a \(c\)-type bond. A molecule \(\mathcal{M}\) can be regarded as an undirected graph with multi-type nodes and multi-type edges. The main goal is to learn a molecular generation model \(P_{\mathcal{M}}(M)\), that is, the probability of sampling any molecule \(\text{M}\) from \(P_{\mathcal{M}}\). In order to capture the combinatorial atom and bond structure of the molecular graph, \(P_{\mathcal{M}}(M)\) is decomposed into two parts:

\[ \begin{aligned} P_\mathcal{M}(M) = P_{\mathcal{M}}((A, B)) ≈ P_{\mathcal{A|B}}(A|B; θ_{\mathcal{A|B}})P_\mathcal{B}(B; θ_\mathcal{B}) \end{aligned} \]

Where \(P_{\mathcal{M}}\) is the distribution of molecular graphs, \(P_\mathcal{B}\) is the distribution of bonds (edges), similar to modeling multi-channel images, and \(P_{\mathcal{A|B}}\) is the conditional distribution of atoms (nodes) given bonds, modeled by utilizing graph convolution operations. \(θ_\mathcal{B}\) and \(θ_{\mathcal{A|B}}\) are learnable modeling parameters. The objective function of the model is as follows:

\[ \begin{aligned} \mathop{\arg\max}\limits_{\theta_\mathcal{B}, \theta_\mathcal{A|B}} \mathbb{E}_{\mathcal{M}=(A,B) \sim \mathcal{PM}−data} [ \log P_\mathcal{A|B}(A | B; θ_\mathcal{A|B} + \log P_\mathcal{B}(B; θ_\mathcal{B})] \end{aligned} \]

Given bond tensor \(B \in \mathcal{B} \subset \mathbb{R}^{c×n×n}\), generate correct atom type matrix \(A \in \mathcal{A} \subset \mathbb{R}^{n×k}\) to form a valid molecule \(M = (A, B) \in \mathcal{M} \subset \mathbb{R}^{n×k+c×n×n}\). First define \(B\) conditional flow and graph conditional flow \(f_\mathcal{A|B}\), transforming \(A\) given \(B\) into conditional latent variable \(Z_{A|B} = f_\mathcal{A|B}(A|B)\), which follows isotropic Gaussian distribution \(P_{\mathcal{Z}_\mathcal{A|B}}\). Through the conditional variable transformation formula, the conditional probability \(P_\mathcal{A|B}\) of atom features given the bond graph can be obtained. \(B\) conditional flow \(Z_{A|B} = f_\mathcal{A|B}(A|B)\) is an invertible and dimension-preserving map, and there exists an inverse transformation \(f^{−1}_\mathcal{A|B}(Z_{A|B} |B) = A|B\), where \(f_\mathcal{A|B}\) and \(f^{−1}_\mathcal{A|B}:\mathcal{A \times B} \mapsto \mathcal{A \times B}\). During the transformation, \(B \in B\) remains unchanged. Under the condition of independence assumption of \(A\) and \(B\), the Jacobian matrix of \(f_\mathcal{A|B}\) is:

\[ \begin{aligned} \frac{\partial f_\mathcal{A|B}}{\partial (A, B)}=\bigg[\begin{matrix} \frac{\partial f_\mathcal{A|B}}{\partial A} & \frac{\partial f_\mathcal{A|B}}{\partial B} \\ 0 & \mathbb{1}_B \end{matrix}\bigg] \end{aligned} \]

Having obtained the distribution, we can sample from it, use the inverse map to get \(A|B\), and use the Jacobian matrix to give the probability distribution of \(A|B\). The log likelihood of the conditional variable transformation formula is:

\[ \begin{aligned} \log P_\mathcal{A|B}(A|B) = \log P_{\mathcal{Z}_\mathcal{A|B}}(Z_{A|B}) + \log |\det \frac{\partial f_\mathcal{A|B}}{\partial A}| \end{aligned} \]

Like flow-based RealNVP and Glow models, in order to obtain invertible mapping, Moflow introduces graph coupling layers. For each graph coupling layer, input \(A \in \mathbb{R}^{n×k}\) is divided into two parts \(A = (A_1, A_2)\) along the n row dimension, and then output \(Z_{A|B} = (Z_{A_1|B}, Z_{A_2|B}) = f_\mathcal{A|B}(A|B)\) is obtained as follows, dividing the input into two parts \(A_1\) and \(A_2\):

\[ \begin{aligned} Z_{A_1 |B} &= A_1 \\ Z_{A_2 |B} &= A_2 \odot \text{Sigmoid}(S_Θ(A_1 |B)) + T_Θ(A_1 |B) \end{aligned} \]

Inverting the above formula yields \(A_1\) and \(A_2\). The graph convolution layer is completed using Relational Graph Convolutional Network (R-GCN), specifically as follows:

\[ \begin{aligned} \text{graphconv}(A_1) & = \sum_{i=1}^c \hat{B}_i (M \odot A)W_i + (M \odot A)W_0 \end{aligned} \]

At the same time, multiple stacked graph convolution->BatchNorm1d->ReLU layers and a Multi-Layer Perceptron (MLP) output layer are used to construct graph scaling function \(S_Θ\) and graph transformation function \(T_Θ\). For numerical stability, the Sigmoid function is adopted in \(S_Θ\) to achieve numerical stability when cascading multiple flow layers. The inverse map \(f^{-1}_\mathcal{A|B}\) of the graph coupling layer is:

\[ \begin{aligned} A_1 &= Z_{A_1}|B \\ A_2 &= (Z_{A_2}|B - T_Θ(Z_{A_1|B}|B)) / \text{Sigmoid}(S_Θ(Z_{A_1|B}|B)) \end{aligned} \]

The logarithm of the Jacobian determinant of each graph coupling layer can be calculated as follows:

\[ \begin{aligned} \log | \det (\frac{\partial f_\mathcal{A|B}}{\partial A})|= \sum_j \log \text{Sigmoid}(S_Θ(A_1|B))_j \end{aligned} \]

Where \(j\) iterates over each element. Arbitrarily complex graph convolution structures can be used to construct \(S_Θ\) and \(T_Θ\), because the above calculation of the Jacobian determinant of \(f_\mathcal{A|B}\) does not involve the calculation of the Jacobian matrix of \(S_Θ\) or \(T_Θ\).

When learning atom representation, in order to ensure data stability, \(\sigma^2 \in \mathbb{R}^{n \times 1}\) is used for normalization for each row dimension, so that the input result after normalization is \(\hat A = \frac{A - \mu} {\sqrt{\sigma^2 + \epsilon}}\), where \(\epsilon\) is a small constant. The inverse transformation is \(A = \hat A \times \sqrt{\sigma^2 + \epsilon} + \mu\), and the logarithmic Jacobian determinant is:

\[ \begin{aligned} \log | \det \frac{\partial actnorm2D}{\partial X}| = \frac{k}{2}\sum_i^n | \log(\sigma^2_i + \epsilon) | \end{aligned} \]

In learning bond data representation, the idea based on Glow is adopted, similar to the above steps of learning atom representation, and for data stability, the \(1 \times 1\) convolution operation in the Glow model is also introduced.

Finally, chemical validity verification is performed, following the valence bond limit of each atom, adopting whether the combination of atoms and bonds conforms to the chemical bond valence constraint, and the bond valence constraint is defined:

\[ \begin{aligned} \sum_{c,j}c \times B(c, i, j) \le \text{Valency}(\text{Atom}_i) + Ch \end{aligned} \]

Where \(c\) is the type of bond (single bond, double bond, triple bond). Unlike other models, the constraint of formal charge \(Ch\) is added. This effect may introduce additional bonds for charged atoms. For example, N of ammonium [NH4]+ may have 4 bonds instead of 3. Similarly, S+ and O+ may have 3 bonds instead of 2.

The model structure is shown in the figure:


MoFlow Model Structure Diagram

2.3 Dataset Introduction

QM9 dataset is derived from the enumerated subset of the GDB-17 database. GDB-17 is a chemical universe containing 166 billion small organic molecules, and QM9 filters out all stable molecules containing no more than 9 heavy atoms. * Total number of molecules: about 134,000 (specifically 133,885 stable organic molecules). * Atom composition: Only contains four heavy atoms: carbon (C), nitrogen (N), oxygen (O), fluorine (F), and hydrogen (H). * Element vocabulary: In the QM9 implementation of MoFlow, the atom type list is strictly defined as ['C', 'N', 'O', 'F']. * Maximum size: 9 heavy atoms. This means that in tensor representation, \(N=9\).

Every molecule has undergone high-precision Density Functional Theory (DFT) calculations, specifically at the B3LYP/6-31G(2df,p) level. These calculations provide geometric, energetic, electronic, and thermodynamic properties. In MoFlow research, the following properties are focused on for conditional generation and property optimization tasks: Symbol Property Name Unit Physical Meaning Relevance to Generation Model HOMO Highest Occupied Molecular Orbital Energy eV Measures the electron donating ability of the molecule. The higher the HOMO energy, the easier it is for the molecule to lose electrons. LUMO Lowest Unoccupied Molecular Orbital Energy eV Measures the electron accepting ability of the molecule. The lower the LUMO energy, the easier it is for the molecule to accept electrons. Gap (\(\Delta \epsilon\)) HOMO-LUMO Gap eV Determines the chemical hardness and light absorption characteristics of the molecule. This is one of the core goals of MoFlow for property optimization. \(\mu\) Dipole Moment Debye Describes the asymmetry of molecular charge distribution. Generating molecules with specific polarity is crucial in material design. \(\alpha\) Isotropic Polarizability \(Bohr^3\) The response ability of the molecule under an external electric field. \(U_0\) Internal Energy at 0K Hartree Thermodynamic stability indicator of the molecule. \(C_v\) Heat Capacity at Constant Volume cal/mol K Thermodynamic property, often used as the target for regression tasks.

ZINC250k contains about 249,455 molecular graphs. The construction of this subset is not random, but follows strict "Drug-likeness" and "Synthesizability" criteria, initially established by Gómez-Bombarelli et al. in their pioneering automatic chemical design paper, and subsequently widely adopted by subsequent studies such as MoFlow. Screening Criteria include: * Heavy atom count limit: Molecular size is limited to within 38 heavy atoms. This is much larger than QM9, allowing for more complex ring systems and long chain structures. * Element diversity: Chemical space expanded from CHNOF to include elements such as halogen and phosphorus-sulfur. The atomic number list defined in MoFlow code is: ``, corresponding to C, N, O, F, P, S, Cl, Br, I respectively. * LogP range: The lipid-water partition coefficient (LogP) of the molecule needs to be within a specific range to ensure oral bioavailability. * Synthetic Accessibility (SA): Priority is given to molecules that are easy to synthesize. * Structure filtering: Molecules containing rings larger than 8 members were eliminated, and complex salt forms with charges were also eliminated, simplifying the topological difficulty of graph generation.

3. Model Implementation

Next, we will explain how to implement the model reproduction of structural reconstruction in drug molecules based on PaddleScience code. To achieve MoFlow model construction, training, inference, and evaluation, only key steps such as model construction, training, testing, and evaluation are described below, while other details please refer to API Documentation.

3.1 Data Processing

In data processing, first by reading the chemical molecular structure and using chemical molecule library processing, extract chemical bonds and molecular nodes from the chemical structure part of the dataset, and process the atomic structure and bond values. Expressed in PaddleScience code as follows

data/dataset/moflow_dataset.py
                fail_count += 1
                self.logger.warning(f"parse(), type: {type(e).__name__}, {e.args}")
                continue
            except Exception as e:
                self.logger.warning(f"parse(), type: {type(e).__name__}, {e.args}")
                fail_count += 1
                continue
            # raw_data = misc.convert_to_dict(np.array([nodes, edges]), self.input_keys)

            all_nodes.append(nodes)
            all_edges.append(edges)
            # inputs.append(raw_data)

            success_count += 1

        labels = np.array(
            [*(df[label_col].values for label_col in self.label_keys)]
        ).T
        result = [np.array(all_nodes), np.array(all_edges)], labels
        self.logger.message(
            f"Preprocess finished. FAIL {fail_count}, "
            f"SUCCESS {success_count}, TOTAL {total_count}"
        )
    else:
        raise NotImplementedError

    return result

def transform_func(self, data_dict, label_dict):
    items = []
    length = len(next(iter(data_dict.values())))
    for idx in range(length):
        input_item = [value[idx] for key, value in data_dict.items()]
        label_item = [value[idx] for key, value in label_dict.items()]

Training data is selected using set labels, dividing the dataset into training data and test data. The processing of qm9 and zinc250k datasets is consistent, with some differences in feature and atomic structure processing choices.

3.2 Constraint Construction

This case solves the problem based on the method of learning chemical bond constraints from data, so according to the PaddleScience API structure description, the built-in SupervisedConstraint is used to construct supervised constraints. Before defining constraints, you need to first specify various parameters used for data loading in supervised constraints.

examples/moflow/moflow_train.py
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "MOlFLOWDataset",
        "file_path": cfg.FILE_PATH,
        "data_name": cfg.data_name,
        "mode": cfg.mode,
        "valid_idx": valid_idx,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.get(cfg.data_name).label_keys,
        "smiles_col": cfg.get(cfg.data_name).smiles_col,
        "transform_fn": transform_fn,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": cfg.TRAIN.num_workers,
}

Among them, the "dataset" field defines the used Dataset class name as MOlFLOWDataset, the "sampler" field defines the used Sampler class name as BatchSampler, the set batch_size is 256, and num_works is 8.

The code for defining supervised constraints is as follows:

examples/moflow/moflow_train.py
# set constraint
output_keys = cfg.MODEL.output_keys
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.FunctionalLoss(model.log_prob_loss),
    {key: (lambda out, k=key: out[k]) for key in output_keys},
    name="Sup_constraint",
)

constraint = {sup_constraint.name: sup_constraint}

3.3 Model Construction

In this case, the drug molecule prediction generation model is implemented based on the MoFlowNet network model. Combined with the PaddleScience code standard format, the model is encapsulated, and flow, grow and other models are called separately. The code for model composition is represented as follows:

examples/moflow/moflow_train.py
# set model
model_cfg = dict(cfg.MODEL)
model_cfg.update({"hyper_params": model_params})
model = ppsci.arch.MoFlowNet(**model_cfg)

Model network parameter configuration is as follows:

examples/moflow/moflow_train.py
# set training hyper-parameters
b_hidden_ch = cfg.get(cfg.data_name).b_hidden_ch
a_hidden_gnn = cfg.get(cfg.data_name).a_hidden_gnn
a_hidden_lin = cfg.get(cfg.data_name).a_hidden_lin
mask_row_size_list = list(cfg.get(cfg.data_name).mask_row_size_list)
mask_row_stride_list = list(cfg.get(cfg.data_name).mask_row_stride_list)
a_n_type = len(cfg.get(cfg.data_name).atomic_num_list)
atomic_num_list = list(cfg.get(cfg.data_name).atomic_num_list)

model_params = Hyperparameters(
    b_n_type=cfg.get(cfg.data_name).b_n_type,
    b_n_flow=cfg.get(cfg.data_name).b_n_flow,
    b_n_block=cfg.get(cfg.data_name).b_n_block,
    b_n_squeeze=cfg.get(cfg.data_name).b_n_squeeze,
    b_hidden_ch=b_hidden_ch,
    b_affine=True,
    b_conv_lu=cfg.get(cfg.data_name).b_conv_lu,
    a_n_node=cfg.get(cfg.data_name).a_n_node,
    a_n_type=a_n_type,
    a_hidden_gnn=a_hidden_gnn,
    a_hidden_lin=a_hidden_lin,
    a_n_flow=cfg.get(cfg.data_name).a_n_flow,
    a_n_block=cfg.get(cfg.data_name).a_n_block,
    mask_row_size_list=mask_row_size_list,
    mask_row_stride_list=mask_row_stride_list,
    a_affine=True,
    learn_dist=cfg.get(cfg.data_name).learn_dist,
    seed=cfg.seed,
    noise_scale=cfg.get(cfg.data_name).noise_scale,
)

logger.info("Model params:\n" + tabulate(model_params.print()))

Parameters are set through the configuration file as follows:

examples/moflow/conf/moflow_train.yaml
# general settings
mode: train # running mode: train/eval
data_name: qm9 # data select:qm9/zinc250k
seed: 1
output_dir: ${hydra:run.dir}
log_freq: 20

# set training hyper-parameters
qm9:
  b_n_flow: 10
  b_n_block: 1
  b_hidden_ch: [128,128]
  a_n_flow: 27
  a_n_block: 1
  a_hidden_gnn: [64]
  a_hidden_lin: [128,64]
  mask_row_size_list: [1]
  mask_row_stride_list: [1]
  learn_dist: True
  noise_scale: 0.6
  b_conv_lu: 1
  atomic_num_list: [6, 7, 8, 9, 0]
  b_n_type: 4
  b_n_squeeze: 3
  a_n_node: 9
  valid_idx: valid_idx_qm9.json
  label_keys: ['A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
  smiles_col: SMILES1

zinc250k:
  b_n_flow: 10
  b_n_block: 1
  b_hidden_ch: [512,512]
  a_n_flow: 38
  a_n_block: 1
  a_hidden_gnn: [256]
  a_hidden_lin: [512,64]
  mask_row_size_list: [1]
  mask_row_stride_list: [1]
  learn_dist: True
  noise_scale: 0.6
  b_conv_lu: 2
  atomic_num_list: [6, 7, 8, 9, 15, 16, 17, 35, 53, 0]
  b_n_type: 4
  b_n_squeeze: 19
  a_n_node: 38
  valid_idx: valid_idx_zinc.json
  label_keys: ['logP', 'qed', 'SAS']
  smiles_col: smiles

# set data path
FILE_PATH: ./datasets/moflow

# model settings
MODEL:
  input_keys: ["nodes", "edges"]
  output_keys: ["output", "sum_log_det"]

Among them, data_name represents the selection of the dataset. After selection, the network parameter part corresponding to different datasets is selected accordingly. input_keys and output_keys represent the names of the input and output variables of the network model respectively. hyper_params represents the network parameters corresponding to different datasets, which will be updated in the model construction after dataset selection to facilitate unified construction of models under different datasets. Use the model's custom loss function for model training.

3.4 Learning Rate and Optimizer Construction

The learning rate size used in this case is set to 0.001. The optimizer uses Adam, expressed in PaddleScience code as follows:

examples/moflow/moflow_train.py
# init optimizer and lr scheduler
optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model)

3.5 Validator Construction

During the training process of this case, the training status of the current model will be evaluated using the validation set at certain training round intervals, and SupervisedValidator is needed to construct the validator. The code is as follows:

examples/moflow/moflow_train.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "MOlFLOWDataset",
        "file_path": cfg.FILE_PATH,
        "data_name": cfg.data_name,
        "mode": "eval",
        "valid_idx": valid_idx,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.get(cfg.data_name).label_keys,
        "smiles_col": cfg.get(cfg.data_name).smiles_col,
        "transform_fn": transform_fn,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.FunctionalLoss(model.log_prob_loss),
    {key: (lambda out, k=key: out[k]) for key in output_keys},
    metric={
        "Valid": ppsci.metric.FunctionalMetric(
            eval_func(model, cfg.EVAL.batch_size, atomic_num_list)
        )
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

Evaluation metric metric uses a custom function to generate molecules using molecular vector values, and independently evaluates the regenerated molecules. The custom evaluation metrics used here are valid, unique, and abs_unique.

3.6 Model Training and Evaluation

After completing the above settings, you only need to pass the instantiated objects to ppsci.solver.Solver in order, and then start training and evaluation.

examples/moflow/moflow_train.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    None,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    seed=cfg.seed,
    validator=validator,
    save_freq=cfg.TRAIN.save_freq,
    eval_during_train=cfg.TRAIN.eval_during_train,
    eval_freq=cfg.TRAIN.eval_freq,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()

# validation for training
solver.eval()

3.7 Model Generation Evaluation Construction

For different dataset constructions, different evaluation methods for different models are provided. The model generation ability is comprehensively evaluated through reconstruction, random generation, and interpolation generation. Different methods have different parameter configurations. The parameter configuration file is as follows:

examples/moflow/conf/moflow_test.yaml
# evaluation settings
EVAL:
  pretrained_model_path: null
  batch_size: 256
  num_workers: 0
  reconstruct: false
  int2point: false
  intgrid: false
  inter_times: 5
  correct_validity: true
  temperature: 1.0
  delta: 0.1
  n_experiments:
  save_fig: true

EVAL_mode: Intergrid #select EVAL_mode: Reconstruct/Random/Inter2point/Intergrid

Reconstruct: #重建生成,针对不同数据集的分子进行重建生成
  batch_size: 256
  reconstruct: true
  n_experiments: 0

Random: #随机生成,针对不同的数据集从潜在空间进行随机生成,10000个样本生成5次
  batch_size: 10000
  temperature: 0.85
  delta: 0.05
  n_experiments: 5
  save_fig: false
  correct_validity: true

Inter2point: #在潜在空间进行插值,两个分子之间插值可视化生成分子图
  batch_size: 1000
  int2point: true
  temperature: 0.65
  inter_times: 50
  correct_validity: true
  n_experiments: 0

Intergrid: #在潜在空间进行插值,分子网格进行可视化生成分子图
  batch_size: 1000
  temperature: 0.65
  delta: 5
  intgrid: true
  inter_times: 40
  correct_validity: true

Among them, EVAL_mode is the selected evaluation mode, and different modes have different evaluation methods. Reconstruct (reconstruction generation) performs reconstruction generation of drug molecules for different datasets, reconstructing molecules in the selected dataset; Random (random generation) performs random generation from the latent space for different datasets, and the parameter setting is to randomly generate 5 times from 10000 samples; Inter2point (intermolecular interpolation generation) performs interpolation in the latent space, and visualizes generated molecular graphs by interpolating between two molecules; Intergrid (molecular grid interpolation generation) performs interpolation in the latent space, and visualizes generated molecular graphs using molecular grids (interpolation generation stores generated new molecules as visible pictures). Parameters in each mode are adjusted according to the actual situation, including result storage, number of generated molecules, etc. The rest of the configuration is the same as training. When selecting models trained on different datasets, pay attention to modifying the data name and checking the location of the pre-trained model.

The code for building the evaluator is:

examples/moflow/test_generate.py
dataloader_cfg = {
    "dataset": {
        "name": "MOlFLOWDataset",
        "file_path": cfg.FILE_PATH,
        "data_name": cfg.data_name,
        "mode": cfg.mode,
        "valid_idx": valid_idx,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.get(cfg.data_name).label_keys,
        "smiles_col": cfg.get(cfg.data_name).smiles_col,
        "transform_fn": transform_fn,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.EVAL.batch_size,
    "num_workers": cfg.EVAL.num_workers,
}

test = ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])
dataloader_cfg["dataset"].update({"mode": "train"})
train = ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])
logger.info(
    "{} in total, {}  training data, {}  testing data, {} batchsize, train/batchsize {}".format(
        len(train) + len(test),
        len(train),
        len(test),
        batch_size,
        len(train) / batch_size,
    )
)

if cfg.EVAL.reconstruct:
    train_dataloader = ppsci.data.build_dataloader(train, dataloader_cfg)
    reconstruction_rate_list = []
    max_iter = len(train_dataloader)
    input_keys = cfg.MODEL.input_keys
    output_keys = cfg.MODEL.output_keys
    for i, batch in enumerate(train_dataloader, start=0):
        output_dict = model(batch[0])
        x = batch[0][input_keys[0]]
        adj = batch[0][input_keys[1]]
        z = output_dict[output_keys[0]]
        z0 = z[0].reshape([tuple(z[0].shape)[0], -1])
        z1 = z[1].reshape([tuple(z[1].shape)[0], -1])
        adj_rev, x_rev = model.reverse(paddle.concat(x=[z0, z1], axis=1))
        reverse_smiles = adj_to_smiles(adj_rev.cpu(), x_rev.cpu(), atomic_num_list)
        train_smiles = adj_to_smiles(adj.cpu(), x.cpu(), atomic_num_list)
        lb = np.array([int(a != b) for a, b in zip(train_smiles, reverse_smiles)])
        idx = np.where(lb)[0]
        if len(idx) > 0:
            for k in idx:
                logger.info(
                    "{}, train: {}, reverse: {}".format(
                        i * batch_size + k, train_smiles[k], reverse_smiles[k]
                    )
                )
        reconstruction_rate = 1.0 - lb.mean()
        reconstruction_rate_list.append(reconstruction_rate)
        logger.message(
            "iter/total: {}/{}, reconstruction_rate:{}".format(
                i, max_iter, reconstruction_rate
            )
        )
    reconstruction_rate_total = np.array(reconstruction_rate_list).mean()
    logger.message(
        "reconstruction_rate for all the train data:{} in {}".format(
            reconstruction_rate_total, len(train)
        )
    )
    exit(0)

if cfg.EVAL.int2point:
    inputs = train.input
    labels = train.label
    items = []
    for idx in range(len(train)):
        input_item = [value[idx] for key, value in inputs.items()]
        label_item = [value[idx] for key, value in labels.items()]
        item = input_item + label_item
        item = transform_fn(item)
        items.append(item)
    items = np.array(items, dtype=object).T
    inputs = {key: np.stack(items[i], axis=0) for i, key in enumerate(inputs)}

    mol_smiles = None
    gen_dir = osp.join(cfg.output_dir, cfg.EVAL_mode)
    logger.message("Dump figure in {}".format(gen_dir))
    if not osp.exists(gen_dir):
        os.makedirs(gen_dir)
    for seed in range(cfg.EVAL.inter_times):
        filepath = osp.join(
            gen_dir, "2points_interpolation-2point_molecules_seed{}".format(seed)
        )
        visualize_interpolation_between_2_points(
            filepath,
            model,
            mol_smiles=mol_smiles,
            mols_per_row=15,
            n_interpolation=50,
            atomic_num_list=atomic_num_list,
            seed=seed,
            true_data=inputs,
            data_name=cfg.data_name,
        )
    exit(0)

if cfg.EVAL.intgrid:
    inputs = train.input
    labels = train.label
    items = []
    for idx in range(len(train)):
        input_item = [value[idx] for key, value in inputs.items()]
        label_item = [value[idx] for key, value in labels.items()]
        item = input_item + label_item
        item = transform_fn(item)
        items.append(item)
    items = np.array(items, dtype=object).T
    inputs = {key: np.stack(items[i], axis=0) for i, key in enumerate(inputs)}

    mol_smiles = None
    gen_dir = os.path.join(cfg.output_dir, cfg.EVAL_mode)
    logger.message("Dump figure in {}".format(gen_dir))
    if not os.path.exists(gen_dir):
        os.makedirs(gen_dir)
    for seed in range(cfg.EVAL.inter_times):
        filepath = os.path.join(
            gen_dir, "generated_interpolation-grid_molecules_seed{}".format(seed)
        )
        visualize_interpolation(
            filepath,
            model,
            mol_smiles=mol_smiles,
            mols_per_row=9,
            delta=cfg.EVAL.delta,
            atomic_num_list=atomic_num_list,
            seed=seed,
            true_data=inputs,
            data_name=cfg.data_name,
            keep_duplicate=True,
        )
        filepath = os.path.join(
            gen_dir,
            "generated_interpolation-grid_molecules_seed{}_unique".format(seed),
        )
        visualize_interpolation(
            filepath,
            model,
            mol_smiles=mol_smiles,
            mols_per_row=9,
            delta=cfg.EVAL.delta,
            atomic_num_list=atomic_num_list,
            seed=seed,
            true_data=inputs,
            data_name=cfg.data_name,
            keep_duplicate=False,
        )
    exit(0)

inputs = train.input

3.8 Model Optimization Construction

After the model completes training, molecular optimization and constrained optimization are performed. An additional MLP model is trained from latent space to QED property or plogp property to obtain optimized molecular properties and evaluate them, subject to constraints on similarity with properties. If running for the first time, the selected pre-trained model will be optimized and trained. Different properties store different optimization models. QED property saves with prefix qed, and plogp property saves with prefix plogp. Relevant optimization models will be saved to the specified folder. When running for the second time, the optimized model will be evaluated. The code is as follows:

examples/moflow/optimize_moflow.py
# set dataloader config
dataloader_cfg = {
    "dataset": {
        "name": "MOlFLOWDataset",
        "file_path": cfg.FILE_PATH,
        "data_name": cfg.data_name,
        "mode": cfg.mode,
        "valid_idx": valid_idx,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.get(cfg.data_name).label_keys,
        "smiles_col": cfg.get(cfg.data_name).smiles_col,
        "transform_fn": transform_fn,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.OPTIMIZE.batch_size,
    "num_workers": 0,
}

# set model
model_cfg = dict(cfg.MODEL)
model_cfg.update({"hyper_params": model_params})
model = ppsci.arch.MoFlowNet(**model_cfg)
ppsci.utils.save_load.load_pretrain(model, path=cfg.TRAIN.pretrained_model_path)

model_prop_cfg = dict(cfg.MODEL_Prop)
model_prop_cfg.update(
    {
        "model": model,
        "hidden_size": hidden,
    }
)
property_model = ppsci.arch.MoFlowProp(**model_prop_cfg)
train = ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])
train_dataloader = ppsci.data.build_dataloader(train, dataloader_cfg)
train_idx = train.train_idx
property_model_path = osp.join(
    cfg.output_dir, "{}_model.pdparams".format(cfg.OPTIMIZE.property_name)
)

if not osp.exists(property_model_path):
    logger.message("Training regression model over molecular embedding:")
    property_csv_path = osp.join(
        cfg.FILE_PATH, "{}_property.csv".format(cfg.data_name)
    )
    prop_list = load_property_csv(property_csv_path, normalize=True)
    train_prop = [prop_list[i] for i in train_idx]
    # test_prop = [prop_list[i] for i in valid_idx]

    N = len(train)
    property_model = fit_model(
        property_model,
        train_dataloader,
        train_prop,
        N,
        property_name=cfg.OPTIMIZE.property_name,
        max_epochs=cfg.OPTIMIZE.max_epochs,
        learning_rate=cfg.OPTIMIZE.learning_rate,
        weight_decay=cfg.OPTIMIZE.weight_decay,
    )
    logger.message(
        "saving {} regression model to: {}".format(
            cfg.OPTIMIZE.property_name, property_model_path
        )
    )
    paddle.save(obj=property_model.state_dict(), path=property_model_path)

else:
    logger.message("Loading trained regression model for optimization")
    property_csv_path = osp.join(
        cfg.FILE_PATH, "{}_property.csv".format(cfg.data_name)
    )
    prop_list = load_property_csv(property_csv_path, normalize=True)
    train_prop = [prop_list[i] for i in train_idx]
    # test_prop = [prop_list[i] for i in valid_idx]

    logger.message(
        "loading {} regression model from: {}".format(
            cfg.OPTIMIZE.property_name, property_model_path
        )
    )

    state_dict = paddle.load(path=property_model_path)
    property_model.set_state_dict(state_dict)
    property_model.eval()
    model.eval()
    if cfg.OPTIMIZE.topscore:
        logger.message("Finding top score:")
        find_top_score_smiles(
            model,
            property_model,
            cfg.data_name,
            cfg.OPTIMIZE.property_name,
            train_prop,
            cfg.OPTIMIZE.topk,
            atomic_num_list,
            cfg.OPTIMIZE.debug,
            cfg.output_dir,
        )
    if cfg.OPTIMIZE.consopt:
        logger.message("Constrained optimization:")
        constrain_optimization_smiles(
            model,
            property_model,
            cfg.data_name,
            cfg.OPTIMIZE.property_name,
            train_prop,
            cfg.OPTIMIZE.topk,
            atomic_num_list,
            cfg.OPTIMIZE.debug,
            cfg.output_dir,
            sim_cutoff=cfg.OPTIMIZE.sim_cutoff,
        )

The main parameters are similar to training, and evaluation parameters need to be set separately. Parameters are set through the configuration file as follows:

examples/moflow/conf/moflow_optimize.yaml
# optimize settings
OPTIMIZE:
  property_name: plogp # qed/plogp
  batch_size: 256
  topk: 800
  debug: false
  topscore: false
  max_epochs: 3
  learning_rate: 0.001
  weight_decay: 1e-2
  hidden: [16] # Hidden dimension list for output regression
  temperature: 1.0
  consopt: true

4. Complete Training Code

examples/moflow/moflow_train.py
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from os import path as osp

import hydra
import moflow_transform
import numpy as np
import paddle
from moflow_utils import Hyperparameters
from moflow_utils import check_validity
from omegaconf import DictConfig
from tabulate import tabulate

import ppsci
from ppsci.utils import logger


def infer(model, batch_size=20, temp=0.7, z_mu=None, true_adj=None):
    """generate mols

    Args:
        model (object): Generated eval Moflownet model
        batch_size (int, optional): Batch size during evaling per GPU. Defaults to 20.
        temp (float, optional): temperature of the gaussian distribution. Defaults to 0.7.
        z_mu (int, optional): latent vector of a molecule. Defaults to None.
        true_adj (paddle.Tensor, optional): True Adjacency. Defaults to None.

    Returns:
        Tuple(paddle.Tensor, paddle.Tensor): Adjacency and nodes
    """
    z_dim = model.b_size + model.a_size
    mu = np.zeros(z_dim)
    sigma_diag = np.ones(z_dim)
    if model.hyper_params.learn_dist:
        if len(model.ln_var) == 1:
            sigma_diag = np.sqrt(np.exp(model.ln_var.item())) * sigma_diag
        elif len(model.ln_var) == 2:
            sigma_diag[: model.b_size] = (
                np.sqrt(np.exp(model.ln_var[0].item())) * sigma_diag[: model.b_size]
            )
            sigma_diag[model.b_size + 1 :] = (
                np.sqrt(np.exp(model.ln_var[1].item())) * sigma_diag[model.b_size + 1 :]
            )
    sigma = temp * sigma_diag
    with paddle.no_grad():
        if z_mu is not None:
            mu = z_mu
            sigma = 0.01 * np.eye(z_dim)
        z = np.random.normal(mu, sigma, (batch_size, z_dim))
        z = paddle.to_tensor(data=z).astype(paddle.get_default_dtype())
        adj, x = model.reverse(z, true_adj=true_adj)
    return adj, x


class eval_func:
    def __init__(
        self,
        metrics_mode,
        batch_size,
        atomic_num_list,
        *args,
    ):
        super().__init__()
        self.metrics_mode = metrics_mode
        self.batch_size = batch_size
        self.atomic_num_list = atomic_num_list

    def __call__(
        self,
        output_dict,
        label_dict,
    ):
        self.metrics_mode.eval()
        adj, x = infer(self.metrics_mode, self.batch_size)
        validity_info = check_validity(adj, x, self.atomic_num_list)
        self.metrics_mode.train()
        results = dict()
        results["valid"] = validity_info["valid_ratio"]
        results["unique"] = validity_info["unique_ratio"]
        results["abs_unique"] = validity_info["abs_unique_ratio"]
        return results


def train(cfg: DictConfig):
    # set training hyper-parameters
    b_hidden_ch = cfg.get(cfg.data_name).b_hidden_ch
    a_hidden_gnn = cfg.get(cfg.data_name).a_hidden_gnn
    a_hidden_lin = cfg.get(cfg.data_name).a_hidden_lin
    mask_row_size_list = list(cfg.get(cfg.data_name).mask_row_size_list)
    mask_row_stride_list = list(cfg.get(cfg.data_name).mask_row_stride_list)
    a_n_type = len(cfg.get(cfg.data_name).atomic_num_list)
    atomic_num_list = list(cfg.get(cfg.data_name).atomic_num_list)

    model_params = Hyperparameters(
        b_n_type=cfg.get(cfg.data_name).b_n_type,
        b_n_flow=cfg.get(cfg.data_name).b_n_flow,
        b_n_block=cfg.get(cfg.data_name).b_n_block,
        b_n_squeeze=cfg.get(cfg.data_name).b_n_squeeze,
        b_hidden_ch=b_hidden_ch,
        b_affine=True,
        b_conv_lu=cfg.get(cfg.data_name).b_conv_lu,
        a_n_node=cfg.get(cfg.data_name).a_n_node,
        a_n_type=a_n_type,
        a_hidden_gnn=a_hidden_gnn,
        a_hidden_lin=a_hidden_lin,
        a_n_flow=cfg.get(cfg.data_name).a_n_flow,
        a_n_block=cfg.get(cfg.data_name).a_n_block,
        mask_row_size_list=mask_row_size_list,
        mask_row_stride_list=mask_row_stride_list,
        a_affine=True,
        learn_dist=cfg.get(cfg.data_name).learn_dist,
        seed=cfg.seed,
        noise_scale=cfg.get(cfg.data_name).noise_scale,
    )

    logger.info("Model params:\n" + tabulate(model_params.print()))

    # set transforms
    if cfg.data_name == "qm9":
        transform_fn = moflow_transform.transform_fn
    elif cfg.data_name == "zinc250k":
        transform_fn = moflow_transform.transform_fn_zinc250k

    # set select eval data
    valid_idx_path = osp.join(cfg.FILE_PATH, cfg.get(cfg.data_name).valid_idx)
    valid_idx = moflow_transform.get_val_ids(valid_idx_path, cfg.data_name)

    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "MOlFLOWDataset",
            "file_path": cfg.FILE_PATH,
            "data_name": cfg.data_name,
            "mode": cfg.mode,
            "valid_idx": valid_idx,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.get(cfg.data_name).label_keys,
            "smiles_col": cfg.get(cfg.data_name).smiles_col,
            "transform_fn": transform_fn,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": cfg.TRAIN.num_workers,
    }

    # set model
    model_cfg = dict(cfg.MODEL)
    model_cfg.update({"hyper_params": model_params})
    model = ppsci.arch.MoFlowNet(**model_cfg)

    # set constraint
    output_keys = cfg.MODEL.output_keys
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.FunctionalLoss(model.log_prob_loss),
        {key: (lambda out, k=key: out[k]) for key in output_keys},
        name="Sup_constraint",
    )

    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # init optimizer and lr scheduler
    optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "MOlFLOWDataset",
            "file_path": cfg.FILE_PATH,
            "data_name": cfg.data_name,
            "mode": "eval",
            "valid_idx": valid_idx,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.get(cfg.data_name).label_keys,
            "smiles_col": cfg.get(cfg.data_name).smiles_col,
            "transform_fn": transform_fn,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.FunctionalLoss(model.log_prob_loss),
        {key: (lambda out, k=key: out[k]) for key in output_keys},
        metric={
            "Valid": ppsci.metric.FunctionalMetric(
                eval_func(model, cfg.EVAL.batch_size, atomic_num_list)
            )
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        None,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        seed=cfg.seed,
        validator=validator,
        save_freq=cfg.TRAIN.save_freq,
        eval_during_train=cfg.TRAIN.eval_during_train,
        eval_freq=cfg.TRAIN.eval_freq,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()

    # validation for training
    solver.eval()


@hydra.main(version_base=None, config_path="./conf", config_name="moflow_train.yaml")
def main(cfg: DictConfig):
    train(cfg)


if __name__ == "__main__":
    main()

5. References