Skip to content

FourCastNet

AI Studio Quick Experience

Before starting training and evaluation, please download the dataset.

# Wind speed pretrain model
python train_pretrain.py
# Wind speed finetune model
python train_finetune.py
# Precipitation model training
python train_precip.py
# Wind speed pretrain model evaluation
python train_pretrain.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/pretrain.pdparams
# Wind speed finetune model evaluation
python train_finetune.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/finetune.pdparams
# Precipitation model evaluation
python train_precip.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/precip.pdparams WIND_MODEL_PATH=https://paddle-org.bj.bcebos.com/paddlescience/models/fourcastnet/finetune.pdparams
# Wind speed pretrain model export
python train_pretrain.py mode=export
# Wind speed finetune model export
python train_finetune.py mode=export
# Precipitation model export
python train_precip.py mode=export
# Download wind speed prediction small sample data
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/FourcastNet/global_stds.npy -P ./datasets/era5/stat/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/FourcastNet/global_means.npy -P ./datasets/era5/stat/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/FourcastNet/2018-04-04_n6_precip.npy -P ./datasets/era5/test/
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/FourcastNet/2018-04-04_n6.npy -P ./datasets/era5/test/
# Download precipitation prediction small sample data
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/FourcastNet/2018-09-08_n32.npy -P ./datasets/era5/test/
# Wind speed pretrain model inference
python train_pretrain.py mode=infer
# Wind speed finetune model inference
python train_finetune.py mode=infer
# Precipitation model inference
python train_precip.py mode=infer
Model Variable Name ACC/RMSE(6h) ACC/RMSE(30h) ACC/RMSE(60h) ACC/RMSE(120h) ACC/RMSE(192h)
Wind Speed Model U10 0.991/0.567 0.963/1.130 0.891/1.930 0.645/3.438 0.371/4.915
Model Variable Name ACC/RMSE(6h) ACC/RMSE(12h) ACC/RMSE(24h) ACC/RMSE(36h)
Precipitation Model TP 0.808/1.390 0.760/1.540 0.668/1.690 0.590/1.920

1. Background Introduction

Weather forecasting typically employs two approaches: physics-based and data-driven methods. Physics-based methods, such as the Integrated Forecasting System (IFS), rely on governing equations to model atmospheric variable relationships, often utilizing over 150 variables across 50+ vertical levels. In contrast, data-driven methods leverage large datasets to train neural networks, learning mappings from input to output without explicit physical equations.

FourCastNet is a data-driven weather forecasting algorithm utilizing Adaptive Fourier Neural Operators (AFNO). It focuses on predicting 10-meter wind speed and 6-hour total precipitation, enabling early warnings for extreme weather. Compared to IFS, FourCastNet uses only 20 atmospheric variables at 5 vertical heights, offering significantly faster inference speeds with reduced input complexity.

2. Model Principle

This chapter only briefly introduces the model principle of FourCastNet. For detailed theoretical derivation, please read FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators.

FourCastNet employs the AFNO network, adapting an architecture previously used in image segmentation. AFNO addresses the limitations of Vision Transformers (ViT) by integrating Fourier Neural Operators (FNO). It utilizes Fourier transforms for token interaction, significantly reducing the computational cost of self-attention in high-resolution settings. For further details, refer to the AFNO, FNO, and ViT papers.

The overall structure of the model is shown in the figure:

fourcastnet-arch

FourCastNet Network Model

The FourCastNet paper trained a wind speed model and a precipitation model. Next, the training and inference processes of these two models will be introduced.

2.1 Training and Inference Process of Wind Speed Model

Model training involves two stages: pre-training and fine-tuning.

In the pre-training stage, the model is initialized with random weights. As shown below, \(X(k)\) represents atmospheric data at time \(k\), \(X(k+1)\) is the model's prediction for time \(k+1\), and \(X_{true}(k+1)\) is the ground truth. The model minimizes the L2 loss between the predicted output and the ground truth.

fourcastnet-pretraining

Wind speed model pre-training

The second stage, fine-tuning, aims to enhance accuracy for medium- to long-range forecasting. Here, the model performs autoregressive prediction: the output for time \(k+1\) (generated from input at time \(k\)) is fed back as input to predict time \(k+2\). This multi-step prediction process improves the model's long-term stability and performance.

fourcastnet-finetuning

Wind speed model fine-tuning

In the inference stage, given data at time \(k\), prediction results at times \(k+1\), \(k+2\), \(k+3\), etc. can be obtained through continuous iteration.

fourcastnet-inference

Wind speed model inference

2.2 Training and Inference Process of Precipitation Model

The precipitation model training relies on the pre-trained wind speed model. As illustrated below, the wind speed model takes atmospheric data \(X(k)\) to predict \(X(k+1)\). This predicted state \(X(k+1)\) then serves as input to the precipitation model, which outputs the precipitation forecast \(p(k+1)\). The model is trained by minimizing the L2 loss between the predicted precipitation \(p(k+1)\) and the ground truth \(p_{true}(k+1)\).

precip-training

Precipitation model training

It should be noted that during the training process of the precipitation model, the parameters of the wind speed model are in a frozen state and do not participate in the optimizer parameter update process.

In the inference stage, given data at time \(k\), atmospheric variable prediction results at times \(k+1\), \(k+2\), \(k+3\), etc. can be obtained through continuous iteration using the wind speed model, and used as input to the precipitation model to predict precipitation at corresponding times.

precip-inference

Precipitation model inference

3. Wind Speed Model Implementation

Next, we will explain how to implement the training and inference of the FourCastNet wind speed model based on PaddleScience code. For other details in this case, please refer to API Documentation.

Info

Since complete reproduction requires 5+TB of storage space and 64-card training resources, if it is only for learning the algorithm principle of FourCastNet, it is recommended to train on a small part of the training dataset to reduce learning costs.

3.1 Dataset Introduction

We use the ERA5 dataset processed by FourCastNet. The dataset has a resolution of 0.25 degrees (\(720 \times 1440\) grid), with each point representing approximately 30 km. Covering the period 1979-2018, the data is split into training, validation, and test sets by year:

Dataset Year
Training set 1979-2015
Validation set 2016-2017
Test set 2018

The dataset can be downloaded from here.

The model training uses 20 atmospheric variables distributed on 5 pressure layers, as shown in the table below,

fourcastnet-vars

20 atmospheric variables

