跳转至

Loss(损失函数) 模块

ppsci.loss

Loss

Bases: Layer

Base class for loss.

Source code in ppsci/loss/base.py
class Loss(nn.Layer):
    """Base class for loss."""

    def __init__(
        self,
        reduction: Literal["mean", "sum"],
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        super().__init__()
        self.reduction = reduction
        self.weight = weight

    def __str__(self):
        return f"{self.__class__.__name__}(reduction={self.reduction}, weight={self.weight})"

FunctionalLoss

Bases: Loss

Functional loss class, which allows to use custom loss computing function from given loss_expr for complex computation cases.

Parameters:

Name Type Description Default
loss_expr Callable

expression of loss calculation.

required
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> import paddle.nn.functional as F
>>> def loss_expr(output_dict, *args):
...     losses = 0
...     for key in output_dict:
...         length = int(len(output_dict[key])/2)
...         out_dict = {key: output_dict[key][:length]}
...         label_dict = {key: output_dict[key][length:]}
...         losses += F.mse_loss(out_dict, label_dict, "sum")
...     return losses
>>> loss = ppsci.loss.FunctionalLoss(loss_expr)
Source code in ppsci/loss/func.py
class FunctionalLoss(base.Loss):
    r"""Functional loss class, which allows to use custom loss computing function from given loss_expr for complex computation cases.

    Args:
        loss_expr (Callable): expression of loss calculation.
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> import paddle.nn.functional as F
        >>> def loss_expr(output_dict, *args):
        ...     losses = 0
        ...     for key in output_dict:
        ...         length = int(len(output_dict[key])/2)
        ...         out_dict = {key: output_dict[key][:length]}
        ...         label_dict = {key: output_dict[key][length:]}
        ...         losses += F.mse_loss(out_dict, label_dict, "sum")
        ...     return losses
        >>> loss = ppsci.loss.FunctionalLoss(loss_expr)
    """

    def __init__(
        self,
        loss_expr: Callable,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)
        self.loss_expr = loss_expr

    def forward(self, output_dict, label_dict=None, weight_dict=None):
        return self.loss_expr(output_dict, label_dict, weight_dict)

L1Loss

Bases: Loss

Class for l1 loss.

\[ L = \Vert \mathbf{x} - \mathbf{y} \Vert_1 \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.L1Loss()
Source code in ppsci/loss/l1.py
class L1Loss(base.Loss):
    r"""Class for l1 loss.

    $$
    L = \Vert \mathbf{x} - \mathbf{y} \Vert_1
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.L1Loss()
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            loss = F.l1_loss(output_dict[key], label_dict[key], "none")
            if weight_dict:
                loss *= weight_dict[key]

            if "area" in output_dict:
                loss *= output_dict["area"]

            loss = loss.sum(axis=1)

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

L2Loss

Bases: Loss

Class for l2 loss.

\[ L =\Vert \mathbf{x} - \mathbf{y} \Vert_2 \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.L2Loss()
Source code in ppsci/loss/l2.py
class L2Loss(base.Loss):
    r"""Class for l2 loss.

    $$
    L =\Vert \mathbf{x} - \mathbf{y} \Vert_2
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.L2Loss()
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            loss = F.mse_loss(output_dict[key], label_dict[key], "none")
            if weight_dict:
                loss *= weight_dict[key]

            if "area" in output_dict:
                loss *= output_dict["area"]

            loss = loss.sum(axis=1).sqrt()

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

L2RelLoss

Bases: Loss

Class for l2 relative loss.

\[ L = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\Vert \mathbf{y} \Vert_2} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.L2RelLoss()
Source code in ppsci/loss/l2.py
class L2RelLoss(base.Loss):
    r"""Class for l2 relative loss.

    $$
    L = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\Vert \mathbf{y} \Vert_2}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.L2RelLoss()
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def rel_loss(self, x, y):
        batch_size = x.shape[0]
        x_ = x.reshape((batch_size, -1))
        y_ = y.reshape((batch_size, -1))
        diff_norms = paddle.norm(x_ - y_, p=2, axis=1)
        y_norms = paddle.norm(y_, p=2, axis=1)
        return diff_norms / y_norms

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0
        for key in label_dict:
            loss = self.rel_loss(output_dict[key], label_dict[key])
            if weight_dict:
                loss *= weight_dict[key]

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, float):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss

        return losses

MAELoss

Bases: Loss

Class for mean absolute error loss.

\[ L = \begin{cases} \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='mean'} \\ \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='sum'} \end{cases} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.MAELoss("mean")
Source code in ppsci/loss/mae.py
class MAELoss(base.Loss):
    r"""Class for mean absolute error loss.

    $$
    L =
    \begin{cases}
        \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='mean'} \\
        \Vert {\mathbf{x}-\mathbf{y}} \Vert_1, & \text{if reduction='sum'}
    \end{cases}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.MAELoss("mean")
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            loss = F.l1_loss(output_dict[key], label_dict[key], "none")
            if weight_dict:
                loss *= weight_dict[key]

            if "area" in output_dict:
                loss *= output_dict["area"]

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()
            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

MSELoss

Bases: Loss

Class for mean squared error loss.

\[ L = \begin{cases} \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='mean'} \\ \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='sum'} \end{cases} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.MSELoss("mean")
Source code in ppsci/loss/mse.py
class MSELoss(base.Loss):
    r"""Class for mean squared error loss.

    $$
    L =
    \begin{cases}
        \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='mean'} \\
        \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='sum'}
    \end{cases}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.MSELoss("mean")
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            loss = F.mse_loss(output_dict[key], label_dict[key], "none")
            if weight_dict:
                loss *= weight_dict[key]

            if "area" in output_dict:
                loss *= output_dict["area"]

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()
            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

MSELossWithL2Decay

Bases: MSELoss

MSELoss with L2 decay.

\[ L = \begin{cases} \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='mean'} \\ \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='sum'} \end{cases} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}, \mathbf{K_i} \in \mathcal{R}^{O_i \times P_i} \]

\(M\) is the number of which apply regularization on.

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".

'mean'
regularization_dict Optional[Dict[str, float]]

Regularization dictionary. Defaults to None.

None
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Raises:

Type Description
ValueError

reduction should be 'mean' or 'sum'.

Examples:

>>> import ppsci
>>> loss = ppsci.loss.MSELossWithL2Decay("mean", {"k_matrix": 2.0})
Source code in ppsci/loss/mse.py
class MSELossWithL2Decay(MSELoss):
    r"""MSELoss with L2 decay.

    $$
    L =
    \begin{cases}
        \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='mean'} \\
         \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='sum'}
    \end{cases}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}, \mathbf{K_i} \in \mathcal{R}^{O_i \times P_i}
    $$

    $M$ is the number of  which apply regularization on.

    Args:
        reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
        regularization_dict (Optional[Dict[str, float]]): Regularization dictionary. Defaults to None.
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Raises:
        ValueError: reduction should be 'mean' or 'sum'.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.MSELossWithL2Decay("mean", {"k_matrix": 2.0})
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        regularization_dict: Optional[Dict[str, float]] = None,
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)
        self.regularization_dict = regularization_dict

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = super().forward(output_dict, label_dict, weight_dict)

        if self.regularization_dict is not None:
            for reg_key, reg_weight in self.regularization_dict.items():
                loss = output_dict[reg_key].pow(2).sum()
                losses += loss * reg_weight
        return losses

