跳转至

Equation(方程) 模块

ppsci.equation

PDE

Base class for Partial Differential Equation

Source code in ppsci/equation/pde/base.py
class PDE:
    """Base class for Partial Differential Equation"""

    def __init__(self):
        super().__init__()
        self.equations = {}
        # for PDE which has learnable parameter(s)
        self.learnable_parameters = nn.ParameterList()

        self.detach_keys: Optional[Tuple[str, ...]] = None

    def create_symbols(
        self, symbol_str: str
    ) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
        """Create symbols

        Args:
            symbol_str (str): String contains symbols, such as "x", "x y z".

        Returns:
            Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: Created symbol(s).
        """
        return sympy.symbols(symbol_str)

    def create_function(
        self, name: str, invars: Tuple[sympy.Symbol, ...]
    ) -> sympy.Function:
        """Create named function depending on given invars.

        Args:
            name (str): Function name. such as "u", "v", and "f".
            invars (Tuple[sympy.Symbol, ...]): List of independent variable of function.

        Returns:
            sympy.Function: Named sympy function.
        """
        expr = sympy.Function(name)(*invars)

        # wrap `expression(...)` to `detach(expression(...))`
        # if name of expression is in given detach_keys
        if self.detach_keys and name in self.detach_keys:
            expr = sympy.Function(DETACH_FUNC_NAME)(expr)
        return expr

    def add_equation(self, name: str, equation: Callable):
        """Add an equation.

        Args:
            name (str): Name of equation
            equation (Callable): Computation function for equation.
        """
        self.equations.update({name: equation})

    def parameters(self) -> List[paddle.Tensor]:
        """Return parameters contained in PDE.

        Returns:
            List[Tensor]: A list of parameters.
        """
        return self.learnable_parameters.parameters()

    def state_dict(self) -> Dict[str, paddle.Tensor]:
        """Return named parameters in dict."""
        return self.learnable_parameters.state_dict()

    def set_state_dict(self, state_dict):
        """Set state dict from dict."""
        self.learnable_parameters.set_state_dict(state_dict)

    def __str__(self):
        return ", ".join(
            [self.__class__.__name__]
            + [f"{name}: {eq}" for name, eq in self.equations.items()]
        )
add_equation(name, equation)

Add an equation.

Parameters:

Name Type Description Default
name str

Name of equation

required
equation Callable

Computation function for equation.

required
Source code in ppsci/equation/pde/base.py
def add_equation(self, name: str, equation: Callable):
    """Add an equation.

    Args:
        name (str): Name of equation
        equation (Callable): Computation function for equation.
    """
    self.equations.update({name: equation})
create_function(name, invars)

Create named function depending on given invars.

Parameters:

Name Type Description Default
name str

Function name. such as "u", "v", and "f".

required
invars Tuple[Symbol, ...]

List of independent variable of function.

required

Returns:

Type Description
Function

sympy.Function: Named sympy function.

Source code in ppsci/equation/pde/base.py
def create_function(
    self, name: str, invars: Tuple[sympy.Symbol, ...]
) -> sympy.Function:
    """Create named function depending on given invars.

    Args:
        name (str): Function name. such as "u", "v", and "f".
        invars (Tuple[sympy.Symbol, ...]): List of independent variable of function.

    Returns:
        sympy.Function: Named sympy function.
    """
    expr = sympy.Function(name)(*invars)

    # wrap `expression(...)` to `detach(expression(...))`
    # if name of expression is in given detach_keys
    if self.detach_keys and name in self.detach_keys:
        expr = sympy.Function(DETACH_FUNC_NAME)(expr)
    return expr
create_symbols(symbol_str)

Create symbols

Parameters:

Name Type Description Default
symbol_str str

String contains symbols, such as "x", "x y z".

required

Returns:

Type Description
Union[Symbol, Tuple[Symbol, ...]]

Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: Created symbol(s).

