跳转至

Metric(评价指标) 模块

ppsci.metric

Metric

Bases: Layer

Base class for metric.

Source code in ppsci/metric/base.py
class Metric(nn.Layer):
    """Base class for metric."""

    def __init__(self, keep_batch: bool = False):
        super().__init__()
        self.keep_batch = keep_batch

FunctionalMetric

Bases: Metric

Functional metric class, which allows to use custom metric computing function from given metric_expr for complex computation cases.

Parameters:

Name Type Description Default
metric_expr Callable

expression of metric calculation.

required
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> import paddle
>>> def metric_expr(output_dict, *args):
...     rel_l2 = 0
...     for key in output_dict:
...         length = int(len(output_dict[key])/2)
...         out_dict = {key: output_dict[key][:length]}
...         label_dict = {key: output_dict[key][length:]}
...         rel_l2 += paddle.norm(out_dict - label_dict) / paddle.norm(label_dict)
...     return {"l2": rel_l2}
>>> metric_dict = ppsci.metric.FunctionalMetric(metric_expr)
Source code in ppsci/metric/func.py
class FunctionalMetric(base.Metric):
    r"""Functional metric class, which allows to use custom metric computing function from given metric_expr for complex computation cases.

    Args:
        metric_expr (Callable): expression of metric calculation.
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> import paddle
        >>> def metric_expr(output_dict, *args):
        ...     rel_l2 = 0
        ...     for key in output_dict:
        ...         length = int(len(output_dict[key])/2)
        ...         out_dict = {key: output_dict[key][:length]}
        ...         label_dict = {key: output_dict[key][length:]}
        ...         rel_l2 += paddle.norm(out_dict - label_dict) / paddle.norm(label_dict)
        ...     return {"l2": rel_l2}
        >>> metric_dict = ppsci.metric.FunctionalMetric(metric_expr)
    """

    def __init__(
        self,
        metric_expr: Callable,
        keep_batch: bool = False,
    ):
        super().__init__(keep_batch)
        self.metric_expr = metric_expr

    def forward(self, output_dict, label_dict=None):
        return self.metric_expr(output_dict, label_dict)

MAE

Bases: Metric

Mean absolute error.