Among them, \(T\), \(U\), \(V\), \(Z\), \(RH\) represent temperature, zonal wind speed, meridional wind speed, geopotential and relative humidity at specified vertical heights respectively; \(U_{10}\), \(V_{10}\), \(T_{2m}\) represent zonal wind speed at 10 meters from the ground, meridional wind speed and temperature at 2 meters from the ground. \(sp\) represents surface pressure, and \(mslp\) represents mean sea level pressure. \(TCWV\) represents total column water vapor.

Data is sampled at 6-hour intervals (00:00, 06:00, 12:00, 18:00). Training and inference involve predicting the state at the next 6-hour interval; for example, taking 20 atmospheric variables at 00:00 as input to predict the variables at 06:00.

3.2 Model Pre-training

First, the various parameter variables defined in the code are displayed. The specific meaning of each parameter will be explained when used below.

examples/fourcastnet/conf/fourcastnet_pretrain.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
USE_SAMPLED_DATA: false

# set train data path
TRAIN_FILE_PATH: ./datasets/era5/train
DATA_MEAN_PATH: ./datasets/era5/stat/global_means.npy
DATA_STD_PATH: ./datasets/era5/stat/global_stds.npy
DATA_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

# set evaluate data path

3.2.1 Constraint Construction

Since this is a data-driven task, we use PaddleScience's SupervisedConstraint. Before defining the constraint, we configure data loading and preprocessing parameters. The preprocessing steps are implemented as follows:

examples/fourcastnet/train_pretrain.py
data_mean, data_std = fourcast_utils.get_mean_std(
    cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
)
data_time_mean = fourcast_utils.get_time_mean(
    cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
)
data_time_mean_normalize = np.expand_dims(
    (data_time_mean[0] - data_mean) / data_std, 0
)
# set train transforms
transforms = [
    {"SqueezeData": {}},
    {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
    {"Normalize": {"mean": data_mean, "std": data_std}},

The data preprocessing part contains a total of 3 preprocessing methods, namely:

  1. SqueezeData: Compress the dimensions of training data. If the dimension of input data is 4, compress data of 0th dimension and 1st dimension together, and finally transform the dimension of input data to 3.
  2. CropData: Crop data at specified position from training data. Because the original data size in ERA5 dataset is \(721 \times 1440\), this case crops the training data to \(720 \times 1440\) according to the original paper setting.
  3. Normalize: Normalize data according to mean and variance on the training dataset.

Full reproduction of FourCastNet requires over 5TB of storage and 64 GPUs. To accommodate different resource availabilities, we offer two training methods (both yield similar convergence):

Method A (Sufficient Storage): Each node stores the full 5TB+ dataset. Data is randomly selected from the complete set using global shuffle, as shown below.

fourcastnet-vars

Global shuffle

In this method, the code for data loading is as follows:

examples/fourcastnet/train_pretrain.py
if not cfg.USE_SAMPLED_DATA:
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,

Among them, the "dataset" field defines the used Dataset class name as ERA5Dataset, the "sampler" field defines the used Sampler class name as BatchSampler, setting batch_size to 1 and num_works to 8.

Method B (Limited Storage): The dataset is evenly partitioned across nodes. You can use ppsci/fourcastnet/sample_data.py to sample data. To use this method, set USE_SAMPLED_DATA to True (Method A is the default). Training uses local shuffle, where each node samples from its local partition. For example, splitting across 8 nodes reduces the per-node storage requirement to approximately 1.2TB.

fourcastnet-vars

Local shuffle

In this method, the code for data loading is as follows:

examples/fourcastnet/train_pretrain.py
else:
    NUM_GPUS_PER_NODE = 8
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5SampledDataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
        },
        "sampler": {
            "name": "DistributedBatchSampler",
            "drop_last": True,
            "shuffle": True,
            "num_replicas": NUM_GPUS_PER_NODE,
            "rank": dist.get_rank() % NUM_GPUS_PER_NODE,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,

Among them, the "dataset" field defines the used Dataset class name as ERA5SampledDataset, the "sampler" field defines the used Sampler class name as DistributedBatchSampler, setting batch_size to 1 and num_works to 8.

When complete reproduction of FourCastNet is not required, simply use the default setting of this case (method a).

The code for defining supervised constraints is as follows:

examples/fourcastnet/train_pretrain.py
    }
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    name="Sup",
)

The first parameter of SupervisedConstraint is the data loading method, here train_dataloader_cfg defined above is used;

The second parameter is the definition of loss function, here L2RelLoss is used;

The third parameter is the name of the constraint condition, which is convenient for subsequent indexing. Here it is named "Sup".

3.2.2 Model Construction

In this case, the wind speed model is based on the AFNONet network model, expressed in PaddleScience code as follows:

examples/fourcastnet/train_pretrain.py
# set model

The parameters of the network model are set through the configuration file as follows:

examples/fourcastnet/conf/fourcastnet_pretrain.yaml
# set inference data path
INFER_FILE_PATH: ./datasets/era5/test/2018-09-08_n32.npy

# model settings

Among them, input_keys and output_keys represent the names of input and output variables of the network model respectively.

3.2.3 Learning Rate and Optimizer Construction

The learning rate method used in this case is Cosine, and the learning rate size is set to 5e-4. The optimizer uses Adam, expressed in PaddleScience code as follows:

examples/fourcastnet/train_pretrain.py
# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

3.2.4 Validator Construction

In this case, the validation set is used to evaluate the training status of the current model at certain training epoch intervals during the training process, and SupervisedValidator is needed to construct the validator. The code is as follows:

examples/fourcastnet/train_pretrain.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.VALID_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": cfg.MODEL.afno.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric={
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    },
    name="Sup_Validator",
)

The SupervisedValidator validator is similar to SupervisedConstraint, the difference is that the validator needs to set the evaluation metric metric, here 3 evaluation metrics are used, namely MAE, LatitudeWeightedRMSE and LatitudeWeightedACC.

3.2.5 Model Training and Evaluation

After completing the above settings, you only need to pass the instantiated objects to ppsci.solver.Solver in order, and then start training and evaluation.

examples/fourcastnet/train_pretrain.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=True,
    seed=cfg.seed,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()
# evaluate after finished training

3.3 Model Fine-tuning

Having covered pre-training, we now discuss fine-tuning the wind speed model. Since the process is similar, we focus only on the differences. Key parameters for fine-tuning are defined below:

examples/fourcastnet/conf/fourcastnet_finetune.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

# set train data path
TRAIN_FILE_PATH: ./datasets/era5/train
DATA_MEAN_PATH: ./datasets/era5/stat/global_means.npy
DATA_STD_PATH: ./datasets/era5/stat/global_stds.npy
DATA_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