Source code in ppsci/equation/pde/base.py
def create_symbols(
    self, symbol_str: str
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
    """Create symbols

    Args:
        symbol_str (str): String contains symbols, such as "x", "x y z".

    Returns:
        Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]: Created symbol(s).
    """
    return sympy.symbols(symbol_str)
parameters()

Return parameters contained in PDE.

Returns:

Type Description
List[Tensor]

List[Tensor]: A list of parameters.

Source code in ppsci/equation/pde/base.py
def parameters(self) -> List[paddle.Tensor]:
    """Return parameters contained in PDE.

    Returns:
        List[Tensor]: A list of parameters.
    """
    return self.learnable_parameters.parameters()
set_state_dict(state_dict)

Set state dict from dict.

Source code in ppsci/equation/pde/base.py
def set_state_dict(self, state_dict):
    """Set state dict from dict."""
    self.learnable_parameters.set_state_dict(state_dict)
state_dict()

Return named parameters in dict.

Source code in ppsci/equation/pde/base.py
def state_dict(self) -> Dict[str, paddle.Tensor]:
    """Return named parameters in dict."""
    return self.learnable_parameters.state_dict()

FractionalPoisson

Bases: PDE

TODO: refine this docstring Args: alpha (float): Alpha. geom (geometry.Geometry): Computation geometry. resolution (Tuple[int, ...]): Resolution.

Examples:

>>> import ppsci
>>> fpde = ppsci.equation.FractionalPoisson(ALPHA, geom["disk"], [8, 100])
Source code in ppsci/equation/fpde/fractional_poisson.py
class FractionalPoisson(PDE):
    r"""

    TODO: refine this docstring
    Args:
        alpha (float): Alpha.
        geom (geometry.Geometry): Computation geometry.
        resolution (Tuple[int, ...]): Resolution.

    Examples:
        >>> import ppsci
        >>> fpde = ppsci.equation.FractionalPoisson(ALPHA, geom["disk"], [8, 100])
    """
    dtype = paddle.get_default_dtype()

    def __init__(
        self, alpha: float, geom: geometry.Geometry, resolution: Tuple[int, ...]
    ):
        super().__init__()
        self.alpha = alpha
        self.geom = geom
        self.resolution = resolution
        self._w_init = self._init_weights()

        def compute_fpde_func(out):
            x = paddle.concat((out["x"], out["y"]), axis=1)
            y = out["u"]
            indices, values, shape = self.int_mat
            int_mat = sparse.sparse_coo_tensor(
                [[p[0] for p in indices], [p[1] for p in indices]],
                values,
                shape,
                stop_gradient=False,
            )
            lhs = sparse.matmul(int_mat, y)
            lhs = lhs[:, 0]
            lhs *= (
                special.gamma((1 - self.alpha) / 2)
                * special.gamma((2 + self.alpha) / 2)
                / (2 * np.pi**1.5)
            )
            x = x[: paddle.numel(lhs)]
            rhs = (
                2**self.alpha
                * special.gamma(2 + self.alpha / 2)
                * special.gamma(1 + self.alpha / 2)
                * (1 - (1 + self.alpha / 2) * paddle.sum(x**2, axis=1))
            )
            res = lhs - rhs
            return res

        self.add_equation("fpde", compute_fpde_func)

    def _init_weights(self):
        n = self._dynamic_dist2npts(self.geom.diam) + 1
        w = [1.0]
        for j in range(1, n):
            w.append(w[-1] * (j - 1 - self.alpha) / j)
        return np.array(w, dtype=self.dtype)

    def get_x(self, x_f):
        if hasattr(self, "train_x"):
            return self.train_x

        self.x0 = x_f
        if np.any(self.geom.on_boundary(self.x0)):
            raise ValueError("x0 contains boundary points.")

        if self.geom.ndim == 1:
            dirns, dirn_w = [-1, 1], [1, 1]
        elif self.geom.ndim == 2:
            gauss_x, gauss_w = np.polynomial.legendre.leggauss(self.resolution[0])
            gauss_x, gauss_w = gauss_x.astype(self.dtype), gauss_w.astype(self.dtype)
            thetas = np.pi * gauss_x + np.pi
            dirns = np.vstack((np.cos(thetas), np.sin(thetas))).T
            dirn_w = np.pi * gauss_w
        elif self.geom.ndim == 3:
            gauss_x, gauss_w = np.polynomial.legendre.leggauss(max(self.resolution[:2]))
            gauss_x, gauss_w = gauss_x.astype(self.dtype), gauss_w.astype(self.dtype)
            thetas = (np.pi * gauss_x[: self.resolution[0]] + np.pi) / 2
            phis = np.pi * gauss_x[: self.resolution[1]] + np.pi
            dirns, dirn_w = [], []
            for i in range(self.resolution[0]):
                for j in range(self.resolution[1]):
                    dirns.append(
                        [
                            np.sin(thetas[i]) * np.cos(phis[j]),
                            np.sin(thetas[i]) * np.sin(phis[j]),
                            np.cos(thetas[i]),
                        ]
                    )
                    dirn_w.append(gauss_w[i] * gauss_w[j] * np.sin(thetas[i]))
            dirn_w = np.pi**2 / 2 * np.array(dirn_w)

        x, self.w = [], []
        for x0i in self.x0:
            xi = list(
                map(
                    lambda dirn: self.background_points(
                        x0i, dirn, self._dynamic_dist2npts, 0
                    ),
                    dirns,
                )
            )
            wi = list(
                map(
                    lambda i: dirn_w[i]
                    * np.linalg.norm(xi[i][1] - xi[i][0]) ** (-self.alpha)
                    * self.get_weight(len(xi[i]) - 1),
                    range(len(dirns)),
                )
            )
            # first order
            # xi, wi = zip(self.modify_first_order(xij, wij) for xij, wij in zip(xi, wi))
            xi, wi = zip(*map(self.modify_first_order, xi, wi))
            # second order
            # xi, wi = zip(*map(self.modify_second_order, xi, wi))
            # third order
            # xi, wi = zip(*map(self.modify_third_order, xi, wi))
            x.append(np.vstack(xi))
            self.w.append(np.hstack(wi))
        self.x = np.vstack([self.x0] + x)
        self.int_mat = self._get_int_matrix(self.x0)
        self.train_x = misc.convert_to_dict(self.x, ("x", "y"))
        return self.train_x

    def get_weight(self, n):
        return self._w_init[: n + 1]

    def background_points(self, x, dirn, dist2npt, shift):
        dirn = dirn / np.linalg.norm(dirn)
        dx = self.distance2boundary_unitdirn(x, -dirn)
        n = max(dist2npt(dx), 1)
        h = dx / n
        pts = x - np.arange(-shift, n - shift + 1, dtype=self.dtype)[:, None] * h * dirn
        return pts

    def distance2boundary_unitdirn(self, x, dirn):
        # https://en.wikipedia.org/wiki/Line%E2%80%93sphere_intersection
        xc = x - self.geom.center
        xc = xc
        ad = np.dot(xc, dirn)
        return (
            -ad + (ad**2 - np.sum(xc * xc, axis=-1) + self.geom.radius**2) ** 0.5
        ).astype(self.dtype)

    def modify_first_order(self, x, w):
        x = np.vstack(([2 * x[0] - x[1]], x[:-1]))
        if not self.geom.is_inside(x[0:1])[0]:
            return x[1:], w[1:]
        return x, w

    def _dynamic_dist2npts(self, dx):
        return int(math.ceil(self.resolution[-1] * dx))

    def _get_int_matrix(self, x: np.ndarray) -> np.ndarray:
        dense_shape = (x.shape[0], self.x.shape[0])
        indices, values = [], []
        beg = x.shape[0]
        for i in range(x.shape[0]):
            for _ in range(self.w[i].shape[0]):
                indices.append([i, beg])
                beg += 1
            values = np.hstack((values, self.w[i]))
        return indices, values.astype(self.dtype), dense_shape

Biharmonic

Bases: PDE

Class for biharmonic equation with supporting special load.

\[ \nabla^4 \varphi = \dfrac{q}{D} \]

Parameters:

Name Type Description Default
dim int

Dimension of equation.

required
q Union[float, str, Basic]

Load.

required
D Union[float, str]

Rigidity.

required
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.Biharmonic(2, -1.0, 1.0)
Source code in ppsci/equation/pde/biharmonic.py
class Biharmonic(base.PDE):
    r"""Class for biharmonic equation with supporting special load.

    $$
    \nabla^4 \varphi = \dfrac{q}{D}
    $$

    Args:
        dim (int): Dimension of equation.
        q (Union[float, str, sympy.Basic]): Load.
        D (Union[float, str]): Rigidity.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.Biharmonic(2, -1.0, 1.0)
    """

    def __init__(
        self,
        dim: int,
        q: Union[float, str, sympy.Basic],
        D: Union[float, str],
        detach_keys: Optional[Tuple[str, ...]] = None,
    ):
        super().__init__()
        self.detach_keys = detach_keys

        invars = self.create_symbols("x y z")[:dim]
        u = self.create_function("u", invars)

        if isinstance(q, str):
            q = self.create_function("q", invars)
        if isinstance(D, str):
            D = self.create_function("D", invars)

        self.dim = dim
        self.q = q
        self.D = D

        biharmonic = -self.q / self.D
        for invar_i in invars:
            for invar_j in invars:
                biharmonic += u.diff(invar_i, 2).diff(invar_j, 2)

        self.add_equation("biharmonic", biharmonic)

Laplace

Bases: PDE

Class for laplace equation.

\[ \nabla^2 \varphi = 0 \]

Parameters:

Name Type Description Default
dim int

Dimension of equation.

required
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.Laplace(2)
Source code in ppsci/equation/pde/laplace.py
class Laplace(base.PDE):
    r"""Class for laplace equation.

    $$
    \nabla^2 \varphi = 0
    $$

    Args:
        dim (int): Dimension of equation.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.Laplace(2)
    """

    def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
        super().__init__()
        self.detach_keys = detach_keys

        invars = self.create_symbols("x y z")[:dim]
        u = self.create_function("u", invars)

        self.dim = dim

        laplace = 0
        for invar in invars:
            laplace += u.diff(invar, 2)

        self.add_equation("laplace", laplace)

LinearElasticity

Bases: PDE

Linear elasticity equations. Use either (E, nu) or (lambda_, mu) to define the material properties.

\[ \begin{cases} stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\ stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\ stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\ traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\ traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\ traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\ \end{cases} \]

Parameters:

Name Type Description Default
E Optional[Union[float, str]]

The Young's modulus. Defaults to None.

None
nu Optional[Union[float, str]]

The Poisson's ratio. Defaults to None.

None
lambda_ Optional[Union[float, str]]

Lamé's first parameter. Defaults to None.

None
mu Optional[Union[float, str]]

Lamé's second parameter (shear modulus). Defaults to None.

None
rho Union[float, str]

Mass density. Defaults to 1.

1
dim int

Dimension of the linear elasticity (2 or 3). Defaults to 3.

3
time bool

Whether contains time data. Defaults to False.

False
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.LinearElasticity(
...     E=None, nu=None, lambda_=1e4, mu=100, dim=3
... )
Source code in ppsci/equation/pde/linear_elasticity.py
class LinearElasticity(base.PDE):
    r"""Linear elasticity equations.
    Use either (E, nu) or (lambda_, mu) to define the material properties.

    $$
    \begin{cases}
        stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\
        stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\
        stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\
        traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\
        traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\
        traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\
    \end{cases}
    $$

    Args:
        E (Optional[Union[float, str]]): The Young's modulus. Defaults to None.
        nu (Optional[Union[float, str]]): The Poisson's ratio. Defaults to None.
        lambda_ (Optional[Union[float, str]]): Lamé's first parameter. Defaults to None.
        mu (Optional[Union[float, str]]): Lamé's second parameter (shear modulus). Defaults to None.
        rho (Union[float, str], optional): Mass density. Defaults to 1.
        dim (int, optional): Dimension of the linear elasticity (2 or 3). Defaults to 3.
        time (bool, optional): Whether contains time data. Defaults to False.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.LinearElasticity(
        ...     E=None, nu=None, lambda_=1e4, mu=100, dim=3
        ... )
    """

    def __init__(
        self,
        E: Optional[Union[float, str]] = None,
        nu: Optional[Union[float, str]] = None,
        lambda_: Optional[Union[float, str]] = None,
        mu: Optional[Union[float, str]] = None,
        rho: Union[float, str] = 1,
        dim: int = 3,
        time: bool = False,
        detach_keys: Optional[Tuple[str, ...]] = None,
    ):
        super().__init__()
        self.detach_keys = detach_keys
        self.dim = dim
        self.time = time

        t, x, y, z = self.create_symbols("t x y z")
        normal_x, normal_y, normal_z = self.create_symbols("normal_x normal_y normal_z")
        invars = (x, y)
        if time:
            invars = (t,) + invars
        if self.dim == 3:
            invars += (z,)

        u = self.create_function("u", invars)
        v = self.create_function("v", invars)
        w = self.create_function("w", invars) if dim == 3 else sp.Number(0)

        sigma_xx = self.create_function("sigma_xx", invars)
        sigma_yy = self.create_function("sigma_yy", invars)
        sigma_xy = self.create_function("sigma_xy", invars)
        sigma_zz = (
            self.create_function("sigma_zz", invars) if dim == 3 else sp.Number(0)
        )
        sigma_xz = (
            self.create_function("sigma_xz", invars) if dim == 3 else sp.Number(0)
        )
        sigma_yz = (
            self.create_function("sigma_yz", invars) if dim == 3 else sp.Number(0)
        )

        # compute lambda and mu
        if lambda_ is None:
            if isinstance(nu, str):
                nu = self.create_function(nu, invars)
            if isinstance(E, str):
                E = self.create_function(E, invars)
            lambda_ = nu * E / ((1 + nu) * (1 - 2 * nu))
            mu = E / (2 * (1 + nu))
        else:
            if isinstance(lambda_, str):
                lambda_ = self.create_function(lambda_, invars)
            if isinstance(mu, str):
                mu = self.create_function(mu, invars)

        if isinstance(rho, str):
            rho = self.create_function(rho, invars)

        self.E = E
        self.nu = nu
        self.lambda_ = lambda_
        self.mu = mu
        self.rho = rho

        # compute stress equations
        stress_disp_xx = (
            lambda_ * (u.diff(x) + v.diff(y) + w.diff(z))
            + 2 * mu * u.diff(x)
            - sigma_xx
        )
        stress_disp_yy = (
            lambda_ * (u.diff(x) + v.diff(y) + w.diff(z))
            + 2 * mu * v.diff(y)
            - sigma_yy
        )
        stress_disp_zz = (
            lambda_ * (u.diff(x) + v.diff(y) + w.diff(z))
            + 2 * mu * w.diff(z)
            - sigma_zz
        )
        stress_disp_xy = mu * (u.diff(y) + v.diff(x)) - sigma_xy
        stress_disp_xz = mu * (u.diff(z) + w.diff(x)) - sigma_xz
        stress_disp_yz = mu * (v.diff(z) + w.diff(y)) - sigma_yz

        # compute equilibrium equations
        equilibrium_x = rho * ((u.diff(t)).diff(t)) - (
            sigma_xx.diff(x) + sigma_xy.diff(y) + sigma_xz.diff(z)
        )
        equilibrium_y = rho * ((v.diff(t)).diff(t)) - (
            sigma_xy.diff(x) + sigma_yy.diff(y) + sigma_yz.diff(z)
        )
        equilibrium_z = rho * ((w.diff(t)).diff(t)) - (
            sigma_xz.diff(x) + sigma_yz.diff(y) + sigma_zz.diff(z)
        )

        # compute traction equations
        traction_x = normal_x * sigma_xx + normal_y * sigma_xy + normal_z * sigma_xz
        traction_y = normal_x * sigma_xy + normal_y * sigma_yy + normal_z * sigma_yz
        traction_z = normal_x * sigma_xz + normal_y * sigma_yz + normal_z * sigma_zz

        # add stress equations
        self.add_equation("stress_disp_xx", stress_disp_xx)
        self.add_equation("stress_disp_yy", stress_disp_yy)
        self.add_equation("stress_disp_xy", stress_disp_xy)
        if self.dim == 3:
            self.add_equation("stress_disp_zz", stress_disp_zz)
            self.add_equation("stress_disp_xz", stress_disp_xz)
            self.add_equation("stress_disp_yz", stress_disp_yz)

        # add equilibrium equations
        self.add_equation("equilibrium_x", equilibrium_x)
        self.add_equation("equilibrium_y", equilibrium_y)
        if self.dim == 3:
            self.add_equation("equilibrium_z", equilibrium_z)

        # add traction equations
        self.add_equation("traction_x", traction_x)
        self.add_equation("traction_y", traction_y)
        if self.dim == 3:
            self.add_equation("traction_z", traction_z)

NavierStokes

Bases: PDE

Class for navier-stokes equation.

\[ \begin{cases} \dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z} = 0 \\ \dfrac{\partial u}{\partial t} + u\dfrac{\partial u}{\partial x} + v\dfrac{\partial u}{\partial y} + w\dfrac{\partial u}{\partial z} = - \dfrac{1}{\rho}\dfrac{\partial p}{\partial x} + \nu( \dfrac{\partial ^2 u}{\partial x ^2} + \dfrac{\partial ^2 u}{\partial y ^2} + \dfrac{\partial ^2 u}{\partial z ^2} ) \\ \dfrac{\partial v}{\partial t} + u\dfrac{\partial v}{\partial x} + v\dfrac{\partial v}{\partial y} + w\dfrac{\partial v}{\partial z} = - \dfrac{1}{\rho}\dfrac{\partial p}{\partial y} + \nu( \dfrac{\partial ^2 v}{\partial x ^2} + \dfrac{\partial ^2 v}{\partial y ^2} + \dfrac{\partial ^2 v}{\partial z ^2} ) \\ \dfrac{\partial w}{\partial t} + u\dfrac{\partial w}{\partial x} + v\dfrac{\partial w}{\partial y} + w\dfrac{\partial w}{\partial z} = - \dfrac{1}{\rho}\dfrac{\partial p}{\partial z} + \nu( \dfrac{\partial ^2 w}{\partial x ^2} + \dfrac{\partial ^2 w}{\partial y ^2} + \dfrac{\partial ^2 w}{\partial z ^2} ) \\ \end{cases} \]

