跳转至

Utils.symbolic(符号计算) 模块

ppsci.utils.symbolic

Sympy to python function conversion module

lambdify(expr, models=None, extra_parameters=None, graph_filename=None, create_graph=True, retain_graph=None, fuse_derivative=False)

Convert sympy expression to callable function.

Parameters:

Name Type Description Default
expr Union[Basic, List[Basic]]

Sympy expression(s) to be converted. Will return callable functions in list if multiple expressions are given. else will return one single callable function.

required
models Optional[Union[Arch, Tuple[Arch, ...]]]

Model(s) for computing forward result in LayerNode.

None
extra_parameters Optional[ParameterList]

Extra learnable parameters. Defaults to None.

None
graph_filename Optional[str]

Save computational graph to graph_filename.png for given expr, if graph_filename is not None and a valid string, such as 'momentum_x'. Defaults to None.

None
create_graph bool

Whether to create the gradient graphs of the computing process. When it is True, higher order derivatives are supported to compute. When it is False, the gradient graphs of the computing process would be discarded. Defaults to True.

True
retain_graph Optional[bool]

Whether to retain the forward graph which is used to calculate the gradient. When it is True, the graph would be retained, in which way users can calculate backward twice for the same graph. When it is False, the graph would be freed. Defaults to None, which means it is equal to create_graph.

None
fuse_derivative bool

Whether to fuse the derivative nodes. For example, if expr is 'Derivative(u, x) + Derivative(u, y)' It will compute grad(u, x) + grad(u, y) if fuse_derivative=False, else will compute sum(grad(u, [x, y])) if fuse_derivative=True as is more efficient in backward-graph. Defaults to False, as it is experimental so not enabled by default if used independently.

False

Returns:

Type Description
Union[ComposedNode, List[ComposedNode]]

Union[ComposedNode, List[ComposedNode]]: Callable object(s) for computing expr with necessary input(s) data in dict given.

Examples:

>>> import paddle
>>> import ppsci
>>> import sympy as sp
>>> a, b, c, x, y = sp.symbols("a b c x y")
>>> u = sp.Function("u")(x, y)
>>> v = sp.Function("v")(x, y)
>>> z = -a + b * (c ** 2) + u * v + 2.3
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 4, 16)
>>> batch_size = 13
>>> a_tensor = paddle.randn([batch_size, 1])
>>> b_tensor = paddle.randn([batch_size, 1])
>>> c_tensor = paddle.randn([batch_size, 1])
>>> x_tensor = paddle.randn([batch_size, 1])
>>> y_tensor = paddle.randn([batch_size, 1])
>>> model_output_dict = model({"x": x_tensor, "y": y_tensor})
>>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"]
>>> z_tensor_manually = (
...     -a_tensor + b_tensor * (c_tensor ** 2)
...     + u_tensor * v_tensor + 2.3
... )
>>> z_tensor_sympy = ppsci.lambdify(z, model)(
...     {
...         "a": a_tensor,
...         "b": b_tensor,
...         "c": c_tensor,
...         "x": x_tensor,
...         "y": y_tensor,
...     }
... )
>>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item()
True
Source code in ppsci/utils/symbolic.py
def lambdify(
    expr: Union[sp.Basic, List[sp.Basic]],
    models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None,
    extra_parameters: Optional[Sequence[paddle.Tensor]] = None,
    graph_filename: Optional[str] = None,
    create_graph: bool = True,
    retain_graph: Optional[bool] = None,
    fuse_derivative: bool = False,
) -> Union[ComposedNode, List[ComposedNode]]:
    """Convert sympy expression to callable function.

    Args:
        expr (Union[sp.Basic, List[sp.Basic]]): Sympy expression(s) to be converted.
            Will return callable functions in list if multiple expressions are given.
            else will return one single callable function.
        models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for
            computing forward result in `LayerNode`.
        extra_parameters (Optional[nn.ParameterList]): Extra learnable parameters.
            Defaults to None.
        graph_filename (Optional[str]): Save computational graph to `graph_filename.png`
            for given `expr`, if `graph_filename` is not None and a valid string,
            such as 'momentum_x'. Defaults to None.
        create_graph (bool, optional): Whether to create the gradient graphs of
            the computing process. When it is True, higher order derivatives are
            supported to compute. When it is False, the gradient graphs of the
            computing process would be discarded. Defaults to True.
        retain_graph (Optional[bool]): Whether to retain the forward graph which
            is used to calculate the gradient. When it is True, the graph would
            be retained, in which way users can calculate backward twice for the
            same graph. When it is False, the graph would be freed. Defaults to None,
            which means it is equal to `create_graph`.
        fuse_derivative (bool, optional): Whether to fuse the derivative nodes.
            For example, if `expr` is 'Derivative(u, x) + Derivative(u, y)'
            It will compute grad(u, x) + grad(u, y) if fuse_derivative=False,
            else will compute sum(grad(u, [x, y])) if fuse_derivative=True as is more
            efficient in backward-graph. Defaults to False, as it is experimental so not
            enabled by default if used independently.

    Returns:
        Union[ComposedNode, List[ComposedNode]]: Callable object(s) for computing expr
            with necessary input(s) data in dict given.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> import sympy as sp

        >>> a, b, c, x, y = sp.symbols("a b c x y")
        >>> u = sp.Function("u")(x, y)
        >>> v = sp.Function("v")(x, y)
        >>> z = -a + b * (c ** 2) + u * v + 2.3

        >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v"), 4, 16)

        >>> batch_size = 13
        >>> a_tensor = paddle.randn([batch_size, 1])
        >>> b_tensor = paddle.randn([batch_size, 1])
        >>> c_tensor = paddle.randn([batch_size, 1])
        >>> x_tensor = paddle.randn([batch_size, 1])
        >>> y_tensor = paddle.randn([batch_size, 1])

        >>> model_output_dict = model({"x": x_tensor, "y": y_tensor})
        >>> u_tensor, v_tensor = model_output_dict["u"], model_output_dict["v"]

        >>> z_tensor_manually = (
        ...     -a_tensor + b_tensor * (c_tensor ** 2)
        ...     + u_tensor * v_tensor + 2.3
        ... )
        >>> z_tensor_sympy = ppsci.lambdify(z, model)(
        ...     {
        ...         "a": a_tensor,
        ...         "b": b_tensor,
        ...         "c": c_tensor,
        ...         "x": x_tensor,
        ...         "y": y_tensor,
        ...     }
        ... )

        >>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item()
        True
    """
    if not extra_parameters:
        extra_parameters = ()

    if isinstance(models, arch.ModelList):
        models = tuple(models.model_list[i] for i in range(len(models.model_list)))
    if not isinstance(models, (tuple, list)):
        models = (models,)

    def _expr_to_callable_nodes(
        single_expr: sp.Basic, graph_filename_: Optional[str] = None
    ) -> List[Node]:
        """Convert sympy expression to a sequence of nodes in topologic order.

        Args:
            single_expr (sp.Basic): Single sympy expression, such as "a+b*c".
            graph_filename_ (Optional[str]): Save computational graph to
            `/path/to/graph_filename.png` for given `expr`, if `graph_filename` is not
            None and a valid string, such as 'momentum_x'. Defaults to None.

        Returns:
            List[Node]: Sequence of callable nodes.
        """
        # NOTE: Those simplify methods may complicate given expr instead, so not use here
        # simplify expression to reduce nodes in tree
        # expr = sp.nsimplify(expr)
        # expr = sp.expand(expr)
        # expr = sp.simplify(expr)

        # remove 1.0 from sympy expression tree
        single_expr = single_expr.subs(1.0, 1)

        # convert sympy expression tree to list of nodes in post-order
        sympy_nodes: List[sp.Basic] = []
        sympy_nodes = _post_traverse(single_expr, sympy_nodes)

        # remove unnecessary symbol nodes already in input dict(except for parameter symbol)
        _parameter_names = tuple(param.name for param in extra_parameters)
        sympy_nodes = [
            node
            for node in sympy_nodes
            if (not node.is_Symbol) or (_cvt_to_key(node) in _parameter_names)
        ]

        # remove duplicated node(s) with topological order kept
        sympy_nodes = list(dict.fromkeys(sympy_nodes))

        # convert sympy node to callable node
        callable_nodes = []
        for i, node in enumerate(sympy_nodes):
            if isinstance(
                node, tuple(SYMPY_TO_PADDLE.keys()) + (sp.Add, sp.Mul, sp.Derivative)
            ):
                if isinstance(node, sp.Derivative):
                    callable_nodes.append(
                        DerivativeNode(node, create_graph, retain_graph)
                    )
                else:
                    callable_nodes.append(OperatorNode(node))
            elif isinstance(node, sp.Function):
                if str(node.func) == equation.DETACH_FUNC_NAME:
                    callable_nodes.append(DetachNode(node))
                    logger.debug(f"Detected detach node {node}")
                else:
                    match_index = None
                    for j, model in enumerate(models):
                        if str(node.func) in model.output_keys:
                            callable_nodes.append(
                                LayerNode(
                                    node,
                                    model,
                                )
                            )
                            if match_index is not None:
                                raise ValueError(
                                    f"Name of function: '{node}' should be unique along given"
                                    f" models, but got same output_key: '{str(node.func)}' "
                                    f"in given models[{match_index}] and models[{j}]."
                                )
                            match_index = j
                    # NOTE: Skip 'sdf' function, which should be already generated in
                    # given data_dict
                    if match_index is None and str(node.func) != "sdf":
                        raise ValueError(
                            f"Node {node} can not match any model in given model(s)."
                        )
            elif node.is_Number or node.is_NumberSymbol:
                callable_nodes.append(ConstantNode(node))
            elif isinstance(node, sp.Symbol):
                callable_nodes.append(
                    ParameterNode(
                        node,
                        *[
                            param
                            for param in extra_parameters
                            if param.name == node.name
                        ],
                    )
                )
            else:
                raise NotImplementedError(
                    f"The node {node} is not supported in lambdify."
                )

        # NOTE: visualize computational graph using 'pygraphviz'
        if isinstance(graph_filename, str):
            _visualize_graph(sympy_nodes, os.path.join(graph_filename, graph_filename_))

        return callable_nodes

    if isinstance(expr, sp.Basic):
        callable_nodes_group = [_expr_to_callable_nodes(expr, "expr")]
    else:
        callable_nodes_group = [
            _expr_to_callable_nodes(expr_i, f"expr_{i}")
            for i, expr_i in enumerate(expr)
        ]

    # [Optional] Fused derivatives nodes that with same function to be differentiated
    while fuse_derivative:
        candidate_pos: List[Tuple[int, int]] = []  # [(group_id, node_id), ...]

        # use 4-nested for-loop to find all potential mergable derivative nodes
        for i in range(len(callable_nodes_group)):
            for j in range(len(callable_nodes_group[i])):
                # skip non-derivative node
                if not isinstance(callable_nodes_group[i][j], DerivativeNode):
                    continue
                # skip sdf function since it is always already given in data_dict
                if callable_nodes_group[i][j].expr.args[0].name == "sdf":
                    continue
                # skip merged node
                if callable_nodes_group[i][j].merged:
                    continue

                candidate_pos = [[i, j]]
                for ii in range(len(callable_nodes_group)):
                    for jj in range(len(callable_nodes_group[ii])):
                        # skip non-derivative node
                        if not isinstance(callable_nodes_group[ii][jj], DerivativeNode):
                            continue

                        # skip same node
                        if i == ii and j == jj:
                            continue
                        # skip merged node
                        if callable_nodes_group[ii][jj].merged:
                            continue

                        # has same function item
                        if (
                            callable_nodes_group[i][j].expr.args[0]
                            == callable_nodes_group[ii][jj].expr.args[0]
                        ):
                            candidate_pos.append([ii, jj])

                if len(candidate_pos) > 1:
                    break
            if len(candidate_pos) > 1:
                break

        # merge all candidate nodes into one or more FusedDerivativeNode node
        if len(candidate_pos) > 1:
            fused_node_seq = _fuse_derivative_nodes(
                [callable_nodes_group[gid][nid].expr for gid, nid in candidate_pos]
            )
            assert isinstance(
                fused_node_seq, list
            ), "'fused_node_seq' should be list of 'FusedDerivativeNode'"
            gid0, nid0 = candidate_pos[0]
            logger.debug(
                f"Fused {len(candidate_pos)} derivatives nodes: "
                f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into"
                f" {len(fused_node_seq)} fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
            )

            # mark merged node
            for i, (gid, nid) in enumerate(candidate_pos):
                assert isinstance(callable_nodes_group[gid][nid], DerivativeNode)
                callable_nodes_group[gid][nid].merged = True

            # replace first mergable node with fused node sequence(packed in list)
            # then mask the rest merged node to None(except [gid0, nid0])
            for i, (gid, nid) in enumerate(candidate_pos[1:]):
                # keep the end node of each group to avoid generating empty callable
                # node sequence, this will not effect performance since cache strategy
                # in Node.forward
                if nid != len(callable_nodes_group[gid]) - 1:
                    callable_nodes_group[gid][nid] = None

            if nid0 == len(callable_nodes_group[gid0]) - 1:
                callable_nodes_group[gid0].insert(nid0, fused_node_seq)
            else:
                callable_nodes_group[gid0][nid0] = fused_node_seq

            # re-organize callable_nodes_group, remove None element and unpack list
            for i in range(len(callable_nodes_group)):
                tmp = []
                for j in range(len(callable_nodes_group[i])):
                    if isinstance(
                        callable_nodes_group[i][j], (Node, FusedDerivativeNode)
                    ):
                        tmp.append(callable_nodes_group[i][j])
                    elif isinstance(callable_nodes_group[i][j], list) and isinstance(
                        callable_nodes_group[i][j][0], FusedDerivativeNode
                    ):
                        tmp.extend(callable_nodes_group[i][j])
                    else:
                        assert (
                            callable_nodes_group[i][j] is None
                        ), f"Unexpected element: {callable_nodes_group[i][j]}"
                callable_nodes_group[i] = tmp
        else:
            # exit while loop if no more fused
            break

    # Compose callable nodes into one callable object
    if isinstance(expr, sp.Basic):
        return ComposedNode(callable_nodes_group[0])
    else:
        return [ComposedNode(callable_nodes) for callable_nodes in callable_nodes_group]