IntegralLoss

Bases: Loss

Class for integral loss with Monte-Carlo integration algorithm.

\[ L = \begin{cases} \dfrac{1}{N} \Vert \displaystyle\sum_{i=1}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\ \Vert \displaystyle\sum_{i=0}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'} \end{cases} \]
\[ \mathbf{x}, \mathbf{s} \in \mathcal{R}^{M \times N}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.IntegralLoss("mean")
Source code in ppsci/loss/integral.py
class IntegralLoss(base.Loss):
    r"""Class for integral loss with Monte-Carlo integration algorithm.

    $$
    L =
    \begin{cases}
        \dfrac{1}{N} \Vert \displaystyle\sum_{i=1}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\
         \Vert \displaystyle\sum_{i=0}^{M}{\mathbf{s}_i \cdot \mathbf{x}_i} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'}
    \end{cases}
    $$

    $$
    \mathbf{x}, \mathbf{s} \in \mathcal{R}^{M \times N}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.IntegralLoss("mean")
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            loss = F.mse_loss(
                (output_dict[key] * output_dict["area"]).sum(axis=1),
                label_dict[key],
                "none",
            )
            if weight_dict:
                loss *= weight_dict[key]

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

PeriodicL1Loss

Bases: Loss

Class for periodic l1 loss.

\[ L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_1 \]

\(\mathbf{x_l} \in \mathcal{R}^{N}\) is the first half of batch output, \(\mathbf{x_r} \in \mathcal{R}^{N}\) is the second half of batch output.

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.PeriodicL1Loss("mean")
Source code in ppsci/loss/l1.py
class PeriodicL1Loss(base.Loss):
    r"""Class for periodic l1 loss.

    $$
    L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_1
    $$

    $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output,
    $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output.

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.PeriodicL1Loss("mean")
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            n_output = len(output_dict[key])
            if n_output % 2 > 0:
                raise ValueError(
                    f"Length of output({n_output}) of key({key}) should be even."
                )

            n_output //= 2
            loss = F.l1_loss(
                output_dict[key][:n_output], output_dict[key][n_output:], "none"
            )
            if weight_dict:
                loss *= weight_dict[key]
            if "area" in output_dict:
                loss *= output_dict["area"]

            loss = loss.sum(axis=1)

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

PeriodicL2Loss

Bases: Loss

Class for Periodic l2 loss.

\[ L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2 \]

\(\mathbf{x_l} \in \mathcal{R}^{N}\) is the first half of batch output, \(\mathbf{x_r} \in \mathcal{R}^{N}\) is the second half of batch output.

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None

Examples:

>>> import ppsci
>>> loss = ppsci.loss.PeriodicL2Loss()
Source code in ppsci/loss/l2.py
class PeriodicL2Loss(base.Loss):
    r"""Class for Periodic l2 loss.

    $$
    L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2
    $$

    $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output,
    $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output.

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.

    Examples:
        >>> import ppsci
        >>> loss = ppsci.loss.PeriodicL2Loss()
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            n_output = len(output_dict[key])
            if n_output % 2 > 0:
                raise ValueError(
                    f"Length of output({n_output}) of key({key}) should be even."
                )
            n_output //= 2

            loss = F.mse_loss(
                output_dict[key][:n_output], output_dict[key][n_output:], "none"
            )
            if weight_dict:
                loss *= weight_dict[key]

            if "area" in output_dict:
                loss *= output_dict["area"]

            loss = loss.sum(axis=1).sqrt()

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

PeriodicMSELoss

Bases: Loss

Class for periodic mean squared error loss.

\[ L = \begin{cases} \dfrac{1}{N} \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='mean'} \\ \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='sum'} \end{cases} \]

\(\mathbf{x_l} \in \mathcal{R}^{N}\) is the first half of batch output, \(\mathbf{x_r} \in \mathcal{R}^{N}\) is the second half of batch output.

Parameters:

Name Type Description Default
reduction Literal['mean', 'sum']

Reduction method. Defaults to "mean".

'mean'
weight Optional[Union[float, Dict[str, float]]]

Weight for loss. Defaults to None.

None
Source code in ppsci/loss/mse.py
class PeriodicMSELoss(base.Loss):
    r"""Class for periodic mean squared error loss.

    $$
    L =
    \begin{cases}
        \dfrac{1}{N} \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='mean'} \\
        \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='sum'}
    \end{cases}
    $$

    $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output,
    $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output.

    Args:
        reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
        weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
    """

    def __init__(
        self,
        reduction: Literal["mean", "sum"] = "mean",
        weight: Optional[Union[float, Dict[str, float]]] = None,
    ):
        if reduction not in ["mean", "sum"]:
            raise ValueError(
                f"reduction should be 'mean' or 'sum', but got {reduction}"
            )
        super().__init__(reduction, weight)

    def forward(self, output_dict, label_dict, weight_dict=None):
        losses = 0.0
        for key in label_dict:
            n_output = len(output_dict[key])
            if n_output % 2 > 0:
                raise ValueError(
                    f"Length of output({n_output}) of key({key}) should be even."
                )

            n_output //= 2
            loss = F.mse_loss(
                output_dict[key][:n_output], output_dict[key][n_output:], "none"
            )
            if weight_dict:
                loss *= weight_dict[key]
            if "area" in output_dict:
                loss *= output_dict["area"]

            if self.reduction == "sum":
                loss = loss.sum()
            elif self.reduction == "mean":
                loss = loss.mean()

            if isinstance(self.weight, (float, int)):
                loss *= self.weight
            elif isinstance(self.weight, dict) and key in self.weight:
                loss *= self.weight[key]

            losses += loss
        return losses

ppsci.loss.mtl

LossAggregator

Base class of loss aggregator mainly for multitask learning.

Parameters:

Name Type Description Default
model Layer

Training model.

required
Source code in ppsci/loss/mtl/base.py
class LossAggregator:
    """Base class of loss aggregator mainly for multitask learning.

    Args:
        model (nn.Layer): Training model.
    """

    def __init__(self, model: nn.Layer) -> None:
        self.model = model
        self.step = 0
        self.param_num = 0
        for param in self.model.parameters():
            if not param.stop_gradient:
                self.param_num += 1

    def __call__(self, losses, step: int = 0):
        self.losses = losses
        self.loss_num = len(losses)
        self.step = step
        return self

    def backward(self) -> None:
        raise NotImplementedError(
            f"'backward' should be implemented in subclass {self.__class__.__name__}"
        )

PCGrad

Bases: LossAggregator

Projecting Conflicting Gradients

Gradient Surgery for Multi-Task Learning

[https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py](\

https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py)

Parameters:

Name Type Description Default
model Layer

Training model.

required

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.PCGrad(model)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     loss1 = paddle.sum(y1)
...     loss2 = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator([loss1, loss2]).backward()
Source code in ppsci/loss/mtl/pcgrad.py
class PCGrad(LossAggregator):
    r"""
    **P**rojecting **C**onflicting Gradients

    [Gradient Surgery for Multi-Task Learning](https://papers.nips.cc/paper/2020/hash/3fe78a8acf5fda99de95303940a2420c-Abstract.html)

    Code reference: [https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py](\
        https://github.com/tianheyu927/PCGrad/blob/master/PCGrad_tf.py)

    Args:
        model (nn.Layer): Training model.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.PCGrad(model)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     loss1 = paddle.sum(y1)
        ...     loss2 = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator([loss1, loss2]).backward()
    """

    def __init__(self, model: nn.Layer) -> None:
        super().__init__(model)
        self._zero = paddle.zeros([])

    def backward(self) -> None:
        np.random.shuffle(self.losses)
        grads_list = self._compute_grads()
        with paddle.no_grad():
            refined_grads = self._refine_grads(grads_list)
            self._set_grads(refined_grads)

    def _compute_grads(self) -> List[paddle.Tensor]:
        # compute all gradients derived by each loss
        grads_list = []  # num_params x num_losses
        for loss in self.losses:
            # backward with current loss
            loss.backward()
            grads_list.append(
                paddle.concat(
                    [
                        param.grad.clone().reshape([-1])
                        for param in self.model.parameters()
                        if param.grad is not None
                    ],
                    axis=0,
                )
            )
            # clear gradients for current loss for not affecting other loss
            self.model.clear_gradients()

        return grads_list

    def _refine_grads(self, grads_list: List[paddle.Tensor]) -> List[paddle.Tensor]:
        def proj_grad(grad: paddle.Tensor):
            for k in range(self.loss_num):
                inner_product = paddle.sum(grad * grads_list[k])
                proj_direction = inner_product / paddle.sum(
                    grads_list[k] * grads_list[k]
                )
                grad = grad - paddle.minimum(proj_direction, self._zero) * grads_list[k]
            return grad

        grads_list = [proj_grad(grad) for grad in grads_list]

        # Unpack flattened projected gradients back to their original shapes.
        proj_grads: List[paddle.Tensor] = []
        for j in range(self.loss_num):
            start_idx = 0
            for idx, var in enumerate(self.model.parameters()):
                grad_shape = var.shape
                flatten_dim = var.numel()
                refined_grad = grads_list[j][start_idx : start_idx + flatten_dim]
                refined_grad = paddle.reshape(refined_grad, grad_shape)
                if len(proj_grads) < self.param_num:
                    proj_grads.append(refined_grad)
                else:
                    proj_grads[idx] += refined_grad
                start_idx += flatten_dim
        return proj_grads

    def _set_grads(self, grads_list: List[paddle.Tensor]) -> None:
        for i, param in enumerate(self.model.parameters()):
            param.grad = grads_list[i]

AGDA

Bases: LossAggregator

Adaptive Gradient Descent Algorithm

Physics-informed neural network based on a new adaptive gradient descent algorithm for solving partial differential equations of flow problems

Parameters:

Name Type Description Default
model Layer

Training model.

required
M int

Smoothing period. Defaults to 100.

100
gamma float

Smooth factor. Defaults to 0.999.

0.999

Examples:

>>> import paddle
>>> from ppsci.loss import mtl
>>> model = paddle.nn.Linear(3, 4)
>>> loss_aggregator = mtl.AGDA(model)
>>> for i in range(5):
...     x1 = paddle.randn([8, 3])
...     x2 = paddle.randn([8, 3])
...     y1 = model(x1)
...     y2 = model(x2)
...     loss1 = paddle.sum(y1)
...     loss2 = paddle.sum((y2 - 2) ** 2)
...     loss_aggregator([loss1, loss2]).backward()
Source code in ppsci/loss/mtl/agda.py
class AGDA(LossAggregator):
    r"""
    **A**daptive **G**radient **D**escent **A**lgorithm

    [Physics-informed neural network based on a new adaptive gradient descent algorithm for solving partial differential equations of flow problems](\
        https://pubs.aip.org/aip/pof/article-abstract/35/6/063608/2899773/Physics-informed-neural-network-based-on-a-new)

    Args:
        model (nn.Layer): Training model.
        M (int, optional): Smoothing period. Defaults to 100.
        gamma (float, optional): Smooth factor. Defaults to 0.999.

    Examples:
        >>> import paddle
        >>> from ppsci.loss import mtl
        >>> model = paddle.nn.Linear(3, 4)
        >>> loss_aggregator = mtl.AGDA(model)
        >>> for i in range(5):
        ...     x1 = paddle.randn([8, 3])
        ...     x2 = paddle.randn([8, 3])
        ...     y1 = model(x1)
        ...     y2 = model(x2)
        ...     loss1 = paddle.sum(y1)
        ...     loss2 = paddle.sum((y2 - 2) ** 2)
        ...     loss_aggregator([loss1, loss2]).backward()
    """

    def __init__(self, model: nn.Layer, M: int = 100, gamma: float = 0.999) -> None:
        super().__init__(model)
        self.M = M
        self.gamma = gamma
        self.Lf_smooth = 0
        self.Lu_smooth = 0
        self.Lf_tilde_acc = 0.0
        self.Lu_tilde_acc = 0.0

    def __call__(self, losses, step: int = 0):
        if len(losses) != 2:
            raise ValueError(
                f"Number of losses(tasks) for AGDA shoule be 2, but got {len(losses)}"
            )
        return super().__call__(losses, step)

    def backward(self) -> None:
        grads_list = self._compute_grads()
        with paddle.no_grad():
            refined_grads = self._refine_grads(grads_list)
            self._set_grads(refined_grads)

    def _compute_grads(self) -> List[paddle.Tensor]:
        # compute all gradients derived by each loss
        grads_list = []  # num_params x num_losses
        for loss in self.losses:
            # backward with current loss
            loss.backward()
            grads_list.append(
                paddle.concat(
                    [
                        param.grad.clone().reshape([-1])
                        for param in self.model.parameters()
                        if param.grad is not None
                    ],
                    axis=0,
                )
            )
            # clear gradients for current loss for not affecting other loss
            self.model.clear_gradients()

        return grads_list

    def _refine_grads(self, grads_list: List[paddle.Tensor]) -> List[paddle.Tensor]:
        # compute moving average of L^smooth_i(n) - eq.(16)
        self.Lf_smooth = (
            self.gamma * self.Lf_smooth + (1 - self.gamma) * self.losses[0].item()
        )
        self.Lu_smooth = (
            self.gamma * self.Lu_smooth + (1 - self.gamma) * self.losses[1].item()
        )

        # compute L^smooth_i(kM) - eq.(17)
        if self.step % self.M == 0:
            Lf_smooth_kM = self.Lf_smooth
            Lu_smooth_kM = self.Lu_smooth
        Lf_tilde = self.Lf_smooth / Lf_smooth_kM
        Lu_tilde = self.Lu_smooth / Lu_smooth_kM

        # compute r_i(n) - eq.(18)
        self.Lf_tilde_acc += Lf_tilde
        self.Lu_tilde_acc += Lu_tilde
        rf = Lf_tilde / self.Lf_tilde_acc
        ru = Lu_tilde / self.Lu_tilde_acc

        # compute E(g(n)) - step1(1)
        gf_magn = (grads_list[0] * grads_list[0]).sum().sqrt()
        gu_magn = (grads_list[1] * grads_list[1]).sum().sqrt()
        Eg = (gf_magn + gu_magn) / 2

        # compute \omega_f(n) - step1(2)
        omega_f = (rf * (Eg - gf_magn) + gf_magn) / gf_magn
        omega_u = (ru * (Eg - gu_magn) + gu_magn) / gu_magn

        # compute g_bar(n) - step1(3)
        gf_bar = omega_f * grads_list[0]
        gu_bar = omega_u * grads_list[1]

        # compute gradient projection - step2(1)
        dot_product = (gf_bar * gu_bar).sum()
        if dot_product < 0:
            gu_bar = gu_bar - (dot_product / (gf_bar * gf_bar).sum()) * gf_bar
        grads_list = [gf_bar, gu_bar]

        proj_grads: List[paddle.Tensor] = []
        for j in range(len(self.losses)):
            start_idx = 0
            for idx, var in enumerate(self.model.parameters()):
                grad_shape = var.shape
                flatten_dim = var.numel()
                refined_grad = grads_list[j][start_idx : start_idx + flatten_dim]
                refined_grad = paddle.reshape(refined_grad, grad_shape)
                if len(proj_grads) < self.param_num:
                    proj_grads.append(refined_grad)
                else:
                    proj_grads[idx] += refined_grad
                start_idx += flatten_dim
        return proj_grads

    def _set_grads(self, grads_list: List[paddle.Tensor]) -> None:
        for i, param in enumerate(self.model.parameters()):
            param.grad = grads_list[i]

最后更新: November 17, 2023
创建日期: November 6, 2023