Parameters:

Name Type Description Default
nu Union[float, str]

Dynamic viscosity.

required
rho Union[float, str]

Density.

required
dim int

Dimension of equation.

required
time bool

Whether the equation is time-dependent.

required
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.NavierStokes(0.1, 1.0, 3, False)
Source code in ppsci/equation/pde/navier_stokes.py
class NavierStokes(base.PDE):
    r"""Class for navier-stokes equation.

    $$
    \begin{cases}
        \dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z} = 0 \\
        \dfrac{\partial u}{\partial t} + u\dfrac{\partial u}{\partial x} + v\dfrac{\partial u}{\partial y} + w\dfrac{\partial u}{\partial z} =
            - \dfrac{1}{\rho}\dfrac{\partial p}{\partial x}
            + \nu(
                \dfrac{\partial ^2 u}{\partial x ^2}
                + \dfrac{\partial ^2 u}{\partial y ^2}
                + \dfrac{\partial ^2 u}{\partial z ^2}
            ) \\
        \dfrac{\partial v}{\partial t} + u\dfrac{\partial v}{\partial x} + v\dfrac{\partial v}{\partial y} + w\dfrac{\partial v}{\partial z} =
            - \dfrac{1}{\rho}\dfrac{\partial p}{\partial y}
            + \nu(
                \dfrac{\partial ^2 v}{\partial x ^2}
                + \dfrac{\partial ^2 v}{\partial y ^2}
                + \dfrac{\partial ^2 v}{\partial z ^2}
            ) \\
        \dfrac{\partial w}{\partial t} + u\dfrac{\partial w}{\partial x} + v\dfrac{\partial w}{\partial y} + w\dfrac{\partial w}{\partial z} =
            - \dfrac{1}{\rho}\dfrac{\partial p}{\partial z}
            + \nu(
                \dfrac{\partial ^2 w}{\partial x ^2}
                + \dfrac{\partial ^2 w}{\partial y ^2}
                + \dfrac{\partial ^2 w}{\partial z ^2}
            ) \\
    \end{cases}
    $$

    Args:
        nu (Union[float, str]): Dynamic viscosity.
        rho (Union[float, str]): Density.
        dim (int): Dimension of equation.
        time (bool): Whether the equation is time-dependent.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.NavierStokes(0.1, 1.0, 3, False)
    """

    def __init__(
        self,
        nu: Union[float, str],
        rho: Union[float, str],
        dim: int,
        time: bool,
        detach_keys: Optional[Tuple[str, ...]] = None,
    ):
        super().__init__()
        self.detach_keys = detach_keys
        self.dim = dim
        self.time = time

        t, x, y, z = self.create_symbols("t x y z")
        invars = (x, y)
        if time:
            invars = (t,) + invars
        if dim == 3:
            invars += (z,)

        if isinstance(nu, str):
            nu = sp_parser.parse_expr(nu)
            if isinstance(nu, sp.Symbol):
                invars += (nu,)

        if isinstance(rho, str):
            rho = sp_parser.parse_expr(rho)
            if isinstance(rho, sp.Symbol):
                invars += (rho,)

        self.nu = nu
        self.rho = rho

        u = self.create_function("u", invars)
        v = self.create_function("v", invars)
        w = self.create_function("w", invars) if dim == 3 else sp.Number(0)
        p = self.create_function("p", invars)

        continuity = u.diff(x) + v.diff(y) + w.diff(z)
        momentum_x = (
            u.diff(t)
            + u * u.diff(x)
            + v * u.diff(y)
            + w * u.diff(z)
            - (
                (nu * u.diff(x)).diff(x)
                + (nu * u.diff(y)).diff(y)
                + (nu * u.diff(z)).diff(z)
            )
            + 1 / rho * p.diff(x)
        )
        momentum_y = (
            v.diff(t)
            + u * v.diff(x)
            + v * v.diff(y)
            + w * v.diff(z)
            - (
                (nu * v.diff(x)).diff(x)
                + (nu * v.diff(y)).diff(y)
                + (nu * v.diff(z)).diff(z)
            )
            + 1 / rho * p.diff(y)
        )
        momentum_z = (
            w.diff(t)
            + u * w.diff(x)
            + v * w.diff(y)
            + w * w.diff(z)
            - (
                (nu * w.diff(x)).diff(x)
                + (nu * w.diff(y)).diff(y)
                + (nu * w.diff(z)).diff(z)
            )
            + 1 / rho * p.diff(z)
        )
        self.add_equation("continuity", continuity)
        self.add_equation("momentum_x", momentum_x)
        self.add_equation("momentum_y", momentum_y)
        if self.dim == 3:
            self.add_equation("momentum_z", momentum_z)