# set evaluate data path
VALID_FILE_PATH: ./datasets/era5/test

# set test data path

The fine-tuning model program adds a num_timestamps parameter to control the number of time steps iterated during model fine-tuning training. This parameter will first be used in the data loading setting to set the time step size of the ground truth generated by the dataset. The code is as follows:

examples/fourcastnet/train_finetune.py
]
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "num_label_timestamps": cfg.TRAIN.num_timestamps,
        "transforms": transforms,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,

The num_timestamps parameter is set through the configuration file as follows:

examples/fourcastnet/conf/fourcastnet_finetune.yaml
epochs: ${TRAIN.epochs}

In addition, unlike pre-training, fine-tuning model construction also requires setting the num_timestamps parameter to control the time step size of the prediction results output by the model. The code is as follows:

examples/fourcastnet/train_finetune.py
# set model
model_cfg = dict(cfg.MODEL.afno)
model_cfg.update(
    {"output_keys": output_keys, "num_timestamps": cfg.TRAIN.num_timestamps}

The code for evaluating model performance on the test set and visualization code have been added to the program for training fine-tuning models. Next, these two parts will be introduced in detail.

3.3.1 Evaluating Model on Test Set

According to the settings in the paper, when evaluating the model on the test set, num_timestamps is set to 32 through the configuration file, and the interval between two adjacent test samples is 8.

examples/fourcastnet/conf/fourcastnet_finetune.yaml
num_timestamps: 2
pretrained_model_path: outputs_fourcastnet_pretrain/checkpoints/latest
checkpoint_path: null

The code for constructing the model is:

examples/fourcastnet/train_finetune.py
# set model
model_cfg = dict(cfg.MODEL.afno)
model_cfg.update(
    {"output_keys": output_keys, "num_timestamps": cfg.EVAL.num_timestamps}
)

The code for constructing the validator is:

examples/fourcastnet/train_finetune.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.TEST_FILE_PATH,
        "input_keys": cfg.MODEL.afno.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "transforms": transforms,
        "num_label_timestamps": cfg.EVAL.num_timestamps,
        "training": False,
        "stride": 8,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set metirc
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H,
        std=data_std,
        keep_batch=True,
        variable_dict={"u10": 0, "v10": 1},
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H,
        mean=data_time_mean_normalize,
        keep_batch=True,
        variable_dict={"u10": 0, "v10": 1},
    ),
}

# set validator for testing
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)

3.3.2 Visualizer Construction

The wind speed model employs autoregressive inference. We first configure the input data:

examples/fourcastnet/train_finetune.py
# set visualizer data
DATE_STRINGS = ("2018-09-08 00:00:00",)
vis_data = get_vis_data(
    cfg.TEST_FILE_PATH,
    DATE_STRINGS,
    cfg.EVAL.num_timestamps,
    cfg.VARS_CHANNEL,
    cfg.IMG_H,
    data_mean,
    data_std,
examples/fourcastnet/train_finetune.py
def get_vis_data(
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    _file = h5py.File(file_path, "r")["fields"]
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h])
    data = np.asarray(data)

    vis_data = {"input": (data[:, 0] - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t + 1]
        wind_data = []
        for i in range(data_t.shape[0]):
            wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5)
        vis_data[f"target_{hour}h"] = np.asarray(wind_data)

In the above code, the corresponding data is read for model input based on the set time parameter DATE_STRINGS. In addition, the get_vis_datas function also reads the ground truth data at the corresponding time. These data will also be visualized for comparison with the model prediction results.

Since the model predicts zonal and meridional wind speeds separately, it is necessary to synthesize wind speeds in these two directions into real wind speed. The code is as follows:

examples/fourcastnet/train_finetune.py
def output_wind_func(d, var_name, data_mean, data_std):
    output = (d[var_name] * data_std) + data_mean
    wind_data = []
    for i in range(output.shape[0]):
        wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
    return paddle.to_tensor(wind_data, paddle.get_default_dtype())

vis_output_expr = {}
for i in range(cfg.EVAL.num_timestamps):
    hour = (i + 1) * 6
    vis_output_expr[f"output_{hour}h"] = functools.partial(
        output_wind_func,
        var_name=f"output_{i}",
        data_mean=paddle.to_tensor(data_mean, paddle.get_default_dtype()),
        data_std=paddle.to_tensor(data_std, paddle.get_default_dtype()),
    )

Finally, the code for constructing the visualizer is as follows:

examples/fourcastnet/train_finetune.py
    vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]
