Skip to content

FuXi

None

None

None

cd examples/fuxi
# Download sample input data and model weight from https://pan.baidu.com/s/1PDeb-nwUprYtu9AKGnWnNw?pwd=fuxi#list/path=%2F
unzip Sample_Data.zip
unzip FuXi_EC.zip

# modify the path of model and datasets in examples/fuxi/conf, and inference
pip install -r requirements.txt
python predict.py

1. Background Introduction

The FuXi model is a machine learning (ML) weather forecasting system designed to generate 15-day global weather forecasts. It utilizes 39 years of ECMWF ERA5 reanalysis data with 0.25° spatial resolution and 6-hour temporal resolution. The FuXi system is named after Fu Xi, a figure in ancient Chinese mythology who is considered the first weather forecaster in China.

Key aspects and background of FuXi model development include:

  • Motivation: The development of FuXi was motivated by the limitations of current ML models in long-term weather forecasting due to error accumulation. Although ML models have shown promise in short-term forecasting, achieving performance comparable to ECMWF's traditional numerical weather prediction (NWP) models in long-term forecasting (e.g., 15 days) remains a challenge.

  • Cascade Model Architecture: To address the problem of error accumulation, FuXi adopts a novel Cascade ML model architecture. This architecture uses pre-trained models optimized for specific 5-day forecast time windows (0-5 days, 5-10 days, and 10-15 days) to improve the accuracy of different forecast lead times.

  • Base Model: The base model of FuXi is an autoregressive model designed to extract complex features and learn relationships from high-dimensional weather data.

  • Training Process: The training process of FuXi includes two steps: pre-training and fine-tuning. The pre-training step optimizes the model to predict a single time step, while fine-tuning involves training the Cascade models for their respective forecast time windows.

  • Performance: The FuXi system demonstrates performance comparable to the ECMWF Ensemble Mean (EM) in 15-day forecasts and outperforms the ECMWF High-Resolution Forecast (HRES) in terms of effective forecast lead time.

The overall structure of the model is shown in the figure:

result

Model Structure

The FuXi model uses the fifth generation ECMWF reanalysis dataset ERA5. This dataset provides hourly data on surface and upper-air parameters from January 1940 to the present. The ERA5 dataset is generated by assimilating high-quality and abundant global observations using ECMWF's Integrated Forecast System (IFS) model. ERA5 data is widely considered a comprehensive and accurate reanalysis archive, making it suitable as ground truth for training the FuXi model. For the FuXi model, a subset of the ERA5 dataset spanning 39 years with 0.25° spatial resolution and 6-hour temporal resolution was used. The model aims to predict 5 upper-air atmospheric variables at 13 pressure levels and 5 surface variables. The dataset is divided into training, validation, and test sets. The training set contains 54,020 samples from 1979 to 2015, the validation set contains 2,920 samples from 2016 and 2017, and the out-of-sample test set contains 1,460 samples from 2018. In addition, two reference datasets HRES-fc0 and ENS-fc0 were created to evaluate the performance of ECMWF High-Resolution Forecast (HRES) and Ensemble Mean (EM).

2. Model Principle

The FuXi model is an autoregressive model that uses weather parameters from the previous two time steps (\(X^{t-1}\), \(X^t\)) as input to predict weather parameters for the next time step (\(X^{t+1}\)). Where t, t-1, and t+1 represent the current, previous, and next time steps, respectively. The time step used in this model is 6 hours. By using the model's output as input for subsequent predictions, the system can generate forecasts for different forecast lead times.

Generating a 15-day forecast using a single FuXi model requires 60 iterations. Unlike physics-based NWP models, pure data-driven ML models lack physical constraints, which can lead to significant error growth and unrealistic forecast results in long-term forecasts. Using autoregressive multi-step loss can effectively reduce cumulative errors in long-term forecasts. This loss function is similar to the cost function used in the 4D-Var data assimilation method, which aims to identify initial weather conditions that best fit observations within the assimilation time window. Although increasing the number of autoregressive steps can improve the accuracy of long-term forecasts, it also reduces the accuracy of short-term forecasts. Furthermore, similar to increasing the assimilation time window of 4D-Var, increasing the number of autoregressive steps requires more memory and computational resources to handle gradients during training.

When making iterative forecasts, error accumulation is inevitable as the forecast lead time increases. Furthermore, previous studies have shown that a single model cannot achieve optimal performance across all forecast lead times. To optimize performance for both short-term and long-term forecasts, the paper proposes a Cascade model architecture using pre-trained FuXi models that are fine-tuned to achieve optimal performance within specific 5-day forecast time windows. These time windows are referred to as FuXi-Short (0-5 days), FuXi-Medium (5-10 days), and FuXi-Long (10-15 days). The outputs of FuXi-Short and FuXi-Medium are used as inputs for FuXi-Medium and FuXi-Long at step 20 and step 40, respectively. Unlike the greedy hierarchical temporal aggregation strategy used in Pangu-Weather (which utilizes 4 models predicting 1-hour, 3-hour, 6-hour, and 24-hour forecast lead times respectively to reduce steps), the Cascade FuXi model does not have the problem of temporal inconsistency.