NormalDotVec

Bases: PDE

Normal Dot Vector.

\[ \mathbf{n} \cdot \mathbf{v} = 0 \]

Parameters:

Name Type Description Default
vec_keys Tuple[str, ...]

Keys for vectors, such as ("u", "v", "w") for velocity vector.

required
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.NormalDotVec(("u", "v", "w"))
Source code in ppsci/equation/pde/normal_dot_vec.py
class NormalDotVec(base.PDE):
    r"""Normal Dot Vector.

    $$
    \mathbf{n} \cdot \mathbf{v} = 0
    $$

    Args:
        vec_keys (Tuple[str, ...]): Keys for vectors, such as ("u", "v", "w") for
            velocity vector.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.NormalDotVec(("u", "v", "w"))
    """

    def __init__(
        self, vec_keys: Tuple[str, ...], detach_keys: Optional[Tuple[str, ...]] = None
    ):
        super().__init__()
        self.detach_keys = detach_keys
        if not vec_keys:
            raise ValueError(f"len(vec_keys)({len(vec_keys)}) should be larger than 0.")

        self.vec_keys = vec_keys
        vec_vars = self.create_symbols(" ".join(vec_keys))
        normals = self.create_symbols("normal_x normal_y normal_z")

        normal_dot_vec = 0
        for (normal, vec) in zip(normals, vec_vars):
            normal_dot_vec += normal * vec

        self.add_equation("normal_dot_vec", normal_dot_vec)

