跳转至

Arch(网络模型) 模块

ppsci.arch

Arch

Bases: Layer

Base class for Network.

Source code in ppsci/arch/base.py
class Arch(nn.Layer):
    """Base class for Network."""

    input_keys: Tuple[str, ...]
    output_keys: Tuple[str, ...]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._input_transform: Callable[
            [Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]
        ] = None

        self._output_transform: Callable[
            [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
            Dict[str, paddle.Tensor],
        ] = None

    def forward(self, *args, **kwargs):
        raise NotImplementedError("Arch.forward is not implemented")

    @property
    def num_params(self) -> int:
        """Return number of parameters within network.

        Returns:
            int: Number of parameters.
        """
        num = 0
        for name, param in self.named_parameters():
            if hasattr(param, "shape"):
                num += np.prod(list(param.shape))
            else:
                logger.warning(f"{name} has no attribute 'shape'")
        return num

    def concat_to_tensor(
        self, data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1
    ) -> Tuple[paddle.Tensor, ...]:
        """Concatenate tensors from dict in the order of given keys.

        Args:
            data_dict (Dict[str, paddle.Tensor]): Dict contains tensor.
            keys (Tuple[str, ...]): Keys tensor fetched from.
            axis (int, optional): Axis concatenate at. Defaults to -1.

        Returns:
            Tuple[paddle.Tensor, ...]: Concatenated tensor.
        """
        if len(keys) == 1:
            return data_dict[keys[0]]
        data = [data_dict[key] for key in keys]
        return paddle.concat(data, axis)

    def split_to_dict(
        self, data_tensor: paddle.Tensor, keys: Tuple[str, ...], axis=-1
    ) -> Dict[str, paddle.Tensor]:
        """Split tensor and wrap into a dict by given keys.

        Args:
            data_tensor (paddle.Tensor): Tensor to be split.
            keys (Tuple[str, ...]): Keys tensor mapping to.
            axis (int, optional): Axis split at. Defaults to -1.

        Returns:
            Dict[str, paddle.Tensor]: Dict contains tensor.
        """
        if len(keys) == 1:
            return {keys[0]: data_tensor}
        data = paddle.split(data_tensor, len(keys), axis=axis)
        return {key: data[i] for i, key in enumerate(keys)}

    def register_input_transform(
        self,
        transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
    ):
        """Register input transform.

        Args:
            transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
                Input transform of network, receive a single tensor dict and return a single tensor dict.
        """
        self._input_transform = transform

    def register_output_transform(
        self,
        transform: Callable[
            [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
            Dict[str, paddle.Tensor],
        ],
    ):
        """Register output transform.

        Args:
            transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
                Output transform of network, receive two single tensor dict(raw input
                and raw output) and return a single tensor dict(transformed output).
        """
        self._output_transform = transform

    def __str__(self):
        num_fc = 0
        num_conv = 0
        num_bn = 0
        for layer in self.sublayers(include_self=True):
            if isinstance(layer, nn.Linear):
                num_fc += 1
            elif isinstance(layer, (nn.Conv2D, nn.Conv3D, nn.Conv1D)):
                num_conv += 1
            elif isinstance(layer, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm3D)):
                num_bn += 1

        return ", ".join(
            [
                self.__class__.__name__,
                f"input_keys = {self.input_keys}",
                f"output_keys = {self.output_keys}",
                f"num_fc = {num_fc}",
                f"num_conv = {num_conv}",
                f"num_bn = {num_bn}",
                f"num_params = {self.num_params}",
            ]
        )
num_params: int property

Return number of parameters within network.

Returns:

Name Type Description
int int

Number of parameters.

concat_to_tensor(data_dict, keys, axis=-1)

Concatenate tensors from dict in the order of given keys.

Parameters:

Name Type Description Default
data_dict Dict[str, Tensor]

Dict contains tensor.

required
keys Tuple[str, ...]

Keys tensor fetched from.

required
axis int

Axis concatenate at. Defaults to -1.

-1

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[paddle.Tensor, ...]: Concatenated tensor.

Source code in ppsci/arch/base.py
def concat_to_tensor(
    self, data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1
) -> Tuple[paddle.Tensor, ...]:
    """Concatenate tensors from dict in the order of given keys.

    Args:
        data_dict (Dict[str, paddle.Tensor]): Dict contains tensor.
        keys (Tuple[str, ...]): Keys tensor fetched from.
        axis (int, optional): Axis concatenate at. Defaults to -1.

    Returns:
        Tuple[paddle.Tensor, ...]: Concatenated tensor.
    """
    if len(keys) == 1:
        return data_dict[keys[0]]
    data = [data_dict[key] for key in keys]
    return paddle.concat(data, axis)
register_input_transform(transform)

Register input transform.

Parameters:

Name Type Description Default
transform Callable[[Dict[str, Tensor]], Dict[str, Tensor]]

Input transform of network, receive a single tensor dict and return a single tensor dict.

required
Source code in ppsci/arch/base.py
def register_input_transform(
    self,
    transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
):
    """Register input transform.

    Args:
        transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
            Input transform of network, receive a single tensor dict and return a single tensor dict.
    """
    self._input_transform = transform
register_output_transform(transform)

Register output transform.

Parameters:

Name Type Description Default
transform Callable[[Dict[str, Tensor], Dict[str, Tensor]], Dict[str, Tensor]]

Output transform of network, receive two single tensor dict(raw input and raw output) and return a single tensor dict(transformed output).

required
Source code in ppsci/arch/base.py
def register_output_transform(
    self,
    transform: Callable[
        [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
        Dict[str, paddle.Tensor],
    ],
):
    """Register output transform.

    Args:
        transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
            Output transform of network, receive two single tensor dict(raw input
            and raw output) and return a single tensor dict(transformed output).
    """
    self._output_transform = transform
split_to_dict(data_tensor, keys, axis=-1)

Split tensor and wrap into a dict by given keys.

Parameters:

Name Type Description Default
data_tensor Tensor

Tensor to be split.

required
keys Tuple[str, ...]

Keys tensor mapping to.

required
axis int

Axis split at. Defaults to -1.

-1

Returns:

Type Description
Dict[str, Tensor]

Dict[str, paddle.Tensor]: Dict contains tensor.

Source code in ppsci/arch/base.py
def split_to_dict(
    self, data_tensor: paddle.Tensor, keys: Tuple[str, ...], axis=-1
) -> Dict[str, paddle.Tensor]:
    """Split tensor and wrap into a dict by given keys.

    Args:
        data_tensor (paddle.Tensor): Tensor to be split.
        keys (Tuple[str, ...]): Keys tensor mapping to.
        axis (int, optional): Axis split at. Defaults to -1.

    Returns:
        Dict[str, paddle.Tensor]: Dict contains tensor.
    """
    if len(keys) == 1:
        return {keys[0]: data_tensor}
    data = paddle.split(data_tensor, len(keys), axis=axis)
    return {key: data[i] for i, key in enumerate(keys)}

AMGNet

Bases: Layer

A Multi-scale Graph neural Network model based on Encoder-Process-Decoder structure for flow field prediction.

https://doi.org/10.1080/09540091.2022.2131737

Code reference: https://github.com/baoshiaijhin/amgnet

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input", ).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("pred", ).

required
input_dim int

Number of input dimension.

required
output_dim int

Number of output dimension.

required
latent_dim int

Number of hidden(feature) dimension.

required
num_layers int

Number of layer(s).

required
message_passing_aggregator Literal['sum']

Message aggregator method in graph. Only "sum" available now.

required
message_passing_steps int

Message passing steps in graph.

required
speed str

Whether use vanilla method or fast method for graph_connectivity computation.

required

Examples:

>>> import ppsci
>>> model = ppsci.arch.AMGNet(("input", ), ("pred", ), 5, 3, 64, 2)
Source code in ppsci/arch/amgnet.py
class AMGNet(nn.Layer):
    """A Multi-scale Graph neural Network model
    based on Encoder-Process-Decoder structure for flow field prediction.

    https://doi.org/10.1080/09540091.2022.2131737

    Code reference: https://github.com/baoshiaijhin/amgnet

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input", ).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("pred", ).
        input_dim (int): Number of input dimension.
        output_dim (int): Number of output dimension.
        latent_dim (int): Number of hidden(feature) dimension.
        num_layers (int): Number of layer(s).
        message_passing_aggregator (Literal["sum"]): Message aggregator method in graph.
            Only "sum" available now.
        message_passing_steps (int): Message passing steps in graph.
        speed (str): Whether use vanilla method or fast method for graph_connectivity
            computation.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.AMGNet(("input", ), ("pred", ), 5, 3, 64, 2)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_dim: int,
        output_dim: int,
        latent_dim: int,
        num_layers: int,
        message_passing_aggregator: Literal["sum"],
        message_passing_steps: int,
        speed: Literal["norm", "fast"],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self._latent_dim = latent_dim
        self.speed = speed
        self._output_dim = output_dim
        self._num_layers = num_layers

        self.encoder = Encoder(input_dim, self._make_mlp, latent_dim=self._latent_dim)
        self.processor = Processor(
            make_mlp=self._make_mlp,
            output_dim=self._latent_dim,
            message_passing_steps=message_passing_steps,
            message_passing_aggregator=message_passing_aggregator,
            use_stochastic_message_passing=False,
        )
        self.post_processor = self._make_mlp(self._latent_dim, 128)
        self.decoder = Decoder(
            make_mlp=functools.partial(self._make_mlp, layer_norm=False),
            output_dim=self._output_dim,
        )

    def forward(self, x: Dict[str, "pgl.Graph"]) -> Dict[str, paddle.Tensor]:
        graphs = x[self.input_keys[0]]
        latent_graph = self.encoder(graphs)
        x, p = self.processor(latent_graph, speed=self.speed)
        node_features = self._spa_compute(x, p)
        pred_field = self.decoder(node_features)
        return {self.output_keys[0]: pred_field}

    def _make_mlp(self, output_dim: int, input_dim: int = 5, layer_norm: bool = True):
        widths = (self._latent_dim,) * self._num_layers + (output_dim,)
        network = FullyConnectedLayer(input_dim, widths)
        if layer_norm:
            network = nn.Sequential(network, nn.LayerNorm(normalized_shape=widths[-1]))
        return network

    def _spa_compute(self, x: List["pgl.Graph"], p):
        j = len(x) - 1
        node_features = x[j].x

        for k in range(1, j + 1):
            pos = p[-k]
            fine_nodes = x[-(k + 1)].pos
            feature = _knn_interpolate(node_features, pos, fine_nodes)
            node_features = x[-(k + 1)].x + feature
            node_features = self.post_processor(node_features)

        return node_features

MLP

Bases: Arch

Multi layer perceptron network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("x", "y", "z").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("u", "v", "w").

required
num_layers int

Number of hidden layers.

required
hidden_size Union[int, Tuple[int, ...]]

Number of hidden size. An integer for all layers, or list of integer specify each layer's size.

required
activation str

Name of activation function. Defaults to "tanh".

'tanh'
skip_connection bool

Whether to use skip connection. Defaults to False.

False
weight_norm bool

Whether to apply weight norm on parameter(s). Defaults to False.

False
input_dim Optional[int]

Number of input's dimension. Defaults to None.

None
output_dim Optional[int]

Number of output's dimension. Defaults to None.

None

Examples:

>>> import ppsci
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128)
Source code in ppsci/arch/mlp.py
class MLP(base.Arch):
    """Multi layer perceptron network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
        num_layers (int): Number of hidden layers.
        hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size.
            An integer for all layers, or list of integer specify each layer's size.
        activation (str, optional): Name of activation function. Defaults to "tanh".
        skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
        weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
        input_dim (Optional[int]): Number of input's dimension. Defaults to None.
        output_dim (Optional[int]): Number of output's dimension. Defaults to None.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 5, 128)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_layers: int,
        hidden_size: Union[int, Tuple[int, ...]],
        activation: str = "tanh",
        skip_connection: bool = False,
        weight_norm: bool = False,
        input_dim: Optional[int] = None,
        output_dim: Optional[int] = None,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.linears = []
        self.acts = []
        if isinstance(hidden_size, (tuple, list)):
            if num_layers is not None:
                raise ValueError(
                    "num_layers should be None when hidden_size is specified"
                )
        elif isinstance(hidden_size, int):
            if not isinstance(num_layers, int):
                raise ValueError(
                    "num_layers should be an int when hidden_size is an int"
                )
            hidden_size = [hidden_size] * num_layers
        else:
            raise ValueError(
                f"hidden_size should be list of int or int"
                f"but got {type(hidden_size)}"
            )

        # initialize FC layer(s)
        cur_size = len(self.input_keys) if input_dim is None else input_dim
        for i, _size in enumerate(hidden_size):
            self.linears.append(
                WeightNormLinear(cur_size, _size)
                if weight_norm
                else nn.Linear(cur_size, _size)
            )
            # initialize activation function
            self.acts.append(
                act_mod.get_activation(activation)
                if activation != "stan"
                else act_mod.get_activation(activation)(_size)
            )
            # special initialization for certain activation
            # TODO: Adapt code below to a more elegant style
            if activation == "siren":
                if i == 0:
                    act_mod.Siren.init_for_first_layer(self.linears[-1])
                else:
                    act_mod.Siren.init_for_hidden_layer(self.linears[-1])

            cur_size = _size

        self.linears = nn.LayerList(self.linears)
        self.acts = nn.LayerList(self.acts)
        self.last_fc = nn.Linear(
            cur_size,
            len(self.output_keys) if output_dim is None else output_dim,
        )

        self.skip_connection = skip_connection

    def forward_tensor(self, x):
        y = x
        skip = None
        for i, linear in enumerate(self.linears):
            y = linear(y)
            if self.skip_connection and i % 2 == 0:
                if skip is not None:
                    skip = y
                    y = y + skip
                else:
                    skip = y
            y = self.acts[i](y)

        y = self.last_fc(y)

        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

DeepONet

Bases: Arch

Deep operator network.

Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.

Parameters:

Name Type Description Default
u_key str

Name of function data for input function u(x).

required
y_key str

Name of location data for input function G(u).

required
G_key str

Output name of predicted G(u)(y).

required
num_loc int

Number of sampled u(x), i.e. m in paper.

required
num_features int

Number of features extracted from u(x), same for y.

required
branch_num_layers int

Number of hidden layers of branch net.

required
trunk_num_layers int

Number of hidden layers of trunk net.

required
branch_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of branch net. An integer for all layers, or list of integer specify each layer's size.

required
trunk_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of trunk net. An integer for all layers, or list of integer specify each layer's size.

required
branch_skip_connection bool

Whether to use skip connection for branch net. Defaults to False.

False
trunk_skip_connection bool

Whether to use skip connection for trunk net. Defaults to False.

False
branch_activation str

Name of activation function. Defaults to "tanh".

'tanh'
trunk_activation str

Name of activation function. Defaults to "tanh".

'tanh'
branch_weight_norm bool

Whether to apply weight norm on parameter(s) for branch net. Defaults to False.

False
trunk_weight_norm bool

Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.

False
use_bias bool

Whether to add bias on predicted G(u)(y). Defaults to True.

True

Examples:

>>> import ppsci
>>> model = ppsci.arch.DeepONet(
...     "u", "y", "G",
...     100, 40,
...     1, 1,
...     40, 40,
...     branch_activation="relu", trunk_activation="relu",
...     use_bias=True,
... )
Source code in ppsci/arch/deeponet.py
class DeepONet(base.Arch):
    """Deep operator network.

    [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5)

    Args:
        u_key (str): Name of function data for input function u(x).
        y_key (str): Name of location data for input function G(u).
        G_key (str): Output name of predicted G(u)(y).
        num_loc (int): Number of sampled u(x), i.e. `m` in paper.
        num_features (int): Number of features extracted from u(x), same for y.
        branch_num_layers (int): Number of hidden layers of branch net.
        trunk_num_layers (int): Number of hidden layers of trunk net.
        branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of branch net.
            An integer for all layers, or list of integer specify each layer's size.
        trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net.
            An integer for all layers, or list of integer specify each layer's size.
        branch_skip_connection (bool, optional): Whether to use skip connection for branch net. Defaults to False.
        trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False.
        branch_activation (str, optional): Name of activation function. Defaults to "tanh".
        trunk_activation (str, optional): Name of activation function. Defaults to "tanh".
        branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for branch net. Defaults to False.
        trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.
        use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.DeepONet(
        ...     "u", "y", "G",
        ...     100, 40,
        ...     1, 1,
        ...     40, 40,
        ...     branch_activation="relu", trunk_activation="relu",
        ...     use_bias=True,
        ... )
    """

    def __init__(
        self,
        u_key: str,
        y_key: str,
        G_key: str,
        num_loc: int,
        num_features: int,
        branch_num_layers: int,
        trunk_num_layers: int,
        branch_hidden_size: Union[int, Tuple[int, ...]],
        trunk_hidden_size: Union[int, Tuple[int, ...]],
        branch_skip_connection: bool = False,
        trunk_skip_connection: bool = False,
        branch_activation: str = "tanh",
        trunk_activation: str = "tanh",
        branch_weight_norm: bool = False,
        trunk_weight_norm: bool = False,
        use_bias: bool = True,
    ):
        super().__init__()
        self.u_key = u_key
        self.y_key = y_key
        self.input_keys = (u_key, y_key)
        self.output_keys = (G_key,)

        self.branch_net = mlp.MLP(
            (self.u_key,),
            ("b",),
            branch_num_layers,
            branch_hidden_size,
            branch_activation,
            branch_skip_connection,
            branch_weight_norm,
            input_dim=num_loc,
            output_dim=num_features,
        )

        self.trunk_net = mlp.MLP(
            (self.y_key,),
            ("t",),
            trunk_num_layers,
            trunk_hidden_size,
            trunk_activation,
            trunk_skip_connection,
            trunk_weight_norm,
            input_dim=1,
            output_dim=num_features,
        )
        self.trunk_act = act_mod.get_activation(trunk_activation)

        self.use_bias = use_bias
        if use_bias:
            # register bias to parameter for updating in optimizer and storage
            self.b = self.create_parameter(
                shape=(1,),
                attr=nn.initializer.Constant(0.0),
            )

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        # Branch net to encode the input function
        u_features = self.branch_net(x)[self.branch_net.output_keys[0]]

        # Trunk net to encode the domain of the output function
        y_features = self.trunk_net(x)
        y_features = self.trunk_act(y_features[self.trunk_net.output_keys[0]])

        # Dot product
        G_u = paddle.einsum("bi,bi->b", u_features, y_features)  # [batch_size, ]
        G_u = paddle.reshape(G_u, [-1, 1])  # reshape [batch_size, ] to [batch_size, 1]

        # Add bias
        if self.use_bias:
            G_u += self.b

        result_dict = {
            self.output_keys[0]: G_u,
        }
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)

        return result_dict

DeepPhyLSTM

Bases: Arch

DeepPhyLSTM init function.

Parameters:

Name Type Description Default
input_size int

The input size.

required
output_size int

The output size.

required
hidden_size int

The hidden size. Defaults to 100.

100
model_type int

The model type, value is 2 or 3, 2 indicates having two sub-models, 3 indicates having three submodels. Defaults to 2.

2

Examples:

>>> import ppsci
>>> model = ppsci.arch.DeepPhyLSTM(1, 1, 100)
Source code in ppsci/arch/phylstm.py
class DeepPhyLSTM(base.Arch):
    """DeepPhyLSTM init function.

    Args:
        input_size (int): The input size.
        output_size (int): The output size.
        hidden_size (int, optional): The hidden size. Defaults to 100.
        model_type (int, optional): The model type, value is 2 or 3, 2 indicates having two sub-models, 3 indicates having three submodels. Defaults to 2.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.DeepPhyLSTM(1, 1, 100)
    """

    def __init__(self, input_size, output_size, hidden_size=100, model_type=2):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.model_type = model_type

        if self.model_type == 2:
            self.lstm_model = nn.Sequential(
                nn.LSTM(input_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, 3 * output_size),
            )

            self.lstm_model_f = nn.Sequential(
                nn.LSTM(3 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, output_size),
            )
        elif self.model_type == 3:
            self.lstm_model = nn.Sequential(
                nn.LSTM(1, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 3 * output_size),
            )

            self.lstm_model_f = nn.Sequential(
                nn.LSTM(3 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, output_size),
            )

            self.lstm_model_g = nn.Sequential(
                nn.LSTM(2 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, output_size),
            )
        else:
            raise ValueError(f"model_type should be 2 or 3, but got {model_type})")

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        if self.model_type == 2:
            result_dict = self._forward_type_2(x)
        elif self.model_type == 3:
            result_dict = self._forward_type_3(x)
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)
        return result_dict

    def _forward_type_2(self, x):
        output = self.lstm_model(x["ag"])
        eta_pred = output[:, :, 0 : self.output_size]
        eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
        g_pred = output[:, :, 2 * self.output_size :]

        # for ag_c
        output_c = self.lstm_model(x["ag_c"])
        eta_pred_c = output_c[:, :, 0 : self.output_size]
        eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
        g_pred_c = output_c[:, :, 2 * self.output_size :]
        eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
        eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
        eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
        tmp = paddle.concat([eta_pred_c, eta_dot1_pred_c, g_pred_c], 2)
        f = self.lstm_model_f(tmp)
        lift_pred_c = eta_tt_pred_c + f

        return {
            "eta_pred": eta_pred,
            "eta_dot_pred": eta_dot_pred,
            "g_pred": g_pred,
            "eta_t_pred_c": eta_t_pred_c,
            "eta_dot_pred_c": eta_dot_pred_c,
            "lift_pred_c": lift_pred_c,
        }

    def _forward_type_3(self, x):
        # physics informed neural networks
        output = self.lstm_model(x["ag"])
        eta_pred = output[:, :, 0 : self.output_size]
        eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
        g_pred = output[:, :, 2 * self.output_size :]

        output_c = self.lstm_model(x["ag_c"])
        eta_pred_c = output_c[:, :, 0 : self.output_size]
        eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
        g_pred_c = output_c[:, :, 2 * self.output_size :]

        eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
        eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
        g_t_pred_c = paddle.matmul(x["phi"], g_pred_c)

        f = self.lstm_model_f(paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2))
        lift_pred_c = eta_tt_pred_c + f

        eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
        g_dot_pred_c = self.lstm_model_g(paddle.concat([eta_dot1_pred_c, g_pred_c], 2))

        return {
            "eta_pred": eta_pred,
            "eta_dot_pred": eta_dot_pred,
            "g_pred": g_pred,
            "eta_t_pred_c": eta_t_pred_c,
            "eta_dot_pred_c": eta_dot_pred_c,
            "lift_pred_c": lift_pred_c,
            "g_t_pred_c": g_t_pred_c,
            "g_dot_pred_c": g_dot_pred_c,
        }