# set visualizer
visualizer = {
    "visualize_wind": ppsci.visualize.VisualizerWeather(
        vis_data,
        vis_output_expr,
        xticks=np.linspace(0, 1439, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, 719, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0,
        vmax=25,
        colorbar_label="m\s",
        batch_size=cfg.EVAL.batch_size,
        num_timestamps=cfg.EVAL.num_timestamps,
        prefix="wind",
    )

The constructed model, validator, and visualizer above will be passed to ppsci.solver.Solver for evaluating performance on the test set and visualization.

examples/fourcastnet/train_finetune.py
solver = ppsci.solver.Solver(
    model,
    output_dir=cfg.output_dir,
    validator=validator,
    visualizer=visualizer,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
solver.eval()
# visualize prediction from pretrained_model_path

4. Precipitation Model Implementation

First, the various parameter variables defined in the code are displayed. The specific meaning of each parameter will be explained when used below.

examples/fourcastnet/conf/fourcastnet_precip.yaml
# set training hyper-parameters
IMG_H: 720
IMG_W: 1440
# FourCastNet use 20 atmospheric variable,their index in the dataset is from 0 to 19.
# The variable name is 'u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z000',
# 'u850', 'v850', 'z850',  'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv'.
# You can obtain detailed information about each variable from
# https://cds.climate.copernicus.eu/cdsapp#!/search?text=era5&type=dataset
VARS_CHANNEL: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

# set train data path
WIND_TRAIN_FILE_PATH: ./datasets/era5/train
WIND_MEAN_PATH: ./datasets/era5/stat/global_means.npy
WIND_STD_PATH: ./datasets/era5/stat/global_stds.npy
WIND_TIME_MEAN_PATH: ./datasets/era5/stat/time_means.npy

TRAIN_FILE_PATH: ./datasets/era5/precip/train
TIME_MEAN_PATH: ./datasets/era5/stat/precip/time_means.npy

# set evaluate data path
WIND_VALID_FILE_PATH: ./datasets/era5/test
VALID_FILE_PATH: ./datasets/era5/precip/test

# set test data path
WIND_TEST_FILE_PATH: ./datasets/era5/out_of_sample/2018.h5
TEST_FILE_PATH: ./datasets/era5/precip/out_of_sample/2018.h5

# set wind model path

4.1 Constraint Construction

This case solves the problem based on data-driven methods, so it is necessary to use SupervisedConstraint built in PaddleScience to construct supervised constraints. Before defining constraints, you need to first specify various parameters used for data loading in supervised constraints. First introduce the data preprocessing part, the code is as follows:

examples/fourcastnet/train_precip.py
wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
    cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
)
data_time_mean = fourcast_utils.get_time_mean(
    cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
)

# set train transforms
transforms = [
    {"SqueezeData": {}},
    {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
    {
        "Normalize": {
            "mean": wind_data_mean,
            "std": wind_data_std,
            "apply_keys": ("input",),
        }
    },
    {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},

The data preprocessing part contains a total of 4 preprocessing methods, namely:

  1. SqueezeData: Compress the dimensions of training data. If the dimension of input data is 4, compress data of 0th dimension and 1st dimension together, and finally transform the dimension of input data to 3.
  2. CropData: Crop data at specified position from training data. Because the original data size in ERA5 dataset is \(721 \times 1440\), this case crops the training data size to \(720 \times 1440\) according to the original paper setting.
  3. Normalize: Normalize data according to mean and variance on the training dataset. Here, the apply_keys field sets this preprocessing method to be applied only to input data.
  4. Log1p: Map data to logarithmic space. Here, the apply_keys field sets this preprocessing method to be applied only to ground truth data.

The code for data loading is as follows:

examples/fourcastnet/train_precip.py
# set train dataloader config
train_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_TRAIN_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": cfg.MODEL.precip.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.TRAIN_FILE_PATH,
        "transforms": transforms,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": True,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 8,

Among them, the "dataset" field defines the used Dataset class name as ERA5Dataset, the "sampler" field defines the used Sampler class name as BatchSampler, setting batch_size to 1 and num_works to 8.

The code for defining supervised constraints is as follows:

examples/fourcastnet/train_precip.py
}
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    name="Sup",
)

The first parameter of SupervisedConstraint is the data loading method, here train_dataloader_cfg defined above is used;

The second parameter is the definition of loss function, here L2RelLoss is used;

The third parameter is the name of the constraint condition, which is convenient for subsequent indexing. Here it is named "Sup".

4.2 Model Construction

We first define the wind speed model architecture and load its pre-trained weights. Then, we define the precipitation model:

examples/fourcastnet/train_precip.py
# set model
wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
model_cfg = dict(cfg.MODEL.precip)
model_cfg.update({"wind_model": wind_model})

The parameters for defining the model are set through configuration as follows:

examples/fourcastnet/conf/fourcastnet_precip.yaml
# set inference data path
WIND_INFER_PATH: ./datasets/era5/test/2018-04-04_n6.npy
INFER_FILE_PATH: ./datasets/era5/test/2018-04-04_n6_precip.npy

# model settings
MODEL:
  afno:

Among them, input_keys and output_keys represent the names of input and output variables of the network model respectively.

4.3 Learning Rate and Optimizer Construction

The learning rate method used in this case is Cosine, and the learning rate size is set to 2.5e-4. The optimizer uses Adam, expressed in PaddleScience code as follows:

examples/fourcastnet/train_precip.py
# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

4.4 Validator Construction

In this case, the validation set is used to evaluate the training status of the current model at certain training epoch intervals during the training process, and SupervisedValidator is needed to construct the validator. The code is as follows:

examples/fourcastnet/train_precip.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_VALID_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": cfg.MODEL.precip.output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.VALID_FILE_PATH,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}

# set metric
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H, keep_batch=True, unlog=True
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
    ),
}

# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)

The SupervisedValidator validator is similar to SupervisedConstraint, the difference is that the validator needs to set the evaluation metric metric, here 3 evaluation metrics are used, namely MAE, LatitudeWeightedRMSE and LatitudeWeightedACC.

4.5 Model Training and Evaluation

After completing the above settings, you only need to pass the instantiated objects to ppsci.solver.Solver in order, and then start training and evaluation.

examples/fourcastnet/train_precip.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    constraint,
    cfg.output_dir,
    optimizer,
    lr_scheduler,
    cfg.TRAIN.epochs,
    ITERS_PER_EPOCH,
    eval_during_train=True,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()
# evaluate after finished training

4.6 Evaluating Model on Test Set

According to the settings in the paper, when evaluating the model on the test set, num_timestamps is set to 6, and the interval between two adjacent test samples is 8.

The code for constructing the model is:

examples/fourcastnet/train_precip.py
# set model for testing
wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
model_cfg = dict(cfg.MODEL.precip)
model_cfg.update(
    {
        "output_keys": output_keys,
        "num_timestamps": cfg.EVAL.num_timestamps,
        "wind_model": wind_model,
    }
)

The code for constructing the validator is:

examples/fourcastnet/train_precip.py
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5Dataset",
        "file_path": cfg.WIND_TEST_FILE_PATH,
        "input_keys": cfg.MODEL.precip.input_keys,
        "label_keys": output_keys,
        "vars_channel": cfg.VARS_CHANNEL,
        "precip_file_path": cfg.TEST_FILE_PATH,
        "num_label_timestamps": cfg.EVAL.num_timestamps,
        "stride": 8,
        "transforms": transforms,
        "training": False,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": False,
    },
    "batch_size": cfg.EVAL.batch_size,
}
# set metirc
metric = {
    "MAE": ppsci.metric.MAE(keep_batch=True),
    "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
        num_lat=cfg.IMG_H, keep_batch=True, unlog=True
    ),
    "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
        num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
    ),
}

# set validator for testing
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.L2RelLoss(),
    metric=metric,
    name="Sup_Validator",
)

4.7 Visualizer Construction

The precipitation model uses autoregressive method for inference, and the input data for model inference needs to be set first. The code is as follows:

examples/fourcastnet/train_precip.py
# set set visualizer data
DATE_STRINGS = ("2018-04-04 00:00:00",)
vis_data = get_vis_data(
    cfg.WIND_TEST_FILE_PATH,
    cfg.TEST_FILE_PATH,
    DATE_STRINGS,
    cfg.EVAL.num_timestamps,
    cfg.VARS_CHANNEL,
    cfg.IMG_H,
    wind_data_mean,
    wind_data_std,
examples/fourcastnet/train_precip.py
def get_vis_data(
    wind_file_path: str,
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    __wind_file = h5py.File(wind_file_path, "r")["fields"]
    _file = h5py.File(file_path, "r")["tp"]
    wind_data = []
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        wind_data.append(__wind_file[ic, vars_channel, 0:img_h])
        data.append(_file[ic + 1 : ic + num_timestamps + 1, 0:img_h])
    wind_data = np.asarray(wind_data)
    data = np.asarray(data)

    vis_data = {"input": (wind_data - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t]
        vis_data[f"target_{hour}h"] = np.asarray(data_t)

In the above code, the corresponding data is read for model input based on the set time parameter DATE_STRINGS. In addition, the get_vis_datas function also reads the ground truth data at the corresponding time. These data will also be visualized for comparison with the model prediction results.

Since the model performs logarithmic processing on precipitation, it is necessary to remap the model results back to linear space. The code is as follows:

examples/fourcastnet/train_precip.py
def output_precip_func(d, var_name):
    output = 1e-2 * paddle.expm1(d[var_name][0])
    return output

visu_output_expr = {}
for i in range(cfg.EVAL.num_timestamps):
    hour = (i + 1) * 6
    visu_output_expr[f"output_{hour}h"] = functools.partial(
        output_precip_func,
        var_name=f"output_{i}",
    )
    visu_output_expr[f"target_{hour}h"] = (
        lambda d, hour=hour: d[f"target_{hour}h"] * 1000

Finally, the code for constructing the visualizer is as follows:

examples/fourcastnet/train_precip.py
    )
# set visualizer
visualizer = {
    "visualize_precip": ppsci.visualize.VisualizerWeather(
        vis_data,
        visu_output_expr,
        xticks=np.linspace(0, 1439, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, 719, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0.001,
        vmax=130,
        colorbar_label="mm",
        log_norm=True,
        batch_size=cfg.EVAL.batch_size,
        num_timestamps=cfg.EVAL.num_timestamps,
        prefix="precip",
    )

The constructed model, validator, and visualizer above will be passed to ppsci.solver.Solver for evaluating performance on the test set and visualization.

examples/fourcastnet/train_precip.py
solver = ppsci.solver.Solver(
    model,
    output_dir=cfg.output_dir,
    validator=validator,
    visualizer=visualizer,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
solver.eval()
# visualize prediction

5. Complete Code

examples/fourcastnet/train_pretrain.py
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from os import path as osp

import hydra
import numpy as np
import paddle.distributed as dist
import utils as fourcast_utils
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def get_data_stat(cfg: DictConfig):
    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    return data_mean, data_std, data_time_mean_normalize


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set train dataloader config
    if not cfg.USE_SAMPLED_DATA:
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5Dataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.afno.input_keys,
                "label_keys": cfg.MODEL.afno.output_keys,
                "vars_channel": cfg.VARS_CHANNEL,
                "transforms": transforms,
            },
            "sampler": {
                "name": "BatchSampler",
                "drop_last": True,
                "shuffle": True,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 8,
        }
    else:
        NUM_GPUS_PER_NODE = 8
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5SampledDataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.afno.input_keys,
                "label_keys": cfg.MODEL.afno.output_keys,
            },
            "sampler": {
                "name": "DistributedBatchSampler",
                "drop_last": True,
                "shuffle": True,
                "num_replicas": NUM_GPUS_PER_NODE,
                "rank": dist.get_rank() % NUM_GPUS_PER_NODE,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 8,
        }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
                num_lat=cfg.IMG_H,
                std=data_std,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
            "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
                num_lat=cfg.IMG_H,
                mean=data_time_mean_normalize,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        seed=cfg.seed,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )
    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": cfg.MODEL.afno.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
                num_lat=cfg.IMG_H,
                std=data_std,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
            "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
                num_lat=cfg.IMG_H,
                mean=data_time_mean_normalize,
                keep_batch=True,
                variable_dict={"u10": 0, "v10": 1},
            ),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        log_freq=cfg.log_freq,
        seed=cfg.seed,
        validator=validator,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # evaluate
    solver.eval()


def export(cfg: DictConfig):
    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        pretrained_model_path=cfg.INFER.pretrained_model_path,
    )

    # export model
    from paddle.static import InputSpec

    input_spec = [
        {
            key: InputSpec([None, 20, cfg.IMG_H, cfg.IMG_W], "float32", name=key)
            for key in model.input_keys
        },
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    from deploy.python_infer import pinn_predictor

    predictor = pinn_predictor.PINNPredictor(cfg)

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )

    data = np.load(cfg.INFER_FILE_PATH)
    input_0 = (data[:, 0] - data_mean) / data_std
    all_data = input_0

    for t in range(cfg.INFER.num_timestamps):
        data_t = data[:, t + 1]
        data_t = (data_t - data_mean) / data_std
        all_data = np.concatenate((all_data, data_t), axis=0)

    input_dict = {cfg.MODEL.afno.input_keys[0]: all_data}

    vis_output = predictor.predict(input_dict, cfg.INFER.batch_size)

    vis_dict = {
        store_key: vis_output[infer_key]
        for store_key, infer_key in zip(cfg.MODEL.afno.output_keys, vis_output.keys())
    }

    def output_wind_func(output, data_mean, data_std):
        output = (output * data_std) + data_mean
        wind_data = (output[0] ** 2 + output[1] ** 2) ** 0.5
        return wind_data

    wind_pred = []
    pred_dict = {}
    for i in range(cfg.INFER.num_timestamps):
        hour = (i + 1) * 6
        wind_ = [
            output_wind_func(
                vis_dict[cfg.MODEL.afno.output_keys[0]][i], data_mean, data_std
            )
        ]
        wind_pred.append(wind_)
        pred_dict[f"output_{hour}h"] = np.asarray(wind_)
    output_dict = {cfg.MODEL.afno.output_keys[0]: np.array(wind_pred)}

    wind_pred = []
    target_dict = {}
    for i in range(cfg.INFER.num_timestamps):
        hour = (i + 1) * 6
        wind_ = [(data[0][i][0] ** 2 + data[0][i][1] ** 2) ** 0.5]
        target_dict[f"target_{hour}h"] = np.asarray(wind_)

    vis_dict = {**pred_dict, **target_dict}

    plot_expr_dict = {}
    for hour in range(6, 6 + cfg.INFER.num_timestamps * 6, 6):
        plot_expr_dict.update(
            {
                f"target_{hour}h": lambda d, hour=hour: d[f"target_{hour}h"],
                f"output_{hour}h": lambda d, hour=hour: d[f"output_{hour}h"],
            }
        )

    visualizer_weather = ppsci.visualize.VisualizerWeather(
        vis_dict,
        plot_expr_dict,
        xticks=np.linspace(0, cfg.IMG_W - 1, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, cfg.IMG_H - 1, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0,
        vmax=25,
        colorbar_label="m\s",
        batch_size=1,
        num_timestamps=cfg.INFER.num_timestamps,
        prefix="wind",
    )
    visualizer_weather.save(cfg.INFER.export_path, vis_dict)
    save_path = osp.join(cfg.INFER.export_path, "predict.npy")
    os.makedirs(cfg.INFER.export_path, exist_ok=True)
    np.save(save_path, output_dict[cfg.MODEL.afno.output_keys[0]])


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_pretrain.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()
examples/fourcastnet/train_finetune.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from os import path as osp
from typing import Tuple

import h5py
import hydra
import numpy as np
import paddle
import utils as fourcast_utils
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def get_vis_data(
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    _file = h5py.File(file_path, "r")["fields"]
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        data.append(_file[ic : ic + num_timestamps + 1, vars_channel, 0:img_h])
    data = np.asarray(data)

    vis_data = {"input": (data[:, 0] - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t + 1]
        wind_data = []
        for i in range(data_t.shape[0]):
            wind_data.append((data_t[i][0] ** 2 + data_t[i][1] ** 2) ** 0.5)
        vis_data[f"target_{hour}h"] = np.asarray(wind_data)
    return vis_data


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.set_random_seed(cfg.seed)

    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")

    # set training hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.TRAIN.num_timestamps))

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )

    # set transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]
    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "num_label_timestamps": cfg.TRAIN.num_timestamps,
            "transforms": transforms,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "num_label_timestamps": cfg.TRAIN.num_timestamps,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metric
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model_cfg = dict(cfg.MODEL.afno)
    model_cfg.update(
        {"output_keys": output_keys, "num_timestamps": cfg.TRAIN.num_timestamps}
    )

    model = ppsci.arch.AFNONet(**model_cfg)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()
    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        validator=validator,
        pretrained_model_path=cfg.TRAIN.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    # set testing hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.EVAL.num_timestamps))

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.DATA_TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W, cfg.VARS_CHANNEL
    )
    data_time_mean_normalize = np.expand_dims(
        (data_time_mean[0] - data_mean) / data_std, 0
    )

    # set transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {"Normalize": {"mean": data_mean, "std": data_std}},
    ]

    # set model
    model_cfg = dict(cfg.MODEL.afno)
    model_cfg.update(
        {"output_keys": output_keys, "num_timestamps": cfg.EVAL.num_timestamps}
    )
    model = ppsci.arch.AFNONet(**model_cfg)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.TEST_FILE_PATH,
            "input_keys": cfg.MODEL.afno.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "transforms": transforms,
            "num_label_timestamps": cfg.EVAL.num_timestamps,
            "training": False,
            "stride": 8,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metirc
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H,
            std=data_std,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H,
            mean=data_time_mean_normalize,
            keep_batch=True,
            variable_dict={"u10": 0, "v10": 1},
        ),
    }

    # set validator for testing
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set visualizer data
    DATE_STRINGS = ("2018-09-08 00:00:00",)
    vis_data = get_vis_data(
        cfg.TEST_FILE_PATH,
        DATE_STRINGS,
        cfg.EVAL.num_timestamps,
        cfg.VARS_CHANNEL,
        cfg.IMG_H,
        data_mean,
        data_std,
    )

    def output_wind_func(d, var_name, data_mean, data_std):
        output = (d[var_name] * data_std) + data_mean
        wind_data = []
        for i in range(output.shape[0]):
            wind_data.append((output[i][0] ** 2 + output[i][1] ** 2) ** 0.5)
        return paddle.to_tensor(wind_data, paddle.get_default_dtype())

    vis_output_expr = {}
    for i in range(cfg.EVAL.num_timestamps):
        hour = (i + 1) * 6
        vis_output_expr[f"output_{hour}h"] = functools.partial(
            output_wind_func,
            var_name=f"output_{i}",
            data_mean=paddle.to_tensor(data_mean, paddle.get_default_dtype()),
            data_std=paddle.to_tensor(data_std, paddle.get_default_dtype()),
        )
        vis_output_expr[f"target_{hour}h"] = lambda d, hour=hour: d[f"target_{hour}h"]
    # set visualizer
    visualizer = {
        "visualize_wind": ppsci.visualize.VisualizerWeather(
            vis_data,
            vis_output_expr,
            xticks=np.linspace(0, 1439, 13),
            xticklabels=[str(i) for i in range(360, -1, -30)],
            yticks=np.linspace(0, 719, 7),
            yticklabels=[str(i) for i in range(90, -91, -30)],
            vmin=0,
            vmax=25,
            colorbar_label="m\s",
            batch_size=cfg.EVAL.batch_size,
            num_timestamps=cfg.EVAL.num_timestamps,
            prefix="wind",
        )
    }

    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        validator=validator,
        visualizer=visualizer,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    solver.eval()
    # visualize prediction from pretrained_model_path
    solver.visualize()