Poisson

Bases: PDE

Class for poisson equation.

\[ \nabla^2 \varphi = C \]

Parameters:

Name Type Description Default
dim int

Dimension of equation.

required
detach_keys Optional[Tuple[str, ...]]

Keys used for detach during computing. Defaults to None.

None

Examples:

>>> import ppsci
>>> pde = ppsci.equation.Poisson(2)
Source code in ppsci/equation/pde/poisson.py
class Poisson(base.PDE):
    r"""Class for poisson equation.

    $$
    \nabla^2 \varphi = C
    $$

    Args:
        dim (int): Dimension of equation.
        detach_keys (Optional[Tuple[str, ...]]): Keys used for detach during computing.
            Defaults to None.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.Poisson(2)
    """

    def __init__(self, dim: int, detach_keys: Optional[Tuple[str, ...]] = None):
        super().__init__()
        self.detach_keys = detach_keys
        invars = self.create_symbols("x y z")[:dim]
        p = self.create_function("p", invars)
        self.dim = dim

        poisson = 0
        for invar in invars:
            poisson += p.diff(invar, 2)

        self.add_equation("poisson", poisson)

Vibration

Bases: PDE

Vortex induced vibration equation.

\[ \rho \dfrac{\partial^2 \eta}{\partial t^2} + e^{k1} \dfrac{\partial \eta}{\partial t} + e^{k2} \eta = f \]