The model architecture of the base FuXi model consists of three main parts, as stated in the paper: Cube Embedding, U-Transformer, and Fully Connected (FC) layer. The input data combines upper-air and surface variables and creates a data cube with dimensions 2×70×721×1440, where 2 represents the previous two time steps (t-1 and t), 70 represents the total number of input variables, and 721 and 1440 represent latitude (H) and longitude (W) grid points respectively.

First, the high-dimensional input data is reduced to C×180×360 through joint spatiotemporal Cube Embedding, where C is the number of channels, set to 1536. The main purpose of Cube Embedding is to reduce the temporal and spatial dimensions of the input data and reduce data redundancy. Subsequently, the U-Transformer processes the embedded data and uses a simple FC layer for prediction. The output result is first reshaped to 70×720×1440, and then restored to the original input shape 70×721×1440 through bilinear interpolation.

The U-Transformer is built from 48 repeated Swin Transformer V2 blocks and computes scaled cosine attention as follows:

\[Attention(Q, K, V) = (cos(Q, K)/\tau +B)V\]

Where B represents relative position bias, which is a learnable scalar and is not shared between different heads and layers. The cosine function is naturally normalized, which results in smaller attention values.

The model uses pre-trained weights for inference. Next, the inference process of the model will be introduced.

3. Model Construction

In this case, FuXiPredictor is implemented for inference of the ONNX model:

examples/fuxi/predict.py
class FuXiPredictor(base.Predictor):
    """General predictor for FuXi model.

    Args:
        cfg (DictConfig): Running configuration.
    """

    def __init__(
        self,
        cfg: DictConfig,
    ):
        print(f"cfg: {cfg}")
        assert cfg.INFER.engine == "onnx", "FuXi engine only supports 'onnx'."

        super().__init__(
            pdmodel_path=None,
            pdiparams_path=None,
            device=cfg.INFER.device,
            engine=cfg.INFER.engine,
            precision=cfg.INFER.precision,
            onnx_path=cfg.INFER.onnx_path,
            ir_optim=cfg.INFER.ir_optim,
            min_subgraph_size=cfg.INFER.min_subgraph_size,
            gpu_mem=cfg.INFER.gpu_mem,
            gpu_id=cfg.INFER.gpu_id,
            max_batch_size=cfg.INFER.max_batch_size,
            num_cpu_threads=cfg.INFER.num_cpu_threads,
        )
        self.log_freq = cfg.log_freq

        # get input names
        self.input_names = [
            input_node.name for input_node in self.predictor.get_inputs()
        ]

        # get output names
        self.output_names = [
            output_node.name for output_node in self.predictor.get_outputs()
        ]

        self.output_dir = cfg.output_dir

    def predict(
        self, input_data, tembs, global_step, stage, num_step, data, batch_size: int = 1
    ) -> tuple[np.ndarray, int]:
        """Predicts the output of the yinglong model for the given input.

        Args:
            input_data(np.ndarray): Atomospheric data of two preceding time steps
            tembs(np.ndarray): Encoded timestamp.
            global_step (int): The global step of forecast.
            stage (int): The stage of forecast model.
            num_step (int): The number of forecast steps.
            batch_size (int, optional): Batch size, now only support 1. Defaults to 1.

        Returns:
            tuple[np.ndarray, int]: Prediction for one stage and the global step.
        """
        if batch_size != 1:
            raise ValueError(
                f"FuXiPredictor only support batch_size=1, but got {batch_size}"
            )

        # prepare input dict
        for _ in range(0, num_step):
            input_dict = {
                self.input_names[0]: input_data,
                self.input_names[1]: tembs[global_step],
            }

            # run predictor
            new_input = self.predictor.run(None, input_dict)[0]
            output = new_input[:, -1]
            save_like(output, data, global_step, self.output_dir)
            print(
                f"stage: {stage}, global_step: {global_step+1:02d}, output: {output.min():.2f} {output.max():.2f}"
            )
            input_data = new_input
            global_step += 1

        return input_data, global_step

FuXi adopts a cascade model structure, predicting three consecutive forecast periods (0-5 days, 5-10 days, and 10-15 days) through fuxi_short.yaml, fuxi_medium.yaml, and fuxi_long.yaml.

4. Result Visualization

Use examples/fuxi/visualize.py for plotting and result visualization.

5. Complete Code

examples/fuxi/predict.py
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hydra
import numpy as np
import paddle
import pandas as pd
import xarray as xr
from omegaconf import DictConfig
from omegaconf import OmegaConf
from packaging import version
from util import save_like

from deploy.python_infer import base
from ppsci.utils import logger


def time_encoding(init_time, total_step, freq=6):
    init_time = np.array([init_time])
    tembs = []
    for i in range(total_step):
        hours = np.array([pd.Timedelta(hours=t * freq) for t in [i - 1, i, i + 1]])
        times = init_time[:, None] + hours[None]
        times = [pd.Period(t, "H") for t in times.reshape(-1)]
        times = [(p.day_of_year / 366, p.hour / 24) for p in times]
        temb = np.array(times, dtype=np.float32)
        temb = np.concatenate([np.sin(temb), np.cos(temb)], axis=-1)
        temb = temb.reshape(1, -1)
        tembs.append(temb)
    return np.stack(tembs)


