跳转至

Arch(网络模型) 模块

ppsci.arch

Arch

Bases: Layer

Base class for Network.

Source code in ppsci/arch/base.py
class Arch(nn.Layer):
    """Base class for Network."""

    input_keys: Tuple[str, ...]
    output_keys: Tuple[str, ...]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._input_transform: Callable[
            [Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]
        ] = None

        self._output_transform: Callable[
            [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
            Dict[str, paddle.Tensor],
        ] = None

    def forward(self, *args, **kwargs):
        raise NotImplementedError("Arch.forward is not implemented")

    @property
    def num_params(self) -> int:
        """Return number of parameters within network.

        Returns:
            int: Number of parameters.
        """
        num = 0
        for name, param in self.named_parameters():
            if hasattr(param, "shape"):
                num += np.prod(list(param.shape), dtype="int")
            else:
                logger.warning(f"{name} has no attribute 'shape'")
        return num

    @staticmethod
    def concat_to_tensor(
        data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1
    ) -> Tuple[paddle.Tensor, ...]:
        """Concatenate tensors from dict in the order of given keys.

        Args:
            data_dict (Dict[str, paddle.Tensor]): Dict contains tensor.
            keys (Tuple[str, ...]): Keys tensor fetched from.
            axis (int, optional): Axis concatenate at. Defaults to -1.

        Returns:
            Tuple[paddle.Tensor, ...]: Concatenated tensor.

        Examples:
            >>> import paddle
            >>> import ppsci
            >>> model = ppsci.arch.Arch()
            >>> # fetch one tensor
            >>> out = model.concat_to_tensor({'x':paddle.rand([64, 64, 1])}, ('x',))
            >>> print(out.dtype, out.shape)
            paddle.float32 [64, 64, 1]
            >>> # fetch more tensors
            >>> out = model.concat_to_tensor(
            ...     {'x1':paddle.rand([64, 64, 1]), 'x2':paddle.rand([64, 64, 1])},
            ...     ('x1', 'x2'),
            ...     axis=2)
            >>> print(out.dtype, out.shape)
            paddle.float32 [64, 64, 2]

        """
        if len(keys) == 1:
            return data_dict[keys[0]]
        data = [data_dict[key] for key in keys]
        return paddle.concat(data, axis)

    @staticmethod
    def split_to_dict(
        data_tensor: paddle.Tensor, keys: Tuple[str, ...], axis=-1
    ) -> Dict[str, paddle.Tensor]:
        """Split tensor and wrap into a dict by given keys.

        Args:
            data_tensor (paddle.Tensor): Tensor to be split.
            keys (Tuple[str, ...]): Keys tensor mapping to.
            axis (int, optional): Axis split at. Defaults to -1.

        Returns:
            Dict[str, paddle.Tensor]: Dict contains tensor.

        Examples:
            >>> import paddle
            >>> import ppsci
            >>> model = ppsci.arch.Arch()
            >>> # split one tensor
            >>> out = model.split_to_dict(paddle.rand([64, 64, 1]), ('x',))
            >>> for k, v in out.items():
            ...     print(f"{k} {v.dtype} {v.shape}")
            x paddle.float32 [64, 64, 1]
            >>> # split more tensors
            >>> out = model.split_to_dict(paddle.rand([64, 64, 2]), ('x1', 'x2'), axis=2)
            >>> for k, v in out.items():
            ...     print(f"{k} {v.dtype} {v.shape}")
            x1 paddle.float32 [64, 64, 1]
            x2 paddle.float32 [64, 64, 1]

        """
        if len(keys) == 1:
            return {keys[0]: data_tensor}
        data = paddle.split(data_tensor, len(keys), axis=axis)
        return {key: data[i] for i, key in enumerate(keys)}

    def register_input_transform(
        self,
        transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
    ):
        """Register input transform.

        Args:
            transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
                Input transform of network, receive a single tensor dict and return a single tensor dict.

        Examples:
            >>> import ppsci
            >>> def transform_in(in_):
            ...     x = in_["x"]
            ...     # transform input
            ...     x_ = 2.0 * x
            ...     input_trans = {"2x": x_}
            ...     return input_trans
            >>> # `MLP` inherits from `Arch`
            >>> model = ppsci.arch.MLP(
            ...     input_keys=("2x",),
            ...     output_keys=("y",),
            ...     num_layers=5,
            ...     hidden_size=32)
            >>> model.register_input_transform(transform_in)
            >>> out = model({"x":paddle.rand([64, 64, 1])})
            >>> for k, v in out.items():
            ...     print(f"{k} {v.dtype} {v.shape}")
            y paddle.float32 [64, 64, 1]

        """
        self._input_transform = transform

    def register_output_transform(
        self,
        transform: Callable[
            [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
            Dict[str, paddle.Tensor],
        ],
    ):
        """Register output transform.

        Args:
            transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
                Output transform of network, receive two single tensor dict(raw input
                and raw output) and return a single tensor dict(transformed output).

        Examples:
            >>> import ppsci
            >>> def transform_out(in_, out):
            ...     x = in_["x"]
            ...     y = out["y"]
            ...     u = 2.0 * x * y
            ...     output_trans = {"u": u}
            ...     return output_trans
            >>> # `MLP` inherits from `Arch`
            >>> model = ppsci.arch.MLP(
            ...     input_keys=("x",),
            ...     output_keys=("y",),
            ...     num_layers=5,
            ...     hidden_size=32)
            >>> model.register_output_transform(transform_out)
            >>> out = model({"x":paddle.rand([64, 64, 1])})
            >>> for k, v in out.items():
            ...     print(f"{k} {v.dtype} {v.shape}")
            u paddle.float32 [64, 64, 1]

        """
        self._output_transform = transform

    def freeze(self):
        """Freeze all parameters.

        Examples:
            >>> import ppsci
            >>> model = ppsci.arch.Arch()
            >>> # freeze all parameters and make model `eval`
            >>> model.freeze()
            >>> assert not model.training
            >>> for p in model.parameters():
            ...     assert p.stop_gradient

        """
        for param in self.parameters():
            param.stop_gradient = True

        self.eval()

    def unfreeze(self):
        """Unfreeze all parameters.

        Examples:
            >>> import ppsci
            >>> model = ppsci.arch.Arch()
            >>> # unfreeze all parameters and make model `train`
            >>> model.unfreeze()
            >>> assert model.training
            >>> for p in model.parameters():
            ...     assert not p.stop_gradient

        """
        for param in self.parameters():
            param.stop_gradient = False

        self.train()

    def __str__(self):
        num_fc = 0
        num_conv = 0
        num_bn = 0
        for layer in self.sublayers(include_self=True):
            if isinstance(layer, nn.Linear):
                num_fc += 1
            elif isinstance(layer, (nn.Conv2D, nn.Conv3D, nn.Conv1D)):
                num_conv += 1
            elif isinstance(layer, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm3D)):
                num_bn += 1

        return ", ".join(
            [
                self.__class__.__name__,
                f"input_keys = {self.input_keys}",
                f"output_keys = {self.output_keys}",
                f"num_fc = {num_fc}",
                f"num_conv = {num_conv}",
                f"num_bn = {num_bn}",
                f"num_params = {self.num_params}",
            ]
        )
num_params: int property

Return number of parameters within network.

Returns:

Name Type Description
int int

Number of parameters.

concat_to_tensor(data_dict, keys, axis=-1) staticmethod

Concatenate tensors from dict in the order of given keys.

Parameters:

Name Type Description Default
data_dict Dict[str, Tensor]

Dict contains tensor.

required
keys Tuple[str, ...]

Keys tensor fetched from.

required
axis int

Axis concatenate at. Defaults to -1.

-1

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[paddle.Tensor, ...]: Concatenated tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # fetch one tensor
>>> out = model.concat_to_tensor({'x':paddle.rand([64, 64, 1])}, ('x',))
>>> print(out.dtype, out.shape)
paddle.float32 [64, 64, 1]
>>> # fetch more tensors
>>> out = model.concat_to_tensor(
...     {'x1':paddle.rand([64, 64, 1]), 'x2':paddle.rand([64, 64, 1])},
...     ('x1', 'x2'),
...     axis=2)
>>> print(out.dtype, out.shape)
paddle.float32 [64, 64, 2]
Source code in ppsci/arch/base.py
@staticmethod
def concat_to_tensor(
    data_dict: Dict[str, paddle.Tensor], keys: Tuple[str, ...], axis=-1
) -> Tuple[paddle.Tensor, ...]:
    """Concatenate tensors from dict in the order of given keys.

    Args:
        data_dict (Dict[str, paddle.Tensor]): Dict contains tensor.
        keys (Tuple[str, ...]): Keys tensor fetched from.
        axis (int, optional): Axis concatenate at. Defaults to -1.

    Returns:
        Tuple[paddle.Tensor, ...]: Concatenated tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.Arch()
        >>> # fetch one tensor
        >>> out = model.concat_to_tensor({'x':paddle.rand([64, 64, 1])}, ('x',))
        >>> print(out.dtype, out.shape)
        paddle.float32 [64, 64, 1]
        >>> # fetch more tensors
        >>> out = model.concat_to_tensor(
        ...     {'x1':paddle.rand([64, 64, 1]), 'x2':paddle.rand([64, 64, 1])},
        ...     ('x1', 'x2'),
        ...     axis=2)
        >>> print(out.dtype, out.shape)
        paddle.float32 [64, 64, 2]

    """
    if len(keys) == 1:
        return data_dict[keys[0]]
    data = [data_dict[key] for key in keys]
    return paddle.concat(data, axis)
freeze()

Freeze all parameters.

Examples:

>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # freeze all parameters and make model `eval`
>>> model.freeze()
>>> assert not model.training
>>> for p in model.parameters():
...     assert p.stop_gradient
Source code in ppsci/arch/base.py
def freeze(self):
    """Freeze all parameters.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.Arch()
        >>> # freeze all parameters and make model `eval`
        >>> model.freeze()
        >>> assert not model.training
        >>> for p in model.parameters():
        ...     assert p.stop_gradient

    """
    for param in self.parameters():
        param.stop_gradient = True

    self.eval()
register_input_transform(transform)

Register input transform.

Parameters:

Name Type Description Default
transform Callable[[Dict[str, Tensor]], Dict[str, Tensor]]

Input transform of network, receive a single tensor dict and return a single tensor dict.

required

Examples:

>>> import ppsci
>>> def transform_in(in_):
...     x = in_["x"]
...     # transform input
...     x_ = 2.0 * x
...     input_trans = {"2x": x_}
...     return input_trans
>>> # `MLP` inherits from `Arch`
>>> model = ppsci.arch.MLP(
...     input_keys=("2x",),
...     output_keys=("y",),
...     num_layers=5,
...     hidden_size=32)
>>> model.register_input_transform(transform_in)
>>> out = model({"x":paddle.rand([64, 64, 1])})
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
y paddle.float32 [64, 64, 1]
Source code in ppsci/arch/base.py
def register_input_transform(
    self,
    transform: Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]],
):
    """Register input transform.

    Args:
        transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
            Input transform of network, receive a single tensor dict and return a single tensor dict.

    Examples:
        >>> import ppsci
        >>> def transform_in(in_):
        ...     x = in_["x"]
        ...     # transform input
        ...     x_ = 2.0 * x
        ...     input_trans = {"2x": x_}
        ...     return input_trans
        >>> # `MLP` inherits from `Arch`
        >>> model = ppsci.arch.MLP(
        ...     input_keys=("2x",),
        ...     output_keys=("y",),
        ...     num_layers=5,
        ...     hidden_size=32)
        >>> model.register_input_transform(transform_in)
        >>> out = model({"x":paddle.rand([64, 64, 1])})
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        y paddle.float32 [64, 64, 1]

    """
    self._input_transform = transform
register_output_transform(transform)

Register output transform.

Parameters:

Name Type Description Default
transform Callable[[Dict[str, Tensor], Dict[str, Tensor]], Dict[str, Tensor]]

Output transform of network, receive two single tensor dict(raw input and raw output) and return a single tensor dict(transformed output).

required

Examples:

>>> import ppsci
>>> def transform_out(in_, out):
...     x = in_["x"]
...     y = out["y"]
...     u = 2.0 * x * y
...     output_trans = {"u": u}
...     return output_trans
>>> # `MLP` inherits from `Arch`
>>> model = ppsci.arch.MLP(
...     input_keys=("x",),
...     output_keys=("y",),
...     num_layers=5,
...     hidden_size=32)
>>> model.register_output_transform(transform_out)
>>> out = model({"x":paddle.rand([64, 64, 1])})
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
u paddle.float32 [64, 64, 1]
Source code in ppsci/arch/base.py
def register_output_transform(
    self,
    transform: Callable[
        [Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]],
        Dict[str, paddle.Tensor],
    ],
):
    """Register output transform.

    Args:
        transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
            Output transform of network, receive two single tensor dict(raw input
            and raw output) and return a single tensor dict(transformed output).

    Examples:
        >>> import ppsci
        >>> def transform_out(in_, out):
        ...     x = in_["x"]
        ...     y = out["y"]
        ...     u = 2.0 * x * y
        ...     output_trans = {"u": u}
        ...     return output_trans
        >>> # `MLP` inherits from `Arch`
        >>> model = ppsci.arch.MLP(
        ...     input_keys=("x",),
        ...     output_keys=("y",),
        ...     num_layers=5,
        ...     hidden_size=32)
        >>> model.register_output_transform(transform_out)
        >>> out = model({"x":paddle.rand([64, 64, 1])})
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        u paddle.float32 [64, 64, 1]

    """
    self._output_transform = transform
split_to_dict(data_tensor, keys, axis=-1) staticmethod

Split tensor and wrap into a dict by given keys.

Parameters:

Name Type Description Default
data_tensor Tensor

Tensor to be split.

required
keys Tuple[str, ...]

Keys tensor mapping to.

required
axis int

Axis split at. Defaults to -1.

-1

Returns:

Type Description
Dict[str, Tensor]

Dict[str, paddle.Tensor]: Dict contains tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # split one tensor
>>> out = model.split_to_dict(paddle.rand([64, 64, 1]), ('x',))
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
x paddle.float32 [64, 64, 1]
>>> # split more tensors
>>> out = model.split_to_dict(paddle.rand([64, 64, 2]), ('x1', 'x2'), axis=2)
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
x1 paddle.float32 [64, 64, 1]
x2 paddle.float32 [64, 64, 1]
Source code in ppsci/arch/base.py
@staticmethod
def split_to_dict(
    data_tensor: paddle.Tensor, keys: Tuple[str, ...], axis=-1
) -> Dict[str, paddle.Tensor]:
    """Split tensor and wrap into a dict by given keys.

    Args:
        data_tensor (paddle.Tensor): Tensor to be split.
        keys (Tuple[str, ...]): Keys tensor mapping to.
        axis (int, optional): Axis split at. Defaults to -1.

    Returns:
        Dict[str, paddle.Tensor]: Dict contains tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.Arch()
        >>> # split one tensor
        >>> out = model.split_to_dict(paddle.rand([64, 64, 1]), ('x',))
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        x paddle.float32 [64, 64, 1]
        >>> # split more tensors
        >>> out = model.split_to_dict(paddle.rand([64, 64, 2]), ('x1', 'x2'), axis=2)
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        x1 paddle.float32 [64, 64, 1]
        x2 paddle.float32 [64, 64, 1]

    """
    if len(keys) == 1:
        return {keys[0]: data_tensor}
    data = paddle.split(data_tensor, len(keys), axis=axis)
    return {key: data[i] for i, key in enumerate(keys)}
unfreeze()

Unfreeze all parameters.

Examples:

>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # unfreeze all parameters and make model `train`
>>> model.unfreeze()
>>> assert model.training
>>> for p in model.parameters():
...     assert not p.stop_gradient
Source code in ppsci/arch/base.py
def unfreeze(self):
    """Unfreeze all parameters.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.Arch()
        >>> # unfreeze all parameters and make model `train`
        >>> model.unfreeze()
        >>> assert model.training
        >>> for p in model.parameters():
        ...     assert not p.stop_gradient

    """
    for param in self.parameters():
        param.stop_gradient = False

    self.train()

AMGNet

Bases: Layer

A Multi-scale Graph neural Network model based on Encoder-Process-Decoder structure for flow field prediction.

https://doi.org/10.1080/09540091.2022.2131737

Code reference: https://github.com/baoshiaijhin/amgnet

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input", ).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("pred", ).

required
input_dim int

Number of input dimension.

required
output_dim int

Number of output dimension.

required
latent_dim int

Number of hidden(feature) dimension.

required
num_layers int

Number of layer(s).

required
message_passing_aggregator Literal['sum']

Message aggregator method in graph. Only "sum" available now.

required
message_passing_steps int

Message passing steps in graph.

required
speed str

Whether use vanilla method or fast method for graph_connectivity computation.

required

Examples:

>>> import ppsci
>>> model = ppsci.arch.AMGNet(
...     ("input", ), ("pred", ), 5, 3, 64, 2, "sum", 6, "norm",
... )
Source code in ppsci/arch/amgnet.py
class AMGNet(nn.Layer):
    """A Multi-scale Graph neural Network model
    based on Encoder-Process-Decoder structure for flow field prediction.

    https://doi.org/10.1080/09540091.2022.2131737

    Code reference: https://github.com/baoshiaijhin/amgnet

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input", ).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("pred", ).
        input_dim (int): Number of input dimension.
        output_dim (int): Number of output dimension.
        latent_dim (int): Number of hidden(feature) dimension.
        num_layers (int): Number of layer(s).
        message_passing_aggregator (Literal["sum"]): Message aggregator method in graph.
            Only "sum" available now.
        message_passing_steps (int): Message passing steps in graph.
        speed (str): Whether use vanilla method or fast method for graph_connectivity
            computation.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.AMGNet(
        ...     ("input", ), ("pred", ), 5, 3, 64, 2, "sum", 6, "norm",
        ... )
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_dim: int,
        output_dim: int,
        latent_dim: int,
        num_layers: int,
        message_passing_aggregator: Literal["sum"],
        message_passing_steps: int,
        speed: Literal["norm", "fast"],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self._latent_dim = latent_dim
        self.speed = speed
        self._output_dim = output_dim
        self._num_layers = num_layers

        self.encoder = Encoder(input_dim, self._make_mlp, latent_dim=self._latent_dim)
        self.processor = Processor(
            make_mlp=self._make_mlp,
            output_dim=self._latent_dim,
            message_passing_steps=message_passing_steps,
            message_passing_aggregator=message_passing_aggregator,
            use_stochastic_message_passing=False,
        )
        self.post_processor = self._make_mlp(self._latent_dim, 128)
        self.decoder = Decoder(
            make_mlp=functools.partial(self._make_mlp, layer_norm=False),
            output_dim=self._output_dim,
        )

    def forward(self, x: Dict[str, "pgl.Graph"]) -> Dict[str, paddle.Tensor]:
        graphs = x[self.input_keys[0]]
        latent_graph = self.encoder(graphs)
        x, p = self.processor(latent_graph, speed=self.speed)
        node_features = self._spa_compute(x, p)
        pred_field = self.decoder(node_features)
        return {self.output_keys[0]: pred_field}

    def _make_mlp(self, output_dim: int, input_dim: int = 5, layer_norm: bool = True):
        widths = (self._latent_dim,) * self._num_layers + (output_dim,)
        network = FullyConnectedLayer(input_dim, widths)
        if layer_norm:
            network = nn.Sequential(network, nn.LayerNorm(normalized_shape=widths[-1]))
        return network

    def _spa_compute(self, x: List["pgl.Graph"], p):
        j = len(x) - 1
        node_features = x[j].x

        for k in range(1, j + 1):
            pos = p[-k]
            fine_nodes = x[-(k + 1)].pos
            feature = _knn_interpolate(node_features, pos, fine_nodes)
            node_features = x[-(k + 1)].x + feature
            node_features = self.post_processor(node_features)

        return node_features

MLP

Bases: Arch

Multi layer perceptron network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("x", "y", "z").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("u", "v", "w").

required
num_layers int

Number of hidden layers.

required
hidden_size Union[int, Tuple[int, ...]]

Number of hidden size. An integer for all layers, or list of integer specify each layer's size.

required
activation str

Name of activation function. Defaults to "tanh".

'tanh'
skip_connection bool

Whether to use skip connection. Defaults to False.

False
weight_norm bool

Whether to apply weight norm on parameter(s). Defaults to False.

False
input_dim Optional[int]

Number of input's dimension. Defaults to None.

None
output_dim Optional[int]

Number of output's dimension. Defaults to None.

None
periods Optional[Dict[int, Tuple[float, bool]]]

Period of each input key, input in given channel will be period embeded if specified, each tuple of periods list is [period, trainable]. Defaults to None.

None
fourier Optional[Dict[str, Union[float, int]]]

Random fourier feature embedding, e.g. {'dim': 256, 'sclae': 1.0}. Defaults to None.

None
random_weight Optional[Dict[str, float]]

Mean and std of random weight factorization layer, e.g. {"mean": 0.5, "std: 0.1"}. Defaults to None.

None

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.MLP(
...     input_keys=("x", "y"),
...     output_keys=("u", "v"),
...     num_layers=5,
...     hidden_size=128
... )
>>> input_dict = {"x": paddle.rand([64, 64, 1]),
...               "y": paddle.rand([64, 64, 1])}
>>> output_dict = model(input_dict)
>>> print(output_dict["u"].shape)
[64, 64, 1]
>>> print(output_dict["v"].shape)
[64, 64, 1]
Source code in ppsci/arch/mlp.py
class MLP(base.Arch):
    """Multi layer perceptron network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
        num_layers (int): Number of hidden layers.
        hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size.
            An integer for all layers, or list of integer specify each layer's size.
        activation (str, optional): Name of activation function. Defaults to "tanh".
        skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
        weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
        input_dim (Optional[int]): Number of input's dimension. Defaults to None.
        output_dim (Optional[int]): Number of output's dimension. Defaults to None.
        periods (Optional[Dict[int, Tuple[float, bool]]]): Period of each input key,
            input in given channel will be period embeded if specified, each tuple of
            periods list is [period, trainable]. Defaults to None.
        fourier (Optional[Dict[str, Union[float, int]]]): Random fourier feature embedding,
            e.g. {'dim': 256, 'sclae': 1.0}. Defaults to None.
        random_weight (Optional[Dict[str, float]]): Mean and std of random weight
            factorization layer, e.g. {"mean": 0.5, "std: 0.1"}. Defaults to None.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.MLP(
        ...     input_keys=("x", "y"),
        ...     output_keys=("u", "v"),
        ...     num_layers=5,
        ...     hidden_size=128
        ... )
        >>> input_dict = {"x": paddle.rand([64, 64, 1]),
        ...               "y": paddle.rand([64, 64, 1])}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["u"].shape)
        [64, 64, 1]
        >>> print(output_dict["v"].shape)
        [64, 64, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_layers: int,
        hidden_size: Union[int, Tuple[int, ...]],
        activation: str = "tanh",
        skip_connection: bool = False,
        weight_norm: bool = False,
        input_dim: Optional[int] = None,
        output_dim: Optional[int] = None,
        periods: Optional[Dict[int, Tuple[float, bool]]] = None,
        fourier: Optional[Dict[str, Union[float, int]]] = None,
        random_weight: Optional[Dict[str, float]] = None,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.linears = []
        self.acts = []
        self.periods = periods
        self.fourier = fourier
        if periods:
            self.period_emb = PeriodEmbedding(periods)

        if isinstance(hidden_size, (tuple, list)):
            if num_layers is not None:
                raise ValueError(
                    "num_layers should be None when hidden_size is specified"
                )
        elif isinstance(hidden_size, int):
            if not isinstance(num_layers, int):
                raise ValueError(
                    "num_layers should be an int when hidden_size is an int"
                )
            hidden_size = [hidden_size] * num_layers
        else:
            raise ValueError(
                f"hidden_size should be list of int or int, but got {type(hidden_size)}"
            )

        # initialize FC layer(s)
        cur_size = len(self.input_keys) if input_dim is None else input_dim
        if input_dim is None and periods:
            # period embeded channel(s) will be doubled automatically
            # if input_dim is not specified
            cur_size += len(periods)

        if fourier:
            self.fourier_emb = FourierEmbedding(
                cur_size, fourier["dim"], fourier["scale"]
            )
            cur_size = fourier["dim"]

        for i, _size in enumerate(hidden_size):
            if weight_norm:
                self.linears.append(WeightNormLinear(cur_size, _size))
            elif random_weight:
                self.linears.append(
                    RandomWeightFactorization(
                        cur_size,
                        _size,
                        mean=random_weight["mean"],
                        std=random_weight["std"],
                    )
                )
            else:
                self.linears.append(nn.Linear(cur_size, _size))

            # initialize activation function
            self.acts.append(
                act_mod.get_activation(activation)
                if activation != "stan"
                else act_mod.get_activation(activation)(_size)
            )
            # special initialization for certain activation
            # TODO: Adapt code below to a more elegant style
            if activation == "siren":
                if i == 0:
                    act_mod.Siren.init_for_first_layer(self.linears[-1])
                else:
                    act_mod.Siren.init_for_hidden_layer(self.linears[-1])

            cur_size = _size

        self.linears = nn.LayerList(self.linears)
        self.acts = nn.LayerList(self.acts)
        if random_weight:
            self.last_fc = RandomWeightFactorization(
                cur_size,
                len(self.output_keys) if output_dim is None else output_dim,
                mean=random_weight["mean"],
                std=random_weight["std"],
            )
        else:
            self.last_fc = nn.Linear(
                cur_size,
                len(self.output_keys) if output_dim is None else output_dim,
            )

        self.skip_connection = skip_connection

    def forward_tensor(self, x):
        y = x
        skip = None
        for i, linear in enumerate(self.linears):
            y = linear(y)
            if self.skip_connection and i % 2 == 0:
                if skip is not None:
                    skip = y
                    y = y + skip
                else:
                    skip = y
            y = self.acts[i](y)

        y = self.last_fc(y)

        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        if self.periods:
            x = self.period_emb(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)

        if self.fourier:
            y = self.fourier_emb(y)

        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

ModifiedMLP

Bases: Arch

Modified Multi layer perceptron network.

Understanding and mitigating gradient pathologies in physics-informed neural networks. https://arxiv.org/pdf/2001.04536.pdf.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("x", "y", "z").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("u", "v", "w").

required
num_layers int

Number of hidden layers.

required
hidden_size int

Number of hidden size, an integer for all layers.

required
activation str

Name of activation function. Defaults to "tanh".

'tanh'
skip_connection bool

Whether to use skip connection. Defaults to False.

False
weight_norm bool

Whether to apply weight norm on parameter(s). Defaults to False.

False
input_dim Optional[int]

Number of input's dimension. Defaults to None.

None
output_dim Optional[int]

Number of output's dimension. Defaults to None.

None

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.ModifiedMLP(
...     input_keys=("x", "y"),
...     output_keys=("u", "v"),
...     num_layers=5,
...     hidden_size=128
... )
>>> input_dict = {"x": paddle.rand([64, 64, 1]),
...               "y": paddle.rand([64, 64, 1])}
>>> output_dict = model(input_dict)
>>> print(output_dict["u"].shape)
[64, 64, 1]
>>> print(output_dict["v"].shape)
[64, 64, 1]
Source code in ppsci/arch/mlp.py
class ModifiedMLP(base.Arch):
    """Modified Multi layer perceptron network.

    Understanding and mitigating gradient pathologies in physics-informed
    neural networks. https://arxiv.org/pdf/2001.04536.pdf.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("x", "y", "z").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("u", "v", "w").
        num_layers (int): Number of hidden layers.
        hidden_size (int): Number of hidden size, an integer for all layers.
        activation (str, optional): Name of activation function. Defaults to "tanh".
        skip_connection (bool, optional): Whether to use skip connection. Defaults to False.
        weight_norm (bool, optional): Whether to apply weight norm on parameter(s). Defaults to False.
        input_dim (Optional[int]): Number of input's dimension. Defaults to None.
        output_dim (Optional[int]): Number of output's dimension. Defaults to None.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.ModifiedMLP(
        ...     input_keys=("x", "y"),
        ...     output_keys=("u", "v"),
        ...     num_layers=5,
        ...     hidden_size=128
        ... )
        >>> input_dict = {"x": paddle.rand([64, 64, 1]),
        ...               "y": paddle.rand([64, 64, 1])}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["u"].shape)
        [64, 64, 1]
        >>> print(output_dict["v"].shape)
        [64, 64, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_layers: int,
        hidden_size: int,
        activation: str = "tanh",
        skip_connection: bool = False,
        weight_norm: bool = False,
        input_dim: Optional[int] = None,
        output_dim: Optional[int] = None,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.linears = []
        self.acts = []
        if isinstance(hidden_size, int):
            if not isinstance(num_layers, int):
                raise ValueError("num_layers should be an int")
            hidden_size = [hidden_size] * num_layers
        else:
            raise ValueError(f"hidden_size should be int, but got {type(hidden_size)}")

        # initialize FC layer(s)
        cur_size = len(self.input_keys) if input_dim is None else input_dim
        self.embed_u = nn.Sequential(
            (
                WeightNormLinear(cur_size, hidden_size[0])
                if weight_norm
                else nn.Linear(cur_size, hidden_size[0])
            ),
            (
                act_mod.get_activation(activation)
                if activation != "stan"
                else act_mod.get_activation(activation)(hidden_size[0])
            ),
        )
        self.embed_v = nn.Sequential(
            (
                WeightNormLinear(cur_size, hidden_size[0])
                if weight_norm
                else nn.Linear(cur_size, hidden_size[0])
            ),
            (
                act_mod.get_activation(activation)
                if activation != "stan"
                else act_mod.get_activation(activation)(hidden_size[0])
            ),
        )

        for i, _size in enumerate(hidden_size):
            self.linears.append(
                WeightNormLinear(cur_size, _size)
                if weight_norm
                else nn.Linear(cur_size, _size)
            )
            # initialize activation function
            self.acts.append(
                act_mod.get_activation(activation)
                if activation != "stan"
                else act_mod.get_activation(activation)(_size)
            )
            # special initialization for certain activation
            # TODO: Adapt code below to a more elegant style
            if activation == "siren":
                if i == 0:
                    act_mod.Siren.init_for_first_layer(self.linears[-1])
                else:
                    act_mod.Siren.init_for_hidden_layer(self.linears[-1])

            cur_size = _size

        self.linears = nn.LayerList(self.linears)
        self.acts = nn.LayerList(self.acts)
        self.last_fc = nn.Linear(
            cur_size,
            len(self.output_keys) if output_dim is None else output_dim,
        )

        self.skip_connection = skip_connection

    def forward_tensor(self, x):
        u = self.embed_u(x)
        v = self.embed_v(x)

        y = x
        skip = None
        for i, linear in enumerate(self.linears):
            y = linear(y)
            y = self.acts[i](y)
            y = (1 - y) * u + y * v
            if self.skip_connection and i % 2 == 0:
                if skip is not None:
                    skip = y
                    y = y + skip
                else:
                    skip = y

        y = self.last_fc(y)

        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

DeepONet

Bases: Arch

Deep operator network.

Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.

Parameters:

Name Type Description Default
u_key str

Name of function data for input function u(x).

required
y_key str

Name of location data for input function G(u).

required
G_key str

Output name of predicted G(u)(y).

required
num_loc int

Number of sampled u(x), i.e. m in paper.

required
num_features int

Number of features extracted from u(x), same for y.

required
branch_num_layers int

Number of hidden layers of branch net.

required
trunk_num_layers int

Number of hidden layers of trunk net.

required
branch_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of branch net. An integer for all layers, or list of integer specify each layer's size.

required
trunk_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of trunk net. An integer for all layers, or list of integer specify each layer's size.

required
branch_skip_connection bool

Whether to use skip connection for branch net. Defaults to False.

False
trunk_skip_connection bool

Whether to use skip connection for trunk net. Defaults to False.

False
branch_activation str

Name of activation function. Defaults to "tanh".

'tanh'
trunk_activation str

Name of activation function. Defaults to "tanh".

'tanh'
branch_weight_norm bool

Whether to apply weight norm on parameter(s) for branch net. Defaults to False.

False
trunk_weight_norm bool

Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.

False
use_bias bool

Whether to add bias on predicted G(u)(y). Defaults to True.

True

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.DeepONet(
...     "u", "y", "G",
...     100, 40,
...     1, 1,
...     40, 40,
...     branch_activation="relu", trunk_activation="relu",
...     use_bias=True,
... )
>>> input_dict = {"u": paddle.rand([200, 100]),
...               "y": paddle.rand([200, 1])}
>>> output_dict = model(input_dict)
>>> print(output_dict["G"].shape)
[200, 1]
Source code in ppsci/arch/deeponet.py
class DeepONet(base.Arch):
    """Deep operator network.

    [Lu et al. Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nat Mach Intell, 2021.](https://doi.org/10.1038/s42256-021-00302-5)

    Args:
        u_key (str): Name of function data for input function u(x).
        y_key (str): Name of location data for input function G(u).
        G_key (str): Output name of predicted G(u)(y).
        num_loc (int): Number of sampled u(x), i.e. `m` in paper.
        num_features (int): Number of features extracted from u(x), same for y.
        branch_num_layers (int): Number of hidden layers of branch net.
        trunk_num_layers (int): Number of hidden layers of trunk net.
        branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of branch net.
            An integer for all layers, or list of integer specify each layer's size.
        trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net.
            An integer for all layers, or list of integer specify each layer's size.
        branch_skip_connection (bool, optional): Whether to use skip connection for branch net. Defaults to False.
        trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False.
        branch_activation (str, optional): Name of activation function. Defaults to "tanh".
        trunk_activation (str, optional): Name of activation function. Defaults to "tanh".
        branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for branch net. Defaults to False.
        trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.
        use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.DeepONet(
        ...     "u", "y", "G",
        ...     100, 40,
        ...     1, 1,
        ...     40, 40,
        ...     branch_activation="relu", trunk_activation="relu",
        ...     use_bias=True,
        ... )
        >>> input_dict = {"u": paddle.rand([200, 100]),
        ...               "y": paddle.rand([200, 1])}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["G"].shape)
        [200, 1]
    """

    def __init__(
        self,
        u_key: str,
        y_key: str,
        G_key: str,
        num_loc: int,
        num_features: int,
        branch_num_layers: int,
        trunk_num_layers: int,
        branch_hidden_size: Union[int, Tuple[int, ...]],
        trunk_hidden_size: Union[int, Tuple[int, ...]],
        branch_skip_connection: bool = False,
        trunk_skip_connection: bool = False,
        branch_activation: str = "tanh",
        trunk_activation: str = "tanh",
        branch_weight_norm: bool = False,
        trunk_weight_norm: bool = False,
        use_bias: bool = True,
    ):
        super().__init__()
        self.u_key = u_key
        self.y_key = y_key
        self.input_keys = (u_key, y_key)
        self.output_keys = (G_key,)

        self.branch_net = mlp.MLP(
            (self.u_key,),
            ("b",),
            branch_num_layers,
            branch_hidden_size,
            branch_activation,
            branch_skip_connection,
            branch_weight_norm,
            input_dim=num_loc,
            output_dim=num_features,
        )

        self.trunk_net = mlp.MLP(
            (self.y_key,),
            ("t",),
            trunk_num_layers,
            trunk_hidden_size,
            trunk_activation,
            trunk_skip_connection,
            trunk_weight_norm,
            input_dim=1,
            output_dim=num_features,
        )
        self.trunk_act = act_mod.get_activation(trunk_activation)

        self.use_bias = use_bias
        if use_bias:
            # register bias to parameter for updating in optimizer and storage
            self.b = self.create_parameter(
                shape=(1,),
                attr=nn.initializer.Constant(0.0),
            )

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        # Branch net to encode the input function
        u_features = self.branch_net(x)[self.branch_net.output_keys[0]]

        # Trunk net to encode the domain of the output function
        y_features = self.trunk_net(x)
        y_features = self.trunk_act(y_features[self.trunk_net.output_keys[0]])

        # Dot product
        G_u = paddle.einsum("bi,bi->b", u_features, y_features)  # [batch_size, ]
        G_u = paddle.reshape(G_u, [-1, 1])  # reshape [batch_size, ] to [batch_size, 1]

        # Add bias
        if self.use_bias:
            G_u += self.b

        result_dict = {
            self.output_keys[0]: G_u,
        }
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)

        return result_dict

DeepPhyLSTM

Bases: Arch

DeepPhyLSTM init function.

Parameters:

Name Type Description Default
input_size int

The input size.

required
output_size int

The output size.

required
hidden_size int

The hidden size. Defaults to 100.

100
model_type int

The model type, value is 2 or 3, 2 indicates having two sub-models, 3 indicates having three submodels. Defaults to 2.

2

Examples:

>>> import paddle
>>> import ppsci
>>> # model_type is `2`
>>> model = ppsci.arch.DeepPhyLSTM(
...     input_size=16,
...     output_size=1,
...     hidden_size=100,
...     model_type=2)
>>> out = model(
...     {"ag":paddle.rand([64, 16, 16]),
...     "ag_c":paddle.rand([64, 16, 16]),
...     "phi":paddle.rand([1, 16, 16])})
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
eta_pred paddle.float32 [64, 16, 1]
eta_dot_pred paddle.float32 [64, 16, 1]
g_pred paddle.float32 [64, 16, 1]
eta_t_pred_c paddle.float32 [64, 16, 1]
eta_dot_pred_c paddle.float32 [64, 16, 1]
lift_pred_c paddle.float32 [64, 16, 1]
>>> # model_type is `3`
>>> model = ppsci.arch.DeepPhyLSTM(
...     input_size=16,
...     output_size=1,
...     hidden_size=100,
...     model_type=3)
>>> out = model(
...     {"ag":paddle.rand([64, 16, 1]),
...     "ag_c":paddle.rand([64, 16, 1]),
...     "phi":paddle.rand([1, 16, 16])})
>>> for k, v in out.items():
...     print(f"{k} {v.dtype} {v.shape}")
eta_pred paddle.float32 [64, 16, 1]
eta_dot_pred paddle.float32 [64, 16, 1]
g_pred paddle.float32 [64, 16, 1]
eta_t_pred_c paddle.float32 [64, 16, 1]
eta_dot_pred_c paddle.float32 [64, 16, 1]
lift_pred_c paddle.float32 [64, 16, 1]
g_t_pred_c paddle.float32 [64, 16, 1]
g_dot_pred_c paddle.float32 [64, 16, 1]
Source code in ppsci/arch/phylstm.py
class DeepPhyLSTM(base.Arch):
    """DeepPhyLSTM init function.

    Args:
        input_size (int): The input size.
        output_size (int): The output size.
        hidden_size (int, optional): The hidden size. Defaults to 100.
        model_type (int, optional): The model type, value is 2 or 3, 2 indicates having two sub-models, 3 indicates having three submodels. Defaults to 2.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> # model_type is `2`
        >>> model = ppsci.arch.DeepPhyLSTM(
        ...     input_size=16,
        ...     output_size=1,
        ...     hidden_size=100,
        ...     model_type=2)
        >>> out = model(
        ...     {"ag":paddle.rand([64, 16, 16]),
        ...     "ag_c":paddle.rand([64, 16, 16]),
        ...     "phi":paddle.rand([1, 16, 16])})
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        eta_pred paddle.float32 [64, 16, 1]
        eta_dot_pred paddle.float32 [64, 16, 1]
        g_pred paddle.float32 [64, 16, 1]
        eta_t_pred_c paddle.float32 [64, 16, 1]
        eta_dot_pred_c paddle.float32 [64, 16, 1]
        lift_pred_c paddle.float32 [64, 16, 1]
        >>> # model_type is `3`
        >>> model = ppsci.arch.DeepPhyLSTM(
        ...     input_size=16,
        ...     output_size=1,
        ...     hidden_size=100,
        ...     model_type=3)
        >>> out = model(
        ...     {"ag":paddle.rand([64, 16, 1]),
        ...     "ag_c":paddle.rand([64, 16, 1]),
        ...     "phi":paddle.rand([1, 16, 16])})
        >>> for k, v in out.items():
        ...     print(f"{k} {v.dtype} {v.shape}")
        eta_pred paddle.float32 [64, 16, 1]
        eta_dot_pred paddle.float32 [64, 16, 1]
        g_pred paddle.float32 [64, 16, 1]
        eta_t_pred_c paddle.float32 [64, 16, 1]
        eta_dot_pred_c paddle.float32 [64, 16, 1]
        lift_pred_c paddle.float32 [64, 16, 1]
        g_t_pred_c paddle.float32 [64, 16, 1]
        g_dot_pred_c paddle.float32 [64, 16, 1]
    """

    def __init__(self, input_size, output_size, hidden_size=100, model_type=2):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.model_type = model_type

        if self.model_type == 2:
            self.lstm_model = nn.Sequential(
                nn.LSTM(input_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, 3 * output_size),
            )

            self.lstm_model_f = nn.Sequential(
                nn.LSTM(3 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.Linear(hidden_size, output_size),
            )
        elif self.model_type == 3:
            self.lstm_model = nn.Sequential(
                nn.LSTM(1, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 3 * output_size),
            )

            self.lstm_model_f = nn.Sequential(
                nn.LSTM(3 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, output_size),
            )

            self.lstm_model_g = nn.Sequential(
                nn.LSTM(2 * output_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LSTM(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, output_size),
            )
        else:
            raise ValueError(f"model_type should be 2 or 3, but got {model_type}")

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        if self.model_type == 2:
            result_dict = self._forward_type_2(x)
        elif self.model_type == 3:
            result_dict = self._forward_type_3(x)
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)
        return result_dict

    def _forward_type_2(self, x):
        output = self.lstm_model(x["ag"])
        eta_pred = output[:, :, 0 : self.output_size]
        eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
        g_pred = output[:, :, 2 * self.output_size :]

        # for ag_c
        output_c = self.lstm_model(x["ag_c"])
        eta_pred_c = output_c[:, :, 0 : self.output_size]
        eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
        g_pred_c = output_c[:, :, 2 * self.output_size :]
        eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
        eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
        eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
        tmp = paddle.concat([eta_pred_c, eta_dot1_pred_c, g_pred_c], 2)
        f = self.lstm_model_f(tmp)
        lift_pred_c = eta_tt_pred_c + f

        return {
            "eta_pred": eta_pred,
            "eta_dot_pred": eta_dot_pred,
            "g_pred": g_pred,
            "eta_t_pred_c": eta_t_pred_c,
            "eta_dot_pred_c": eta_dot_pred_c,
            "lift_pred_c": lift_pred_c,
        }

    def _forward_type_3(self, x):
        # physics informed neural networks
        output = self.lstm_model(x["ag"])
        eta_pred = output[:, :, 0 : self.output_size]
        eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
        g_pred = output[:, :, 2 * self.output_size :]

        output_c = self.lstm_model(x["ag_c"])
        eta_pred_c = output_c[:, :, 0 : self.output_size]
        eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
        g_pred_c = output_c[:, :, 2 * self.output_size :]

        eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
        eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
        g_t_pred_c = paddle.matmul(x["phi"], g_pred_c)

        f = self.lstm_model_f(paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2))
        lift_pred_c = eta_tt_pred_c + f

        eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
        g_dot_pred_c = self.lstm_model_g(paddle.concat([eta_dot1_pred_c, g_pred_c], 2))

        return {
            "eta_pred": eta_pred,
            "eta_dot_pred": eta_dot_pred,
            "g_pred": g_pred,
            "eta_t_pred_c": eta_t_pred_c,
            "eta_dot_pred_c": eta_dot_pred_c,
            "lift_pred_c": lift_pred_c,
            "g_t_pred_c": g_t_pred_c,
            "g_dot_pred_c": g_dot_pred_c,
        }

LorenzEmbedding

Bases: Arch

Embedding Koopman model for the Lorenz ODE system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
input_size int

Size of input data. Defaults to 3.

3
hidden_size int

Number of hidden size. Defaults to 500.

500
embed_size int

Number of embedding size. Defaults to 32.

32
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import ppsci
>>> model = ppsci.arch.LorenzEmbedding(
...     input_keys=("x", "y"),
...     output_keys=("u", "v"),
...     input_size=3,
...     hidden_size=500,
...     embed_size=32,
...     drop=0.0,
...     mean=None,
...     std=None,
... )
>>> x_shape = [8, 3, 2]
>>> y_shape = [8, 3, 1]
>>> input_dict = {"x": paddle.rand(x_shape),
...               "y": paddle.rand(y_shape)}
>>> output_dict = model(input_dict)
>>> print(output_dict["u"].shape)
[8, 2, 3]
>>> print(output_dict["v"].shape)
[8, 3, 3]
Source code in ppsci/arch/embedding_koopman.py
class LorenzEmbedding(base.Arch):
    """Embedding Koopman model for the Lorenz ODE system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        input_size (int, optional): Size of input data. Defaults to 3.
        hidden_size (int, optional): Number of hidden size. Defaults to 500.
        embed_size (int, optional): Number of embedding size. Defaults to 32.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.LorenzEmbedding(
        ...     input_keys=("x", "y"),
        ...     output_keys=("u", "v"),
        ...     input_size=3,
        ...     hidden_size=500,
        ...     embed_size=32,
        ...     drop=0.0,
        ...     mean=None,
        ...     std=None,
        ... )
        >>> x_shape = [8, 3, 2]
        >>> y_shape = [8, 3, 1]
        >>> input_dict = {"x": paddle.rand(x_shape),
        ...               "y": paddle.rand(y_shape)}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["u"].shape)
        [8, 2, 3]
        >>> print(output_dict["v"].shape)
        [8, 3, 3]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        input_size: int = 3,
        hidden_size: int = 500,
        embed_size: int = 32,
        drop: float = 0.0,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embed_size = embed_size

        # build observable network
        self.encoder_net = self.build_encoder(input_size, hidden_size, embed_size, drop)
        # build koopman operator
        self.k_diag, self.k_ut = self.build_koopman_operator(embed_size)
        # build recovery network
        self.decoder_net = self.build_decoder(input_size, hidden_size, embed_size)

        mean = [0.0, 0.0, 0.0] if mean is None else mean
        std = [1.0, 1.0, 1.0] if std is None else std
        self.register_buffer("mean", paddle.to_tensor(mean).reshape([1, 3]))
        self.register_buffer("std", paddle.to_tensor(std).reshape([1, 3]))

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Layer):
        if isinstance(m, nn.Linear):
            k = 1 / m.weight.shape[0]
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def build_encoder(
        self, input_size: int, hidden_size: int, embed_size: int, drop: float = 0.0
    ):
        net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embed_size),
            nn.LayerNorm(embed_size),
            nn.Dropout(drop),
        )
        return net

    def build_decoder(self, input_size: int, hidden_size: int, embed_size: int):
        net = nn.Sequential(
            nn.Linear(embed_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
        )
        return net

    def build_koopman_operator(self, embed_size: int):
        # Learned Koopman operator
        data = paddle.linspace(1, 0, embed_size)
        k_diag = paddle.create_parameter(
            shape=data.shape,
            dtype=paddle.get_default_dtype(),
            default_initializer=nn.initializer.Assign(data),
        )

        data = 0.1 * paddle.rand([2 * embed_size - 3])
        k_ut = paddle.create_parameter(
            shape=data.shape,
            dtype=paddle.get_default_dtype(),
            default_initializer=nn.initializer.Assign(data),
        )
        return k_diag, k_ut

    def encoder(self, x: paddle.Tensor):
        x = self._normalize(x)
        g = self.encoder_net(x)
        return g

    def decoder(self, g: paddle.Tensor):
        out = self.decoder_net(g)
        x = self._unnormalize(out)
        return x

    def koopman_operation(self, embed_data: paddle.Tensor, k_matrix: paddle.Tensor):
        # Apply Koopman operation
        embed_pred_data = paddle.bmm(
            k_matrix.expand(
                [embed_data.shape[0], k_matrix.shape[0], k_matrix.shape[1]]
            ),
            embed_data.transpose([0, 2, 1]),
        ).transpose([0, 2, 1])
        return embed_pred_data

    def _normalize(self, x: paddle.Tensor):
        return (x - self.mean) / self.std

    def _unnormalize(self, x: paddle.Tensor):
        return self.std * x + self.mean

    def get_koopman_matrix(self):
        # # Koopman operator
        k_ut_tensor = self.k_ut * 1
        k_ut_tensor = paddle.diag(
            k_ut_tensor[0 : self.embed_size - 1], offset=1
        ) + paddle.diag(k_ut_tensor[self.embed_size - 1 :], offset=2)
        k_matrix = k_ut_tensor + (-1) * k_ut_tensor.t()
        k_matrix = k_matrix + paddle.diag(self.k_diag)
        return k_matrix

    def forward_tensor(self, x):
        k_matrix = self.get_koopman_matrix()
        embed_data = self.encoder(x)
        recover_data = self.decoder(embed_data)

        embed_pred_data = self.koopman_operation(embed_data, k_matrix)
        pred_data = self.decoder(embed_pred_data)

        return (pred_data[:, :-1, :], recover_data, k_matrix)

    @staticmethod
    def split_to_dict(data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(x_tensor)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

RosslerEmbedding

Bases: LorenzEmbedding

Embedding Koopman model for the Rossler ODE system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
input_size int

Size of input data. Defaults to 3.

3
hidden_size int

Number of hidden size. Defaults to 500.

500
embed_size int

Number of embedding size. Defaults to 32.

32
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import ppsci
>>> model = ppsci.arch.RosslerEmbedding(
...     input_keys=("x", "y"),
...     output_keys=("u", "v"),
...     input_size=3,
...     hidden_size=500,
...     embed_size=32,
...     drop=0.0,
...     mean=None,
...     std=None,
... )
>>> x_shape = [8, 3, 2]
>>> y_shape = [8, 3, 1]
>>> input_dict = {"x": paddle.rand(x_shape),
...               "y": paddle.rand(y_shape)}
>>> output_dict = model(input_dict)
>>> print(output_dict["u"].shape)
[8, 2, 3]
>>> print(output_dict["v"].shape)
[8, 3, 3]
Source code in ppsci/arch/embedding_koopman.py
class RosslerEmbedding(LorenzEmbedding):
    """Embedding Koopman model for the Rossler ODE system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        input_size (int, optional): Size of input data. Defaults to 3.
        hidden_size (int, optional): Number of hidden size. Defaults to 500.
        embed_size (int, optional): Number of embedding size. Defaults to 32.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.RosslerEmbedding(
        ...     input_keys=("x", "y"),
        ...     output_keys=("u", "v"),
        ...     input_size=3,
        ...     hidden_size=500,
        ...     embed_size=32,
        ...     drop=0.0,
        ...     mean=None,
        ...     std=None,
        ... )
        >>> x_shape = [8, 3, 2]
        >>> y_shape = [8, 3, 1]
        >>> input_dict = {"x": paddle.rand(x_shape),
        ...               "y": paddle.rand(y_shape)}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["u"].shape)
        [8, 2, 3]
        >>> print(output_dict["v"].shape)
        [8, 3, 3]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        input_size: int = 3,
        hidden_size: int = 500,
        embed_size: int = 32,
        drop: float = 0.0,
    ):
        super().__init__(
            input_keys,
            output_keys,
            mean,
            std,
            input_size,
            hidden_size,
            embed_size,
            drop,
        )

CylinderEmbedding

Bases: Arch

Embedding Koopman model for the Cylinder system.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("states", "visc").

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_states", "recover_states").

required
mean Optional[Tuple[float, ...]]

Mean of training dataset. Defaults to None.

None
std Optional[Tuple[float, ...]]

Standard Deviation of training dataset. Defaults to None.

None
embed_size int

Number of embedding size. Defaults to 128.

128
encoder_channels Optional[Tuple[int, ...]]

Number of channels in encoder network. Defaults to None.

None
decoder_channels Optional[Tuple[int, ...]]

Number of channels in decoder network. Defaults to None.

None
drop float

Probability of dropout the units. Defaults to 0.0.

0.0

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.CylinderEmbedding(("states", "visc"), ("pred_states", "recover_states"))
>>> states_shape = [32, 10, 3, 64, 128]
>>> visc_shape = [32, 1]
>>> input_dict = {"states" : paddle.rand(states_shape),
...               "visc" : paddle.rand(visc_shape)}
>>> out_dict = model(input_dict)
>>> print(out_dict["pred_states"].shape)
[32, 9, 3, 64, 128]
>>> print(out_dict["recover_states"].shape)
[32, 10, 3, 64, 128]
Source code in ppsci/arch/embedding_koopman.py
class CylinderEmbedding(base.Arch):
    """Embedding Koopman model for the Cylinder system.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("states", "visc").
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_states", "recover_states").
        mean (Optional[Tuple[float, ...]]): Mean of training dataset. Defaults to None.
        std (Optional[Tuple[float, ...]]): Standard Deviation of training dataset. Defaults to None.
        embed_size (int, optional): Number of embedding size. Defaults to 128.
        encoder_channels (Optional[Tuple[int, ...]]): Number of channels in encoder network. Defaults to None.
        decoder_channels (Optional[Tuple[int, ...]]): Number of channels in decoder network. Defaults to None.
        drop (float, optional):  Probability of dropout the units. Defaults to 0.0.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.CylinderEmbedding(("states", "visc"), ("pred_states", "recover_states"))
        >>> states_shape = [32, 10, 3, 64, 128]
        >>> visc_shape = [32, 1]
        >>> input_dict = {"states" : paddle.rand(states_shape),
        ...               "visc" : paddle.rand(visc_shape)}
        >>> out_dict = model(input_dict)
        >>> print(out_dict["pred_states"].shape)
        [32, 9, 3, 64, 128]
        >>> print(out_dict["recover_states"].shape)
        [32, 10, 3, 64, 128]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        embed_size: int = 128,
        encoder_channels: Optional[Tuple[int, ...]] = None,
        decoder_channels: Optional[Tuple[int, ...]] = None,
        drop: float = 0.0,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.embed_size = embed_size

        X, Y = np.meshgrid(np.linspace(-2, 14, 128), np.linspace(-4, 4, 64))
        self.mask = paddle.to_tensor(np.sqrt(X**2 + Y**2)).unsqueeze(0).unsqueeze(0)

        encoder_channels = (
            [4, 16, 32, 64, 128] if encoder_channels is None else encoder_channels
        )
        decoder_channels = (
            [embed_size // 32, 128, 64, 32, 16]
            if decoder_channels is None
            else decoder_channels
        )
        self.encoder_net = self.build_encoder(embed_size, encoder_channels, drop)
        self.k_diag_net, self.k_ut_net, self.k_lt_net = self.build_koopman_operator(
            embed_size
        )
        self.decoder_net = self.build_decoder(decoder_channels)

        xidx = []
        yidx = []
        for i in range(1, 5):
            yidx.append(np.arange(i, embed_size))
            xidx.append(np.arange(0, embed_size - i))
        self.xidx = paddle.to_tensor(np.concatenate(xidx), dtype="int64")
        self.yidx = paddle.to_tensor(np.concatenate(yidx), dtype="int64")

        mean = [0.0, 0.0, 0.0, 0.0] if mean is None else mean
        std = [1.0, 1.0, 1.0, 1.0] if std is None else std
        self.register_buffer("mean", paddle.to_tensor(mean).reshape([1, 4, 1, 1]))
        self.register_buffer("std", paddle.to_tensor(std).reshape([1, 4, 1, 1]))

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            k = 1 / m.weight.shape[0]
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)
        elif isinstance(m, nn.Conv2D):
            k = 1 / (m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3])
            uniform = Uniform(-(k**0.5), k**0.5)
            uniform(m.weight)
            if m.bias is not None:
                uniform(m.bias)

    def _build_conv_relu_list(
        self, in_channels: Tuple[int, ...], out_channels: Tuple[int, ...]
    ):
        net_list = [
            nn.Conv2D(
                in_channels,
                out_channels,
                kernel_size=(3, 3),
                stride=2,
                padding=1,
                padding_mode="replicate",
            ),
            nn.ReLU(),
        ]
        return net_list

    def build_encoder(
        self, embed_size: int, channels: Tuple[int, ...], drop: float = 0.0
    ):
        net = []
        for i in range(1, len(channels)):
            net.extend(self._build_conv_relu_list(channels[i - 1], channels[i]))
        net.append(
            nn.Conv2D(
                channels[-1],
                embed_size // 32,
                kernel_size=(3, 3),
                padding=1,
                padding_mode="replicate",
            )
        )
        net.append(
            nn.LayerNorm(
                (4, 4, 8),
            )
        )
        net.append(nn.Dropout(drop))
        net = nn.Sequential(*net)
        return net

    def _build_upsample_conv_relu(
        self, in_channels: Tuple[int, ...], out_channels: Tuple[int, ...]
    ):
        net_list = [
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2D(
                in_channels,
                out_channels,
                kernel_size=(3, 3),
                stride=1,
                padding=1,
                padding_mode="replicate",
            ),
            nn.ReLU(),
        ]
        return net_list

    def build_decoder(self, channels: Tuple[int, ...]):
        net = []
        for i in range(1, len(channels)):
            net.extend(self._build_upsample_conv_relu(channels[i - 1], channels[i]))
        net.append(
            nn.Conv2D(
                channels[-1],
                3,
                kernel_size=(3, 3),
                stride=1,
                padding=1,
                padding_mode="replicate",
            ),
        )
        net = nn.Sequential(*net)
        return net

    def build_koopman_operator(self, embed_size: int):
        # Learned Koopman operator parameters
        k_diag_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, embed_size)
        )

        k_ut_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, 4 * embed_size - 10)
        )
        k_lt_net = nn.Sequential(
            nn.Linear(1, 50), nn.ReLU(), nn.Linear(50, 4 * embed_size - 10)
        )
        return k_diag_net, k_ut_net, k_lt_net

    def encoder(self, x: paddle.Tensor, viscosity: paddle.Tensor):
        B, T, C, H, W = x.shape
        x = x.reshape((B * T, C, H, W))
        viscosity = viscosity.repeat_interleave(T, axis=1).reshape((B * T, 1))
        x = paddle.concat(
            [x, viscosity.unsqueeze(-1).unsqueeze(-1) * paddle.ones_like(x[:, :1])],
            axis=1,
        )
        x = self._normalize(x)
        g = self.encoder_net(x)
        g = g.reshape([B, T, -1])
        return g

    def decoder(self, g: paddle.Tensor):
        B, T, _ = g.shape
        x = self.decoder_net(g.reshape([-1, self.embed_size // 32, 4, 8]))
        x = self._unnormalize(x)
        mask0 = (
            self.mask.repeat_interleave(x.shape[1], axis=1).repeat_interleave(
                x.shape[0], axis=0
            )
            < 1
        )
        x[mask0] = 0
        _, C, H, W = x.shape
        x = x.reshape([B, T, C, H, W])
        return x

    def get_koopman_matrix(self, g: paddle.Tensor, visc: paddle.Tensor):
        # # Koopman operator
        kMatrix = paddle.zeros([g.shape[0], self.embed_size, self.embed_size])
        kMatrix.stop_gradient = False
        # Populate the off diagonal terms
        kMatrixUT_data = self.k_ut_net(100 * visc)
        kMatrixLT_data = self.k_lt_net(100 * visc)

        kMatrix = kMatrix.transpose([1, 2, 0])
        kMatrixUT_data_t = kMatrixUT_data.transpose([1, 0])
        kMatrixLT_data_t = kMatrixLT_data.transpose([1, 0])
        kMatrix[self.xidx, self.yidx] = kMatrixUT_data_t
        kMatrix[self.yidx, self.xidx] = kMatrixLT_data_t

        # Populate the diagonal
        ind = np.diag_indices(kMatrix.shape[1])
        ind = paddle.to_tensor(ind, dtype="int64")

        kMatrixDiag = self.k_diag_net(100 * visc)
        kMatrixDiag_t = kMatrixDiag.transpose([1, 0])
        kMatrix[ind[0], ind[1]] = kMatrixDiag_t
        return kMatrix.transpose([2, 0, 1])

    def koopman_operation(self, embed_data: paddle.Tensor, k_matrix: paddle.Tensor):
        embed_pred_data = paddle.bmm(
            k_matrix, embed_data.transpose([0, 2, 1])
        ).transpose([0, 2, 1])
        return embed_pred_data

    def _normalize(self, x: paddle.Tensor):
        x = (x - self.mean) / self.std
        return x

    def _unnormalize(self, x: paddle.Tensor):
        return self.std[:, :3] * x + self.mean[:, :3]

    def forward_tensor(self, states, visc):
        # states.shape=(B, T, C, H, W)
        embed_data = self.encoder(states, visc)
        recover_data = self.decoder(embed_data)

        k_matrix = self.get_koopman_matrix(embed_data, visc)
        embed_pred_data = self.koopman_operation(embed_data, k_matrix)
        pred_data = self.decoder(embed_pred_data)

        return (pred_data[:, :-1], recover_data, k_matrix)

    @staticmethod
    def split_to_dict(data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):

        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.forward_tensor(**x)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

Generator

Bases: Arch

Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is unique to "tempoGAN" example but not an open source network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input1", "input2").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output1", "output2").

required
in_channel int

Number of input channels of the first conv layer.

required
out_channels_tuple Tuple[Tuple[int, ...], ...]

Number of output channels of all conv layers, such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]

required
kernel_sizes_tuple Tuple[Tuple[int, ...], ...]

Number of kernel_size of all conv layers, such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]