Parameters:

Name Type Description Default
rho float

Generalized mass.

required
k1 float

Learnable parameter for modal damping.

required
k2 float

Learnable parameter for generalized stiffness.

required

Examples:

>>> import ppsci
>>> pde = ppsci.equation.Vibration(1.0, 4.0, -1.0)
Source code in ppsci/equation/pde/viv.py
class Vibration(base.PDE):
    r"""Vortex induced vibration equation.

    $$
    \rho \dfrac{\partial^2 \eta}{\partial t^2} + e^{k1} \dfrac{\partial \eta}{\partial t} + e^{k2} \eta = f
    $$

    Args:
        rho (float): Generalized mass.
        k1 (float): Learnable parameter for modal damping.
        k2 (float): Learnable parameter for generalized stiffness.

    Examples:
        >>> import ppsci
        >>> pde = ppsci.equation.Vibration(1.0, 4.0, -1.0)
    """

    def __init__(self, rho: float, k1: float, k2: float):
        super().__init__()
        self.rho = rho
        self.k1 = paddle.create_parameter(
            shape=[],
            dtype=paddle.get_default_dtype(),
            name="k1",
            default_initializer=initializer.Constant(k1),
        )
        self.k2 = paddle.create_parameter(
            shape=[],
            dtype=paddle.get_default_dtype(),
            name="k2",
            default_initializer=initializer.Constant(k2),
        )
        self.learnable_parameters.append(self.k1)
        self.learnable_parameters.append(self.k2)

        t_f = self.create_symbols("t_f")
        eta = self.create_function("eta", (t_f,))
        k1 = self.create_symbols(self.k1.name)
        k2 = self.create_symbols(self.k2.name)
        f = self.rho * eta.diff(t_f, 2) + sp.exp(k1) * eta.diff(t_f) + sp.exp(k2) * eta
        self.add_equation("f", f)

