跳转至

Utils.initializer(初始化) 模块

ppsci.utils.initializer

The initialization method under this module is aligned with pytorch initialization. If you need to use the initialization method of PaddlePaddle, please refer to paddle.nn.initializer

This code is based on torch.nn.init Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file.

uniform_(tensor, a, b)

Modify tensor inplace using uniform_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
a float

min value.

required
b float

max value.

required

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.uniform_(param, -1, 1)
Source code in ppsci/utils/initializer.py
def uniform_(tensor: paddle.Tensor, a: float, b: float) -> paddle.Tensor:
    """Modify tensor inplace using uniform_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        a (float): min value.
        b (float): max value.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.uniform_(param, -1, 1)
    """
    return _no_grad_uniform_(tensor, a, b)

normal_(tensor, mean=0.0, std=1.0)

Modify tensor inplace using normal_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
mean float

mean value. Defaults to 0.0.

0.0
std float

std value. Defaults to 1.0.

1.0

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.normal_(param, 0, 1)
Source code in ppsci/utils/initializer.py
def normal_(
    tensor: paddle.Tensor, mean: float = 0.0, std: float = 1.0
) -> paddle.Tensor:
    """Modify tensor inplace using normal_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        mean (float, optional): mean value. Defaults to 0.0.
        std (float, optional): std value. Defaults to 1.0.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.normal_(param, 0, 1)
    """
    return _no_grad_normal_(tensor, mean, std)

trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0)

Modify tensor inplace using trunc_normal_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
mean float

The mean of the normal distribution. Defaults to 0.0.

0.0
std float

The standard deviation of the normal distribution. Defaults to 1.0.

1.0
a float

The minimum cutoff value. Defaults to -2.0.

-2.0
b float

The maximum cutoff value. Defaults to 2.0.