required
strides_tuple Tuple[Tuple[int, ...], ...]

Number of stride of all conv layers, such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]

required
use_bns_tuple Tuple[Tuple[bool, ...], ...]

Whether to use the batch_norm layer after each conv layer.

required
acts_tuple Tuple[Tuple[str, ...], ...]

Whether to use the activation layer after each conv layer. If so, witch activation to use, such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

required

Examples:

>>> import ppsci
>>> in_channel = 1
>>> rb_channel0 = (2, 8, 8)
>>> rb_channel1 = (128, 128, 128)
>>> rb_channel2 = (32, 8, 8)
>>> rb_channel3 = (2, 1, 1)
>>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
>>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
>>> strides_tuple = ((1, 1, 1), ) * 4
>>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
>>> acts_tuple = (("relu", None, None), ) * 4
>>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
>>> batch_size = 4
>>> height = 64
>>> width = 64
>>> input_data = paddle.randn([batch_size, in_channel, height, width])
>>> input_dict = {'in': input_data}
>>> output_data = model(input_dict)
>>> print(output_data['out'].shape)
[4, 1, 64, 64]
Source code in ppsci/arch/gan.py
class Generator(base.Arch):
    """Generator Net of GAN. Attention, the net using a kind of variant of ResBlock which is
        unique to "tempoGAN" example but not an open source network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int): Number of input channels of the first conv layer.
        out_channels_tuple (Tuple[Tuple[int, ...], ...]): Number of output channels of all conv layers,
            such as [[out_res0_conv0, out_res0_conv1], [out_res1_conv0, out_res1_conv1]]
        kernel_sizes_tuple (Tuple[Tuple[int, ...], ...]): Number of kernel_size of all conv layers,
            such as [[kernel_size_res0_conv0, kernel_size_res0_conv1], [kernel_size_res1_conv0, kernel_size_res1_conv1]]
        strides_tuple (Tuple[Tuple[int, ...], ...]): Number of stride of all conv layers,
            such as [[stride_res0_conv0, stride_res0_conv1], [stride_res1_conv0, stride_res1_conv1]]
        use_bns_tuple (Tuple[Tuple[bool, ...], ...]): Whether to use the batch_norm layer after each conv layer.
        acts_tuple (Tuple[Tuple[str, ...], ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as [[act_res0_conv0, act_res0_conv1], [act_res1_conv0, act_res1_conv1]]

    Examples:
        >>> import ppsci
        >>> in_channel = 1
        >>> rb_channel0 = (2, 8, 8)
        >>> rb_channel1 = (128, 128, 128)
        >>> rb_channel2 = (32, 8, 8)
        >>> rb_channel3 = (2, 1, 1)
        >>> out_channels_tuple = (rb_channel0, rb_channel1, rb_channel2, rb_channel3)
        >>> kernel_sizes_tuple = (((5, 5), ) * 2 + ((1, 1), ), ) * 4
        >>> strides_tuple = ((1, 1, 1), ) * 4
        >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), )
        >>> acts_tuple = (("relu", None, None), ) * 4
        >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple)
        >>> batch_size = 4
        >>> height = 64
        >>> width = 64
        >>> input_data = paddle.randn([batch_size, in_channel, height, width])
        >>> input_dict = {'in': input_data}
        >>> output_data = model(input_dict)
        >>> print(output_data['out'].shape)
        [4, 1, 64, 64]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels_tuple: Tuple[Tuple[int, ...], ...],
        kernel_sizes_tuple: Tuple[Tuple[int, ...], ...],
        strides_tuple: Tuple[Tuple[int, ...], ...],
        use_bns_tuple: Tuple[Tuple[bool, ...], ...],
        acts_tuple: Tuple[Tuple[str, ...], ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels_tuple = out_channels_tuple
        self.kernel_sizes_tuple = kernel_sizes_tuple
        self.strides_tuple = strides_tuple
        self.use_bns_tuple = use_bns_tuple
        self.acts_tuple = acts_tuple

        self.init_blocks()

    def init_blocks(self):
        blocks_list = []
        for i in range(len(self.out_channels_tuple)):
            in_channel = (
                self.in_channel if i == 0 else self.out_channels_tuple[i - 1][-1]
            )
            blocks_list.append(
                VariantResBlock(
                    in_channel=in_channel,
                    out_channels=self.out_channels_tuple[i],
                    kernel_sizes=self.kernel_sizes_tuple[i],
                    strides=self.strides_tuple[i],
                    use_bns=self.use_bns_tuple[i],
                    acts=self.acts_tuple[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )
        self.blocks = nn.LayerList(blocks_list)

    def forward_tensor(self, x):
        y = x
        for block in self.blocks:
            y = block(y)
        return y

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.forward_tensor(y)
        y = self.split_to_dict(y, self.output_keys, axis=-1)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

Discriminator

Bases: Arch

Discriminator Net of GAN.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input1", "input2").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output1", "output2").

required
in_channel int

Number of input channels of the first conv layer.

required
out_channels Tuple[int, ...]

Number of output channels of all conv layers, such as (out_conv0, out_conv1, out_conv2).

required
fc_channel int

Number of input features of linear layer. Number of output features of the layer is set to 1 in this Net to construct a fully_connected layer.

required
kernel_sizes Tuple[int, ...]

Number of kernel_size of all conv layers, such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).

required
strides Tuple[int, ...]

Number of stride of all conv layers, such as (stride_conv0, stride_conv1, stride_conv2).

required
use_bns Tuple[bool, ...]

Whether to use the batch_norm layer after each conv layer.

required
acts Tuple[str, ...]

Whether to use the activation layer after each conv layer. If so, witch activation to use, such as (act_conv0, act_conv1, act_conv2).

required

Examples:

>>> import ppsci
>>> in_channel = 2
>>> in_channel_tempo = 3
>>> out_channels = (32, 64, 128, 256)
>>> fc_channel = 65536
>>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
>>> strides = (2, 2, 2, 1)
>>> use_bns = (False, True, True, True)
>>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
>>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
>>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
>>> input_data = [paddle.to_tensor(paddle.randn([1, in_channel, 128, 128])),paddle.to_tensor(paddle.randn([1, in_channel, 128, 128]))]
>>> input_dict = {"in_1": input_data[0],"in_2": input_data[1]}
>>> out_dict = model(input_dict)
>>> for k, v in out_dict.items():
...     print(k, v.shape)
out_1 [1, 32, 64, 64]
out_2 [1, 64, 32, 32]
out_3 [1, 128, 16, 16]
out_4 [1, 256, 16, 16]
out_5 [1, 1]
out_6 [1, 32, 64, 64]
out_7 [1, 64, 32, 32]
out_8 [1, 128, 16, 16]
out_9 [1, 256, 16, 16]
out_10 [1, 1]
Source code in ppsci/arch/gan.py
class Discriminator(base.Arch):
    """Discriminator Net of GAN.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input1", "input2").
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output1", "output2").
        in_channel (int):  Number of input channels of the first conv layer.
        out_channels (Tuple[int, ...]): Number of output channels of all conv layers,
            such as (out_conv0, out_conv1, out_conv2).
        fc_channel (int):  Number of input features of linear layer. Number of output features of the layer
            is set to 1 in this Net to construct a fully_connected layer.
        kernel_sizes (Tuple[int, ...]): Number of kernel_size of all conv layers,
            such as (kernel_size_conv0, kernel_size_conv1, kernel_size_conv2).
        strides (Tuple[int, ...]): Number of stride of all conv layers,
            such as (stride_conv0, stride_conv1, stride_conv2).
        use_bns (Tuple[bool, ...]): Whether to use the batch_norm layer after each conv layer.
        acts (Tuple[str, ...]): Whether to use the activation layer after each conv layer. If so, witch activation to use,
            such as (act_conv0, act_conv1, act_conv2).

    Examples:
        >>> import ppsci
        >>> in_channel = 2
        >>> in_channel_tempo = 3
        >>> out_channels = (32, 64, 128, 256)
        >>> fc_channel = 65536
        >>> kernel_sizes = ((4, 4), (4, 4), (4, 4), (4, 4))
        >>> strides = (2, 2, 2, 1)
        >>> use_bns = (False, True, True, True)
        >>> acts = ("leaky_relu", "leaky_relu", "leaky_relu", "leaky_relu", None)
        >>> output_keys_disc = ("out_1", "out_2", "out_3", "out_4", "out_5", "out_6", "out_7", "out_8", "out_9", "out_10")
        >>> model = ppsci.arch.Discriminator(("in_1","in_2"), output_keys_disc, in_channel, out_channels, fc_channel, kernel_sizes, strides, use_bns, acts)
        >>> input_data = [paddle.to_tensor(paddle.randn([1, in_channel, 128, 128])),paddle.to_tensor(paddle.randn([1, in_channel, 128, 128]))]
        >>> input_dict = {"in_1": input_data[0],"in_2": input_data[1]}
        >>> out_dict = model(input_dict)
        >>> for k, v in out_dict.items():
        ...     print(k, v.shape)
        out_1 [1, 32, 64, 64]
        out_2 [1, 64, 32, 32]
        out_3 [1, 128, 16, 16]
        out_4 [1, 256, 16, 16]
        out_5 [1, 1]
        out_6 [1, 32, 64, 64]
        out_7 [1, 64, 32, 32]
        out_8 [1, 128, 16, 16]
        out_9 [1, 256, 16, 16]
        out_10 [1, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        in_channel: int,
        out_channels: Tuple[int, ...],
        fc_channel: int,
        kernel_sizes: Tuple[int, ...],
        strides: Tuple[int, ...],
        use_bns: Tuple[bool, ...],
        acts: Tuple[str, ...],
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.in_channel = in_channel
        self.out_channels = out_channels
        self.fc_channel = fc_channel
        self.kernel_sizes = kernel_sizes
        self.strides = strides
        self.use_bns = use_bns
        self.acts = acts

        self.init_layers()

    def init_layers(self):
        layers_list = []
        for i in range(len(self.out_channels)):
            in_channel = self.in_channel if i == 0 else self.out_channels[i - 1]
            layers_list.append(
                Conv2DBlock(
                    in_channel=in_channel,
                    out_channel=self.out_channels[i],
                    kernel_size=self.kernel_sizes[i],
                    stride=self.strides[i],
                    use_bn=self.use_bns[i],
                    act=self.acts[i],
                    mean=0.0,
                    std=0.04,
                    value=0.1,
                )
            )

        layers_list.append(
            FCBlock(self.fc_channel, self.acts[4], mean=0.0, std=0.04, value=0.1)
        )
        self.layers = nn.LayerList(layers_list)

    def forward_tensor(self, x):
        y = x
        y_list = []
        for layer in self.layers:
            y = layer(y)
            y_list.append(y)
        return y_list  # y_conv1, y_conv2, y_conv3, y_conv4, y_fc(y_out)

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        y_list = []
        # y1_conv1, y1_conv2, y1_conv3, y1_conv4, y1_fc, y2_conv1, y2_conv2, y2_conv3, y2_conv4, y2_fc
        for k in x:
            y_list.extend(self.forward_tensor(x[k]))

        y = self.split_to_dict(y_list, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)

        return y

    @staticmethod
    def split_to_dict(
        data_list: List[paddle.Tensor], keys: Tuple[str, ...]
    ) -> Dict[str, paddle.Tensor]:
        """Overwrite of split_to_dict() method belongs to Class base.Arch.

        Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
        That is because input in "tempoGAN" example is not in a regular format, but a format like:
        {
            "input1": paddle.concat([in1, in2], axis=1),
            "input2": paddle.concat([in1, in3], axis=1),
        }

        Args:
            data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
            keys (Tuple[str, ...]): Keys of outputs.

        Returns:
            Dict[str, paddle.Tensor]: Dict with split data.
        """
        if len(keys) == 1:
            return {keys[0]: data_list[0]}
        return {key: data_list[i] for i, key in enumerate(keys)}
split_to_dict(data_list, keys) staticmethod

Overwrite of split_to_dict() method belongs to Class base.Arch.

Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example. That is because input in "tempoGAN" example is not in a regular format, but a format like: { "input1": paddle.concat([in1, in2], axis=1), "input2": paddle.concat([in1, in3], axis=1), }

Parameters:

Name Type Description Default
data_list List[Tensor]

The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.

required
keys Tuple[str, ...]

Keys of outputs.

required

Returns:

Type Description
Dict[str, Tensor]

Dict[str, paddle.Tensor]: Dict with split data.

Source code in ppsci/arch/gan.py
@staticmethod
def split_to_dict(
    data_list: List[paddle.Tensor], keys: Tuple[str, ...]
) -> Dict[str, paddle.Tensor]:
    """Overwrite of split_to_dict() method belongs to Class base.Arch.

    Reason for overwriting is there is no concat_to_tensor() method called in "tempoGAN" example.
    That is because input in "tempoGAN" example is not in a regular format, but a format like:
    {
        "input1": paddle.concat([in1, in2], axis=1),
        "input2": paddle.concat([in1, in3], axis=1),
    }

    Args:
        data_list (List[paddle.Tensor]): The data to be split. It should be a list of tensor(s), but not a paddle.Tensor.
        keys (Tuple[str, ...]): Keys of outputs.

    Returns:
        Dict[str, paddle.Tensor]: Dict with split data.
    """
    if len(keys) == 1:
        return {keys[0]: data_list[0]}
    return {key: data_list[i] for i, key in enumerate(keys)}

PhysformerGPT2

Bases: Arch

Transformer decoder model for modeling physics.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Input keys, such as ("embeds",).

required
output_keys Tuple[str, ...]

Output keys, such as ("pred_embeds",).

required
num_layers int

Number of transformer layers.

required
num_ctx int

Context length of block.

required
embed_size int

The number of embedding size.

required
num_heads int

The number of heads in multi-head attention.

required
embd_pdrop float

The dropout probability used on embedding features. Defaults to 0.0.

0.0
attn_pdrop float

The dropout probability used on attention weights. Defaults to 0.0.

0.0
resid_pdrop float

The dropout probability used on block outputs. Defaults to 0.0.

0.0
initializer_range float

Initializer range of linear layer. Defaults to 0.05.

0.05
embedding_model Optional[Arch]

Embedding model, If this parameter is set, the embedding model will map the input data to the embedding space and the output data to the physical space. Defaults to None.

None

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
>>> data = paddle.to_tensor(paddle.randn([10, 16, 128]))
>>> inputs = {"embeds": data}
>>> outputs = model(inputs)
>>> print(outputs["pred_embeds"].shape)
[10, 16, 128]
Source code in ppsci/arch/physx_transformer.py
class PhysformerGPT2(base.Arch):
    """Transformer decoder model for modeling physics.

    Args:
        input_keys (Tuple[str, ...]): Input keys, such as ("embeds",).
        output_keys (Tuple[str, ...]): Output keys, such as ("pred_embeds",).
        num_layers (int): Number of transformer layers.
        num_ctx (int): Context length of block.
        embed_size (int): The number of embedding size.
        num_heads (int): The number of heads in multi-head attention.
        embd_pdrop (float, optional): The dropout probability used on embedding features. Defaults to 0.0.
        attn_pdrop (float, optional): The dropout probability used on attention weights. Defaults to 0.0.
        resid_pdrop (float, optional): The dropout probability used on block outputs. Defaults to 0.0.
        initializer_range (float, optional): Initializer range of linear layer. Defaults to 0.05.
        embedding_model (Optional[base.Arch]): Embedding model, If this parameter is set,
            the embedding model will map the input data to the embedding space and the
            output data to the physical space. Defaults to None.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.PhysformerGPT2(("embeds", ), ("pred_embeds", ), 6, 16, 128, 4)
        >>> data = paddle.to_tensor(paddle.randn([10, 16, 128]))
        >>> inputs = {"embeds": data}
        >>> outputs = model(inputs)
        >>> print(outputs["pred_embeds"].shape)
        [10, 16, 128]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_layers: int,
        num_ctx: int,
        embed_size: int,
        num_heads: int,
        embd_pdrop: float = 0.0,
        attn_pdrop: float = 0.0,
        resid_pdrop: float = 0.0,
        initializer_range: float = 0.05,
        embedding_model: Optional[base.Arch] = None,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.num_layers = num_layers
        self.num_ctx = num_ctx
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop
        self.initializer_range = initializer_range

        self.drop = nn.Dropout(embd_pdrop)
        self.blocks = nn.LayerList(
            [
                Block(
                    num_ctx, embed_size, num_heads, attn_pdrop, resid_pdrop, scale=True
                )
                for _ in range(num_layers)
            ]
        )
        self.ln = nn.LayerNorm(embed_size)
        self.linear = nn.Linear(embed_size, embed_size)

        self.apply(self._init_weights)
        self.embedding_model = embedding_model

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            normal_ = Normal(mean=0.0, std=self.initializer_range)
            normal_(module.weight)
            if module.bias is not None:
                zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            zeros_(module.bias)
            ones_(module.weight)

    def get_position_embed(self, x):
        B, N, _ = x.shape
        position_ids = paddle.arange(0, N, dtype=paddle.get_default_dtype()).reshape(
            [1, N, 1]
        )
        position_ids = position_ids.repeat_interleave(B, axis=0)

        position_embeds = paddle.zeros_like(x)
        i = paddle.arange(0, self.embed_size // 2).unsqueeze(0).unsqueeze(0)
        position_embeds[:, :, ::2] = paddle.sin(
            position_ids / 10000 ** (2 * i / self.embed_size)
        )
        position_embeds[:, :, 1::2] = paddle.cos(
            position_ids / 10000 ** (2 * i / self.embed_size)
        )
        return position_embeds

    def _generate_time_series(self, x, max_length):
        cur_len = x.shape[1]
        if cur_len >= max_length:
            raise ValueError(
                f"max_length({max_length}) should be larger than "
                f"the length of input context({cur_len})"
            )

        while cur_len < max_length:
            model_inputs = x[:, -1:]
            outputs = self.forward_tensor(model_inputs)
            next_output = outputs[0][:, -1:]
            x = paddle.concat([x, next_output], axis=1)
            cur_len = cur_len + 1
        return x

    @paddle.no_grad()
    def generate(self, x, max_length=256):
        if max_length <= 0:
            raise ValueError(
                f"max_length({max_length}) should be a strictly positive integer."
            )
        outputs = self._generate_time_series(x, max_length)
        return outputs

    def forward_tensor(self, x):
        position_embeds = self.get_position_embed(x)
        # Combine input embedding, position embedding
        hidden_states = x + position_embeds
        hidden_states = self.drop(hidden_states)

        # Loop through transformer self-attention layers
        for block in self.blocks:
            block_outputs = block(hidden_states)
            hidden_states = block_outputs[0]
        outputs = self.linear(self.ln(hidden_states))
        return (outputs,)

    def forward_eval(self, x):
        input_embeds = x[:, :1]
        outputs = self.generate(input_embeds)
        return (outputs[:, 1:],)

    @staticmethod
    def split_to_dict(data_tensors, keys):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)
        x_tensor = self.concat_to_tensor(x, self.input_keys, axis=-1)
        if self.embedding_model is not None:
            if isinstance(self.embedding_model, CylinderEmbedding):
                x_tensor = self.embedding_model.encoder(x_tensor, x["visc"])
            else:
                x_tensor = self.embedding_model.encoder(x_tensor)

        if self.training:
            y = self.forward_tensor(x_tensor)
        else:
            y = self.forward_eval(x_tensor)

        if self.embedding_model is not None:
            y = (self.embedding_model.decoder(y[0]),)

        y = self.split_to_dict(y, self.output_keys)
        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

ModelList

Bases: Arch

ModelList layer which wrap more than one model that shares inputs.

Parameters:

Name Type Description Default
model_list Tuple[Arch, ...]

Model(s) nested in tuple.

required

Examples:

>>> import paddle
>>> import ppsci
>>> model1 = ppsci.arch.MLP(("x", "y"), ("u", "v"), 10, 128)
>>> model2 = ppsci.arch.MLP(("x", "y"), ("w", "p"), 5, 128)
>>> model = ppsci.arch.ModelList((model1, model2))
>>> input_dict = {"x": paddle.rand([64, 64, 1]),"y": paddle.rand([64, 64, 1])}
>>> output_dict = model(input_dict)
>>> for k, v in output_dict.items():
...     print(k, v.shape)
u [64, 64, 1]
v [64, 64, 1]
w [64, 64, 1]
p [64, 64, 1]
Source code in ppsci/arch/model_list.py
class ModelList(base.Arch):
    """ModelList layer which wrap more than one model that shares inputs.

    Args:
        model_list (Tuple[base.Arch, ...]): Model(s) nested in tuple.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model1 = ppsci.arch.MLP(("x", "y"), ("u", "v"), 10, 128)
        >>> model2 = ppsci.arch.MLP(("x", "y"), ("w", "p"), 5, 128)
        >>> model = ppsci.arch.ModelList((model1, model2))
        >>> input_dict = {"x": paddle.rand([64, 64, 1]),"y": paddle.rand([64, 64, 1])}
        >>> output_dict = model(input_dict)
        >>> for k, v in output_dict.items():
        ...     print(k, v.shape)
        u [64, 64, 1]
        v [64, 64, 1]
        w [64, 64, 1]
        p [64, 64, 1]
    """

    def __init__(
        self,
        model_list: Tuple[base.Arch, ...],
    ):
        super().__init__()
        self.input_keys = sum([model.input_keys for model in model_list], ())
        self.input_keys = set(self.input_keys)

        output_keys_set = set()
        for model in model_list:
            if len(output_keys_set & set(model.output_keys)):
                raise ValueError(
                    "output_keys of model from model_list should be unique,"
                    f"but got duplicate keys: {output_keys_set & set(model.output_keys)}"
                )
            output_keys_set = output_keys_set | set(model.output_keys)
        self.output_keys = tuple(output_keys_set)

        self.model_list = nn.LayerList(model_list)

    def forward(self, x):
        y_all = {}
        for model in self.model_list:
            y = model(x)
            y_all.update(y)

        return y_all

AFNONet

Bases: Arch

Adaptive Fourier Neural Network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
img_size Tuple[int, ...]

Image size. Defaults to (720, 1440).

(720, 1440)
patch_size Tuple[int, ...]

Path. Defaults to (8, 8).

(8, 8)
in_channels int

The input tensor channels. Defaults to 20.

20
out_channels int

The output tensor channels. Defaults to 20.

20
embed_dim int

The embedding dimension for PatchEmbed. Defaults to 768.

768
depth int

Number of transformer depth. Defaults to 12.

12
mlp_ratio float

Number of ratio used in MLP. Defaults to 4.0.

4.0
drop_rate float

The drop ratio used in MLP. Defaults to 0.0.

0.0
drop_path_rate float

The drop ratio used in DropPath. Defaults to 0.0.

0.0
num_blocks int

Number of blocks. Defaults to 8.

8
sparsity_threshold float

The value of threshold for softshrink. Defaults to 0.01.

0.01
hard_thresholding_fraction float

The value of threshold for keep mode. Defaults to 1.0.

1.0
num_timestamps int

Number of timestamp. Defaults to 1.

1

Examples:

>>> import ppsci
>>> model = ppsci.arch.AFNONet(("input", ), ("output", ))
>>> input_data = {"input": paddle.randn([1, 20, 720, 1440])}
>>> output_data = model(input_data)
>>> for k, v in output_data.items():
...     print(k, v.shape)
output [1, 20, 720, 1440]
Source code in ppsci/arch/afno.py
class AFNONet(base.Arch):
    """Adaptive Fourier Neural Network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440).
        patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8).
        in_channels (int, optional): The input tensor channels. Defaults to 20.
        out_channels (int, optional): The output tensor channels. Defaults to 20.
        embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768.
        depth (int, optional): Number of transformer depth. Defaults to 12.
        mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0.
        drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0.
        drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0.
        num_blocks (int, optional): Number of blocks. Defaults to 8.
        sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
        hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
        num_timestamps (int, optional): Number of timestamp. Defaults to 1.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.AFNONet(("input", ), ("output", ))
        >>> input_data = {"input": paddle.randn([1, 20, 720, 1440])}
        >>> output_data = model(input_data)
        >>> for k, v in output_data.items():
        ...     print(k, v.shape)
        output [1, 20, 720, 1440]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        img_size: Tuple[int, ...] = (720, 1440),
        patch_size: Tuple[int, ...] = (8, 8),
        in_channels: int = 20,
        out_channels: int = 20,
        embed_dim: int = 768,
        depth: int = 12,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        num_blocks: int = 8,
        sparsity_threshold: float = 0.01,
        hard_thresholding_fraction: float = 1.0,
        num_timestamps: int = 1,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.embed_dim = embed_dim
        self.num_blocks = num_blocks
        self.num_timestamps = num_timestamps
        norm_layer = partial(nn.LayerNorm, epsilon=1e-6)

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=self.patch_size,
            in_channels=self.in_channels,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        data = paddle.zeros((1, num_patches, embed_dim))
        data = initializer.trunc_normal_(data, std=0.02)
        self.pos_embed = paddle.create_parameter(
            shape=data.shape,
            dtype=data.dtype,
            default_initializer=nn.initializer.Assign(data),
        )
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)]

        self.h = img_size[0] // self.patch_size[0]
        self.w = img_size[1] // self.patch_size[1]

        self.blocks = nn.LayerList(
            [
                Block(
                    dim=embed_dim,
                    mlp_ratio=mlp_ratio,
                    drop=drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    num_blocks=self.num_blocks,
                    sparsity_threshold=sparsity_threshold,
                    hard_thresholding_fraction=hard_thresholding_fraction,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(
            embed_dim,
            self.out_channels * self.patch_size[0] * self.patch_size[1],
            bias_attr=False,
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            initializer.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                initializer.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            initializer.ones_(m.weight)
            initializer.zeros_(m.bias)
        elif isinstance(m, nn.Conv2D):
            initializer.conv_init_(m)

    def forward_tensor(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = x.reshape((B, self.h, self.w, self.embed_dim))
        for block in self.blocks:
            x = block(x)

        x = self.head(x)

        b = x.shape[0]
        p1 = self.patch_size[0]
        p2 = self.patch_size[1]
        h = self.img_size[0] // self.patch_size[0]
        w = self.img_size[1] // self.patch_size[1]
        c_out = x.shape[3] // (p1 * p2)
        x = x.reshape((b, h, w, p1, p2, c_out))
        x = x.transpose((0, 5, 1, 3, 2, 4))
        x = x.reshape((b, c_out, h * p1, w * p2))

        return x

    @staticmethod
    def split_to_dict(data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        y = []
        input = x_tensor
        for _ in range(self.num_timestamps):
            out = self.forward_tensor(input)
            y.append(out)
            input = out
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

PrecipNet

Bases: Arch

Precipitation Network.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
wind_model Arch

Wind model.

required
img_size Tuple[int, ...]

Image size. Defaults to (720, 1440).

(720, 1440)
patch_size Tuple[int, ...]

Path. Defaults to (8, 8).

(8, 8)
in_channels int

The input tensor channels. Defaults to 20.

20
out_channels int

The output tensor channels. Defaults to 1.

1
embed_dim int

The embedding dimension for PatchEmbed. Defaults to 768.

768
depth int

Number of transformer depth. Defaults to 12.

12
mlp_ratio float

Number of ratio used in MLP. Defaults to 4.0.

4.0
drop_rate float

The drop ratio used in MLP. Defaults to 0.0.

0.0
drop_path_rate float

The drop ratio used in DropPath. Defaults to 0.0.

0.0
num_blocks int

Number of blocks. Defaults to 8.

8
sparsity_threshold float

The value of threshold for softshrink. Defaults to 0.01.

0.01
hard_thresholding_fraction float

The value of threshold for keep mode. Defaults to 1.0.

1.0
num_timestamps int

Number of timestamp. Defaults to 1.

1

Examples:

>>> import ppsci
>>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", ))
>>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model)
>>> data = paddle.randn([1, 20, 720, 1440])
>>> data_dict = {"input": data}
>>> output = model.forward(data_dict)
>>> print(output['output'].shape)
[1, 1, 720, 1440]
Source code in ppsci/arch/afno.py
class PrecipNet(base.Arch):
    """Precipitation Network.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        wind_model (base.Arch): Wind model.
        img_size (Tuple[int, ...], optional): Image size. Defaults to (720, 1440).
        patch_size (Tuple[int, ...], optional): Path. Defaults to (8, 8).
        in_channels (int, optional): The input tensor channels. Defaults to 20.
        out_channels (int, optional): The output tensor channels. Defaults to 1.
        embed_dim (int, optional): The embedding dimension for PatchEmbed. Defaults to 768.
        depth (int, optional): Number of transformer depth. Defaults to 12.
        mlp_ratio (float, optional): Number of ratio used in MLP. Defaults to 4.0.
        drop_rate (float, optional): The drop ratio used in MLP. Defaults to 0.0.
        drop_path_rate (float, optional): The drop ratio used in DropPath. Defaults to 0.0.
        num_blocks (int, optional): Number of blocks. Defaults to 8.
        sparsity_threshold (float, optional): The value of threshold for softshrink. Defaults to 0.01.
        hard_thresholding_fraction (float, optional): The value of threshold for keep mode. Defaults to 1.0.
        num_timestamps (int, optional): Number of timestamp. Defaults to 1.

    Examples:
        >>> import ppsci
        >>> wind_model = ppsci.arch.AFNONet(("input", ), ("output", ))
        >>> model = ppsci.arch.PrecipNet(("input", ), ("output", ), wind_model)
        >>> data = paddle.randn([1, 20, 720, 1440])
        >>> data_dict = {"input": data}
        >>> output = model.forward(data_dict)
        >>> print(output['output'].shape)
        [1, 1, 720, 1440]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        wind_model: base.Arch,
        img_size: Tuple[int, ...] = (720, 1440),
        patch_size: Tuple[int, ...] = (8, 8),
        in_channels: int = 20,
        out_channels: int = 1,
        embed_dim: int = 768,
        depth: int = 12,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        num_blocks: int = 8,
        sparsity_threshold: float = 0.01,
        hard_thresholding_fraction: float = 1.0,
        num_timestamps=1,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.embed_dim = embed_dim
        self.num_blocks = num_blocks
        self.num_timestamps = num_timestamps
        self.backbone = AFNONet(
            ("input",),
            ("output",),
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            out_channels=out_channels,
            embed_dim=embed_dim,
            depth=depth,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            num_blocks=num_blocks,
            sparsity_threshold=sparsity_threshold,
            hard_thresholding_fraction=hard_thresholding_fraction,
        )
        self.ppad = PeriodicPad2d(1)
        self.conv = nn.Conv2D(
            self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=0
        )
        self.act = nn.ReLU()
        self.apply(self._init_weights)
        self.wind_model = wind_model
        self.wind_model.eval()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            initializer.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                initializer.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            initializer.ones_(m.weight)
            initializer.zeros_(m.bias)
        elif isinstance(m, nn.Conv2D):
            initializer.conv_init_(m)

    def forward_tensor(self, x):
        x = self.backbone.forward_tensor(x)
        x = self.ppad(x)
        x = self.conv(x)
        x = self.act(x)
        return x

    @staticmethod
    def split_to_dict(data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        input_wind = x_tensor
        y = []
        for _ in range(self.num_timestamps):
            with paddle.no_grad():
                out_wind = self.wind_model.forward_tensor(input_wind)
            out = self.forward_tensor(out_wind)
            y.append(out)
            input_wind = out_wind
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

PhyCRNet

Bases: Arch

Physics-informed convolutional-recurrent neural networks.

Parameters:

Name Type Description Default
input_channels int

The input channels.

required
hidden_channels Tuple[int, ...]

The hidden channels.

required
input_kernel_size Tuple[int, ...]

The input kernel size(s).

required
input_stride Tuple[int, ...]

The input stride(s).

required
input_padding Tuple[int, ...]

The input padding(s).

required
dt float

The dt parameter.

required
num_layers Tuple[int, ...]

The number of layers.

required
upscale_factor int

The upscale factor.

required
step int

The step(s). Defaults to 1.

1
effective_step Tuple[int, ...]

The effective step. Defaults to (1, ).

(1)

Examples:

>>> import ppsci
>>> model = ppsci.arch.PhyCRNet(
...     input_channels=2,
...     hidden_channels=[8, 32, 128, 128],
...     input_kernel_size=[4, 4, 4, 3],
...     input_stride=[2, 2, 2, 1],
...     input_padding=[1, 1, 1, 1],
...     dt=0.002,
...     num_layers=[3, 1],
...     upscale_factor=8
... )
Source code in ppsci/arch/phycrnet.py
class PhyCRNet(base.Arch):
    """Physics-informed convolutional-recurrent neural networks.

    Args:
        input_channels (int): The input channels.
        hidden_channels (Tuple[int, ...]): The hidden channels.
        input_kernel_size (Tuple[int, ...]):  The input kernel size(s).
        input_stride (Tuple[int, ...]): The input stride(s).
        input_padding (Tuple[int, ...]): The input padding(s).
        dt (float): The dt parameter.
        num_layers (Tuple[int, ...]): The number of layers.
        upscale_factor (int): The upscale factor.
        step (int, optional): The step(s). Defaults to 1.
        effective_step (Tuple[int, ...], optional): The effective step. Defaults to (1, ).

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.PhyCRNet(
        ...     input_channels=2,
        ...     hidden_channels=[8, 32, 128, 128],
        ...     input_kernel_size=[4, 4, 4, 3],
        ...     input_stride=[2, 2, 2, 1],
        ...     input_padding=[1, 1, 1, 1],
        ...     dt=0.002,
        ...     num_layers=[3, 1],
        ...     upscale_factor=8
        ... )
    """

    def __init__(
        self,
        input_channels: int,
        hidden_channels: Tuple[int, ...],
        input_kernel_size: Tuple[int, ...],
        input_stride: Tuple[int, ...],
        input_padding: Tuple[int, ...],
        dt: float,
        num_layers: Tuple[int, ...],
        upscale_factor: int,
        step: int = 1,
        effective_step: Tuple[int, ...] = (1,),
    ):
        super(PhyCRNet, self).__init__()

        # input channels of layer includes input_channels and hidden_channels of cells
        self.input_channels = [input_channels] + hidden_channels
        self.hidden_channels = hidden_channels
        self.input_kernel_size = input_kernel_size
        self.input_stride = input_stride
        self.input_padding = input_padding
        self.step = step
        self.effective_step = effective_step
        self._all_layers = []
        self.dt = dt
        self.upscale_factor = upscale_factor

        # number of layers
        self.num_encoder = num_layers[0]
        self.num_convlstm = num_layers[1]

        # encoder - downsampling
        self.encoder = paddle.nn.LayerList(
            [
                encoder_block(
                    input_channels=self.input_channels[i],
                    hidden_channels=self.hidden_channels[i],
                    input_kernel_size=self.input_kernel_size[i],
                    input_stride=self.input_stride[i],
                    input_padding=self.input_padding[i],
                )
                for i in range(self.num_encoder)
            ]
        )

        # ConvLSTM
        self.convlstm = paddle.nn.LayerList(
            [
                ConvLSTMCell(
                    input_channels=self.input_channels[i],
                    hidden_channels=self.hidden_channels[i],
                    input_kernel_size=self.input_kernel_size[i],
                    input_stride=self.input_stride[i],
                    input_padding=self.input_padding[i],
                )
                for i in range(self.num_encoder, self.num_encoder + self.num_convlstm)
            ]
        )

        # output layer
        self.output_layer = nn.Conv2D(
            2, 2, kernel_size=5, stride=1, padding=2, padding_mode="circular"
        )

        # pixelshuffle - upscale
        self.pixelshuffle = nn.PixelShuffle(self.upscale_factor)

        # initialize weights
        self.apply(_initialize_weights)
        initializer_0 = paddle.nn.initializer.Constant(0.0)
        initializer_0(self.output_layer.bias)
        self.enable_transform = True

    def forward(self, x):
        if self.enable_transform:
            if self._input_transform is not None:
                x = self._input_transform(x)
        output_x = x

        self.initial_state = x["initial_state"]
        x = x["input"]
        internal_state = []
        outputs = []
        second_last_state = []

        for step in range(self.step):
            xt = x

            # encoder
            for encoder in self.encoder:
                x = encoder(x)

            # convlstm
            for i, lstm in enumerate(self.convlstm, self.num_encoder):
                if step == 0:
                    (h, c) = lstm.init_hidden_tensor(
                        prev_state=self.initial_state[i - self.num_encoder]
                    )
                    internal_state.append((h, c))

                # one-step forward
                (h, c) = internal_state[i - self.num_encoder]
                x, new_c = lstm(x, h, c)
                internal_state[i - self.num_encoder] = (x, new_c)

            # output
            x = self.pixelshuffle(x)
            x = self.output_layer(x)

            # residual connection
            x = xt + self.dt * x

            if step == (self.step - 2):
                second_last_state = internal_state.copy()

            if step in self.effective_step:
                outputs.append(x)

        result_dict = {"outputs": outputs, "second_last_state": second_last_state}
        if self.enable_transform:
            if self._output_transform is not None:
                result_dict = self._output_transform(output_x, result_dict)
        return result_dict

UNetEx

Bases: Arch

U-Net Extension for CFD.

Reference: Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.

Parameters:

Name Type Description Default
input_key str

Name of function data for input.

required
output_key str

Name of function data for output.

required
in_channel int

Number of channels of input.

required
out_channel int

Number of channels of output.

required
kernel_size int

Size of kernel of convolution layer. Defaults to 3.

3
filters Tuple[int, ...]

Number of filters. Defaults to (16, 32, 64).

(16, 32, 64)
layers int

Number of encoders or decoders. Defaults to 3.

3
weight_norm bool

Whether use weight normalization layer. Defaults to True.

True
batch_norm bool

Whether add batch normalization layer. Defaults to True.

True
activation Type[Layer]

Name of activation function. Defaults to nn.ReLU.

ReLU
final_activation Optional[Type[Layer]]

Name of final activation function. Defaults to None.

None

Examples:

>>> import ppsci
>>> model = ppsci.arch.UNetEx(
...     input_key="input",
...     output_key="output",
...     in_channel=3,
...     out_channel=3,
...     kernel_size=5,
...     filters=(4, 4, 4, 4),
...     layers=3,
...     weight_norm=False,
...     batch_norm=False,
...     activation=None,
...     final_activation=None,
... )
>>> input_dict = {'input': paddle.rand([4, 3, 4, 4])}
>>> output_dict = model(input_dict)
>>> print(output_dict['output'])
>>> print(output_dict['output'].shape)
[4, 3, 4, 4]
Source code in ppsci/arch/unetex.py
class UNetEx(base.Arch):
    """U-Net Extension for CFD.

    Reference: [Ribeiro M D, Rehman A, Ahmed S, et al. DeepCFD: Efficient steady-state laminar flow approximation with deep convolutional neural networks[J]. arXiv preprint arXiv:2004.08826, 2020.](https://arxiv.org/abs/2004.08826)

    Args:
        input_key (str): Name of function data for input.
        output_key (str): Name of function data for output.
        in_channel (int): Number of channels of input.
        out_channel (int): Number of channels of output.
        kernel_size (int, optional): Size of kernel of convolution layer. Defaults to 3.
        filters (Tuple[int, ...], optional): Number of filters. Defaults to (16, 32, 64).
        layers (int, optional): Number of encoders or decoders. Defaults to 3.
        weight_norm (bool, optional): Whether use weight normalization layer. Defaults to True.
        batch_norm (bool, optional): Whether add batch normalization layer. Defaults to True.
        activation (Type[nn.Layer], optional): Name of activation function. Defaults to nn.ReLU.
        final_activation (Optional[Type[nn.Layer]]): Name of final activation function. Defaults to None.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.UNetEx(
        ...     input_key="input",
        ...     output_key="output",
        ...     in_channel=3,
        ...     out_channel=3,
        ...     kernel_size=5,
        ...     filters=(4, 4, 4, 4),
        ...     layers=3,
        ...     weight_norm=False,
        ...     batch_norm=False,
        ...     activation=None,
        ...     final_activation=None,
        ... )
        >>> input_dict = {'input': paddle.rand([4, 3, 4, 4])}
        >>> output_dict = model(input_dict)
        >>> print(output_dict['output']) # doctest: +SKIP
        >>> print(output_dict['output'].shape)
        [4, 3, 4, 4]
    """

    def __init__(
        self,
        input_key: str,
        output_key: str,
        in_channel: int,
        out_channel: int,
        kernel_size: int = 3,
        filters: Tuple[int, ...] = (16, 32, 64),
        layers: int = 3,
        weight_norm: bool = True,
        batch_norm: bool = True,
        activation: Type[nn.Layer] = nn.ReLU,
        final_activation: Optional[Type[nn.Layer]] = None,
    ):
        if len(filters) == 0:
            raise ValueError("The filters shouldn't be empty ")

        super().__init__()
        self.input_keys = (input_key,)
        self.output_keys = (output_key,)
        self.final_activation = final_activation
        self.encoder = create_encoder(
            in_channel,
            filters,
            kernel_size,
            weight_norm,
            batch_norm,
            activation,
            layers,
        )
        decoders = [
            create_decoder(
                1, filters, kernel_size, weight_norm, batch_norm, activation, layers
            )
            for i in range(out_channel)
        ]
        self.decoders = nn.Sequential(*decoders)

    def encode(self, x):
        tensors = []
        indices = []
        sizes = []
        for encoder in self.encoder:
            x = encoder(x)
            sizes.append(x.shape)
            tensors.append(x)
            x, ind = nn.functional.max_pool2d(x, 2, 2, return_mask=True)
            indices.append(ind)
        return x, tensors, indices, sizes

    def decode(self, x, tensors, indices, sizes):
        y = []
        for _decoder in self.decoders:
            _x = x
            _tensors = tensors[:]
            _indices = indices[:]
            _sizes = sizes[:]
            for decoder in _decoder:
                tensor = _tensors.pop()
                size = _sizes.pop()
                indice = _indices.pop()
                # upsample operations
                _x = nn.functional.max_unpool2d(_x, indice, 2, 2, output_size=size)
                _x = paddle.concat([tensor, _x], axis=1)
                _x = decoder(_x)
            y.append(_x)
        return paddle.concat(y, axis=1)

    def forward(self, x):
        x = x[self.input_keys[0]]
        x, tensors, indices, sizes = self.encode(x)
        x = self.decode(x, tensors, indices, sizes)
        if self.final_activation is not None:
            x = self.final_activation(x)
        return {self.output_keys[0]: x}

USCNN

Bases: Arch

Physics-informed convolutional neural networks.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("coords").

required
output_keys Tuple[str, ...]

Name of output keys, such as ("outputV").

required
hidden_size Union[int, Tuple[int, ...]]

the hidden channel for convolutional layers

required
h float

the spatial step

required
nx int

the number of grids along x-axis

required
ny int

the number of grids along y-axis

required
nvar_in int

input channel. Defaults to 1.

1
nvar_out int

output channel. Defaults to 1.

1
pad_singleside int

pad for hard boundary constraint. Defaults to 1.

1
k int

kernel_size. Defaults to 5.

5
s int

stride. Defaults to 1.

1
p int

padding. Defaults to 2.

2

Examples:

>>> import ppsci
>>> model = ppsci.arch.USCNN(
...     ["coords"],
...     ["outputV"],
...     [16, 32, 16],
...     h=0.01,
...     ny=19,
...     nx=84,
...     nvar_in=2,
...     nvar_out=1,
...     pad_singleside=1,
... )
Source code in ppsci/arch/uscnn.py
class USCNN(base.Arch):
    """Physics-informed convolutional neural networks.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("coords").
        output_keys (Tuple[str, ...]):Name of output keys, such as ("outputV").
        hidden_size (Union[int, Tuple[int, ...]]): the hidden channel for convolutional layers
        h (float): the spatial step
        nx (int):  the number of grids along x-axis
        ny (int): the number of grids along y-axis
        nvar_in (int, optional):  input channel. Defaults to 1.
        nvar_out (int, optional): output channel. Defaults to 1.
        pad_singleside (int, optional): pad for hard boundary constraint. Defaults to 1.
        k (int, optional): kernel_size. Defaults to 5.
        s (int, optional): stride. Defaults to 1.
        p (int, optional): padding. Defaults to 2.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.USCNN(
        ...     ["coords"],
        ...     ["outputV"],
        ...     [16, 32, 16],
        ...     h=0.01,
        ...     ny=19,
        ...     nx=84,
        ...     nvar_in=2,
        ...     nvar_out=1,
        ...     pad_singleside=1,
        ... )
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        hidden_size: Union[int, Tuple[int, ...]],
        h: float,
        nx: int,
        ny: int,
        nvar_in: int = 1,
        nvar_out: int = 1,
        pad_singleside: int = 1,
        k: int = 5,
        s: int = 1,
        p: int = 2,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.nvar_in = nvar_in
        self.nvar_out = nvar_out
        self.k = k
        self.s = s
        self.p = p
        self.deltaX = h
        self.nx = nx
        self.ny = ny
        self.pad_singleside = pad_singleside
        self.relu = nn.ReLU()
        self.US = nn.Upsample(size=[self.ny - 2, self.nx - 2], mode="bicubic")
        self.conv1 = nn.Conv2D(
            self.nvar_in, hidden_size[0], kernel_size=k, stride=s, padding=p
        )
        self.conv2 = nn.Conv2D(
            hidden_size[0], hidden_size[1], kernel_size=k, stride=s, padding=p
        )
        self.conv3 = nn.Conv2D(
            hidden_size[1], hidden_size[2], kernel_size=k, stride=s, padding=p
        )
        self.conv4 = nn.Conv2D(
            hidden_size[2], self.nvar_out, kernel_size=k, stride=s, padding=p
        )
        self.pixel_shuffle = nn.PixelShuffle(1)
        self.apply(self.init_weights)
        self.udfpad = nn.Pad2D(
            [pad_singleside, pad_singleside, pad_singleside, pad_singleside], value=0
        )

    def init_weights(self, m):
        if isinstance(m, nn.Conv2D):
            bound = 1 / np.sqrt(np.prod(m.weight.shape[1:]))
            ppsci.utils.initializer.uniform_(m.weight, -bound, bound)
            if m.bias is not None:
                ppsci.utils.initializer.uniform_(m.bias, -bound, bound)

    def forward(self, x):
        y = self.concat_to_tensor(x, self.input_keys, axis=-1)
        y = self.US(y)
        y = self.relu(self.conv1(y))
        y = self.relu(self.conv2(y))
        y = self.relu(self.conv3(y))
        y = self.pixel_shuffle(self.conv4(y))

        y = self.udfpad(y)
        y = y[:, 0, :, :].reshape([y.shape[0], 1, y.shape[2], y.shape[3]])
        y = self.split_to_dict(y, self.output_keys)
        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

NowcastNet

Bases: Arch

The NowcastNet model.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
input_length int

Input length. Defaults to 9.

9
total_length int

Total length. Defaults to 29.

29
image_height int

Image height. Defaults to 512.

512
image_width int

Image width. Defaults to 512.

512
image_ch int

Image channel. Defaults to 2.

2
ngf int

Noise Projector input length. Defaults to 32.

32

Examples:

>>> import ppsci
>>> model = ppsci.arch.NowcastNet(("input", ), ("output", ))
>>> input_data = paddle.rand([1, 9, 512, 512, 2])
>>> input_dict = {"input": input_data}
>>> output_dict = model(input_dict)
>>> print(output_dict["output"].shape)
[1, 20, 512, 512, 1]
Source code in ppsci/arch/nowcastnet.py
class NowcastNet(base.Arch):
    """The NowcastNet model.

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        input_length (int, optional): Input length. Defaults to 9.
        total_length (int, optional): Total length. Defaults to 29.
        image_height (int, optional): Image height. Defaults to 512.
        image_width (int, optional): Image width. Defaults to 512.
        image_ch (int, optional): Image channel. Defaults to 2.
        ngf (int, optional): Noise Projector input length. Defaults to 32.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.NowcastNet(("input", ), ("output", ))
        >>> input_data = paddle.rand([1, 9, 512, 512, 2])
        >>> input_dict = {"input": input_data}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["output"].shape)
        [1, 20, 512, 512, 1]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_length: int = 9,
        total_length: int = 29,
        image_height: int = 512,
        image_width: int = 512,
        image_ch: int = 2,
        ngf: int = 32,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.input_length = input_length
        self.total_length = total_length
        self.image_height = image_height
        self.image_width = image_width
        self.image_ch = image_ch
        self.ngf = ngf

        configs = collections.namedtuple(
            "Object", ["ngf", "evo_ic", "gen_oc", "ic_feature"]
        )
        configs.ngf = self.ngf
        configs.evo_ic = self.total_length - self.input_length
        configs.gen_oc = self.total_length - self.input_length
        configs.ic_feature = self.ngf * 10

        self.pred_length = self.total_length - self.input_length
        self.evo_net = Evolution_Network(self.input_length, self.pred_length, base_c=32)
        self.gen_enc = Generative_Encoder(self.total_length, base_c=self.ngf)
        self.gen_dec = Generative_Decoder(configs)
        self.proj = Noise_Projector(self.ngf)
        sample_tensor = paddle.zeros(shape=[1, 1, self.image_height, self.image_width])
        self.grid = make_grid(sample_tensor)

    @staticmethod
    def split_to_dict(data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x_tensor = self.concat_to_tensor(x, self.input_keys)

        y = []
        out = self.forward_tensor(x_tensor)
        y.append(out)
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

    def forward_tensor(self, x):
        all_frames = x[:, :, :, :, :1]
        frames = all_frames.transpose(perm=[0, 1, 4, 2, 3])
        batch = frames.shape[0]
        height = frames.shape[3]
        width = frames.shape[4]
        # Input Frames
        input_frames = frames[:, : self.input_length]
        input_frames = input_frames.reshape((batch, self.input_length, height, width))
        # Evolution Network
        intensity, motion = self.evo_net(input_frames)
        motion_ = motion.reshape((batch, self.pred_length, 2, height, width))
        intensity_ = intensity.reshape((batch, self.pred_length, 1, height, width))
        series = []
        last_frames = all_frames[:, self.input_length - 1 : self.input_length, :, :, 0]
        grid = self.grid.tile((batch, 1, 1, 1))
        for i in range(self.pred_length):
            last_frames = warp(
                last_frames, motion_[:, i], grid, mode="nearest", padding_mode="border"
            )
            last_frames = last_frames + intensity_[:, i]
            series.append(last_frames)
        evo_result = paddle.concat(x=series, axis=1)
        evo_result = evo_result / 128
        # Generative Network
        evo_feature = self.gen_enc(paddle.concat(x=[input_frames, evo_result], axis=1))
        noise = paddle.randn(shape=[batch, self.ngf, height // 32, width // 32])
        noise_feature = (
            self.proj(noise)
            .reshape((batch, -1, 4, 4, 8, 8))
            .transpose(perm=[0, 1, 4, 5, 2, 3])
            .reshape((batch, -1, height // 8, width // 8))
        )
        feature = paddle.concat(x=[evo_feature, noise_feature], axis=1)
        gen_result = self.gen_dec(feature, evo_result)
        return gen_result.unsqueeze(axis=-1)

HEDeepONets

Bases: Arch

Physical information deep operator networks.

Parameters:

Name Type Description Default
heat_input_keys Tuple[str, ...]

Name of input data for heat boundary.

required
cold_input_keys Tuple[str, ...]

Name of input data for cold boundary.

required
trunk_input_keys Tuple[str, ...]

Name of input data for trunk net.

required
output_keys Tuple[str, ...]

Output name of predicted temperature.

required
heat_num_loc int

Number of sampled input data for heat boundary.

required
cold_num_loc int

Number of sampled input data for cold boundary.

required
num_features int

Number of features extracted from heat boundary, same for cold boundary and trunk net.

required
branch_num_layers int

Number of hidden layers of branch net.

required
trunk_num_layers int

Number of hidden layers of trunk net.

required
branch_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of branch net. An integer for all layers, or list of integer specify each layer's size.

required
trunk_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of trunk net. An integer for all layers, or list of integer specify each layer's size.

required
branch_skip_connection bool

Whether to use skip connection for branch net. Defaults to False.

False
trunk_skip_connection bool

Whether to use skip connection for trunk net. Defaults to False.

False
branch_activation str

Name of activation function for branch net. Defaults to "tanh".

'tanh'
trunk_activation str

Name of activation function for trunk net. Defaults to "tanh".

'tanh'
branch_weight_norm bool

Whether to apply weight norm on parameter(s) for branch net. Defaults to False.

False
trunk_weight_norm bool

Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.

False
use_bias bool

Whether to add bias on predicted G(u)(y). Defaults to True.

True

Examples:

>>> import ppsci
>>> model = ppsci.arch.HEDeepONets(
...     ('qm_h',),
...     ('qm_c',),
...     ("x",'t'),
...     ("T_h",'T_c','T_w'),
...     1,
...     1,
...     100,
...     9,
...     6,
...     256,
...     128,
...     branch_activation="swish",
...     trunk_activation="swish",
...     use_bias=True,
... )
Source code in ppsci/arch/he_deeponets.py
class HEDeepONets(base.Arch):
    """Physical information deep operator networks.

    Args:
        heat_input_keys (Tuple[str, ...]): Name of input data for heat boundary.
        cold_input_keys (Tuple[str, ...]): Name of input data for cold boundary.
        trunk_input_keys (Tuple[str, ...]): Name of input data for trunk net.
        output_keys (Tuple[str, ...]): Output name of predicted temperature.
        heat_num_loc (int): Number of sampled input data for heat boundary.
        cold_num_loc (int): Number of sampled input data for cold boundary.
        num_features (int): Number of features extracted from heat boundary, same for cold boundary and trunk net.
        branch_num_layers (int): Number of hidden layers of branch net.
        trunk_num_layers (int): Number of hidden layers of trunk net.
        branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of branch net.
            An integer for all layers, or list of integer specify each layer's size.
        trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net.
            An integer for all layers, or list of integer specify each layer's size.
        branch_skip_connection (bool, optional): Whether to use skip connection for branch net. Defaults to False.
        trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False.
        branch_activation (str, optional): Name of activation function for branch net. Defaults to "tanh".
        trunk_activation (str, optional): Name of activation function for trunk net. Defaults to "tanh".
        branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for branch net. Defaults to False.
        trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.
        use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.HEDeepONets(
        ...     ('qm_h',),
        ...     ('qm_c',),
        ...     ("x",'t'),
        ...     ("T_h",'T_c','T_w'),
        ...     1,
        ...     1,
        ...     100,
        ...     9,
        ...     6,
        ...     256,
        ...     128,
        ...     branch_activation="swish",
        ...     trunk_activation="swish",
        ...     use_bias=True,
        ... )
    """

    def __init__(
        self,
        heat_input_keys: Tuple[str, ...],
        cold_input_keys: Tuple[str, ...],
        trunk_input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        heat_num_loc: int,
        cold_num_loc: int,
        num_features: int,
        branch_num_layers: int,
        trunk_num_layers: int,
        branch_hidden_size: Union[int, Tuple[int, ...]],
        trunk_hidden_size: Union[int, Tuple[int, ...]],
        branch_skip_connection: bool = False,
        trunk_skip_connection: bool = False,
        branch_activation: str = "tanh",
        trunk_activation: str = "tanh",
        branch_weight_norm: bool = False,
        trunk_weight_norm: bool = False,
        use_bias: bool = True,
    ):
        super().__init__()
        self.trunk_input_keys = trunk_input_keys
        self.heat_input_keys = heat_input_keys
        self.cold_input_keys = cold_input_keys
        self.input_keys = (
            self.trunk_input_keys + self.heat_input_keys + self.cold_input_keys
        )
        self.output_keys = output_keys
        self.num_features = num_features

        self.heat_net = mlp.MLP(
            self.heat_input_keys,
            ("h",),
            branch_num_layers,
            branch_hidden_size,
            branch_activation,
            branch_skip_connection,
            branch_weight_norm,
            input_dim=heat_num_loc,
            output_dim=num_features * len(self.output_keys),
        )

        self.cold_net = mlp.MLP(
            self.cold_input_keys,
            ("c",),
            branch_num_layers,
            branch_hidden_size,
            branch_activation,
            branch_skip_connection,
            branch_weight_norm,
            input_dim=cold_num_loc,
            output_dim=num_features * len(self.output_keys),
        )

        self.trunk_net = mlp.MLP(
            self.trunk_input_keys,
            ("t",),
            trunk_num_layers,
            trunk_hidden_size,
            trunk_activation,
            trunk_skip_connection,
            trunk_weight_norm,
            input_dim=len(self.trunk_input_keys),
            output_dim=num_features * len(self.output_keys),
        )
        self.trunk_act = act_mod.get_activation(trunk_activation)
        self.heat_act = act_mod.get_activation(branch_activation)
        self.cold_act = act_mod.get_activation(branch_activation)

        self.use_bias = use_bias
        if use_bias:
            # register bias to parameter for updating in optimizer and storage
            self.b = self.create_parameter(
                shape=(len(self.output_keys),),
                attr=nn.initializer.Constant(0.0),
            )

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        # Branch net to encode the input function
        heat_features = self.heat_net(x)[self.heat_net.output_keys[0]]
        cold_features = self.cold_net(x)[self.cold_net.output_keys[0]]
        # Trunk net to encode the domain of the output function
        y_features = self.trunk_net(x)[self.trunk_net.output_keys[0]]
        y_features = self.trunk_act(y_features)
        # Dot product
        G_u_h = paddle.sum(
            heat_features[:, : self.num_features]
            * y_features[:, : self.num_features]
            * cold_features[:, : self.num_features],
            axis=1,
            keepdim=True,
        )
        G_u_c = paddle.sum(
            heat_features[:, self.num_features : 2 * self.num_features]
            * y_features[:, self.num_features : 2 * self.num_features]
            * cold_features[:, self.num_features : 2 * self.num_features],
            axis=1,
            keepdim=True,
        )
        G_u_w = paddle.sum(
            heat_features[:, 2 * self.num_features :]
            * y_features[:, 2 * self.num_features :]
            * cold_features[:, 2 * self.num_features :],
            axis=1,
            keepdim=True,
        )
        # Add bias
        if self.use_bias:
            G_u_h += self.b[0]
            G_u_c += self.b[1]
            G_u_w += self.b[2]

        result_dict = {
            self.output_keys[0]: G_u_h,
            self.output_keys[1]: G_u_c,
            self.output_keys[2]: G_u_w,
        }
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)

        return result_dict

DGMR

Bases: Arch

Deep Generative Model of Radar. Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954. but slightly modified for multiple satellite channels

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
forecast_steps int

Number of steps to predict in the future

18
input_channels int

Number of input channels per image

1
gen_lr float

Learning rate for the generator

5e-05
disc_lr float

Learning rate for the discriminators, shared for both temporal and spatial discriminator

0.0002
conv_type str

Type of 2d convolution to use, see satflow/models/utils.py for options

'standard'
beta1 float

Beta1 for Adam optimizer

0.0
beta2 float

Beta2 for Adam optimizer

0.999
num_samples int

Number of samples of the latent space to sample for training/validation

6
grid_lambda float

Lambda for the grid regularization loss

20.0
output_shape int

Shape of the output predictions, generally should be same as the input shape

256
generation_steps int

Number of generation steps to use in forward pass, in paper is 6 and the best is chosen for the loss this results in huge amounts of GPU memory though, so less might work better for training.

6
context_channels int

Number of output channels for the lowest block of conditioning stack

384
latent_channels int

Number of channels that the latent space should be reshaped to, input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs

768

Examples:

>>> import ppsci
>>> import paddle
>>> model = ppsci.arch.DGMR(("input", ), ("output", ))
>>> input_dict = {"input": paddle.randn((1, 4, 1, 256, 256))}
>>> output_dict = model(input_dict)
>>> print(output_dict["output"].shape)
[1, 18, 1, 256, 256]
Source code in ppsci/arch/dgmr.py
class DGMR(base.Arch):
    """Deep Generative Model of Radar.
        Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954.
        but slightly modified for multiple satellite channels

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        forecast_steps (int, optional): Number of steps to predict in the future
        input_channels (int, optional): Number of input channels per image
        gen_lr (float, optional): Learning rate for the generator
        disc_lr (float, optional): Learning rate for the discriminators, shared for both temporal and spatial discriminator
        conv_type (str, optional): Type of 2d convolution to use, see satflow/models/utils.py for options
        beta1 (float, optional): Beta1 for Adam optimizer
        beta2 (float, optional): Beta2 for Adam optimizer
        num_samples (int, optional): Number of samples of the latent space to sample for training/validation
        grid_lambda (float, optional): Lambda for the grid regularization loss
        output_shape (int, optional): Shape of the output predictions, generally should be same as the input shape
        generation_steps (int, optional): Number of generation steps to use in forward pass, in paper is 6 and the best is chosen for the loss
            this results in huge amounts of GPU memory though, so less might work better for training.
        context_channels (int, optional): Number of output channels for the lowest block of conditioning stack
        latent_channels (int, optional): Number of channels that the latent space should be reshaped to,
            input dimension into ConvGRU, also affects the number of channels for other linked inputs/outputs

    Examples:
        >>> import ppsci
        >>> import paddle
        >>> model = ppsci.arch.DGMR(("input", ), ("output", ))
        >>> input_dict = {"input": paddle.randn((1, 4, 1, 256, 256))}
        >>> output_dict = model(input_dict) # doctest: +SKIP
        >>> print(output_dict["output"].shape) # doctest: +SKIP
        [1, 18, 1, 256, 256]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        forecast_steps: int = 18,
        input_channels: int = 1,
        output_shape: int = 256,
        gen_lr: float = 5e-05,
        disc_lr: float = 0.0002,
        conv_type: str = "standard",
        num_samples: int = 6,
        grid_lambda: float = 20.0,
        beta1: float = 0.0,
        beta2: float = 0.999,
        latent_channels: int = 768,
        context_channels: int = 384,
        generation_steps: int = 6,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.gen_lr = gen_lr
        self.disc_lr = disc_lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.grid_lambda = grid_lambda
        self.num_samples = num_samples
        self.latent_channels = latent_channels
        self.context_channels = context_channels
        self.input_channels = input_channels
        self.generation_steps = generation_steps
        self.conditioning_stack = ContextConditioningStack(
            input_channels=input_channels,
            conv_type=conv_type,
            output_channels=self.context_channels,
        )
        self.latent_stack = LatentConditioningStack(
            shape=(8 * self.input_channels, output_shape // 32, output_shape // 32),
            output_channels=self.latent_channels,
        )
        self.sampler = Sampler(
            forecast_steps=forecast_steps,
            latent_channels=self.latent_channels,
            context_channels=self.context_channels,
        )
        self.generator = Generator(
            self.conditioning_stack, self.latent_stack, self.sampler
        )
        self.discriminator = Discriminator(input_channels)
        self.global_iteration = 0
        self.automatic_optimization = False

    def split_to_dict(
        self, data_tensors: Tuple[paddle.Tensor, ...], keys: Tuple[str, ...]
    ):
        return {key: data_tensors[i] for i, key in enumerate(keys)}

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)
        x_tensor = self.concat_to_tensor(x, self.input_keys)
        y = [self.generator(x_tensor)]
        y = self.split_to_dict(y, self.output_keys)

        if self._output_transform is not None:
            y = self._output_transform(x, y)
        return y

ChipDeepONets

Bases: Arch

Multi-branch physics-informed deep operator neural network. The network consists of three branch networks: random heat source, boundary function, and boundary type, as well as a trunk network.

Parameters:

Name Type Description Default
branch_input_keys Tuple[str, ...]

Name of input data for internal heat source on branch nets.

required
BCtype_input_keys Tuple[str, ...]

Name of input data for boundary types on branch nets.

required
BC_input_keys Tuple[str, ...]

Name of input data for boundary on branch nets.

required
trunk_input_keys Tuple[str, ...]

Name of input data for trunk net.

required
output_keys Tuple[str, ...]

Output name of predicted temperature.

required
num_loc int

Number of sampled input data for internal heat source.

required
bctype_loc int

Number of sampled input data for boundary types.

required
BC_num_loc int

Number of sampled input data for boundary.

required
num_features int

Number of features extracted from trunk net, same for all branch nets.

required
branch_num_layers int

Number of hidden layers of internal heat source on branch nets.

required
BC_num_layers int

Number of hidden layers of boundary on branch nets.

required
trunk_num_layers int

Number of hidden layers of trunk net.

required
branch_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of internal heat source on branch nets. An integer for all layers, or list of integer specify each layer's size.

required
BC_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of boundary on branch nets. An integer for all layers, or list of integer specify each layer's size.

required
trunk_hidden_size Union[int, Tuple[int, ...]]

Number of hidden size of trunk net. An integer for all layers, or list of integer specify each layer's size.

required
branch_skip_connection bool

Whether to use skip connection for internal heat source on branch net. Defaults to False.

False
BC_skip_connection bool

Whether to use skip connection for boundary on branch net. Defaults to False.

False
trunk_skip_connection bool

Whether to use skip connection for trunk net. Defaults to False.

False
branch_activation str

Name of activation function for internal heat source on branch net. Defaults to "tanh".

'tanh'
BC_activation str

Name of activation function for boundary on branch net. Defaults to "tanh".

'tanh'
trunk_activation str

Name of activation function for trunk net. Defaults to "tanh".

'tanh'
branch_weight_norm bool

Whether to apply weight norm on parameter(s) for internal heat source on branch net. Defaults to False.

False
BC_weight_norm bool

Whether to apply weight norm on parameter(s) for boundary on branch net. Defaults to False.

False
trunk_weight_norm bool

Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.

False
use_bias bool

Whether to add bias on predicted G(u)(y). Defaults to True.

True

Examples:

>>> import ppsci
>>> model = ppsci.arch.ChipDeepONets(
...     ('u',),
...     ('bc',),
...     ('bc_data',),
...     ("x",'y'),
...     ("T",),
...     324,
...     1,
...     76,
...     400,
...     9,
...     9,
...     6,
...     256,
...     256,
...     128,
...     branch_activation="swish",
...     BC_activation="swish",
...     trunk_activation="swish",
...     use_bias=True,
... )
Source code in ppsci/arch/chip_deeponets.py
class ChipDeepONets(base.Arch):
    """Multi-branch physics-informed deep operator neural network. The network consists of three branch networks: random heat source, boundary function, and boundary type, as well as a trunk network.

    Args:
        branch_input_keys (Tuple[str, ...]): Name of input data for internal heat source on branch nets.
        BCtype_input_keys (Tuple[str, ...]): Name of input data for boundary types on branch nets.
        BC_input_keys (Tuple[str, ...]): Name of input data for boundary on branch nets.
        trunk_input_keys (Tuple[str, ...]): Name of input data for trunk net.
        output_keys (Tuple[str, ...]): Output name of predicted temperature.
        num_loc (int): Number of sampled input data for internal heat source.
        bctype_loc (int): Number of sampled input data for boundary types.
        BC_num_loc (int): Number of sampled input data for boundary.
        num_features (int): Number of features extracted from trunk net, same for all branch nets.
        branch_num_layers (int): Number of hidden layers of internal heat source on branch nets.
        BC_num_layers (int): Number of hidden layers of boundary on branch nets.
        trunk_num_layers (int): Number of hidden layers of trunk net.
        branch_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of internal heat source on branch nets.
            An integer for all layers, or list of integer specify each layer's size.
        BC_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of boundary on branch nets.
            An integer for all layers, or list of integer specify each layer's size.
        trunk_hidden_size (Union[int, Tuple[int, ...]]): Number of hidden size of trunk net.
            An integer for all layers, or list of integer specify each layer's size.
        branch_skip_connection (bool, optional): Whether to use skip connection for internal heat source on branch net. Defaults to False.
        BC_skip_connection (bool, optional): Whether to use skip connection for boundary on branch net. Defaults to False.
        trunk_skip_connection (bool, optional): Whether to use skip connection for trunk net. Defaults to False.
        branch_activation (str, optional): Name of activation function for internal heat source on branch net. Defaults to "tanh".
        BC_activation (str, optional): Name of activation function for boundary on branch net. Defaults to "tanh".
        trunk_activation (str, optional): Name of activation function for trunk net. Defaults to "tanh".
        branch_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for internal heat source on branch net. Defaults to False.
        BC_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for boundary on branch net. Defaults to False.
        trunk_weight_norm (bool, optional): Whether to apply weight norm on parameter(s) for trunk net. Defaults to False.
        use_bias (bool, optional): Whether to add bias on predicted G(u)(y). Defaults to True.

    Examples:
        >>> import ppsci
        >>> model = ppsci.arch.ChipDeepONets(
        ...     ('u',),
        ...     ('bc',),
        ...     ('bc_data',),
        ...     ("x",'y'),
        ...     ("T",),
        ...     324,
        ...     1,
        ...     76,
        ...     400,
        ...     9,
        ...     9,
        ...     6,
        ...     256,
        ...     256,
        ...     128,
        ...     branch_activation="swish",
        ...     BC_activation="swish",
        ...     trunk_activation="swish",
        ...     use_bias=True,
        ... )
    """

    def __init__(
        self,
        branch_input_keys: Tuple[str, ...],
        BCtype_input_keys: Tuple[str, ...],
        BC_input_keys: Tuple[str, ...],
        trunk_input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        num_loc: int,
        bctype_loc: int,
        BC_num_loc: int,
        num_features: int,
        branch_num_layers: int,
        BC_num_layers: int,
        trunk_num_layers: int,
        branch_hidden_size: Union[int, Tuple[int, ...]],
        BC_hidden_size: Union[int, Tuple[int, ...]],
        trunk_hidden_size: Union[int, Tuple[int, ...]],
        branch_skip_connection: bool = False,
        BC_skip_connection: bool = False,
        trunk_skip_connection: bool = False,
        branch_activation: str = "tanh",
        BC_activation: str = "tanh",
        trunk_activation: str = "tanh",
        branch_weight_norm: bool = False,
        BC_weight_norm: bool = False,
        trunk_weight_norm: bool = False,
        use_bias: bool = True,
    ):
        super().__init__()
        self.trunk_input_keys = trunk_input_keys
        self.branch_input_keys = branch_input_keys
        self.BCtype_input_keys = BCtype_input_keys
        self.BC_input_keys = BC_input_keys
        self.input_keys = (
            self.trunk_input_keys
            + self.branch_input_keys
            + self.BC_input_keys
            + self.BCtype_input_keys
        )
        self.output_keys = output_keys

        self.branch_net = mlp.MLP(
            self.branch_input_keys,
            ("b",),
            branch_num_layers,
            branch_hidden_size,
            branch_activation,
            branch_skip_connection,
            branch_weight_norm,
            input_dim=num_loc,
            output_dim=num_features,
        )

        self.BCtype_net = mlp.MLP(
            self.BCtype_input_keys,
            ("bctype",),
            BC_num_layers,
            BC_hidden_size,
            BC_activation,
            BC_skip_connection,
            BC_weight_norm,
            input_dim=bctype_loc,
            output_dim=num_features,
        )

        self.BC_net = mlp.MLP(
            self.BC_input_keys,
            ("bc",),
            BC_num_layers,
            BC_hidden_size,
            BC_activation,
            BC_skip_connection,
            BC_weight_norm,
            input_dim=BC_num_loc,
            output_dim=num_features,
        )

        self.trunk_net = mlp.MLP(
            self.trunk_input_keys,
            ("t",),
            trunk_num_layers,
            trunk_hidden_size,
            trunk_activation,
            trunk_skip_connection,
            trunk_weight_norm,
            input_dim=len(self.trunk_input_keys),
            output_dim=num_features,
        )
        self.trunk_act = act_mod.get_activation(trunk_activation)
        self.bc_act = act_mod.get_activation(BC_activation)
        self.branch_act = act_mod.get_activation(branch_activation)

        self.use_bias = use_bias
        if use_bias:
            # register bias to parameter for updating in optimizer and storage
            self.b = self.create_parameter(
                shape=(1,),
                attr=nn.initializer.Constant(0.0),
            )

    def forward(self, x):

        if self._input_transform is not None:
            x = self._input_transform(x)

        # Branch net to encode the input function
        u_features = self.branch_net(x)[self.branch_net.output_keys[0]]
        bc_features = self.BC_net(x)[self.BC_net.output_keys[0]]
        bctype_features = self.BCtype_net(x)[self.BCtype_net.output_keys[0]]
        # Trunk net to encode the domain of the output function
        y_features = self.trunk_net(x)[self.trunk_net.output_keys[0]]
        y_features = self.trunk_act(y_features)
        # Dot product
        G_u = paddle.sum(
            u_features * y_features * bc_features * bctype_features,
            axis=1,
            keepdim=True,
        )
        # Add bias
        if self.use_bias:
            G_u += self.b

        result_dict = {
            self.output_keys[0]: G_u,
        }
        if self._output_transform is not None:
            result_dict = self._output_transform(x, result_dict)

        return result_dict

AutoEncoder

Bases: Arch

AutoEncoder is a class that represents an autoencoder neural network model.

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

A tuple of input keys.

required
output_keys Tuple[str, ...]

A tuple of output keys.

required
input_dim int

The dimension of the input data.

required
latent_dim int

The dimension of the latent space.

required
hidden_dim int

The dimension of the hidden layer.

required

Examples:

>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.AutoEncoder(
...    input_keys=("input1",),
...    output_keys=("mu", "log_sigma", "decoder_z",),
...    input_dim=100,
...    latent_dim=50,
...    hidden_dim=200
... )
>>> input_dict = {"input1": paddle.rand([200, 100]),}
>>> output_dict = model(input_dict)
>>> print(output_dict["mu"].shape)
[200, 50]
>>> print(output_dict["log_sigma"].shape)
[200, 50]
>>> print(output_dict["decoder_z"].shape)
[200, 100]
Source code in ppsci/arch/vae.py
class AutoEncoder(base.Arch):
    """
    AutoEncoder is a class that represents an autoencoder neural network model.

    Args:
        input_keys (Tuple[str, ...]): A tuple of input keys.
        output_keys (Tuple[str, ...]): A tuple of output keys.
        input_dim (int): The dimension of the input data.
        latent_dim (int): The dimension of the latent space.
        hidden_dim (int): The dimension of the hidden layer.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.AutoEncoder(
        ...    input_keys=("input1",),
        ...    output_keys=("mu", "log_sigma", "decoder_z",),
        ...    input_dim=100,
        ...    latent_dim=50,
        ...    hidden_dim=200
        ... )
        >>> input_dict = {"input1": paddle.rand([200, 100]),}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["mu"].shape)
        [200, 50]
        >>> print(output_dict["log_sigma"].shape)
        [200, 50]
        >>> print(output_dict["decoder_z"].shape)
        [200, 100]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_dim: int,
        latent_dim: int,
        hidden_dim: int,
    ):
        super(AutoEncoder, self).__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        # encoder
        self._encoder_linear = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
        )
        self._encoder_mu = nn.Linear(hidden_dim, latent_dim)
        self._encoder_log_sigma = nn.Linear(hidden_dim, latent_dim)

        self._decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, input_dim),
        )

    def encoder(self, x):
        h = self._encoder_linear(x)
        mu = self._encoder_mu(h)
        log_sigma = self._encoder_log_sigma(h)
        return mu, log_sigma

    def decoder(self, x):
        return self._decoder(x)

    def forward_tensor(self, x):
        mu, log_sigma = self.encoder(x)
        z = mu + paddle.randn(mu.shape) * paddle.exp(log_sigma)
        return mu, log_sigma, self.decoder(z)

    def forward(self, x):
        x = self.concat_to_tensor(x, self.input_keys, axis=-1)
        mu, log_sigma, decoder_z = self.forward_tensor(x)
        result_dict = {
            self.output_keys[0]: mu,
            self.output_keys[1]: log_sigma,
            self.output_keys[2]: decoder_z,
        }
        return result_dict

CuboidTransformer

Bases: Arch

Cuboid Transformer for spatiotemporal forecasting

We adopt the Non-autoregressive encoder-decoder architecture. The decoder takes the multi-scale memory output from the encoder.

The initial downsampling / upsampling layers will be Downsampling: [K x Conv2D --> PatchMerge] Upsampling: [Nearest Interpolation-based Upsample --> K x Conv2D]

x --> downsample (optional) ---> (+pos_embed) ---> enc --> mem_l initial_z (+pos_embed) ---> FC | | |------------| | | y <--- upsample (optional) <--- dec <----------

Parameters:

Name Type Description Default
input_keys Tuple[str, ...]

Name of input keys, such as ("input",).

required
output_keys Tuple[str, ...]

Name of output keys, such as ("output",).

required
input_shape Tuple[int, ...]

The shape of the input data.

required
target_shape Tuple[int, ...]

The shape of the target data.

required
base_units int

The base units. Defaults to 128.

128
block_units int

The block units. Defaults to None.

None
scale_alpha float

We scale up the channels based on the formula: - round_to(base_units * max(downsample_scale) ** units_alpha, 4). Defaults to 1.0.

1.0
num_heads int

The number of heads. Defaults to 4.

4
attn_drop float

The attention dropout. Defaults to 0.0.

0.0
proj_drop float

The projection dropout. Defaults to 0.0.

0.0
ffn_drop float

The ffn dropout. Defaults to 0.0.

0.0
downsample int

The rate of downsample. Defaults to 2.

2
downsample_type str

The type of downsample. Defaults to "patch_merge".

'patch_merge'
upsample_type str

The rate of upsample. Defaults to "upsample".

'upsample'
upsample_kernel_size int

The kernel size of upsample. Defaults to 3.

3
enc_depth list

The depth of encoder. Defaults to [4, 4, 4].

[4, 4, 4]
enc_attn_patterns str

The pattern of encoder attention. Defaults to None.

None
enc_cuboid_size list

The cuboid size of encoder. Defaults to [(4, 4, 4), (4, 4, 4)].

[(4, 4, 4), (4, 4, 4)]
enc_cuboid_strategy list

The cuboid strategy of encoder. Defaults to [("l", "l", "l"), ("d", "d", "d")].

[('l', 'l', 'l'), ('d', 'd', 'd')]
enc_shift_size list

The shift size of encoder. Defaults to [(0, 0, 0), (0, 0, 0)].

[(0, 0, 0), (0, 0, 0)]
enc_use_inter_ffn bool

Whether to use intermediate FFN for encoder. Defaults to True.

True
dec_depth list

The depth of decoder. Defaults to [2, 2].

[2, 2]
dec_cross_start int

The cross start of decoder. Defaults to 0.

0
dec_self_attn_patterns str

The partterns of decoder. Defaults to None.

None
dec_self_cuboid_size list

The cuboid size of decoder. Defaults to [(4, 4, 4), (4, 4, 4)].

[(4, 4, 4), (4, 4, 4)]
dec_self_cuboid_strategy list

The strategy of decoder. Defaults to [("l", "l", "l"), ("d", "d", "d")].

[('l', 'l', 'l'), ('d', 'd', 'd')]
dec_self_shift_size list

The shift size of decoder. Defaults to [(1, 1, 1), (0, 0, 0)].

[(1, 1, 1), (0, 0, 0)]
dec_cross_attn_patterns _type_

The cross attention patterns of decoder. Defaults to None.

None
dec_cross_cuboid_hw list

The cuboid_hw of decoder. Defaults to [(4, 4), (4, 4)].

[(4, 4), (4, 4)]
dec_cross_cuboid_strategy list

The cuboid strategy of decoder. Defaults to [("l", "l", "l"), ("d", "l", "l")].

[('l', 'l', 'l'), ('d', 'l', 'l')]
dec_cross_shift_hw list

The shift_hw of decoder. Defaults to [(0, 0), (0, 0)].

[(0, 0), (0, 0)]
dec_cross_n_temporal list

The cross_n_temporal of decoder. Defaults to [1, 2].

[1, 2]
dec_cross_last_n_frames int

The cross_last_n_frames of decoder. Defaults to None.

None
dec_use_inter_ffn bool

Whether to use intermediate FFN for decoder. Defaults to True.

True
dec_hierarchical_pos_embed bool

Whether to use hierarchical pos_embed for decoder. Defaults to False.

False
num_global_vectors int

The num of global vectors. Defaults to 4.

4
use_dec_self_global bool

Whether to use global vector for decoder. Defaults to True.

True
dec_self_update_global bool

Whether to update global vector for decoder. Defaults to True.

True
use_dec_cross_global bool

Whether to use cross global vector for decoder. Defaults to True.

True
use_global_vector_ffn bool

Whether to use global vector FFN. Defaults to True.

True
use_global_self_attn bool

Whether to use global attentions. Defaults to False.

False
separate_global_qkv bool

Whether to separate global qkv. Defaults to False.

False
global_dim_ratio int

The ratio of global dim. Defaults to 1.

1
self_pattern str

The pattern. Defaults to "axial".

'axial'
cross_self_pattern str

The self cross pattern. Defaults to "axial".

'axial'
cross_pattern str

The cross pattern. Defaults to "cross_1x1".

'cross_1x1'
z_init_method str

How the initial input to the decoder is initialized. Defaults to "nearest_interp".

'nearest_interp'
initial_downsample_type str

The downsample type of initial. Defaults to "conv".

'conv'
initial_downsample_activation str

The downsample activation of initial. Defaults to "leaky".

'leaky'
initial_downsample_scale int

The downsample scale of initial. Defaults to 1.

1
initial_downsample_conv_layers int

The conv layer of downsample of initial. Defaults to 2.

2
final_upsample_conv_layers int

The conv layer of final upsample. Defaults to 2.

2
initial_downsample_stack_conv_num_layers int

The num of stack conv layer of initial downsample. Defaults to 1.

1
initial_downsample_stack_conv_dim_list list

The dim list of stack conv of initial downsample. Defaults to None.

None
initial_downsample_stack_conv_downscale_list list

The downscale list of stack conv of initial downsample. Defaults to [1].

[1]
initial_downsample_stack_conv_num_conv_list list

The num of stack conv list of initial downsample. Defaults to [2].

[2]
ffn_activation str

The activation of FFN. Defaults to "leaky".

'leaky'
gated_ffn bool

Whether to use gate FFN. Defaults to False.

False
norm_layer str

The type of normilize. Defaults to "layer_norm".

'layer_norm'
padding_type str

The type of padding. Defaults to "ignore".

'ignore'
pos_embed_type str

The type of pos embeding. Defaults to "t+hw".

't+hw'
checkpoint_level bool

Whether to use checkpoint. Defaults to True.

True
use_relative_pos bool

Whether to use relative pose. Defaults to True.

True
self_attn_use_final_proj bool

Whether to use final projection. Defaults to True.

True
dec_use_first_self_attn bool

Whether to use first self attention for decoder. Defaults to False.

False
attn_linear_init_mode str

The mode of attention linear init. Defaults to "0".

'0'
ffn_linear_init_mode str

The mode of FFN linear init. Defaults to "0".

'0'
conv_init_mode str

The mode of conv init. Defaults to "0".

'0'
down_up_linear_init_mode str

The mode of downsample and upsample linear init. Defaults to "0".

'0'
norm_init_mode str

The mode of normalization init. Defaults to "0".

'0'
Source code in ppsci/arch/cuboid_transformer.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
class CuboidTransformer(base.Arch):
    """Cuboid Transformer for spatiotemporal forecasting

    We adopt the Non-autoregressive encoder-decoder architecture.
    The decoder takes the multi-scale memory output from the encoder.

    The initial downsampling / upsampling layers will be
    Downsampling: [K x Conv2D --> PatchMerge]
    Upsampling: [Nearest Interpolation-based Upsample --> K x Conv2D]

    x --> downsample (optional) ---> (+pos_embed) ---> enc --> mem_l         initial_z (+pos_embed) ---> FC
                                                     |            |
                                                     |------------|
                                                           |
                                                           |
             y <--- upsample (optional) <--- dec <----------

    Args:
        input_keys (Tuple[str, ...]): Name of input keys, such as ("input",).
        output_keys (Tuple[str, ...]): Name of output keys, such as ("output",).
        input_shape (Tuple[int, ...]): The shape of the input data.
        target_shape (Tuple[int, ...]): The shape of the target data.
        base_units (int, optional): The base units. Defaults to 128.
        block_units (int, optional): The block units. Defaults to None.
        scale_alpha (float, optional): We scale up the channels based on the formula:
            - round_to(base_units * max(downsample_scale) ** units_alpha, 4). Defaults to 1.0.
        num_heads (int, optional): The number of heads. Defaults to 4.
        attn_drop (float, optional): The attention dropout. Defaults to 0.0.
        proj_drop (float, optional): The projection dropout. Defaults to 0.0.
        ffn_drop (float, optional): The ffn dropout. Defaults to 0.0.
        downsample (int, optional): The rate of downsample. Defaults to 2.
        downsample_type (str, optional): The type of downsample. Defaults to "patch_merge".
        upsample_type (str, optional): The rate of upsample. Defaults to "upsample".
        upsample_kernel_size (int, optional): The kernel size of upsample. Defaults to 3.
        enc_depth (list, optional): The depth of encoder. Defaults to [4, 4, 4].
        enc_attn_patterns (str, optional): The pattern of encoder attention. Defaults to None.
        enc_cuboid_size (list, optional): The cuboid size of encoder. Defaults to [(4, 4, 4), (4, 4, 4)].
        enc_cuboid_strategy (list, optional): The cuboid strategy of encoder. Defaults to [("l", "l", "l"), ("d", "d", "d")].
        enc_shift_size (list, optional): The shift size of encoder. Defaults to [(0, 0, 0), (0, 0, 0)].
        enc_use_inter_ffn (bool, optional): Whether to use intermediate FFN for encoder. Defaults to True.
        dec_depth (list, optional): The depth of decoder. Defaults to [2, 2].
        dec_cross_start (int, optional): The cross start of decoder. Defaults to 0.
        dec_self_attn_patterns (str, optional): The partterns of decoder. Defaults to None.
        dec_self_cuboid_size (list, optional): The cuboid size of decoder. Defaults to [(4, 4, 4), (4, 4, 4)].
        dec_self_cuboid_strategy (list, optional): The strategy of decoder. Defaults to [("l", "l", "l"), ("d", "d", "d")].
        dec_self_shift_size (list, optional): The shift size of decoder. Defaults to [(1, 1, 1), (0, 0, 0)].
        dec_cross_attn_patterns (_type_, optional): The cross attention patterns of decoder. Defaults to None.
        dec_cross_cuboid_hw (list, optional): The cuboid_hw of decoder. Defaults to [(4, 4), (4, 4)].
        dec_cross_cuboid_strategy (list, optional): The cuboid strategy of decoder. Defaults to [("l", "l", "l"), ("d", "l", "l")].
        dec_cross_shift_hw (list, optional): The shift_hw of decoder. Defaults to [(0, 0), (0, 0)].
        dec_cross_n_temporal (list, optional): The cross_n_temporal of decoder. Defaults to [1, 2].
        dec_cross_last_n_frames (int, optional): The cross_last_n_frames of decoder. Defaults to None.
        dec_use_inter_ffn (bool, optional): Whether to use intermediate FFN for decoder. Defaults to True.
        dec_hierarchical_pos_embed (bool, optional): Whether to use hierarchical pos_embed for decoder. Defaults to False.
        num_global_vectors (int, optional): The num of global vectors. Defaults to 4.
        use_dec_self_global (bool, optional): Whether to use global vector for decoder. Defaults to True.
        dec_self_update_global (bool, optional): Whether to update global vector for decoder. Defaults to True.
        use_dec_cross_global (bool, optional): Whether to use cross global vector for decoder. Defaults to True.
        use_global_vector_ffn (bool, optional): Whether to use global vector FFN. Defaults to True.
        use_global_self_attn (bool, optional): Whether to use global attentions. Defaults to False.
        separate_global_qkv (bool, optional): Whether to separate global qkv. Defaults to False.
        global_dim_ratio (int, optional): The ratio of global dim. Defaults to 1.
        self_pattern (str, optional): The pattern. Defaults to "axial".
        cross_self_pattern (str, optional): The self cross pattern. Defaults to "axial".
        cross_pattern (str, optional): The cross pattern. Defaults to "cross_1x1".
        z_init_method (str, optional): How the initial input to the decoder is initialized. Defaults to "nearest_interp".
        initial_downsample_type (str, optional): The downsample type of initial. Defaults to "conv".
        initial_downsample_activation (str, optional): The downsample activation of initial. Defaults to "leaky".
        initial_downsample_scale (int, optional): The downsample scale of initial. Defaults to 1.
        initial_downsample_conv_layers (int, optional): The conv layer of downsample of initial. Defaults to 2.
        final_upsample_conv_layers (int, optional): The conv layer of final upsample. Defaults to 2.
        initial_downsample_stack_conv_num_layers (int, optional): The num of stack conv layer of initial downsample. Defaults to 1.
        initial_downsample_stack_conv_dim_list (list, optional): The dim list of stack conv of initial downsample. Defaults to None.
        initial_downsample_stack_conv_downscale_list (list, optional): The downscale list of stack conv of initial downsample. Defaults to [1].
        initial_downsample_stack_conv_num_conv_list (list, optional): The num of stack conv list of initial downsample. Defaults to [2].
        ffn_activation (str, optional): The activation of FFN. Defaults to "leaky".
        gated_ffn (bool, optional): Whether to use gate FFN. Defaults to False.
        norm_layer (str, optional): The type of normilize. Defaults to "layer_norm".
        padding_type (str, optional): The type of padding. Defaults to "ignore".
        pos_embed_type (str, optional): The type of pos embeding. Defaults to "t+hw".
        checkpoint_level (bool, optional): Whether to use checkpoint. Defaults to True.
        use_relative_pos (bool, optional): Whether to use relative pose. Defaults to True.
        self_attn_use_final_proj (bool, optional): Whether to use final projection. Defaults to True.
        dec_use_first_self_attn (bool, optional): Whether to use first self attention for decoder. Defaults to False.
        attn_linear_init_mode (str, optional): The mode of attention linear init. Defaults to "0".
        ffn_linear_init_mode (str, optional): The mode of FFN linear init. Defaults to "0".
        conv_init_mode (str, optional): The mode of conv init. Defaults to "0".
        down_up_linear_init_mode (str, optional): The mode of downsample and upsample linear init. Defaults to "0".
        norm_init_mode (str, optional): The mode of normalization init. Defaults to "0".
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        input_shape: Tuple[int, ...],
        target_shape: Tuple[int, ...],
        base_units: int = 128,
        block_units: int = None,
        scale_alpha: float = 1.0,
        num_heads: int = 4,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        ffn_drop: float = 0.0,
        downsample: int = 2,
        downsample_type: str = "patch_merge",
        upsample_type: str = "upsample",
        upsample_kernel_size: int = 3,
        enc_depth: Tuple[int, ...] = [4, 4, 4],
        enc_attn_patterns: str = None,
        enc_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)],
        enc_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [
            ("l", "l", "l"),
            ("d", "d", "d"),
        ],
        enc_shift_size: Tuple[Tuple[int, ...], ...] = [(0, 0, 0), (0, 0, 0)],
        enc_use_inter_ffn: str = True,
        dec_depth: Tuple[int, ...] = [2, 2],
        dec_cross_start: int = 0,
        dec_self_attn_patterns: str = None,
        dec_self_cuboid_size: Tuple[Tuple[int, ...], ...] = [(4, 4, 4), (4, 4, 4)],
        dec_self_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [
            ("l", "l", "l"),
            ("d", "d", "d"),
        ],
        dec_self_shift_size: Tuple[Tuple[int, ...], ...] = [(1, 1, 1), (0, 0, 0)],
        dec_cross_attn_patterns: str = None,
        dec_cross_cuboid_hw: Tuple[Tuple[int, ...], ...] = [(4, 4), (4, 4)],
        dec_cross_cuboid_strategy: Tuple[Tuple[str, ...], ...] = [
            ("l", "l", "l"),
            ("d", "l", "l"),
        ],
        dec_cross_shift_hw: Tuple[Tuple[int, ...], ...] = [(0, 0), (0, 0)],
        dec_cross_n_temporal: Tuple[int, ...] = [1, 2],
        dec_cross_last_n_frames: int = None,
        dec_use_inter_ffn: bool = True,
        dec_hierarchical_pos_embed: bool = False,
        num_global_vectors: int = 4,
        use_dec_self_global: bool = True,
        dec_self_update_global: bool = True,
        use_dec_cross_global: bool = True,
        use_global_vector_ffn: bool = True,
        use_global_self_attn: bool = False,
        separate_global_qkv: bool = False,
        global_dim_ratio: int = 1,
        self_pattern: str = "axial",
        cross_self_pattern: str = "axial",
        cross_pattern: str = "cross_1x1",
        z_init_method: str = "nearest_interp",
        initial_downsample_type: str = "conv",
        initial_downsample_activation: str = "leaky",
        initial_downsample_scale: int = 1,
        initial_downsample_conv_layers: int = 2,
        final_upsample_conv_layers: int = 2,
        initial_downsample_stack_conv_num_layers: int = 1,
        initial_downsample_stack_conv_dim_list: Tuple[int, ...] = None,
        initial_downsample_stack_conv_downscale_list: Tuple[int, ...] = [1],
        initial_downsample_stack_conv_num_conv_list: Tuple[int, ...] = [2],
        ffn_activation: str = "leaky",
        gated_ffn: bool = False,
        norm_layer: str = "layer_norm",
        padding_type: str = "ignore",
        pos_embed_type: str = "t+hw",
        checkpoint_level: bool = True,
        use_relative_pos: bool = True,
        self_attn_use_final_proj: bool = True,
        dec_use_first_self_attn: bool = False,
        attn_linear_init_mode: str = "0",
        ffn_linear_init_mode: str = "0",
        conv_init_mode: str = "0",
        down_up_linear_init_mode: str = "0",
        norm_init_mode: str = "0",
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.attn_linear_init_mode = attn_linear_init_mode
        self.ffn_linear_init_mode = ffn_linear_init_mode
        self.conv_init_mode = conv_init_mode
        self.down_up_linear_init_mode = down_up_linear_init_mode
        self.norm_init_mode = norm_init_mode
        assert len(enc_depth) == len(dec_depth)
        self.base_units = base_units
        self.num_global_vectors = num_global_vectors

        num_blocks = len(enc_depth)
        if isinstance(self_pattern, str):
            enc_attn_patterns = [self_pattern] * num_blocks

        if isinstance(cross_self_pattern, str):
            dec_self_attn_patterns = [cross_self_pattern] * num_blocks

        if isinstance(cross_pattern, str):
            dec_cross_attn_patterns = [cross_pattern] * num_blocks

        if global_dim_ratio != 1:
            assert (
                separate_global_qkv is True
            ), "Setting global_dim_ratio != 1 requires separate_global_qkv == True."
        self.global_dim_ratio = global_dim_ratio
        self.z_init_method = z_init_method
        assert self.z_init_method in ["zeros", "nearest_interp", "last", "mean"]
        self.input_shape = input_shape
        self.target_shape = target_shape
        T_in, H_in, W_in, C_in = input_shape
        T_out, H_out, W_out, C_out = target_shape
        assert H_in == H_out and W_in == W_out
        if self.num_global_vectors > 0:
            init_data = paddle.zeros(
                (self.num_global_vectors, global_dim_ratio * base_units)
            )
            self.init_global_vectors = paddle.create_parameter(
                shape=init_data.shape,
                dtype=init_data.dtype,
                default_initializer=nn.initializer.Constant(0.0),
            )

            self.init_global_vectors.stop_gradient = not True
        new_input_shape = self.get_initial_encoder_final_decoder(
            initial_downsample_scale=initial_downsample_scale,
            initial_downsample_type=initial_downsample_type,
            activation=initial_downsample_activation,
            initial_downsample_conv_layers=initial_downsample_conv_layers,
            final_upsample_conv_layers=final_upsample_conv_layers,
            padding_type=padding_type,
            initial_downsample_stack_conv_num_layers=initial_downsample_stack_conv_num_layers,
            initial_downsample_stack_conv_dim_list=initial_downsample_stack_conv_dim_list,
            initial_downsample_stack_conv_downscale_list=initial_downsample_stack_conv_downscale_list,
            initial_downsample_stack_conv_num_conv_list=initial_downsample_stack_conv_num_conv_list,
        )
        T_in, H_in, W_in, _ = new_input_shape
        self.encoder = cuboid_encoder.CuboidTransformerEncoder(
            input_shape=(T_in, H_in, W_in, base_units),
            base_units=base_units,
            block_units=block_units,
            scale_alpha=scale_alpha,
            depth=enc_depth,
            downsample=downsample,
            downsample_type=downsample_type,
            block_attn_patterns=enc_attn_patterns,
            block_cuboid_size=enc_cuboid_size,
            block_strategy=enc_cuboid_strategy,
            block_shift_size=enc_shift_size,
            num_heads=num_heads,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            ffn_drop=ffn_drop,
            gated_ffn=gated_ffn,
            ffn_activation=ffn_activation,
            norm_layer=norm_layer,
            use_inter_ffn=enc_use_inter_ffn,
            padding_type=padding_type,
            use_global_vector=num_global_vectors > 0,
            use_global_vector_ffn=use_global_vector_ffn,
            use_global_self_attn=use_global_self_attn,
            separate_global_qkv=separate_global_qkv,
            global_dim_ratio=global_dim_ratio,
            checkpoint_level=checkpoint_level,
            use_relative_pos=use_relative_pos,
            self_attn_use_final_proj=self_attn_use_final_proj,
            attn_linear_init_mode=attn_linear_init_mode,
            ffn_linear_init_mode=ffn_linear_init_mode,
            conv_init_mode=conv_init_mode,
            down_linear_init_mode=down_up_linear_init_mode,
            norm_init_mode=norm_init_mode,
        )
        self.enc_pos_embed = cuboid_decoder.PosEmbed(
            embed_dim=base_units, typ=pos_embed_type, maxH=H_in, maxW=W_in, maxT=T_in
        )
        mem_shapes = self.encoder.get_mem_shapes()
        self.z_proj = paddle.nn.Linear(
            in_features=mem_shapes[-1][-1], out_features=mem_shapes[-1][-1]
        )
        self.dec_pos_embed = cuboid_decoder.PosEmbed(
            embed_dim=mem_shapes[-1][-1],
            typ=pos_embed_type,
            maxT=T_out,
            maxH=mem_shapes[-1][1],
            maxW=mem_shapes[-1][2],
        )
        self.decoder = cuboid_decoder.CuboidTransformerDecoder(
            target_temporal_length=T_out,
            mem_shapes=mem_shapes,
            cross_start=dec_cross_start,
            depth=dec_depth,
            upsample_type=upsample_type,
            block_self_attn_patterns=dec_self_attn_patterns,
            block_self_cuboid_size=dec_self_cuboid_size,
            block_self_shift_size=dec_self_shift_size,
            block_self_cuboid_strategy=dec_self_cuboid_strategy,
            block_cross_attn_patterns=dec_cross_attn_patterns,
            block_cross_cuboid_hw=dec_cross_cuboid_hw,
            block_cross_shift_hw=dec_cross_shift_hw,
            block_cross_cuboid_strategy=dec_cross_cuboid_strategy,
            block_cross_n_temporal=dec_cross_n_temporal,
            cross_last_n_frames=dec_cross_last_n_frames,
            num_heads=num_heads,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            ffn_drop=ffn_drop,
            upsample_kernel_size=upsample_kernel_size,
            ffn_activation=ffn_activation,
            gated_ffn=gated_ffn,
            norm_layer=norm_layer,
            use_inter_ffn=dec_use_inter_ffn,
            max_temporal_relative=T_in + T_out,
            padding_type=padding_type,
            hierarchical_pos_embed=dec_hierarchical_pos_embed,
            pos_embed_type=pos_embed_type,
            use_self_global=num_global_vectors > 0 and use_dec_self_global,
            self_update_global=dec_self_update_global,
            use_cross_global=num_global_vectors > 0 and use_dec_cross_global,
            use_global_vector_ffn=use_global_vector_ffn,
            use_global_self_attn=use_global_self_attn,
            separate_global_qkv=separate_global_qkv,
            global_dim_ratio=global_dim_ratio,
            checkpoint_level=checkpoint_level,
            use_relative_pos=use_relative_pos,
            self_attn_use_final_proj=self_attn_use_final_proj,
            use_first_self_attn=dec_use_first_self_attn,
            attn_linear_init_mode=attn_linear_init_mode,
            ffn_linear_init_mode=ffn_linear_init_mode,
            conv_init_mode=conv_init_mode,
            up_linear_init_mode=down_up_linear_init_mode,
            norm_init_mode=norm_init_mode,
        )
        self.reset_parameters()

    def get_initial_encoder_final_decoder(
        self,
        initial_downsample_type,
        activation,
        initial_downsample_scale,
        initial_downsample_conv_layers,
        final_upsample_conv_layers,
        padding_type,
        initial_downsample_stack_conv_num_layers,
        initial_downsample_stack_conv_dim_list,
        initial_downsample_stack_conv_downscale_list,
        initial_downsample_stack_conv_num_conv_list,
    ):
        T_in, H_in, W_in, C_in = self.input_shape
        T_out, H_out, W_out, C_out = self.target_shape
        self.initial_downsample_type = initial_downsample_type
        if self.initial_downsample_type == "conv":
            if isinstance(initial_downsample_scale, int):
                initial_downsample_scale = (
                    1,
                    initial_downsample_scale,
                    initial_downsample_scale,
                )
            elif len(initial_downsample_scale) == 2:
                initial_downsample_scale = 1, *initial_downsample_scale
            elif len(initial_downsample_scale) == 3:
                initial_downsample_scale = tuple(initial_downsample_scale)
            else:
                raise NotImplementedError(
                    f"initial_downsample_scale {initial_downsample_scale} format not supported!"
                )
            self.initial_encoder = InitialEncoder(
                dim=C_in,
                out_dim=self.base_units,
                downsample_scale=initial_downsample_scale,
                num_conv_layers=initial_downsample_conv_layers,
                padding_type=padding_type,
                activation=activation,
                conv_init_mode=self.conv_init_mode,
                linear_init_mode=self.down_up_linear_init_mode,
                norm_init_mode=self.norm_init_mode,
            )

            self.final_decoder = FinalDecoder(
                dim=self.base_units,
                target_thw=(T_out, H_out, W_out),
                num_conv_layers=final_upsample_conv_layers,
                activation=activation,
                conv_init_mode=self.conv_init_mode,
                linear_init_mode=self.down_up_linear_init_mode,
                norm_init_mode=self.norm_init_mode,
            )
            new_input_shape = self.initial_encoder.patch_merge.get_out_shape(
                self.input_shape
            )
            self.dec_final_proj = paddle.nn.Linear(
                in_features=self.base_units, out_features=C_out
            )
        elif self.initial_downsample_type == "stack_conv":
            if initial_downsample_stack_conv_dim_list is None:
                initial_downsample_stack_conv_dim_list = [
                    self.base_units
                ] * initial_downsample_stack_conv_num_layers
            self.initial_encoder = InitialStackPatchMergingEncoder(
                num_merge=initial_downsample_stack_conv_num_layers,
                in_dim=C_in,
                out_dim_list=initial_downsample_stack_conv_dim_list,
                downsample_scale_list=initial_downsample_stack_conv_downscale_list,
                num_conv_per_merge_list=initial_downsample_stack_conv_num_conv_list,
                padding_type=padding_type,
                activation=activation,
                conv_init_mode=self.conv_init_mode,
                linear_init_mode=self.down_up_linear_init_mode,
                norm_init_mode=self.norm_init_mode,
            )
            initial_encoder_out_shape_list = self.initial_encoder.get_out_shape_list(
                self.target_shape
            )
            (
                dec_target_shape_list,
                dec_in_dim,
            ) = FinalStackUpsamplingDecoder.get_init_params(
                enc_input_shape=self.target_shape,
                enc_out_shape_list=initial_encoder_out_shape_list,
                large_channel=True,
            )
            self.final_decoder = FinalStackUpsamplingDecoder(
                target_shape_list=dec_target_shape_list,
                in_dim=dec_in_dim,
                num_conv_per_up_list=initial_downsample_stack_conv_num_conv_list[::-1],
                activation=activation,
                conv_init_mode=self.conv_init_mode,
                linear_init_mode=self.down_up_linear_init_mode,
                norm_init_mode=self.norm_init_mode,
            )
            self.dec_final_proj = paddle.nn.Linear(
                in_features=dec_target_shape_list[-1][-1], out_features=C_out
            )
            new_input_shape = self.initial_encoder.get_out_shape_list(self.input_shape)[
                -1
            ]
        else:
            raise NotImplementedError(f"{self.initial_downsample_type} is invalid.")
        self.input_shape_after_initial_downsample = new_input_shape
        T_in, H_in, W_in, _ = new_input_shape
        return new_input_shape

    def reset_parameters(self):
        if self.num_global_vectors > 0:
            self.init_global_vectors = initializer.trunc_normal_(
                self.init_global_vectors, std=0.02
            )
        if hasattr(self.initial_encoder, "reset_parameters"):
            self.initial_encoder.reset_parameters()
        else:
            cuboid_utils.apply_initialization(
                self.initial_encoder,
                conv_mode=self.conv_init_mode,
                linear_mode=self.down_up_linear_init_mode,
                norm_mode=self.norm_init_mode,
            )
        if hasattr(self.final_decoder, "reset_parameters"):
            self.final_decoder.reset_parameters()
        else:
            cuboid_utils.apply_initialization(
                self.final_decoder,
                conv_mode=self.conv_init_mode,
                linear_mode=self.down_up_linear_init_mode,
                norm_mode=self.norm_init_mode,
            )
        cuboid_utils.apply_initialization(
            self.dec_final_proj, linear_mode=self.down_up_linear_init_mode
        )
        self.encoder.reset_parameters()
        self.enc_pos_embed.reset_parameters()
        self.decoder.reset_parameters()
        self.dec_pos_embed.reset_parameters()
        cuboid_utils.apply_initialization(self.z_proj, linear_mode="0")

    def get_initial_z(self, final_mem, T_out):
        B = final_mem.shape[0]
        if self.z_init_method == "zeros":
            z_shape = list((1, T_out)) + final_mem.shape[2:]
            initial_z = paddle.zeros(shape=z_shape, dtype=final_mem.dtype)
            initial_z = self.z_proj(self.dec_pos_embed(initial_z)).expand(
                shape=[B, -1, -1, -1, -1]
            )
        elif self.z_init_method == "nearest_interp":
            initial_z = paddle.nn.functional.interpolate(
                x=final_mem.transpose(perm=[0, 4, 1, 2, 3]),
                size=(T_out, final_mem.shape[2], final_mem.shape[3]),
            ).transpose(perm=[0, 2, 3, 4, 1])
            initial_z = self.z_proj(initial_z)
        elif self.z_init_method == "last":
            initial_z = paddle.broadcast_to(
                x=final_mem[:, -1:, :, :, :], shape=(B, T_out) + final_mem.shape[2:]
            )
            initial_z = self.z_proj(initial_z)
        elif self.z_init_method == "mean":
            initial_z = paddle.broadcast_to(
                x=final_mem.mean(axis=1, keepdims=True),
                shape=(B, T_out) + final_mem.shape[2:],
            )
            initial_z = self.z_proj(initial_z)
        else:
            raise NotImplementedError
        return initial_z

    def forward(self, x: "paddle.Tensor", verbose: bool = False) -> "paddle.Tensor":
        """
        Args:
            x (paddle.Tensor): Tensor with shape (B, T, H, W, C).
            verbose (bool): if True, print intermediate shapes.

        Returns:
            out (paddle.Tensor): The output Shape (B, T_out, H, W, C_out)
        """

        x = self.concat_to_tensor(x, self.input_keys)
        flag_ndim = x.ndim
        if flag_ndim == 6:
            x = x.reshape([-1, *x.shape[2:]])
        B, _, _, _, _ = x.shape

        T_out = self.target_shape[0]
        x = self.initial_encoder(x)
        x = self.enc_pos_embed(x)

        if self.num_global_vectors > 0:
            init_global_vectors = self.init_global_vectors.expand(
                shape=[
                    B,
                    self.num_global_vectors,
                    self.global_dim_ratio * self.base_units,
                ]
            )
            mem_l, mem_global_vector_l = self.encoder(x, init_global_vectors)
        else:
            mem_l = self.encoder(x)

        if verbose:
            for i, mem in enumerate(mem_l):
                print(f"mem[{i}].shape = {mem.shape}")
        initial_z = self.get_initial_z(final_mem=mem_l[-1], T_out=T_out)

        if self.num_global_vectors > 0:
            dec_out = self.decoder(initial_z, mem_l, mem_global_vector_l)
        else:
            dec_out = self.decoder(initial_z, mem_l)

        dec_out = self.final_decoder(dec_out)

        out = self.dec_final_proj(dec_out)
        if flag_ndim == 6:
            out = out.reshape([-1, *out.shape])
        return {key: out for key in self.output_keys}
forward(x, verbose=False)

Parameters:

Name Type Description Default
x Tensor

Tensor with shape (B, T, H, W, C).

required
verbose bool

if True, print intermediate shapes.

False

Returns:

Name Type Description
out Tensor

The output Shape (B, T_out, H, W, C_out)

Source code in ppsci/arch/cuboid_transformer.py
def forward(self, x: "paddle.Tensor", verbose: bool = False) -> "paddle.Tensor":
    """
    Args:
        x (paddle.Tensor): Tensor with shape (B, T, H, W, C).
        verbose (bool): if True, print intermediate shapes.

    Returns:
        out (paddle.Tensor): The output Shape (B, T_out, H, W, C_out)
    """

    x = self.concat_to_tensor(x, self.input_keys)
    flag_ndim = x.ndim
    if flag_ndim == 6:
        x = x.reshape([-1, *x.shape[2:]])
    B, _, _, _, _ = x.shape

    T_out = self.target_shape[0]
    x = self.initial_encoder(x)
    x = self.enc_pos_embed(x)

    if self.num_global_vectors > 0:
        init_global_vectors = self.init_global_vectors.expand(
            shape=[
                B,
                self.num_global_vectors,
                self.global_dim_ratio * self.base_units,
            ]
        )
        mem_l, mem_global_vector_l = self.encoder(x, init_global_vectors)
    else:
        mem_l = self.encoder(x)

    if verbose:
        for i, mem in enumerate(mem_l):
            print(f"mem[{i}].shape = {mem.shape}")
    initial_z = self.get_initial_z(final_mem=mem_l[-1], T_out=T_out)

    if self.num_global_vectors > 0:
        dec_out = self.decoder(initial_z, mem_l, mem_global_vector_l)
    else:
        dec_out = self.decoder(initial_z, mem_l)

    dec_out = self.final_decoder(dec_out)

    out = self.dec_final_proj(dec_out)
    if flag_ndim == 6:
        out = out.reshape([-1, *out.shape])
    return {key: out for key in self.output_keys}