def export(cfg: DictConfig):
    # set model
    model = ppsci.arch.AFNONet(**cfg.MODEL.afno)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        pretrained_model_path=cfg.INFER.pretrained_model_path,
    )
    # export model
    from paddle.static import InputSpec

    input_spec = [
        {
            key: InputSpec([None, 20, cfg.IMG_H, cfg.IMG_W], "float32", name=key)
            for key in model.input_keys
        },
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    from deploy.python_infer import pinn_predictor

    predictor = pinn_predictor.PINNPredictor(cfg)

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.DATA_MEAN_PATH, cfg.DATA_STD_PATH, cfg.VARS_CHANNEL
    )

    data = np.load(cfg.INFER_FILE_PATH)
    input_0 = (data[:, 0] - data_mean) / data_std
    all_data = input_0

    for t in range(cfg.INFER.num_timestamps):
        data_t = data[:, t + 1]
        data_t = (data_t - data_mean) / data_std
        all_data = np.concatenate((all_data, data_t), axis=0)

    input_dict = {cfg.MODEL.afno.input_keys[0]: all_data}

    vis_output = predictor.predict(input_dict, cfg.INFER.batch_size)

    vis_dict = {
        store_key: vis_output[infer_key]
        for store_key, infer_key in zip(cfg.MODEL.afno.output_keys, vis_output.keys())
    }

    def output_wind_func(output, data_mean, data_std):
        output = (output * data_std) + data_mean
        wind_data = (output[0] ** 2 + output[1] ** 2) ** 0.5
        return wind_data

    wind_pred = []
    pred_dict = {}
    for i in range(cfg.INFER.num_timestamps):
        hour = (i + 1) * 6
        wind_ = [
            output_wind_func(
                vis_dict[cfg.MODEL.afno.output_keys[0]][i], data_mean, data_std
            )
        ]
        wind_pred.append(wind_)
        pred_dict[f"output_{hour}h"] = np.asarray(wind_)
    output_dict = {cfg.MODEL.afno.output_keys[0]: np.array(wind_pred)}

    wind_pred = []
    target_dict = {}
    for i in range(cfg.INFER.num_timestamps):
        hour = (i + 1) * 6
        wind_ = [(data[0][i][0] ** 2 + data[0][i][1] ** 2) ** 0.5]
        target_dict[f"target_{hour}h"] = np.asarray(wind_)

    vis_dict = {**pred_dict, **target_dict}

    plot_expr_dict = {}
    for hour in range(6, 6 + cfg.INFER.num_timestamps * 6, 6):
        plot_expr_dict.update(
            {
                f"target_{hour}h": lambda d, hour=hour: d[f"target_{hour}h"],
                f"output_{hour}h": lambda d, hour=hour: d[f"output_{hour}h"],
            }
        )

    visualizer_weather = ppsci.visualize.VisualizerWeather(
        vis_dict,
        plot_expr_dict,
        xticks=np.linspace(0, cfg.IMG_W - 1, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, cfg.IMG_H - 1, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0,
        vmax=25,
        colorbar_label="m\s",
        batch_size=1,
        num_timestamps=cfg.INFER.num_timestamps,
        prefix="wind",
    )
    visualizer_weather.save(cfg.INFER.export_path, vis_dict)
    save_path = osp.join(cfg.INFER.export_path, "predict.npy")
    os.makedirs(cfg.INFER.export_path, exist_ok=True)
    np.save(save_path, output_dict[cfg.MODEL.afno.output_keys[0]])


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_finetune.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()
examples/fourcastnet/train_precip.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import os.path as osp
from typing import Tuple

import h5py
import hydra
import numpy as np
import paddle
import utils as fourcast_utils
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger


def get_vis_data(
    wind_file_path: str,
    file_path: str,
    date_strings: Tuple[str, ...],
    num_timestamps: int,
    vars_channel: Tuple[int, ...],
    img_h: int,
    data_mean: np.ndarray,
    data_std: np.ndarray,
):
    __wind_file = h5py.File(wind_file_path, "r")["fields"]
    _file = h5py.File(file_path, "r")["tp"]
    wind_data = []
    data = []
    for date_str in date_strings:
        hours_since_jan_01_epoch = fourcast_utils.date_to_hours(date_str)
        ic = int(hours_since_jan_01_epoch / 6)
        wind_data.append(__wind_file[ic, vars_channel, 0:img_h])
        data.append(_file[ic + 1 : ic + num_timestamps + 1, 0:img_h])
    wind_data = np.asarray(wind_data)
    data = np.asarray(data)

    vis_data = {"input": (wind_data - data_mean) / data_std}
    for t in range(num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t]
        vis_data[f"target_{hour}h"] = np.asarray(data_t)
    return vis_data


def train(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", f"{cfg.output_dir}/train.log", "info")

    wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
        cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
    )

    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {
            "Normalize": {
                "mean": wind_data_mean,
                "std": wind_data_std,
                "apply_keys": ("input",),
            }
        },
        {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
    ]

    # set train dataloader config
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": cfg.MODEL.precip.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.TRAIN_FILE_PATH,
            "transforms": transforms,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 8,
    }
    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_VALID_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": cfg.MODEL.precip.output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.VALID_FILE_PATH,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set metric
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H, keep_batch=True, unlog=True
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
        ),
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
    ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update({"wind_model": wind_model})
    model = ppsci.arch.PrecipNet(**model_cfg)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()
    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        constraint,
        cfg.output_dir,
        optimizer,
        lr_scheduler,
        cfg.TRAIN.epochs,
        ITERS_PER_EPOCH,
        eval_during_train=True,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set random seed for reproducibility
    ppsci.utils.misc.set_random_seed(cfg.seed)
    # initialize logger
    logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info")

    # set testing hyper-parameters
    output_keys = tuple(f"output_{i}" for i in range(cfg.EVAL.num_timestamps))

    # set model for testing
    wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
    ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.WIND_MODEL_PATH)
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update(
        {
            "output_keys": output_keys,
            "num_timestamps": cfg.EVAL.num_timestamps,
            "wind_model": wind_model,
        }
    )
    model = ppsci.arch.PrecipNet(**model_cfg)

    wind_data_mean, wind_data_std = fourcast_utils.get_mean_std(
        cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
    )
    data_time_mean = fourcast_utils.get_time_mean(
        cfg.TIME_MEAN_PATH, cfg.IMG_H, cfg.IMG_W
    )

    # set train transforms
    transforms = [
        {"SqueezeData": {}},
        {"CropData": {"xmin": (0, 0), "xmax": (cfg.IMG_H, cfg.IMG_W)}},
        {
            "Normalize": {
                "mean": wind_data_mean,
                "std": wind_data_std,
                "apply_keys": ("input",),
            }
        },
        {"Log1p": {"scale": 1e-5, "apply_keys": ("label",)}},
    ]

    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5Dataset",
            "file_path": cfg.WIND_TEST_FILE_PATH,
            "input_keys": cfg.MODEL.precip.input_keys,
            "label_keys": output_keys,
            "vars_channel": cfg.VARS_CHANNEL,
            "precip_file_path": cfg.TEST_FILE_PATH,
            "num_label_timestamps": cfg.EVAL.num_timestamps,
            "stride": 8,
            "transforms": transforms,
            "training": False,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": False,
        },
        "batch_size": cfg.EVAL.batch_size,
    }
    # set metirc
    metric = {
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "LatitudeWeightedRMSE": ppsci.metric.LatitudeWeightedRMSE(
            num_lat=cfg.IMG_H, keep_batch=True, unlog=True
        ),
        "LatitudeWeightedACC": ppsci.metric.LatitudeWeightedACC(
            num_lat=cfg.IMG_H, mean=data_time_mean, keep_batch=True, unlog=True
        ),
    }

    # set validator for testing
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.L2RelLoss(),
        metric=metric,
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set set visualizer data
    DATE_STRINGS = ("2018-04-04 00:00:00",)
    vis_data = get_vis_data(
        cfg.WIND_TEST_FILE_PATH,
        cfg.TEST_FILE_PATH,
        DATE_STRINGS,
        cfg.EVAL.num_timestamps,
        cfg.VARS_CHANNEL,
        cfg.IMG_H,
        wind_data_mean,
        wind_data_std,
    )

    def output_precip_func(d, var_name):
        output = 1e-2 * paddle.expm1(d[var_name][0])
        return output

    visu_output_expr = {}
    for i in range(cfg.EVAL.num_timestamps):
        hour = (i + 1) * 6
        visu_output_expr[f"output_{hour}h"] = functools.partial(
            output_precip_func,
            var_name=f"output_{i}",
        )
        visu_output_expr[f"target_{hour}h"] = (
            lambda d, hour=hour: d[f"target_{hour}h"] * 1000
        )
    # set visualizer
    visualizer = {
        "visualize_precip": ppsci.visualize.VisualizerWeather(
            vis_data,
            visu_output_expr,
            xticks=np.linspace(0, 1439, 13),
            xticklabels=[str(i) for i in range(360, -1, -30)],
            yticks=np.linspace(0, 719, 7),
            yticklabels=[str(i) for i in range(90, -91, -30)],
            vmin=0.001,
            vmax=130,
            colorbar_label="mm",
            log_norm=True,
            batch_size=cfg.EVAL.batch_size,
            num_timestamps=cfg.EVAL.num_timestamps,
            prefix="precip",
        )
    }

    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        validator=validator,
        visualizer=visualizer,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    solver.eval()
    # visualize prediction
    solver.visualize()