LorenzEmbedding

Bases: Arch

Embedding Koopman model for the Lorenz ODE system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
input_size int

Size of input data. Defaults to 3.

3
hidden_size int

Number of hidden size. Defaults to 500.

500
embed_size int

Number of embedding size. Defaults to 32.

32
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import ppsci
>>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v"))
Source code in ppsci/arch/embedding_koopman.py
class LorenzEmbedding(base.Arch):
    """Embedding Koopman model for the Lorenz ODE system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        input_size (int, optional): Size of input data. Defaults to 3.
        hidden_size (int, optional): Number of hidden size. Defaults to 500.
        embed_size (int, optional): Number of embedding size. Defaults to 32.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.LorenzEmbedding(("x", "y"), ("u", "v"))
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        input_size: int = 3,
        hidden_size: int = 500,
        embed_size: int = 32,
        drop: float = 0.0,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embed_size = embed_size

        # build observable network
        self.encoder_net = self.build_encoder(input_size, hidden_size, embed_size, drop)
        # build koopman operator
        self.k_diag, self.k_ut = self.build_koopman_operator(embed_size)
        # build recovery network
        self.decoder_net = self.build_decoder(input_size, hidden_size, embed_size)

        mean = [0.0, 0.0, 0.0] if mean is None else mean
        std = [1.0, 1.0, 1.0] if std is None else std
        self.register_buffer("mean", paddle.to_tensor(mean).reshape([1, 3]))
        self.register_buffer("std", paddle.to_tensor(std).reshape([1, 3]))

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Layer):
        if isinstance(m, nn.Linear):
            k = 1 / m.weight.shape[0]
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def build_encoder(
        self, input_size: int, hidden_size: int, embed_size: int, drop: float = 0.0
    ):
        net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embed_size),
            nn.LayerNorm(embed_size),
            nn.Dropout(drop),
        )
        return net

    def build_decoder(self, input_size: int, hidden_size: int, embed_size: int):
        net = nn.Sequential(
            nn.Linear(embed_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
        )
        return net

    def build_koopman_operator(self, embed_size: int):
        # Learned Koopman operator
        data = paddle.linspace(1, 0, embed_size)
        k_diag = paddle.create_parameter(
            shape=data.shape,
            dtype=paddle.get_default_dtype(),
            default_initializer=nn.initializer.Assign(data),
        )

        data = 0.1 * paddle.rand([2 * embed_size - 3])
        k_ut = paddle.create_parameter(
            shape=data.shape,
            dtype=paddle.get_default_dtype(),
            default_initializer=nn.initializer.Assign(data),
        )
        return k_diag, k_ut

    def encoder(self, x: paddle.Tensor):
        x = self._normalize(x)
        g = self.encoder_net(x)
        return g

    def decoder(self, g: paddle.Tensor):
        out = self.decoder_net(g)
        x = self._unnormalize(out)
        return x

    def koopman_operation(self, embed_data: paddle.Tensor, k_matrix: paddle.Tensor):
        # Apply Koopman operation
        embed_pred_data = paddle.bmm(
            k_matrix.expand(
                [embed_data.shape[0], k_matrix.shape[0], k_matrix.shape[1]]
            ),
            embed_data.transpose([0, 2, 1]),
        ).transpose([0, 2, 1])
        return embed_pred_data

    def _normalize(self, x: paddle.Tensor):
        return (x - self.mean) / self.std

    def _unnormalize(self, x: paddle.Tensor):
        return self.std * x + self.mean

    def get_koopman_matrix(self):
        # # Koopman operator
        k_ut_tensor = self.k_ut * 1
        k_ut_tensor = paddle.diag(
            k_ut_tensor[0 : self.embed_size - 1], offset=1
        ) + paddle.diag(k_ut_tensor[self.embed_size - 1 :], offset=2)
        k_matrix = k_ut_tensor + (-1) * k_ut_tensor.t()
        k_matrix = k_matrix + paddle.diag(self.k_diag)
        return k_matrix

    def forward_tensor(self, x):
        k_matrix = self.get_koopman_matrix()
        embed_data = self.encoder(x)
        recover_data = self.decoder(embed_data)

        embed_pred_data = self.koopman_operation(embed_data, k_matrix)
        pred_data = self.decoder(embed_pred_data)

        return (pred_data[:, :-1, :], recover_data, k_matrix)

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(x_tensor)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

RosslerEmbedding

Bases: LorenzEmbedding

Embedding Koopman model for the Rossler ODE system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
input_size int

Size of input data. Defaults to 3.

3
hidden_size int

Number of hidden size. Defaults to 500.

500
embed_size int

Number of embedding size. Defaults to 32.

32
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import ppsci
>>> model = ppsci.arch.RosslerEmbedding(("x", "y"), ("u", "v"))
Source code in ppsci/arch/embedding_koopman.py
class RosslerEmbedding(LorenzEmbedding):
    """Embedding Koopman model for the Rossler ODE system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        input_size (int, optional): Size of input data. Defaults to 3.
        hidden_size (int, optional): Number of hidden size. Defaults to 500.
        embed_size (int, optional): Number of embedding size. Defaults to 32.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.RosslerEmbedding(("x", "y"), ("u", "v"))
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        input_size: int = 3,
        hidden_size: int = 500,
        embed_size: int = 32,
        drop: float = 0.0,
    ):
        super().__init__(
            input_keys,
            output_keys,
            mean,
            std,
            input_size,
            hidden_size,
            embed_size,
            drop,
        )