2.0

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.trunc_normal_(param, 0.0, 1.0)
Source code in ppsci/utils/initializer.py
def trunc_normal_(
    tensor: paddle.Tensor,
    mean: float = 0.0,
    std: float = 1.0,
    a: float = -2.0,
    b: float = 2.0,
) -> paddle.Tensor:
    """Modify tensor inplace using trunc_normal_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        mean (float, optional): The mean of the normal distribution. Defaults to 0.0.
        std (float, optional): The standard deviation of the normal distribution. Defaults to 1.0.
        a (float, optional): The minimum cutoff value. Defaults to -2.0.
        b (float, optional): The maximum cutoff value. Defaults to 2.0.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.trunc_normal_(param, 0.0, 1.0)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

constant_(tensor, value=0.0)

Modify tensor inplace using constant_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
value float

value to fill tensor. Defaults to 0.0.

0.0

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.constant_(param, 2)
Source code in ppsci/utils/initializer.py
def constant_(tensor: paddle.Tensor, value: float = 0.0) -> paddle.Tensor:
    """Modify tensor inplace using constant_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        value (float, optional): value to fill tensor. Defaults to 0.0.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.constant_(param, 2)
    """
    return _no_grad_fill_(tensor, value)

ones_(tensor)

Modify tensor inplace using ones_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.ones_(param)
Source code in ppsci/utils/initializer.py
def ones_(tensor: paddle.Tensor) -> paddle.Tensor:
    """Modify tensor inplace using ones_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.ones_(param)
    """
    return _no_grad_fill_(tensor, 1)

zeros_(tensor)

Modify tensor inplace using zeros_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.zeros_(param)
Source code in ppsci/utils/initializer.py
def zeros_(tensor: paddle.Tensor) -> paddle.Tensor:
    """Modify tensor inplace using zeros_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.zeros_(param)
    """
    return _no_grad_fill_(tensor, 0)

xavier_uniform_(tensor, gain=1.0, reverse=False)

Modify tensor inplace using xavier_uniform_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
gain float

Hyperparameter. Defaults to 1.0.

1.0
reverse bool

Tensor data format order, False by default as [fout, fin, ...].. Defaults to False.

False

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.xavier_uniform_(param)
Source code in ppsci/utils/initializer.py
def xavier_uniform_(
    tensor: paddle.Tensor, gain: float = 1.0, reverse: bool = False
) -> paddle.Tensor:
    """Modify tensor inplace using xavier_uniform_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        gain (float, optional): Hyperparameter. Defaults to 1.0.
        reverse (bool, optional): Tensor data format order, False by default as
            [fout, fin, ...].. Defaults to False.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.xavier_uniform_(param)
    """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    k = math.sqrt(3.0) * std
    return _no_grad_uniform_(tensor, -k, k)

xavier_normal_(tensor, gain=1.0, reverse=False)

Modify tensor inplace using xavier_normal_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
gain float

Hyperparameter. Defaults to 1.0.

1.0
reverse bool

tensor data format order, False by default as [fout, fin, ...]. Defaults to False.

False

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.xavier_normal_(param)
Source code in ppsci/utils/initializer.py
def xavier_normal_(
    tensor: paddle.Tensor, gain: float = 1.0, reverse: bool = False
) -> paddle.Tensor:
    """Modify tensor inplace using xavier_normal_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        gain (float, optional): Hyperparameter. Defaults to 1.0.
        reverse (bool, optional): tensor data format order, False by
            default as [fout, fin, ...]. Defaults to False.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.xavier_normal_(param)
    """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    return _no_grad_normal_(tensor, 0, std)

kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', reverse=False)

Modify tensor inplace using kaiming_uniform method.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
a float

The negative slope of the rectifier used after this layer. Defaults to 0.

0
mode Literal["fan_in", "fan_out"]

["fan_in", "fan_out"]. Defaults to "fan_in".

'fan_in'
nonlinearity str

Nonlinearity method name. Defaults to "leaky_relu".

'leaky_relu'
reverse bool

tensor data format order, False by default as [fout, fin, ...].. Defaults to False.

False

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.kaiming_uniform_(param)
Source code in ppsci/utils/initializer.py
def kaiming_uniform_(
    tensor: paddle.Tensor,
    a: float = 0,
    mode: Literal["fan_in", "fan_out"] = "fan_in",
    nonlinearity: str = "leaky_relu",
    reverse: bool = False,
) -> paddle.Tensor:
    """Modify tensor inplace using kaiming_uniform method.

    Args:
        tensor (paddle.Tensor):  Paddle Tensor.
        a (float, optional): The negative slope of the rectifier used after this layer.
            Defaults to 0.
        mode (Literal["fan_in", "fan_out"], optional):
            ["fan_in", "fan_out"]. Defaults to "fan_in".
        nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu".
        reverse (bool, optional): tensor data format order, False by default as
            [fout, fin, ...].. Defaults to False.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.kaiming_uniform_(param)
    """
    fan = _calculate_correct_fan(tensor, mode, reverse)
    gain = _calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    k = math.sqrt(3.0) * std
    return _no_grad_uniform_(tensor, -k, k)

kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', reverse=False)

Modify tensor inplace using kaiming_normal_.

Parameters:

Name Type Description Default
tensor Tensor

Paddle Tensor.

required
a float

The negative slope of the rectifier used after this layer. Defaults to 0.

0
mode Literal["fan_in", "fan_out"]

Either 'fan_in' (default) or 'fan_out'. Defaults to "fan_in".

'fan_in'
nonlinearity str

Nonlinearity method name. Defaults to "leaky_relu".

'leaky_relu'
reverse bool

Tensor data format order. Defaults to False.

False

Returns:

Type Description
Tensor

paddle.Tensor: Initialized tensor.

Examples:

>>> import paddle
>>> import ppsci
>>> param = paddle.empty((128, 256), "float32")
>>> param = ppsci.utils.initializer.kaiming_normal_(param)
Source code in ppsci/utils/initializer.py
def kaiming_normal_(
    tensor: paddle.Tensor,
    a: float = 0,
    mode: Literal["fan_in", "fan_out"] = "fan_in",
    nonlinearity: str = "leaky_relu",
    reverse: bool = False,
) -> paddle.Tensor:
    """Modify tensor inplace using kaiming_normal_.

    Args:
        tensor (paddle.Tensor): Paddle Tensor.
        a (float, optional): The negative slope of the rectifier used after this layer.
            Defaults to 0.
        mode (Literal["fan_in", "fan_out"], optional): Either
            'fan_in' (default) or 'fan_out'. Defaults to "fan_in".
        nonlinearity (str, optional): Nonlinearity method name. Defaults to "leaky_relu".
        reverse (bool, optional): Tensor data format order. Defaults to False.

    Returns:
        paddle.Tensor: Initialized tensor.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> param = paddle.empty((128, 256), "float32")
        >>> param = ppsci.utils.initializer.kaiming_normal_(param)
    """
    fan = _calculate_correct_fan(tensor, mode, reverse)
    gain = _calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    return _no_grad_normal_(tensor, 0, std)

linear_init_(module)

Initialize module's weight and bias as it is a linear layer.

Parameters:

Name Type Description Default
module Layer

Linear Layer to be initialized.

required
Source code in ppsci/utils/initializer.py
def linear_init_(module: nn.Layer) -> None:
    """Initialize module's weight and bias as it is a linear layer.

    Args:
        module (nn.Layer): Linear Layer to be initialized.
    """
    kaiming_uniform_(module.weight, a=math.sqrt(5))
    if module.bias is not None:
        fan_in, _ = _calculate_fan_in_and_fan_out(module.weight, reverse=True)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        uniform_(module.bias, -bound, bound)

conv_init_(module)

Initialize module's weight and bias as it is a conv layer.

Parameters:

Name Type Description Default
module Layer

Convolution Layer to be initialized.

required
Source code in ppsci/utils/initializer.py
def conv_init_(module: nn.Layer) -> None:
    """Initialize module's weight and bias as it is a conv layer.

    Args:
        module (nn.Layer): Convolution Layer to be initialized.
    """
    kaiming_uniform_(module.weight, a=math.sqrt(5))
    if module.bias is not None:
        fan_in, _ = _calculate_fan_in_and_fan_out(module.weight, reverse=False)
        if fan_in != 0:
            bound = 1 / math.sqrt(fan_in)
            uniform_(module.bias, -bound, bound)

最后更新: November 17, 2023
创建日期: November 6, 2023