def export(cfg: DictConfig):
    # set model
    wind_model = ppsci.arch.AFNONet(**cfg.MODEL.afno)
    ppsci.utils.save_load.load_pretrain(wind_model, path=cfg.INFER.WIND_MODEL_PATH)
    output_keys = tuple(f"output_{i}" for i in range(cfg.INFER.num_timestamps))
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update(
        {
            "output_keys": output_keys,
            "num_timestamps": cfg.INFER.num_timestamps,
            "wind_model": wind_model,
        }
    )
    model = ppsci.arch.PrecipNet(**model_cfg)
    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        pretrained_model_path=cfg.INFER.pretrained_model_path,
    )
    # export model
    from paddle.static import InputSpec

    input_spec = [
        {
            key: InputSpec([None, 20, cfg.IMG_H, cfg.IMG_W], "float32", name=key)
            for key in model.input_keys
        },
    ]
    solver.export(input_spec, cfg.INFER.export_path)


def inference(cfg: DictConfig):
    output_keys = tuple(f"output_{i}" for i in range(cfg.INFER.num_timestamps))
    model_cfg = dict(cfg.MODEL.precip)
    model_cfg.update(
        {
            "output_keys": output_keys,
        }
    )

    from deploy.python_infer import pinn_predictor

    predictor = pinn_predictor.PINNPredictor(cfg)

    data_mean, data_std = fourcast_utils.get_mean_std(
        cfg.WIND_MEAN_PATH, cfg.WIND_STD_PATH, cfg.VARS_CHANNEL
    )

    wind_data = np.load(cfg.WIND_INFER_PATH)
    data = np.load(cfg.INFER_FILE_PATH)

    input_datas = (wind_data - data_mean) / data_std
    input_dict = {cfg.MODEL.precip.input_keys[0]: input_datas}
    vis_datas = {cfg.MODEL.precip.input_keys[0]: input_datas}

    for t in range(cfg.INFER.num_timestamps):
        hour = (t + 1) * 6
        data_t = data[:, t] * 1000
        vis_datas[f"target_{hour}h"] = np.asarray(data_t)

    vis_output = predictor.predict(input_dict, cfg.INFER.batch_size)

    re_dict = {
        store_key: vis_output[infer_key]
        for store_key, infer_key in zip(model_cfg["output_keys"], vis_output.keys())
    }

    plot_dict = vis_datas

    output_dict = {}
    for t in range(cfg.INFER.num_timestamps):
        hour = (t + 1) * 6
        output_dict[f"output_{t}"] = 1e-2 * np.expm1(re_dict[f"output_{t}"][0])
        plot_dict[f"output_{hour}h"] = output_dict[f"output_{t}"]
    output = np.concatenate(list(output_dict.values()), axis=0)
    output_dict[cfg.MODEL.precip.output_keys[0]] = output

    plot_expr_dict = {}
    for hour in range(6, 6 + cfg.INFER.num_timestamps * 6, 6):
        plot_expr_dict.update(
            {
                f"target_{hour}h": lambda d, hour=hour: d[f"target_{hour}h"],
                f"output_{hour}h": lambda d, hour=hour: d[f"output_{hour}h"],
            }
        )

    visualizer_weather = ppsci.visualize.VisualizerWeather(
        plot_dict,
        plot_expr_dict,
        xticks=np.linspace(0, cfg.IMG_W - 1, 13),
        xticklabels=[str(i) for i in range(360, -1, -30)],
        yticks=np.linspace(0, cfg.IMG_H - 1, 7),
        yticklabels=[str(i) for i in range(90, -91, -30)],
        vmin=0.001,
        vmax=130,
        colorbar_label="mm",
        log_norm=True,
        batch_size=1,
        num_timestamps=cfg.INFER.num_timestamps,
        prefix="precip",
    )
    visualizer_weather.save(cfg.INFER.export_path, plot_dict)
    save_path = osp.join(cfg.INFER.export_path, "predict.npy")
    os.makedirs(cfg.INFER.export_path, exist_ok=True)
    np.save(save_path, output_dict)


@hydra.main(
    version_base=None, config_path="./conf", config_name="fourcastnet_precip.yaml"
)
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

6. Result Display

The figure below shows the prediction results and ground truth results of the wind speed model at 6-hour intervals.

result_wind

Wind speed model prediction result ("output") vs ground truth result ("target")

The figure below shows the prediction results and ground truth results of the precipitation model at 6-hour intervals.

result_precip

Precipitation model prediction result ("output") vs ground truth result ("target")