CylinderEmbedding

Bases: Arch

Embedding Koopman model for the Cylinder system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states", "visc").

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
embed_size int

Number of embedding size. Defaults to 128.

128
encoder_channels Optional[Tuple[int, ...]]

Number of channels in encoder network. Defaults to None.

None
decoder_channels Optional[Tuple[int, ...]]

Number of channels in decoder network. Defaults to None.

None
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import ppsci
>>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v"))
Source code in ppsci/arch/embedding_koopman.py
class CylinderEmbedding(base.Arch):
    """Embedding Koopman model for the Cylinder system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc").
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        embed_size (int, optional): Number of embedding size. Defaults to 128.
        encoder_channels (Optional[Tuple[int, ...]]): Number of channels in encoder network. Defaults to None.
        decoder_channels (Optional[Tuple[int, ...]]): Number of channels in decoder network. Defaults to None.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.CylinderEmbedding(("x", "y"), ("u", "v"))
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        embed_size: int = 128,
        encoder_channels: Optional[Tuple[int, ...]] = None,
        decoder_channels: Optional[Tuple[int, ...]] = None,
        drop: float = 0.0,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.embed_size = embed_size

        X, Y = np.meshgrid(np.linspace(-2, 14, 128), np.linspace(-4, 4, 64))
        self.mask = paddle.to_tensor(np.sqrt(X**2 + Y**2)).unsqueeze(0).unsqueeze(0)

        encoder_channels = (
            [4, 16, 32, 64, 128] if encoder_channels is None else encoder_channels
        )
        decoder_channels = (
            [embed_size // 32, 128, 64, 32, 16]
            if decoder_channels is None
            else decoder_channels
        )
        self.encoder_net = self.build_encoder(embed_size, encoder_channels, drop)
        self.k_diag_net, self.k_ut_net, self.k_lt_net = self.build_koopman_operator(
            embed_size
        )
        self.decoder_net = self.build_decoder(decoder_channels)

        xidx = []
        yidx = []
        for i in range(1, 5):
            yidx.append(np.arange(i, embed_size))
            xidx.append(np.arange(0, embed_size - i))
        self.xidx = paddle.to_tensor(np.concatenate(xidx), dtype="int64")
        self.yidx = paddle.to_tensor(np.concatenate(yidx), dtype="int64")

        mean = [0.0, 0.0, 0.0, 0.0] if mean is None else mean
        std = [1.0, 1.0, 1.0, 1.0] if std is None else std
        self.register_buffer("mean", paddle.to_tensor(mean).reshape([1, 4, 1, 1]))
        self.register_buffer("std", paddle.to_tensor(std).reshape([1, 4, 1, 1]))

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            k = 1 / m.weight.shape[0]
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)
        elif isinstance(m, nn.Conv2D):
            k = 1 / (m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3])
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)

    def _build_conv_relu_list(
        self, in_channels: Tuple[int, ...], out_channels: Tuple[int, ...]
    ):
        net_list = [
            nn.Conv2D(
                in_channels,
                out_channels,
                kernel_size=(3, 3),
                stride=2,
                padding=1,
                padding_mode="replicate",
            ),
            nn.ReLU(),
        ]
        return net_list

    def build_encoder(
        self, embed_size: int, channels: Tuple[int, ...], drop: float = 0.0
    ):
        net = []
        for i in range(1, len(channels)):
            net.extend(self._build_conv_relu_list(channels[i - 1], channels[i]))
        net.append(
            nn.Conv2D(
                channels[-1],
                embed_size // 32,
                kernel_size=(3, 3),
                padding=1,
                padding_mode="replicate",
            )
        )
        net.append(
            nn.LayerNorm(
                (4, 4, 8),
            )
        )
        net.append(nn.Dropout(drop))
        net = nn.Sequential(*net)
        return net

    def _build_upsample_conv_relu(
        self, in_channels: Tuple[int, ...], out_channels: Tuple[int, ...]
    ):
        net_list = [
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2D(
                in_channels,
                out_channels,
                kernel_size=(3, 3),
                stride=1,
                padding=1,
                padding_mode="replicate",
            ),
            nn.ReLU(),
        ]
        return net_list

    def build_decoder(self, channels: Tuple[int, ...]):
        net = []
        for i in range(1, len(channels)):
            net.extend(self._build_upsample_conv_relu(channels[i - 1], channels[i]))
        net.append(
            nn.Conv2D(
                channels[-1],
                3,
                kernel_size=(3, 3),
                stride=1,
                padding=1,
                padding_mode="replicate",
            ),
        )
        net = nn.Sequential(*net)
        return net

    def build_koopman_operator(self, embed_size: int):
        # Learned Koopman operator parameters
        k_diag_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, embed_size)
        )

        k_ut_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, 4 * embed_size - 10)
        )
        k_lt_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, 4 * embed_size - 10)
        )
        return k_diag_net, k_ut_net, k_lt_net

    def encoder(self, x: paddle.Tensor, viscosity: paddle.Tensor):
        B, T, C, H, W = x.shape
        x = x.reshape((B * T, C, H, W))
        viscosity = viscosity.repeat_interleave(T, axis=1).reshape((B * T, 1))
        x = paddle.concat(
            [x, viscosity.unsqueeze(-1).unsqueeze(-1) * paddle.ones_like(x[:, :1])],
            axis=1,
        )
        x = self._normalize(x)
        g = self.encoder_net(x)
        g = g.reshape([B, T, -1])
        return g

    def decoder(self, g: paddle.Tensor):
        B, T, _ = g.shape
        x = self.decoder_net(g.reshape([-1, self.embed_size // 32, 4, 8]))
        x = self._unnormalize(x)
        mask0 = (
            self.mask.repeat_interleave(x.shape[1], axis=1).repeat_interleave(
                x.shape[0], axis=0
            )
            < 1
        )
        x[mask0] = 0
        _, C, H, W = x.shape
        x = x.reshape([B, T, C, H, W])
        return x

    def get_koopman_matrix(self, g: paddle.Tensor, visc: paddle.Tensor):
        # # Koopman operator
        kMatrix = paddle.zeros([g.shape[0], self.embed_size, self.embed_size])
        kMatrix.stop_gradient = False
        # Populate the off diagonal terms
        kMatrixUT_data = self.k_ut_net(100 * visc)
        kMatrixLT_data = self.k_lt_net(100 * visc)

        kMatrix = kMatrix.transpose([1, 2, 0])
        kMatrixUT_data_t = kMatrixUT_data.transpose([1, 0])
        kMatrixLT_data_t = kMatrixLT_data.transpose([1, 0])
        kMatrix[self.xidx, self.yidx] = kMatrixUT_data_t
        kMatrix[self.yidx, self.xidx] = kMatrixLT_data_t

        # Populate the diagonal
        ind = np.diag_indices(kMatrix.shape[1])
        ind = paddle.to_tensor(ind, dtype="int64")

        kMatrixDiag = self.k_diag_net(100 * visc)
        kMatrixDiag_t = kMatrixDiag.transpose([1, 0])
        kMatrix[ind[0], ind[1]] = kMatrixDiag_t
        return kMatrix.transpose([2, 0, 1])

    def koopman_operation(self, embed_data: paddle.Tensor, k_matrix: paddle.Tensor):
        embed_pred_data = paddle.bmm(
            k_matrix, embed_data.transpose([0, 2, 1])
        ).transpose([0, 2, 1])
        return embed_pred_data

    def _normalize(self, x: paddle.Tensor):
        x = (x - self.mean) / self.std
        return x

    def _unnormalize(self, x: paddle.Tensor):
        return self.std[:, :3] * x + self.mean[:, :3]

    def forward_tensor(self, states, visc):
        # states.shape=(B, T, C, H, W)
        embed_data = self.encoder(states, visc)
        recover_data = self.decoder(embed_data)

        k_matrix = self.get_koopman_matrix(embed_data, visc)
        embed_pred_data = self.koopman_operation(embed_data, k_matrix)
        pred_data = self.decoder(embed_pred_data)

        return (pred_data[:, :-1], recover_data, k_matrix)

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):

        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.forward_tensor(**x)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

Generator

Bases: Arch

Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is unique to "tempoGAN" example but not an open source network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input1", "input2").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output1", "output2").

required
in_channel int

Number of input channels of the first conv layer.

required
out_channels_tuple Tuple[Tuple[int, ...], ...]

Number of output channels of all conv layers, such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]

required
kernel_sizes_tuple Tuple[Tuple[int, ...], ...]

Number of kernel_size of all conv layers, such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]

required
strides_tuple Tuple[Tuple[int, ...], ...]

Number of stride of all conv layers, such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]

required
use_bns_tuple Tuple[Tuple[bool, ...], ...]

Whether to use the batch_norm layer after each conv layer.

required
acts_tuple Tuple[Tuple[str, ...], ...]

Whether to use the activation layer after each conv layer. If so, witch activation to use, such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

required

Examples:

>>> import ppsci
>>> in_channel = 1
>>> rb_channel0 = (2, 8, 8)
>>> rb_channel1 = (128, 128, 128)
>>> rb_channel2 = (32, 8, 8)
>>> rb_channel3 = (2, 1, 1)
>>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
>>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
>>> strides_tuple = ((1, 1, 1), ) * 4
>>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
>>> acts_tuple = (("relu", None, None), ) * 4
>>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
Source code in ppsci/arch/gan.py
class Generator(base.Arch):
    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
        unique to "tempoGAN" example but not an open source network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int): Number of input channels of the first conv layer.
        out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
            such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
        kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers,
            such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]
        strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers,
            such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]
        use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer.
        acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

    Examples:
        >>> import ppsci
        >>> in_channel = 1
        >>> rb_channel0 = (2, 8, 8)
        >>> rb_channel1 = (128, 128, 128)
        >>> rb_channel2 = (32, 8, 8)
        >>> rb_channel3 = (2, 1, 1)
        >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
        >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
        >>> strides_tuple = ((1, 1, 1), ) * 4
        >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
        >>> acts_tuple = (("relu", None, None), ) * 4
        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels_tuple: Tuple[Tuple[int, ...], ...],
        kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
        strides_tuple: Tuple[Tuple[int, ...], ...],
        use_bns_tuple: Tuple[Tuple[bool, ...], ...],
        acts_tuple: Tuple[Tuple[str, ...], ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels_tuple = out_channels_tuple
        self.kernel_sizes_tuple = kernel_sizes_tuple
        self.strides_tuple = strides_tuple
        self.use_bns_tuple = use_bns_tuple
        self.acts_tuple = acts_tuple

        self.init_blocks()

    def init_blocks(self):
        blocks_list = []
        for i in range(len(self.out_channels_tuple)):
            in_channel = (
                self.in_channel if i == 0 else self.out_channels_tuple[i - 1][-1]
            )
            blocks_list.append(
                VariantResBlock(
                    in_channel=in_channel,
                    out_channels=self.out_channels_tuple[i],
                    kernel_sizes=self.kernel_sizes_tuple[i],
                    strides=self.strides_tuple[i],
                    use_bns=self.use_bns_tuple[i],
                    acts=self.acts_tuple[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )
        self.blocks = nn.LayerList(blocks_list)

    def forward_tensor(self, x):
        y = x
        for block in self.blocks:
            y = block(y)
        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

Discriminator

Bases: Arch

Discriminator Net of GAN.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input1", "input2").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output1", "output2").

required
in_channel int

Number of input channels of the first conv layer.

required
out_channels Tuple[int, ...]

Number of output channels of all conv layers, such as (out_conv0, out_conv1, out_conv2).

required
fc_channel int

Number of input features of linear layer. Number of output features of the layer is set to 1 in this Net to construct a fully_connected layer.

required
kernel_sizes Tuple[int, ...]

Number of kernel_size of all conv layers, such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).

required
strides Tuple[int, ...]

Number of stride of all conv layers, such as (stride_conv0, stride_conv1, stride_conv2).

required
use_bns Tuple[bool, ...]

Whether to use the batch_norm layer after each conv layer.

required
acts Tuple[str, ...]

Whether to use the activation layer after each conv layer. If so, witch activation to use, such as (act_conv0, act_conv1, act_conv2).

required

Examples:

>>> import ppsci
>>> in_channel = 2
>>> in_channel_tempo = 3
>>> out_channels = (32, 64, 128, 256)
>>> fc_channel = 65536
>>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
>>> strides = (2, 2, 2, 1)
>>> use_bns = (False, True, True, True)
>>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
>>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
>>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
Source code in ppsci/arch/gan.py
class Discriminator(base.Arch):
    """Discriminator Net of GAN.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int):  Number of input channels of the first conv layer.
        out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
            such as (out_conv0, out_conv1, out_conv2).
        fc_channel (int):  Number of input features of linear layer. Number of output features of the layer
            is set to 1 in this Net to construct a fully_connected layer.
        kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers,
            such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).
        strides (Tuple[int, ...]): Number of stride of all conv layers,
            such as (stride_conv0, stride_conv1, stride_conv2).
        use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer.
        acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as (act_conv0, act_conv1, act_conv2).

    Examples:
        >>> import ppsci
        >>> in_channel = 2
        >>> in_channel_tempo = 3
        >>> out_channels = (32, 64, 128, 256)
        >>> fc_channel = 65536
        >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
        >>> strides = (2, 2, 2, 1)
        >>> use_bns = (False, True, True, True)
        >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
        >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels: Tuple[int, ...],
        fc_channel: int,
        kernel_sizes: Tuple[int, ...],
        strides: Tuple[int, ...],
        use_bns: Tuple[bool, ...],
        acts: Tuple[str, ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.fc_channel = fc_channel
        self.kernel_sizes = kernel_sizes
        self.strides = strides
        self.use_bns = use_bns
        self.acts = acts

        self.init_layers()

    def init_layers(self):
        layers_list = []
        for i in range(len(self.out_channels)):
            in_channel = self.in_channel if i == 0 else self.out_channels[i - 1]
            layers_list.append(
                Conv2DBlock(
                    in_channel=in_channel,
                    out_channel=self.out_channels[i],
                    kernel_size=self.kernel_sizes[i],
                    stride=self.strides[i],
                    use_bn=self.use_bns[i],
                    act=self.acts[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )

        layers_list.append(
            FCBlock(self.fc_channel, self.acts[4], mean=0.0, std=0.04, value=0.1)
        )
        self.layers = nn.LayerList(layers_list)

    def forward_tensor(self, x):
        y = x
        y_list = []
        for layer in self.layers:
            y = layer(y)
            y_list.append(y)
        return y_list  # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out)

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y_list = []
        # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc
        for k in x:
            y_list.extend(self.forward_tensor(x[k]))

        y = self.split_to_dict(y_list, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)

        return y

    def split_to_dict(
        self, data_list: List[paddle.Tensor], keys: Tuple[str, ...]
    ) -> Dict[str, paddle.Tensor]:
        """Overwrite of split_to_dict() method belongs to Class base.Arch.

        Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
        That is because input in "tempoGAN" example is not in a regular format, but a format like:
        {
            "input1": paddle.concat([in1, in2], axis=1),
            "input2": paddle.concat([in1, in3], axis=1),
        }

        Args:
            data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
            keys (Tuple[str, ...]): Keys of outputs.

        Returns:
            Dict[str, paddle.Tensor]: Dict with split data.
        """
        if len(keys) == 1:
            return {keys[0]: data_list[0]}
        return {key: data_list[i] for i, key in enumerate(keys)}
split_to_dict(data_list, keys)

Overwrite of split_to_dict() method belongs to Class base.Arch.

Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example. That is because input in "tempoGAN" example is not in a regular format, but a format like: { "input1": paddle.concat([in1, in2], axis=1), "input2": paddle.concat([in1, in3], axis=1), }

Parameters:

Name Type Description Default
data_list List[Tensor]

The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.

required
keys Tuple[str, ...]

Keys of outputs.

required

Returns:

Type Description
Dict[str, Tensor]

Dict[str, paddle.Tensor]: Dict with split data.

Source code in ppsci/arch/gan.py
def split_to_dict(
    self, data_list: List[paddle.Tensor], keys: Tuple[str, ...]
) -> Dict[str, paddle.Tensor]:
    """Overwrite of split_to_dict() method belongs to Class base.Arch.

    Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
    That is because input in "tempoGAN" example is not in a regular format, but a format like:
    {
        "input1": paddle.concat([in1, in2], axis=1),
        "input2": paddle.concat([in1, in3], axis=1),
    }

    Args:
        data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
        keys (Tuple[str, ...]): Keys of outputs.

    Returns:
        Dict[str, paddle.Tensor]: Dict with split data.
    """
    if len(keys) == 1:
        return {keys[0]: data_list[0]}
    return {key: data_list[i] for i, key in enumerate(keys)}

PhysformerGPT2

Bases: Arch

Transformer decoder model for modeling physics.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("embeds",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_embeds",).

required
num_layers int

Number of transformer layers.

required
num_ctx int

Context length of block.

required
embed_size int

The number of embedding size.

required
num_heads int

The number of heads in multi-head attention.

required
embd_pdrop float

The dropout probability used on embedding features. Defaults to 0.0.

0.0
attn_pdrop float

The dropout probability used on attention weights. Defaults to 0.0.

0.0
resid_pdrop float

The dropout probability used on block outputs. Defaults to 0.0.

0.0
initializer_range float

Initializer range of linear layer. Defaults to 0.05.

0.05

Examples:

>>> import ppsci
>>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
Source code in ppsci/arch/physx_transformer.py
class PhysformerGPT2(base.Arch):
    """Transformer decoder model for modeling physics.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("embeds",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",).
        num_layers (int): Number of transformer layers.
        num_ctx (int): Context length of block.
        embed_size (int): The number of embedding size.
        num_heads (int): The number of heads in multi-head attention.
        embd_pdrop (float, optional): The dropout probability used on embedding features. Defaults to 0.0.
        attn_pdrop (float, optional): The dropout probability used on attention weights. Defaults to 0.0.
        resid_pdrop (float, optional): The dropout probability used on block outputs. Defaults to 0.0.
        initializer_range (float, optional): Initializer range of linear layer. Defaults to 0.05.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_layers: int,
        num_ctx: int,
        embed_size: int,
        num_heads: int,
        embd_pdrop: float = 0.0,
        attn_pdrop: float = 0.0,
        resid_pdrop: float = 0.0,
        initializer_range: float = 0.05,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.num_layers = num_layers
        self.num_ctx = num_ctx
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop
        self.initializer_range = initializer_range

        self.drop = nn.Dropout(embd_pdrop)
        self.blocks = nn.LayerList(
            [
                Block(
                    num_ctx, embed_size, num_heads, attn_pdrop, resid_pdrop, scale=True
                )
                for _ in range(num_layers)
            ]
        )
        self.ln = nn.LayerNorm(embed_size)
        self.linear = nn.Linear(embed_size, embed_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            normal_ = Normal(mean=0.0, std=self.initializer_range)
            normal_(module.weight)
            if module.bias is not None:
                zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            zeros_(module.bias)
            ones_(module.weight)

    def get_position_embed(self, x):
        B, N, _ = x.shape
        position_ids = paddle.arange(0, N, dtype=paddle.get_default_dtype()).reshape(
            [1, N, 1]
        )
        position_ids = position_ids.repeat_interleave(B, axis=0)

        position_embeds = paddle.zeros_like(x)
        i = paddle.arange(0, self.embed_size // 2).unsqueeze(0).unsqueeze(0)
        position_embeds[:, :, ::2] = paddle.sin(
            position_ids / 10000 ** (2 * i / self.embed_size)
        )
        position_embeds[:, :, 1::2] = paddle.cos(
            position_ids / 10000 ** (2 * i / self.embed_size)
        )
        return position_embeds

    def _generate_time_series(self, x, max_length):
        cur_len = x.shape[1]
        if cur_len >= max_length:
            raise ValueError(
                f"max_length({max_length}) should be larger than "
                f"the length of input context({cur_len})"
            )

        while cur_len < max_length:
            model_inputs = x[:, -1:]
            outputs = self.forward_tensor(model_inputs)
            next_output = outputs[0][:, -1:]
            x = paddle.concat([x, next_output], axis=1)
            cur_len = cur_len + 1
        return x

    @paddle.no_grad()
    def generate(self, x, max_length=256):
        if max_length <= 0:
            raise ValueError(
                "max_length({max_length}) should be a strictly positive integer."
            )
        outputs = self._generate_time_series(x, max_length)
        return outputs

    def forward_tensor(self, x):
        position_embeds = self.get_position_embed(x)
        # Combine input embedding, position embedding
        hidden_states = x + position_embeds
        hidden_states = self.drop(hidden_states)

        # Loop through transformer self-attention layers
        for block in self.blocks:
            block_outputs = block(hidden_states)
            hidden_states = block_outputs[0]
        outputs = self.linear(self.ln(hidden_states))
        return (outputs,)

    def forward_eval(self, x):
        input_embeds = x[:, :1]
        outputs = self.generate(input_embeds)
        return (outputs[:, 1:],)

    def split_to_dict(self, data_tensors, keys):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)
        x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
        if self.training:
            y = self.forward_tensor(x_tensor)
        else:
            y = self.forward_eval(x_tensor)
        y = self.split_to_dict(y, self.output_keys)
        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

ModelList

Bases: Arch

ModelList layer which wrap more than one model that shares inputs.

Parameters:

Name Type Description Default
model_list Tuple[Arch, ...]

Model(s) nested in tuple.

required

Examples:

>>> import ppsci
>>> model1 = ppsci.arch.MLP(("x", "y"), ("u", "v"), 10, 128)
>>> model2 = ppsci.arch.MLP(("x", "y"), ("w", "p"), 5, 128)
>>> model = ppsci.arch.ModelList((model1, model2))
Source code in ppsci/arch/model_list.py
class ModelList(base.Arch):
    """ModelList layer which wrap more than one model that shares inputs.

    Args:
        model_list (Tuple[base.Arch, ...]): Model(s) nested in tuple.

    Examples:
        >>> import ppsci
        >>> model1 = ppsci.arch.MLP(("x", "y"), ("u", "v"), 10, 128)
        >>> model2 = ppsci.arch.MLP(("x", "y"), ("w", "p"), 5, 128)
        >>> model = ppsci.arch.ModelList((model1, model2))
    """

    def __init__(
        self,
        model_list: Tuple[base.Arch, ...],
    ):
        super().__init__()
        self.input_keys = sum([model.input_keys for model in model_list], ())
        self.input_keys = set(self.input_keys)

        output_keys_set = set()
        for model in model_list:
            if len(output_keys_set & set(model.output_keys)):
                raise ValueError(
                    "output_keys of model from model_list should be unique,"
                    f"but got duplicate keys: {output_keys_set & set(model.output_keys)}"
                )
            output_keys_set = output_keys_set | set(model.output_keys)
        self.output_keys = tuple(output_keys_set)

        self.model_list = nn.LayerList(model_list)

    def forward(self, x):
        y_all = {}
        for model in self.model_list:
            y = model(x)
            y_all.update(y)

        return y_all

AFNONet

Bases: Arch

Adaptive Fourier Neural Network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
img_size Tuple[int, ...]

Image size. Defaults to (720, 1440).

(720, 1440)
patch_size Tuple[int, ...]

Path. Defaults to (8, 8).

(8, 8)
in_channels int

The input tensor channels. Defaults to 20.

20
out_channels int

The output tensor channels. Defaults to 20.

20
embed_dim int

The embedding dimension for PatchEmbed. Defaults to 768.

768
depth int

Number of transformer depth. Defaults to 12.

12
mlp_ratio float

Number of ratio used in MLP. Defaults to 4.0.

4.0
drop_rate float

The drop ratio used in MLP. Defaults to 0.0.

0.0
drop_path_rate float

The drop ratio used in DropPath. Defaults to 0.0.

0.0
num_blocks int

Number of blocks. Defaults to 8.

8
sparsity_threshold float

The value of threshold for softshrink. Defaults to 0.01.

0.01
hard_thresholding_fraction float

The value of threshold for keep mode. Defaults to 1.0.

1.0
num_timestamps int

Number of timestamp. Defaults to 1.

1

Examples:

>>> import ppsci
>>> model = ppsci.arch.AFNONet(("input", ), ("output", ))
Source code in ppsci/arch/afno.py
class AFNONet(base.Arch):
    """Adaptive Fourier Neural Network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440).
        patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8).
        in_channels (int, optional): The input tensor channels. Defaults to 20.
        out_channels (int, optional): The output tensor channels. Defaults to 20.
        embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768.
        depth (int, optional): Number of transformer depth. Defaults to 12.
        mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0.
        drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0.
        drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0.
        num_blocks (int, optional): Number of blocks. Defaults to 8.
        sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
        hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
        num_timestamps (int, optional): Number of timestamp. Defaults to 1.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.AFNONet(("input", ), ("output", ))
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        img_size: Tuple[int, ...] = (720, 1440),
        patch_size: Tuple[int, ...] = (8, 8),
        in_channels: int = 20,
        out_channels: int = 20,
        embed_dim: int = 768,
        depth: int = 12,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        num_blocks: int = 8,
        sparsity_threshold: float = 0.01,
        hard_thresholding_fraction: float = 1.0,
        num_timestamps: int = 1,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.embed_dim = embed_dim
        self.num_blocks = num_blocks
        self.num_timestamps = num_timestamps
        norm_layer = partial(nn.LayerNorm, epsilon=1e-6)

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=self.patch_size,
            in_channels=self.in_channels,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        data = paddle.zeros((1, num_patches, embed_dim))
        data = initializer.trunc_normal_(data, std=0.02)
        self.pos_embed = paddle.create_parameter(
            shape=data.shape,
            dtype=data.dtype,
            default_initializer=nn.initializer.Assign(data),
        )
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)]

        self.h = img_size[0] // self.patch_size[0]
        self.w = img_size[1] // self.patch_size[1]

        self.blocks = nn.LayerList(
            [
                Block(
                    dim=embed_dim,
                    mlp_ratio=mlp_ratio,
                    drop=drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    num_blocks=self.num_blocks,
                    sparsity_threshold=sparsity_threshold,
                    hard_thresholding_fraction=hard_thresholding_fraction,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(
            embed_dim,
            self.out_channels * self.patch_size[0] * self.patch_size[1],
            bias_attr=False,
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            initializer.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                initializer.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            initializer.ones_(m.weight)
            initializer.zeros_(m.bias)
        elif isinstance(m, nn.Conv2D):
            initializer.conv_init_(m)

    def forward_tensor(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = x.reshape((B, self.h, self.w, self.embed_dim))
        for block in self.blocks:
            x = block(x)

        x = self.head(x)

        b = x.shape[0]
        p1 = self.patch_size[0]
        p2 = self.patch_size[1]
        h = self.img_size[0] // self.patch_size[0]
        w = self.img_size[1] // self.patch_size[1]
        c_out = x.shape[3] // (p1 * p2)
        x = x.reshape((b, h, w, p1, p2, c_out))
        x = x.transpose((0, 5, 1, 3, 2, 4))
        x = x.reshape((b, c_out, h * p1, w * p2))

        return x

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        y = []
        input = x_tensor
        for _ in range(self.num_timestamps):
            out = self.forward_tensor(input)
            y.append(out)
            input = out
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

PrecipNet

Bases: Arch

Precipitation Network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
wind_model Arch

Wind model.

required
img_size Tuple[int, ...]

Image size. Defaults to (720, 1440).

(720, 1440)
patch_size Tuple[int, ...]

Path. Defaults to (8, 8).

(8, 8)
in_channels int

The input tensor channels. Defaults to 20.

20
out_channels int

The output tensor channels. Defaults to 1.

1
embed_dim int

The embedding dimension for PatchEmbed. Defaults to 768.

768
depth int

Number of transformer depth. Defaults to 12.

12
mlp_ratio float

Number of ratio used in MLP. Defaults to 4.0.

4.0
drop_rate float

The drop ratio used in MLP. Defaults to 0.0.

0.0
drop_path_rate float

The drop ratio used in DropPath. Defaults to 0.0.

0.0
num_blocks int

Number of blocks. Defaults to 8.

8
sparsity_threshold float

The value of threshold for softshrink. Defaults to 0.01.

0.01
hard_thresholding_fraction float

The value of threshold for keep mode. Defaults to 1.0.

1.0
num_timestamps int

Number of timestamp. Defaults to 1.

1

Examples:

>>> import ppsci
>>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", ))
>>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model)
Source code in ppsci/arch/afno.py
class PrecipNet(base.Arch):
    """Precipitation Network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        wind_model (base.Arch): Wind model.
        img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440).
        patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8).
        in_channels (int, optional): The input tensor channels. Defaults to 20.
        out_channels (int, optional): The output tensor channels. Defaults to 1.
        embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768.
        depth (int, optional): Number of transformer depth. Defaults to 12.
        mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0.
        drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0.
        drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0.
        num_blocks (int, optional): Number of blocks. Defaults to 8.
        sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
        hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
        num_timestamps (int, optional): Number of timestamp. Defaults to 1.

    Examples:
        >>> import ppsci
        >>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", ))
        >>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model)
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        wind_model: base.Arch,
        img_size: Tuple[int, ...] = (720, 1440),
        patch_size: Tuple[int, ...] = (8, 8),
        in_channels: int = 20,
        out_channels: int = 1,
        embed_dim: int = 768,
        depth: int = 12,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        num_blocks: int = 8,
        sparsity_threshold: float = 0.01,
        hard_thresholding_fraction: float = 1.0,
        num_timestamps=1,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.embed_dim = embed_dim
        self.num_blocks = num_blocks
        self.num_timestamps = num_timestamps
        self.backbone = AFNONet(
            ("input",),
            ("output",),
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            out_channels=out_channels,
            embed_dim=embed_dim,
            depth=depth,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            num_blocks=num_blocks,
            sparsity_threshold=sparsity_threshold,
            hard_thresholding_fraction=hard_thresholding_fraction,
        )
        self.ppad = PeriodicPad2d(1)
        self.conv = nn.Conv2D(
            self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=0
        )
        self.act = nn.ReLU()
        self.apply(self._init_weights)
        self.wind_model = wind_model
        self.wind_model.eval()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            initializer.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                initializer.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            initializer.ones_(m.weight)
            initializer.zeros_(m.bias)
        elif isinstance(m, nn.Conv2D):
            initializer.conv_init_(m)

    def forward_tensor(self, x):
        x = self.backbone.forward_tensor(x)
        x = self.ppad(x)
        x = self.conv(x)
        x = self.act(x)
        return x

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        input_wind = x_tensor
        y = []
        for _ in range(self.num_timestamps):
            with paddle.no_grad():
                out_wind = self.wind_model.forward_tensor(input_wind)
            out = self.forward_tensor(out_wind)
            y.append(out)
            input_wind = out_wind
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

UNetEx

Bases: Arch

U-Net

Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.

Parameters:

Name Type Description Default
input_key str

Name of function data for input.

required
output_key str

Name of function data for output.

required
in_channel int

Number of channels of input.

required
out_channel int

Number of channels of output.

required
kernel_size int

Size of kernel of convolution layer. Defaults to 3.

3
filters Tuple[int, ...]

Number of filters. Defaults to (16, 32, 64).

(16, 32, 64)
layers int

Number of encoders or decoders. Defaults to 3.

3
weight_norm bool

Whether use weight normalization layer. Defaults to True.

True
batch_norm bool

Whether add batch normalization layer. Defaults to True.

True
activation Type[Layer]

Name of activation function. Defaults to nn.ReLU.

ReLU
final_activation Optional[Type[Layer]]

Name of final activation function. Defaults to None.

None

Examples:

>>> import ppsci
>>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, False, False)
Source code in ppsci/arch/unetex.py
class UNetEx(base.Arch):
    """U-Net

    [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826)

    Args:
        input_key (str): Name of function data for input.
        output_key (str): Name of function data for output.
        in_channel (int): Number of channels of input.
        out_channel (int): Number of channels of output.
        kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
        filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64).
        layers (int, optional): Number of encoders or decoders. Defaults to 3.
        weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True.
        batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True.
        activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU.
        final_activation (Optional[Type[nn.Layer]]): Name of final activation function. Defaults to None.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.ppsci.arch.UNetEx("input", "output", 3, 3, (8, 16, 32, 32), 5, False, False)
    """

    def __init__(
        self,
        input_key: str,
        output_key: str,
        in_channel: int,
        out_channel: int,
        kernel_size: int = 3,
        filters: Tuple[int, ...] = (16, 32, 64),
        layers: int = 3,
        weight_norm: bool = True,
        batch_norm: bool = True,
        activation: Type[nn.Layer] = nn.ReLU,
        final_activation: Optional[Type[nn.Layer]] = None,
    ):
        if len(filters) == 0:
            raise ValueError("The filters shouldn't be empty ")

        super().__init__()
        self.input_keys = (input_key,)
        self.output_keys = (output_key,)
        self.final_activation = final_activation
        self.encoder = create_encoder(
            in_channel,
            filters,
            kernel_size,
            weight_norm,
            batch_norm,
            activation,
            layers,
        )
        decoders = [
            create_decoder(
                1, filters, kernel_size, weight_norm, batch_norm, activation, layers
            )
            for i in range(out_channel)
        ]
        self.decoders = nn.Sequential(*decoders)

    def encode(self, x):
        tensors = []
        indices = []
        sizes = []
        for encoder in self.encoder:
            x = encoder(x)
            sizes.append(x.shape)
            tensors.append(x)
            x, ind = nn.functional.max_pool2d(x, 2, 2, return_mask=True)
            indices.append(ind)
        return x, tensors, indices, sizes

    def decode(self, x, tensors, indices, sizes):
        y = []
        for _decoder in self.decoders:
            _x = x
            _tensors = tensors[:]
            _indices = indices[:]
            _sizes = sizes[:]
            for decoder in _decoder:
                tensor = _tensors.pop()
                size = _sizes.pop()
                indice = _indices.pop()
                # upsample operations
                _x = nn.functional.max_unpool2d(_x, indice, 2, 2, output_size=size)
                _x = paddle.concat([tensor, _x], axis=1)
                _x = decoder(_x)
            y.append(_x)
        return paddle.concat(y, axis=1)

    def forward(self, x):
        x = x[self.input_keys[0]]
        x, tensors, indices, sizes = self.encode(x)
        x = self.decode(x, tensors, indices, sizes)
        if self.final_activation is not None:
            x = self.final_activation(x)
        return {self.output_keys[0]: x}

NowcastNet

Bases: Arch

The NowcastNet model.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
input_length int

Input length. Defaults to 9.

9
total_length int

Total length. Defaults to 29.

29
image_height int

Image height. Defaults to 512.

512
image_width int

Image width. Defaults to 512.

512
image_ch int

Image channel. Defaults to 2.

2
ngf int

Noise Projector input length. Defaults to 32.

32

Examples:

>>> import ppsci
>>> model = ppsci.arch.NowcastNet(("input", ), ("output", ))
Source code in ppsci/arch/nowcastnet.py
class NowcastNet(base.Arch):
    """The NowcastNet model.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        input_length (int, optional): Input length. Defaults to 9.
        total_length (int, optional): Total length. Defaults to 29.
        image_height (int, optional): Image height. Defaults to 512.
        image_width (int, optional): Image width. Defaults to 512.
        image_ch (int, optional): Image channel. Defaults to 2.
        ngf (int, optional): Noise Projector input length. Defaults to 32.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.NowcastNet(("input", ), ("output", ))
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_length: int = 9,
        total_length: int = 29,
        image_height: int = 512,
        image_width: int = 512,
        image_ch: int = 2,
        ngf: int = 32,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.input_length = input_length
        self.total_length = total_length
        self.image_height = image_height
        self.image_width = image_width
        self.image_ch = image_ch
        self.ngf = ngf

        configs = collections.namedtuple(
            "Object", ["ngf", "evo_ic", "gen_oc", "ic_feature"]
        )
        configs.ngf = self.ngf
        configs.evo_ic = self.total_length - self.input_length
        configs.gen_oc = self.total_length - self.input_length
        configs.ic_feature = self.ngf * 10

        self.pred_length = self.total_length - self.input_length
        self.evo_net = Evolution_Network(self.input_length, self.pred_length, base_c=32)
        self.gen_enc = Generative_Encoder(self.total_length, base_c=self.ngf)
        self.gen_dec = Generative_Decoder(configs)
        self.proj = Noise_Projector(self.ngf)
        sample_tensor = paddle.zeros(shape=[1, 1, self.image_height, self.image_width])
        self.grid = make_grid(sample_tensor)

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        y = []
        out = self.forward_tensor(x_tensor)
        y.append(out)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

    def forward_tensor(self, x):
        all_frames = x[:, :, :, :, :1]
        frames = all_frames.transpose(perm=[0, 1, 4, 2, 3])
        batch = frames.shape[0]
        height = frames.shape[3]
        width = frames.shape[4]
        # Input Frames
        input_frames = frames[:, : self.input_length]
        input_frames = input_frames.reshape((batch, self.input_length, height, width))
        # Evolution Network
        intensity, motion = self.evo_net(input_frames)
        motion_ = motion.reshape((batch, self.pred_length, 2, height, width))
        intensity_ = intensity.reshape((batch, self.pred_length, 1, height, width))
        series = []
        last_frames = all_frames[:, self.input_length - 1 : self.input_length, :, :, 0]
        grid = self.grid.tile((batch, 1, 1, 1))
        for i in range(self.pred_length):
            last_frames = warp(
                last_frames, motion_[:, i], grid, mode="nearest", padding_mode="border"
            )
            last_frames = last_frames + intensity_[:, i]
            series.append(last_frames)
        evo_result = paddle.concat(x=series, axis=1)
        evo_result = evo_result / 128
        # Generative Network
        evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1))
        noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32])
        noise_feature = (
            self.proj(noise)
            .reshape((batch, -1, 4, 4, 8, 8))
            .transpose(perm=[0, 1, 4, 5, 2, 3])
            .reshape((batch, -1, height // 8, width // 8))
        )
        feature = paddle.concat(x=[evo_feature, noise_feature], axis=1)
        gen_result = self.gen_dec(feature, evo_result)
        return gen_result.unsqueeze(axis=-1)

最后更新: November 17, 2023
创建日期: November 6, 2023