跳转至

Optimizer.lr_scheduler(学习率) 模块

ppsci.optimizer.lr_scheduler

Linear

Bases: LRBase

Linear learning rate decay.

Parameters:

Name Type Description Default
epochs int

total epoch(s).

required
iters_per_epoch int

number of iterations within an epoch.

required
learning_rate float

learning rate.

required
end_lr float

The minimum final learning rate. Defaults to 0.0.

0.0
power float

Power of polynomial. Defaults to 1.0.

1.0
warmup_epoch int

number of warmup epochs.

0
warmup_start_lr float

start learning rate within warmup.

0.0
last_epoch int

last epoch.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)()
Source code in ppsci/optimizer/lr_scheduler.py
class Linear(LRBase):
    """Linear learning rate decay.

    Args:
        epochs (int): total epoch(s).
        iters_per_epoch (int): number of iterations within an epoch.
        learning_rate (float): learning rate.
        end_lr (float, optional): The minimum final learning rate. Defaults to 0.0.
        power (float, optional): Power of polynomial. Defaults to 1.0.
        warmup_epoch (int): number of warmup epochs.
        warmup_start_lr (float): start learning rate within warmup.
        last_epoch (int): last epoch.
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.Linear(10, 2, 0.001)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        end_lr: float = 0.0,
        power: float = 1.0,
        cycle: bool = False,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.decay_steps = (epochs - self.warmup_epoch) * iters_per_epoch
        self.end_lr = end_lr
        self.power = power
        self.cycle = cycle
        self.warmup_steps = round(self.warmup_epoch * iters_per_epoch)
        if self.by_epoch:
            self.decay_steps = self.epochs - self.warmup_epoch

    def __call__(self):
        learning_rate = (
            lr.PolynomialDecay(
                learning_rate=self.learning_rate,
                decay_steps=self.decay_steps,
                end_lr=self.end_lr,
                power=self.power,
                cycle=self.cycle,
                last_epoch=self.last_epoch,
            )
            if self.decay_steps > 0
            else Constant(self.learning_rate)
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

Cosine

Bases: LRBase

Cosine learning rate decay.

lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)

Parameters:

Name Type Description Default
epochs int

total epoch(s).

required
iters_per_epoch int

number of iterations within an epoch.

required
learning_rate float

learning rate.

required
eta_min float

Minimum learning rate. Defaults to 0.0.

0.0
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

last epoch. Defaults to -1.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Cosine(10, 2, 1e-3)()
Source code in ppsci/optimizer/lr_scheduler.py
class Cosine(LRBase):
    r"""Cosine learning rate decay.

    lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)

    Args:
        epochs (int): total epoch(s).
        iters_per_epoch (int): number of iterations within an epoch.
        learning_rate (float): learning rate.
        eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True,
            else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.Cosine(10, 2, 1e-3)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        eta_min: float = 0.0,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.T_max = (self.epochs - self.warmup_epoch) * self.iters_per_epoch
        self.eta_min = eta_min
        if self.by_epoch:
            self.T_max = self.epochs - self.warmup_epoch

    def __call__(self):
        learning_rate = (
            lr.CosineAnnealingDecay(
                learning_rate=self.learning_rate,
                T_max=self.T_max,
                eta_min=self.eta_min,
                last_epoch=self.last_epoch,
            )
            if self.T_max > 0
            else Constant(self.learning_rate)
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

Step

Bases: LRBase

Step learning rate decay.

Parameters:

Name Type Description Default
epochs int

total epoch(s).

required
iters_per_epoch int

number of iterations within an epoch.

required
learning_rate float

learning rate.

required
step_size int

the interval to update.

required
gamma float

The Ratio that the learning rate will be reduced. new_lr = origin_lr * gamma. It should be less than 1.0. Default: 0.1.

required
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

last epoch. Defaults to -1.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Step(10, 1, 1e-3, 2, 0.95)()
Source code in ppsci/optimizer/lr_scheduler.py
class Step(LRBase):
    """Step learning rate decay.

    Args:
        epochs (int): total epoch(s).
        iters_per_epoch (int): number of iterations within an epoch.
        learning_rate (float): learning rate.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced.
            ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True,
            else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.Step(10, 1, 1e-3, 2, 0.95)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        step_size: int,
        gamma: float,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.step_size = step_size * iters_per_epoch
        self.gamma = gamma
        if self.by_epoch:
            self.step_size = step_size

    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

Piecewise

Bases: LRBase

Piecewise learning rate decay

Parameters:

Name Type Description Default
epochs int

total epoch(s)

required
iters_per_epoch int

number of iterations within an epoch

required
decay_epochs Tuple[int, ...]

A list of steps numbers. The type of element in the list is python int.

required
values Tuple[float, ...]

Tuple of learning rate values that will be picked during different epoch boundaries.

required
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

last epoch. Defaults to -1.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.Piecewise(
...     10, 1, [2, 4], (1e-3, 1e-4, 1e-5)
... )()
Source code in ppsci/optimizer/lr_scheduler.py
class Piecewise(LRBase):
    """Piecewise learning rate decay

    Args:
        epochs (int): total epoch(s)
        iters_per_epoch (int): number of iterations within an epoch
        decay_epochs (Tuple[int, ...]): A list of steps numbers. The type of element in the
            list is python int.
        values (Tuple[float, ...]): Tuple of learning rate values that will be picked during
            different epoch boundaries.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True,
            else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.Piecewise(
        ...     10, 1, [2, 4], (1e-3, 1e-4, 1e-5)
        ... )()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        decay_epochs: Tuple[int, ...],
        values: Tuple[float, ...],
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            values[0],
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.values = values
        self.boundaries_steps = [e * iters_per_epoch for e in decay_epochs]
        if self.by_epoch is True:
            self.boundaries_steps = decay_epochs

    def __call__(self):
        learning_rate = lr.PiecewiseDecay(
            boundaries=self.boundaries_steps,
            values=self.values,
            last_epoch=self.last_epoch,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

MultiStepDecay

Bases: LRBase

MultiStepDecay learning rate decay

Parameters:

Name Type Description Default
epochs int

total epoch(s)

required
iters_per_epoch int

number of iterations within an epoch

required
learning_rate float

learning rate

required
milestones Tuple[int, ...]

Tuple of each boundaries. should be increasing.

required
gamma float

The Ratio that the learning rate will be reduced. new_lr = origin_lr * gamma. It should be less than 1.0. Defaults to 0.1.

0.1
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

last epoch. Defaults to -1.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.MultiStepDecay(10, 1, 1e-3, (4, 5))()
Source code in ppsci/optimizer/lr_scheduler.py
class MultiStepDecay(LRBase):
    """MultiStepDecay learning rate decay

    Args:
        epochs (int): total epoch(s)
        iters_per_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        milestones (Tuple[int, ...]): Tuple of each boundaries. should be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced.
            `new_lr = origin_lr * gamma`. It should be less than 1.0. Defaults to 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True,
            else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.MultiStepDecay(10, 1, 1e-3, (4, 5))()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        milestones: Tuple[int, ...],
        gamma: float = 0.1,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.milestones = [x * iters_per_epoch for x in milestones]
        self.gamma = gamma
        if self.by_epoch:
            self.milestones = milestones

    def __call__(self):
        learning_rate = lr.MultiStepDecay(
            learning_rate=self.learning_rate,
            milestones=self.milestones,
            gamma=self.gamma,
            last_epoch=self.last_epoch,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

ExponentialDecay

Bases: LRBase

ExponentialDecay learning rate decay.

Parameters:

Name Type Description Default
epochs int

total epoch(s).

required
iters_per_epoch int

number of iterations within an epoch.

required
learning_rate float

learning rate.

required
warmup_epoch int

number of warmup epochs.

0
warmup_start_lr float

start learning rate within warmup.

0.0
last_epoch int

last epoch.

-1
by_epoch bool

learning rate decays by epoch when by_epoch is True, else by iter.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)()
Source code in ppsci/optimizer/lr_scheduler.py
class ExponentialDecay(LRBase):
    """ExponentialDecay learning rate decay.

    Args:
        epochs (int): total epoch(s).
        iters_per_epoch (int): number of iterations within an epoch.
        learning_rate (float): learning rate.
        warmup_epoch (int): number of warmup epochs.
        warmup_start_lr (float): start learning rate within warmup.
        last_epoch (int): last epoch.
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(10, 2, 1e-3, 0.95, 3)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        gamma: float,
        decay_steps: int,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.decay_steps = decay_steps
        self.gamma = gamma
        self.warmup_steps = round(self.warmup_epoch * iters_per_epoch)
        if self.by_epoch:
            self.decay_steps /= iters_per_epoch

    def __call__(self):
        learning_rate = lr.ExponentialDecay(
            learning_rate=self.learning_rate,
            gamma=self.gamma ** (1 / self.decay_steps),
            last_epoch=self.last_epoch,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

CosineWarmRestarts

Bases: LRBase

Set the learning rate using a cosine annealing schedule with warm restarts.

Parameters:

Name Type Description Default
epochs int

Total epoch(s)

required
iters_per_epoch int

Number of iterations within an epoch

required
learning_rate float

Learning rate

required
T_0 int

Number of iterations for the first restart.

required
T_mult int

A factor increases T_i after a restart

required
eta_min float

Minimum learning rate. Defaults to 0.0.

0.0
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

Start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

Last epoch. Defaults to -1.

-1
by_epoch bool

Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.CosineWarmRestarts(20, 1, 1e-3, 14, 2)()
Source code in ppsci/optimizer/lr_scheduler.py
class CosineWarmRestarts(LRBase):
    """Set the learning rate using a cosine annealing schedule with warm restarts.

    Args:
        epochs (int): Total epoch(s)
        iters_per_epoch (int): Number of iterations within an epoch
        learning_rate (float): Learning rate
        T_0 (int): Number of iterations for the first restart.
        T_mult (int): A factor increases T_i after a restart
        eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): Start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): Last epoch. Defaults to -1.
        by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.CosineWarmRestarts(20, 1, 1e-3, 14, 2)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        learning_rate: float,
        T_0: int,
        T_mult: int,
        eta_min: float = 0.0,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.T_0 = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        if self.by_epoch is False:
            self.T_0 = T_0 * iters_per_epoch

    def __call__(self):
        learning_rate = CosineAnnealingWarmRestarts(
            learning_rate=self.learning_rate,
            T_0=self.T_0,
            T_mult=self.T_mult,
            eta_min=self.eta_min,
            last_epoch=self.last_epoch,
            verbose=self.verbose,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

OneCycleLR

Bases: LRBase

Sets the learning rate according to the one cycle learning rate scheduler. The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

It has been proposed in Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates.

Please note that the default behavior of this scheduler follows the fastai implementation of one cycle, which claims that "unpublished work has shown even better results by using only two phases". If you want the behavior of this scheduler to be consistent with the paper, please set three_phase=True.

Parameters:

Name Type Description Default
epochs int

Total epoch(s).

required
iters_per_epoch int

Number of iterations within an epoch.

required
max_learning_rate float

The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by divide_factor .

required
divide_factor float

Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Defaults to 25.0.

25.0
end_learning_rate float

The minimum learning rate during training, it should be much less than initial learning rate. Defaults to 0.0001.

0.0001
phase_pct float

The percentage of total steps which used to increasing learning rate. Defaults to 0.3.

0.3
anneal_strategy str

Strategy of adjusting learning rate. "cos" for cosine annealing, "linear" for linear annealing. Defaults to "cos".

'cos'
three_phase bool

Whether to use three phase. Defaults to False.

False
warmup_epoch int

The epoch numbers for LinearWarmup. Defaults to 0.

0
warmup_start_lr float

start learning rate within warmup. Defaults to 0.0.

0.0
last_epoch int

Last epoch. Defaults to -1.

-1
by_epoch bool

Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

False

Examples:

>>> import ppsci
>>> lr = ppsci.optimizer.lr_scheduler.OneCycleLR(100, 1, 1e-3)()
Source code in ppsci/optimizer/lr_scheduler.py
class OneCycleLR(LRBase):
    """Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in [Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120).

    Please note that the default behavior of this scheduler follows the fastai implementation of one cycle,
    which claims that **"unpublished work has shown even better results by using only two phases"**.
    If you want the behavior of this scheduler to be consistent with the paper, please set `three_phase=True`.

    Args:
        epochs (int): Total epoch(s).
        iters_per_epoch (int): Number of iterations within an epoch.
        max_learning_rate (float): The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by `divide_factor` .
        divide_factor (float, optional): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Defaults to 25.0.
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate. Defaults to 0.0001.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Defaults to 0.3.
        anneal_strategy (str, optional): Strategy of adjusting learning rate. "cos" for cosine annealing, "linear" for linear annealing. Defaults to "cos".
        three_phase (bool, optional): Whether to use three phase. Defaults to False.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): Last epoch. Defaults to -1.
        by_epoch (bool, optional): Learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.

    Examples:
        >>> import ppsci
        >>> lr = ppsci.optimizer.lr_scheduler.OneCycleLR(100, 1, 1e-3)()
    """

    def __init__(
        self,
        epochs: int,
        iters_per_epoch: int,
        max_learning_rate: float,
        divide_factor: float = 25.0,
        end_learning_rate: float = 0.0001,
        phase_pct: float = 0.3,
        anneal_strategy: str = "cos",
        three_phase: bool = False,
        warmup_epoch: int = 0,
        warmup_start_lr: float = 0.0,
        last_epoch: int = -1,
        by_epoch: bool = False,
    ):
        super().__init__(
            epochs,
            iters_per_epoch,
            max_learning_rate,
            warmup_epoch,
            warmup_start_lr,
            last_epoch,
            by_epoch,
        )
        self.total_steps = epochs
        if not by_epoch:
            self.total_steps *= iters_per_epoch
        self.divide_factor = divide_factor
        self.end_learning_rate = end_learning_rate
        self.phase_pct = phase_pct
        self.anneal_strategy = anneal_strategy
        self.three_phase = three_phase

    def __call__(self):
        learning_rate = lr.OneCycleLR(
            max_learning_rate=self.learning_rate,
            total_steps=self.total_steps,
            divide_factor=self.divide_factor,
            end_learning_rate=self.end_learning_rate,
            phase_pct=self.phase_pct,
            anneal_strategy=self.anneal_strategy,
            three_phase=self.three_phase,
            last_epoch=self.last_epoch,
            verbose=self.verbose,
        )

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate

最后更新: November 17, 2023
创建日期: November 6, 2023