class FuXiPredictor(base.Predictor):
    """General predictor for FuXi model.

    Args:
        cfg (DictConfig): Running configuration.
    """

    def __init__(
        self,
        cfg: DictConfig,
    ):
        print(f"cfg: {cfg}")
        assert cfg.INFER.engine == "onnx", "FuXi engine only supports 'onnx'."

        super().__init__(
            pdmodel_path=None,
            pdiparams_path=None,
            device=cfg.INFER.device,
            engine=cfg.INFER.engine,
            precision=cfg.INFER.precision,
            onnx_path=cfg.INFER.onnx_path,
            ir_optim=cfg.INFER.ir_optim,
            min_subgraph_size=cfg.INFER.min_subgraph_size,
            gpu_mem=cfg.INFER.gpu_mem,
            gpu_id=cfg.INFER.gpu_id,
            max_batch_size=cfg.INFER.max_batch_size,
            num_cpu_threads=cfg.INFER.num_cpu_threads,
        )
        self.log_freq = cfg.log_freq

        # get input names
        self.input_names = [
            input_node.name for input_node in self.predictor.get_inputs()
        ]

        # get output names
        self.output_names = [
            output_node.name for output_node in self.predictor.get_outputs()
        ]

        self.output_dir = cfg.output_dir

    def predict(
        self, input_data, tembs, global_step, stage, num_step, data, batch_size: int = 1
    ) -> tuple[np.ndarray, int]:
        """Predicts the output of the yinglong model for the given input.

        Args:
            input_data(np.ndarray): Atomospheric data of two preceding time steps
            tembs(np.ndarray): Encoded timestamp.
            global_step (int): The global step of forecast.
            stage (int): The stage of forecast model.
            num_step (int): The number of forecast steps.
            batch_size (int, optional): Batch size, now only support 1. Defaults to 1.

        Returns:
            tuple[np.ndarray, int]: Prediction for one stage and the global step.
        """
        if batch_size != 1:
            raise ValueError(
                f"FuXiPredictor only support batch_size=1, but got {batch_size}"
            )

        # prepare input dict
        for _ in range(0, num_step):
            input_dict = {
                self.input_names[0]: input_data,
                self.input_names[1]: tembs[global_step],
            }

            # run predictor
            new_input = self.predictor.run(None, input_dict)[0]
            output = new_input[:, -1]
            save_like(output, data, global_step, self.output_dir)
            print(
                f"stage: {stage}, global_step: {global_step+1:02d}, output: {output.min():.2f} {output.max():.2f}"
            )
            input_data = new_input
            global_step += 1

        return input_data, global_step


def inference(cfg: DictConfig):
    # log paddlepaddle's version
    if version.Version(paddle.__version__) != version.Version("0.0.0"):
        paddle_version = paddle.__version__
        if version.Version(paddle.__version__) < version.Version("2.6.0"):
            logger.warning(
                f"Detected paddlepaddle version is '{paddle_version}', "
                "currently it is recommended to use release 2.6 or develop version."
            )
    else:
        paddle_version = f"develop({paddle.version.commit[:7]})"

    logger.info(f"Using paddlepaddle {paddle_version}")

    num_steps = cfg.num_steps
    stages = ["short", "medium", "long"]

    # load data
    data = xr.open_dataarray(cfg.input_file)

    total_step = sum(num_steps)
    init_time = pd.to_datetime(data.time.values[-1])
    tembs = time_encoding(init_time, total_step)

    print(f'init_time: {init_time.strftime(("%Y%m%d-%H"))}')
    print(f"latitude: {data.lat.values[0]} ~ {data.lat.values[-1]}")

    assert data.lat.values[0] == 90
    assert data.lat.values[-1] == -90

    input_data = data.values[None]

    step = 0
    for i, num_step in enumerate(num_steps):
        print(f"Inference {stages[i]} ...")
        cfg_path = cfg.fuxi_config_dir + "fuxi_" + stages[i] + ".yaml"
        config = OmegaConf.load(cfg_path)
        print(f"predictor_cfg: {config}")
        predictor = FuXiPredictor(config)
        # run predictor
        input_data, step = predictor.predict(
            input_data=input_data,
            tembs=tembs,
            global_step=step,
            stage=i,
            num_step=num_step,
            data=data,
        )

        if step > total_step:
            break


@hydra.main(version_base=None, config_path="./conf", config_name="fuxi.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['infer'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

6. Result Display

The model inference result contains 60 NetCDF files, representing meteorological data for 20 time steps of each model within the next 15 days starting from the prediction time point.

Use examples/fuxi/visualize.py for plotting and result visualization.

python3 visualize.py --data_dir outputs_fuxi/ --save_dir outputs_fuxi/ --step 6

The figure below shows:

result

Weather forecast result for the next 6 hours

7. References