Volterra

Bases: PDE

A second kind of volterra integral equation with Gaussian quadrature algorithm.

\[ x(t) - f(t)=\int_a^t K(t, s) x(s) d s \]

Volterra integral equation

Gaussian quadrature

Parameters:

Name Type Description Default
bound float

Lower bound a for Volterra integral equation.

required
num_points int

Sampled points in integral interval.

required
quad_deg int

Number of quadrature.

required
kernel_func Callable

Kernel func K(t,s).

required
func Callable

x(t) - f(t) in Volterra integral equation.

required

Examples:

>>> import ppsci
>>> import numpy as np
>>> vol_eq = ppsci.equation.Volterra(
...     0, 12, 20, lambda t, s: np.exp(s - t), lambda out: out["u"],
... )
Source code in ppsci/equation/ide/volterra.py
class Volterra(PDE):
    r"""A second kind of volterra integral equation with Gaussian quadrature algorithm.

    $$
    x(t) - f(t)=\int_a^t K(t, s) x(s) d s
    $$

    [Volterra integral equation](https://en.wikipedia.org/wiki/Volterra_integral_equation)

    [Gaussian quadrature](https://en.wikipedia.org/wiki/Gaussian_quadrature#Change_of_interval)

    Args:
        bound (float): Lower bound `a` for Volterra integral equation.
        num_points (int): Sampled points in integral interval.
        quad_deg (int): Number of quadrature.
        kernel_func (Callable): Kernel func `K(t,s)`.
        func (Callable): `x(t) - f(t)` in Volterra integral equation.

    Examples:
        >>> import ppsci
        >>> import numpy as np
        >>> vol_eq = ppsci.equation.Volterra(
        ...     0, 12, 20, lambda t, s: np.exp(s - t), lambda out: out["u"],
        ... )
    """

    dtype = paddle.get_default_dtype()

    def __init__(
        self,
        bound: float,
        num_points: int,
        quad_deg: int,
        kernel_func: Callable,
        func: Callable,
    ):
        super().__init__()
        self.bound = bound
        self.num_points = num_points
        self.quad_deg = quad_deg
        self.kernel_func = kernel_func
        self.func = func

        self.quad_x, self.quad_w = np.polynomial.legendre.leggauss(quad_deg)
        self.quad_x = self.quad_x.astype(Volterra.dtype).reshape([-1, 1])  # [Q, 1]
        self.quad_x = paddle.to_tensor(self.quad_x)  # [Q, 1]

        self.quad_w = self.quad_w.astype(Volterra.dtype)  # [Q, ]

        def compute_volterra_func(out):
            x, u = out["x"], out["u"]
            lhs = self.func(out)

            int_mat = paddle.to_tensor(self._get_int_matrix(x), stop_gradient=False)
            rhs = paddle.mm(int_mat, u)  # (N, 1)

            volterra = lhs[: len(rhs)] - rhs
            return volterra

        self.add_equation("volterra", compute_volterra_func)

    def get_quad_points(self, t: paddle.Tensor) -> paddle.Tensor:
        """Scale and transform quad_x from [-1, 1] to range [a, b].
        reference: https://en.wikipedia.org/wiki/Gaussian_quadrature#Change_of_interval

        Args:
            t (paddle.Tensor): Tensor array of upper bounds 't' for integral.

        Returns:
            paddle.Tensor: Transformed points in desired range with shape of [N, Q].
        """
        a, b = self.bound, t
        return ((b - a) / 2) @ self.quad_x.T + (b + a) / 2

    def _get_quad_weights(self, t: float) -> np.ndarray:
        """Scale weights to range according to given t and lower bound of integral.
        reference: https://en.wikipedia.org/wiki/Gaussian_quadrature#Change_of_interval

        Args:
            t (float): Array of upper bound 't' for integral.

        Returns:
            np.ndarray: Transformed weights in desired range with shape of [Q, ].
        """
        a, b = self.bound, t
        return (b - a) / 2 * self.quad_w

    def _get_int_matrix(self, x: np.ndarray) -> np.ndarray:
        int_mat = np.zeros(
            (self.num_points, self.num_points + (self.num_points * self.quad_deg)),
            dtype=Volterra.dtype,
        )
        for i in range(self.num_points):
            xi = float(x[i])
            beg = self.num_points + self.quad_deg * i
            end = self.num_points + self.quad_deg * (i + 1)
            K = np.ravel(
                self.kernel_func(np.full((self.quad_deg, 1), xi), x[beg:end].numpy())
            )
            int_mat[i, beg:end] = self._get_quad_weights(xi) * K
        return int_mat
get_quad_points(t)

Scale and transform quad_x from [-1, 1] to range [a, b]. reference: https://en.wikipedia.org/wiki/Gaussian_quadrature#Change_of_interval

Parameters:

Name Type Description Default
t Tensor

Tensor array of upper bounds 't' for integral.

required

Returns:

Type Description
Tensor

paddle.Tensor: Transformed points in desired range with shape of [N, Q].

Source code in ppsci/equation/ide/volterra.py
def get_quad_points(self, t: paddle.Tensor) -> paddle.Tensor:
    """Scale and transform quad_x from [-1, 1] to range [a, b].
    reference: https://en.wikipedia.org/wiki/Gaussian_quadrature#Change_of_interval

    Args:
        t (paddle.Tensor): Tensor array of upper bounds 't' for integral.

    Returns:
        paddle.Tensor: Transformed points in desired range with shape of [N, Q].
    """
    a, b = self.bound, t
    return ((b - a) / 2) @ self.quad_x.T + (b + a) / 2

最后更新: November 17, 2023
创建日期: November 6, 2023