\[ metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_1 \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> metric = ppsci.metric.MAE()
Source code in ppsci/metric/mae.py
class MAE(base.Metric):
    r"""Mean absolute error.

    $$
    metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_1
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> metric = ppsci.metric.MAE()
    """

    def __init__(self, keep_batch: bool = False):
        super().__init__(keep_batch)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            mae = F.l1_loss(output_dict[key], label_dict[key], "none")
            if self.keep_batch:
                metric_dict[key] = mae.mean(axis=tuple(range(1, mae.ndim)))
            else:
                metric_dict[key] = mae.mean()

        return metric_dict

MSE

Bases: Metric

Mean square error

\[ metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2 \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> metric = ppsci.metric.MSE()
Source code in ppsci/metric/mse.py
class MSE(base.Metric):
    r"""Mean square error

    $$
    metric = \dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> metric = ppsci.metric.MSE()
    """

    def __init__(self, keep_batch: bool = False):
        super().__init__(keep_batch)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            mse = F.mse_loss(output_dict[key], label_dict[key], "none")
            if self.keep_batch:
                metric_dict[key] = mse.mean(axis=tuple(range(1, mse.ndim)))
            else:
                metric_dict[key] = mse.mean()

        return metric_dict

RMSE

Bases: Metric

Root mean square error

\[ metric = \sqrt{\dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> metric = ppsci.metric.RMSE()
Source code in ppsci/metric/rmse.py
class RMSE(base.Metric):
    r"""Root mean square error

    $$
    metric = \sqrt{\dfrac{1}{N} \Vert \mathbf{x} - \mathbf{y} \Vert_2^2}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> metric = ppsci.metric.RMSE()
    """

    def __init__(self, keep_batch: bool = False):
        if keep_batch:
            raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
        super().__init__(keep_batch)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            rmse = F.mse_loss(output_dict[key], label_dict[key], "mean") ** 0.5
            metric_dict[key] = rmse

        return metric_dict

L2Rel

Bases: Metric

Class for l2 relative error.

NOTE: This metric API is slightly different from MeanL2Rel, difference is as below:

  • L2Rel regards the input sample as a whole and calculates the l2 relative error of the whole;
  • MeanL2Rel will calculate L2Rel separately for each input sample and return the average of l2 relative error for all samples.
\[ metric = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\max(\Vert \mathbf{y} \Vert_2, \epsilon)} \]
\[ \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> metric = ppsci.metric.L2Rel()
Source code in ppsci/metric/l2_rel.py
class L2Rel(base.Metric):
    r"""Class for l2 relative error.

    NOTE: This metric API is slightly different from `MeanL2Rel`, difference is as below:

    - `L2Rel` regards the input sample as a whole and calculates the l2 relative error of the whole;
    - `MeanL2Rel` will calculate L2Rel separately for each input sample and return the average of l2 relative error for all samples.

    $$
    metric = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\max(\Vert \mathbf{y} \Vert_2, \epsilon)}
    $$

    $$
    \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}
    $$

    Args:
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> metric = ppsci.metric.L2Rel()
    """

    # NOTE: Avoid divide by zero in result
    # see https://github.com/scikit-learn/scikit-learn/pull/15007
    EPS: float = np.finfo(np.float32).eps

    def __init__(self, keep_batch: bool = False):
        if keep_batch:
            raise ValueError(f"keep_batch should be False, but got {keep_batch}.")
        super().__init__(keep_batch)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            rel_l2 = paddle.norm(label_dict[key] - output_dict[key], p=2) / paddle.norm(
                label_dict[key], p=2
            ).clip(min=self.EPS)
            metric_dict[key] = rel_l2

        return metric_dict

MeanL2Rel

Bases: Metric

Class for mean l2 relative error.

NOTE: This metric API is slightly different from L2Rel, difference is as below:

  • MeanL2Rel will calculate L2Rel separately for each input sample and return the average of l2 relative error for all samples.
  • L2Rel regards the input sample as a whole and calculates the l2 relative error of the whole;
\[ metric = \dfrac{1}{M} \sum_{i=1}^{M}\dfrac{\Vert \mathbf{x_i} - \mathbf{y_i} \Vert_2}{\max(\Vert \mathbf{y_i} \Vert_2, \epsilon) } \]
\[ \mathbf{x_i}, \mathbf{y_i} \in \mathcal{R}^{N} \]

Parameters:

Name Type Description Default
keep_batch bool

Whether keep batch axis. Defaults to False.

False

Examples:

>>> import ppsci
>>> metric = ppsci.metric.MeanL2Rel()
Source code in ppsci/metric/l2_rel.py
class MeanL2Rel(base.Metric):
    r"""Class for mean l2 relative error.

    NOTE: This metric API is slightly different from `L2Rel`, difference is as below:

    - `MeanL2Rel` will calculate L2Rel separately for each input sample and return the average of l2 relative error for all samples.
    - `L2Rel` regards the input sample as a whole and calculates the l2 relative error of the whole;

    $$
    metric = \dfrac{1}{M} \sum_{i=1}^{M}\dfrac{\Vert \mathbf{x_i} - \mathbf{y_i} \Vert_2}{\max(\Vert \mathbf{y_i} \Vert_2, \epsilon) }
    $$

    $$
    \mathbf{x_i}, \mathbf{y_i} \in \mathcal{R}^{N}
    $$

    Args:
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.

    Examples:
        >>> import ppsci
        >>> metric = ppsci.metric.MeanL2Rel()
    """

    # NOTE: Avoid divide by zero in result
    # see https://github.com/scikit-learn/scikit-learn/pull/15007
    EPS: float = np.finfo(np.float32).eps

    def __init__(self, keep_batch: bool = False):
        super().__init__(keep_batch)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            rel_l2 = paddle.norm(
                label_dict[key] - output_dict[key], p=2, axis=1
            ) / paddle.norm(label_dict[key], p=2, axis=1).clip(min=self.EPS)
            if self.keep_batch:
                metric_dict[key] = rel_l2
            else:
                metric_dict[key] = rel_l2.mean()

        return metric_dict

LatitudeWeightedACC

Bases: Metric

Latitude weighted anomaly correlation coefficient.

\[ metric = \dfrac{\sum\limits_{m,n}{L_mX_{mn}Y_{mn}}}{\sqrt{\sum\limits_{m,n}{L_mX_{mn}^{2}}\sum\limits_{m,n}{L_mY_{mn}^{2}}}} \]
\[ L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} \]

\(lat_m\) is the latitude at m. \(N_{lat}\) is the number of latitude set by num_lat.

Parameters:

Name Type Description Default
num_lat int

Number of latitude.

required
mean Optional[Union[array, Tuple[float, ...]]]

Mean of training data. Defaults to None.

required
keep_batch bool

Whether keep batch axis. Defaults to False.

False
variable_dict Optional[Dict[str, int]]

Variable dictionary, the key is the name of a variable and the value is its index. Defaults to None.

None
unlog bool

whether calculate expm1 for all elements in the array. Defaults to False.

False
scale float

The scale value used after expm1. Defaults to 1e-5.

1e-05

Examples:

>>> import numpy as np
>>> import ppsci
>>> mean = np.random.randn(20, 720, 1440)
>>> metric = ppsci.metric.LatitudeWeightedACC(720, mean=mean)
Source code in ppsci/metric/anomaly_coef.py
class LatitudeWeightedACC(base.Metric):
    r"""Latitude weighted anomaly correlation coefficient.

    $$
    metric =
        \dfrac{\sum\limits_{m,n}{L_mX_{mn}Y_{mn}}}{\sqrt{\sum\limits_{m,n}{L_mX_{mn}^{2}}\sum\limits_{m,n}{L_mY_{mn}^{2}}}}
    $$

    $$
    L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)}
    $$

    $lat_m$ is the latitude at m.
    $N_{lat}$ is the number of latitude set by `num_lat`.

    Args:
        num_lat (int): Number of latitude.
        mean (Optional[Union[np.array, Tuple[float, ...]]]): Mean of training data. Defaults to None.
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
        variable_dict (Optional[Dict[str, int]]): Variable dictionary, the key is the name of a variable and
            the value is its index. Defaults to None.
        unlog (bool, optional): whether calculate expm1 for all elements in the array. Defaults to False.
        scale (float, optional): The scale value used after expm1. Defaults to 1e-5.

    Examples:
        >>> import numpy as np
        >>> import ppsci
        >>> mean = np.random.randn(20, 720, 1440)
        >>> metric = ppsci.metric.LatitudeWeightedACC(720, mean=mean)
    """

    def __init__(
        self,
        num_lat: int,
        mean: Optional[Union[np.array, Tuple[float, ...]]],
        keep_batch: bool = False,
        variable_dict: Optional[Dict[str, int]] = None,
        unlog: bool = False,
        scale: float = 1e-5,
    ):
        super().__init__(keep_batch)
        self.num_lat = num_lat
        self.mean = (
            None if mean is None else paddle.to_tensor(mean, paddle.get_default_dtype())
        )
        self.variable_dict = variable_dict
        self.unlog = unlog
        self.scale = scale

        self.weight = self.get_latitude_weight(num_lat)

    def get_latitude_weight(self, num_lat: int = 720):
        lat_t = paddle.linspace(start=0, stop=1, num=num_lat)
        lat_t = paddle.cos(3.1416 * (0.5 - lat_t))
        weight = num_lat * lat_t / paddle.sum(lat_t)
        weight = weight.reshape((1, 1, -1, 1))
        return weight

    def scale_expm1(self, x: paddle.Tensor):
        return self.scale * paddle.expm1(x)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            output = (
                self.scale_expm1(output_dict[key]) if self.unlog else output_dict[key]
            )
            label = self.scale_expm1(label_dict[key]) if self.unlog else label_dict[key]

            if self.mean is not None:
                output = output - self.mean
                label = label - self.mean

            rmse = paddle.sum(
                self.weight * output * label, axis=(-1, -2)
            ) / paddle.sqrt(
                paddle.sum(self.weight * output**2, axis=(-1, -2))
                * paddle.sum(self.weight * label**2, axis=(-1, -2))
            )

            if self.variable_dict is not None:
                for variable_name, idx in self.variable_dict.items():
                    if self.keep_batch:
                        metric_dict[f"{key}.{variable_name}"] = rmse[:, idx]
                    else:
                        metric_dict[f"{key}.{variable_name}"] = rmse[:, idx].mean()
            else:
                if self.keep_batch:
                    metric_dict[key] = rmse.mean(axis=1)
                else:
                    metric_dict[key] = rmse.mean()
        return metric_dict

LatitudeWeightedRMSE

Bases: Metric

Latitude weighted root mean square error.

\[ metric =\sqrt{\dfrac{1}{MN}\sum\limits_{m=1}^{M}\sum\limits_{n=1}^{N}L_m(X_{mn}-Y_{mn})^{2}} \]
\[ L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)} \]

\(lat_m\) is the latitude at m. \(N_{lat}\) is the number of latitude set by num_lat.

Parameters:

Name Type Description Default
num_lat int

Number of latitude.

required
std Optional[Union[array, Tuple[float, ...]]]

Standard Deviation of training dataset. Defaults to None.

None
keep_batch bool

Whether keep batch axis. Defaults to False.

False
variable_dict Optional[Dict[str, int]]

Variable dictionary, the key is the name of a variable and the value is its index. Defaults to None.

None
unlog bool

whether calculate expm1 for all elements in the array. Defaults to False.

False
scale float

The scale value used after expm1. Defaults to 1e-5.

1e-05

Examples:

>>> import numpy as np
>>> import ppsci
>>> std = np.random.randn(20, 1, 1)
>>> metric = ppsci.metric.LatitudeWeightedRMSE(720, std=std)
Source code in ppsci/metric/rmse.py
class LatitudeWeightedRMSE(base.Metric):
    r"""Latitude weighted root mean square error.

    $$
    metric =\sqrt{\dfrac{1}{MN}\sum\limits_{m=1}^{M}\sum\limits_{n=1}^{N}L_m(X_{mn}-Y_{mn})^{2}}
    $$

    $$
    L_m = N_{lat}\dfrac{\cos(lat_m)}{\sum\limits_{j=1}^{N_{lat}}\cos(lat_j)}
    $$

    $lat_m$ is the latitude at m.
    $N_{lat}$ is the number of latitude set by `num_lat`.

    Args:
        num_lat (int): Number of latitude.
        std (Optional[Union[np.array, Tuple[float, ...]]]): Standard Deviation of training dataset. Defaults to None.
        keep_batch (bool, optional): Whether keep batch axis. Defaults to False.
        variable_dict (Optional[Dict[str, int]]): Variable dictionary, the key is the name of a variable and
            the value is its index. Defaults to None.
        unlog (bool, optional): whether calculate expm1 for all elements in the array. Defaults to False.
        scale (float, optional): The scale value used after expm1. Defaults to 1e-5.

    Examples:
        >>> import numpy as np
        >>> import ppsci
        >>> std = np.random.randn(20, 1, 1)
        >>> metric = ppsci.metric.LatitudeWeightedRMSE(720, std=std)
    """

    def __init__(
        self,
        num_lat: int,
        std: Optional[Union[np.array, Tuple[float, ...]]] = None,
        keep_batch: bool = False,
        variable_dict: Dict[str, int] = None,
        unlog: bool = False,
        scale: float = 1e-5,
    ):
        super().__init__(keep_batch)
        self.num_lat = num_lat
        self.std = (
            None
            if std is None
            else paddle.to_tensor(std, paddle.get_default_dtype()).reshape((1, -1))
        )
        self.variable_dict = variable_dict
        self.unlog = unlog
        self.scale = scale
        self.weight = self.get_latitude_weight(num_lat)

    def get_latitude_weight(self, num_lat: int = 720):
        lat_t = paddle.linspace(start=0, stop=1, num=num_lat)
        lat_t = paddle.cos(3.1416 * (0.5 - lat_t))
        weight = num_lat * lat_t / paddle.sum(lat_t)
        weight = weight.reshape((1, 1, -1, 1))
        return weight

    def scale_expm1(self, x: paddle.Tensor):
        return self.scale * paddle.expm1(x)

    @paddle.no_grad()
    def forward(self, output_dict, label_dict):
        metric_dict = {}
        for key in label_dict:
            output = (
                self.scale_expm1(output_dict[key]) if self.unlog else output_dict[key]
            )
            label = self.scale_expm1(label_dict[key]) if self.unlog else label_dict[key]

            mse = F.mse_loss(output, label, "none")
            rmse = (mse * self.weight).mean(axis=(-1, -2)) ** 0.5
            if self.std is not None:
                rmse = rmse * self.std
            if self.variable_dict is not None:
                for variable_name, idx in self.variable_dict.items():
                    metric_dict[f"{key}.{variable_name}"] = (
                        rmse[:, idx] if self.keep_batch else rmse[:, idx].mean()
                    )
            else:
                metric_dict[key] = rmse.mean(axis=1) if self.keep_batch else rmse.mean()

        return metric_dict

最后更新: November 17, 2023
创建日期: November 6, 2023