Skip to content

UNetFormer

Note

  1. Before running, it is recommended to quickly understand Dataset and Data Reading Method.
  2. Download [Vaihingen Dataset] to the corresponding subdirectory in data directory (e.g. data/vaihingen/train_images).
  3. Run tools/vaihingen_patch_split.py to process the original dataset and get trainable data.

The file dataset structure is as follows

airs
├── unetformer(code)
├── model_weights (save the model weights trained on ISPRS vaihingen)
├── fig_results (save the masks predicted by models)
├── lightning_logs (CSV format training logs)
├── data
   ├── vaihingen
      ├── train_images (original)
      ├── train_masks (original)
      ├── test_images (original)
      ├── test_masks (original)
      ├── test_masks_eroded (original)
      ├── train (processed)
      ├── test (processed)
# Download [Vaihingen Dataset] to the corresponding subdirectory in `data` directory (e.g. `data/vaihingen/train_images`)
# Create training dataset
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/train_images" --mask-dir "data/vaihingen/train_masks" --output-img-dir "data/vaihingen/train/images_1024" --output-mask-dir "data/vaihingen/train/masks_1024" --mode "train" --split-size 1024 --stride 512
# Create test dataset
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks_eroded" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024" --mode "val" --split-size 1024 --stride 1024 --eroded
# Create masks_1024_rgb visualization dataset
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024_rgb" --mode "val" --split-size 1024 --stride 1024 --gt
# Model training
python train_supervision.py -c config/vaihingen/unetformer.py
# Download processed [Vaihingen Test Dataset](https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip), and unzip.
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip -P ./data/vaihingen/
unzip -q ./data/vaihingen/test.zip -d data/vaihingen/
# Download pretrained model file
wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/unetformer/unetformer-r18-512-crop-ms-e105_epoch0_best.pdparams -P ./model_weights/vaihingen/unetformer-r18-512-crop-ms-e105/
python vaihingen_test.py -c config/vaihingen/unetformer.py -o fig_results/vaihingen/unetformer --rgb

1. Background Introduction

Semantic segmentation of remote sensing urban scene images has wide demands in many practical applications, such as land cover mapping, urban change detection, environmental protection and economic evaluation. Driven by the rapid development of deep learning technology, Convolutional Neural Networks (CNN) have dominated the field of semantic segmentation for many years. CNN adopts hierarchical feature representation and shows strong local information extraction ability. However, the local nature of the convolutional layer limits the network's ability to capture global context information. In recent years, as a hot research direction in the field of computer vision, Transformer architecture has shown great potential in global information modeling, significantly improving the performance of vision-related tasks such as image classification, object detection, and especially semantic segmentation.

This paper proposes a Transformer-based decoder architecture and constructs a UNet-like Transformer network (UNetFormer) for real-time urban scene segmentation. To achieve efficient segmentation, UNetFormer selects lightweight ResNet18 as the encoder and develops an efficient global-local attention mechanism in the decoder to model global and local information simultaneously. The Transformer-based decoder proposed in this paper combined with Swin Transformer encoder also achieved the current best performance (91.3% F1 score and 84.1% mIoU) on the Vaihingen dataset.

2. Model Principle

This paragraph only briefly introduces the model principle. For specific details, please read UNetFormer: A UNet-like Transformer for Efficient Semantic Segmentation of Remote Sensing Urban Scene Imagery.

2.1 Model Structure

UNetFormer is a deep learning network based on transformer decoder. The figure below shows the overall structure of the model.

UNetFormer1

  • ResBlock is each module of resnet18 network.

  • GLTB consists of global-local attention, MLP, two batchnorm layers and two sum operations.

2.2 Loss Function

The discriminator's loss function consists of two parts. The main loss function \(\mathcal{L}_{\text {p }}\) is SoftCrossEntropyLoss cross entropy loss function \(\mathcal{L}_{c e}\) and DiceLoss loss function \(\mathcal{L}_{\text {dice }}\). Its expression is:

\[ \mathcal{L}_{c e}=-\frac{1}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} y_{k}^{(n)} \log \hat{y}_{k}^{(n)} \]
\[ \mathcal{L}_{\text {dice }}=1-\frac{2}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} \frac{\hat{y}_{k}^{(n)} y_{k}^{(n)}}{\hat{y}_{k}^{(n)}+y_{k}^{(n)}} \]
\[ \mathcal{L}_{\text {p }}=\mathcal{L}_{c e}+\mathcal{L}_{\text {dice }} \]

Where N and K respectively represent the number of samples and the number of categories. \(y^{(n)}\) and \(\hat{y}^{(n)}\) represent the one-hot encoding of the label and the corresponding softmax output, \(\mathrm{n} \in[1, \ldots, \mathrm{n}]\).

In order to combine better, we choose the cross entropy function as the auxiliary loss function \({L}_{a u x}\), and multiply by the coefficient \(\alpha\). The expression of the total loss function is:

\[ \mathcal{L}=\mathcal{L}_{p}+\alpha \times \mathcal{L}_{a u x} \]

Where \(\alpha\) defaults to 0.4.

3. Model Construction

Below we explain the key parts of building UnetFormer using PaddleScience.

3.1 Dataset Introduction

The dataset uses Vaihingen dataset open sourced by ISPRS.

ISPRS provides two state-of-the-art airborne image datasets for urban classification and 3D building reconstruction test projects. This dataset uses digital surface models (DSM) generated by high-resolution orthophotos and corresponding dense image matching techniques. Both dataset areas cover urban scenes. Vaihingen is a relatively small village with many independent buildings and small multi-story buildings. This dataset contains 33 remote sensing images of different sizes, each extracted from a larger top-level orthophoto image. The image selection process avoids the situation of no data. The spatial resolution of the top-level image and DSM is 9 cm. Remote sensing image format is 8-bit TIFF file, consisting of 3 bands: near infrared, red and green. DSM is a single-band TIFF file, and the gray level (corresponding to DSM height) is 32-bit floating point value encoding.

image-vaihingen

Each dataset has been manually classified into 6 most common land cover categories.

① Impervious surface (RGB: 255, 255, 255)

② Building (RGB: 0, 0, 255)

③ Low vegetation (RGB: 0, 255, 255)

④ Tree (RGB: 0, 255, 0)

⑤ Car (RGB: 255, 255, 0)

⑥ Background (RGB: 255, 0, 0)

The background class includes water bodies and objects different from other defined categories (such as containers, tennis courts, swimming pools), which usually belong to semantic objects of no interest in urban scenes.

3.2 Build dataset API

Since a dataset consists of 33 ultra-large remote sensing images. In order to facilitate training, we customize an image segmentation program to segment the original image into 1024×1024 size trainable images. The specific information of the program code can be seen in GeoSeg/tools/vaihingen_patch_split.py.

3.3 Model Construction

The model construction code for this case is as follows

Parameter configuration is as follows:

max_epoch = 105
ignore_index = len(CLASSES)
train_batch_size = 8
val_batch_size = 8
lr = 0.0006
weight_decay = 0.01
backbone_lr = 6e-05
backbone_weight_decay = 0.01
num_classes = len(CLASSES)
classes = CLASSES
weights_name = "unetformer-r18-512-crop-ms-e105"
weights_path = "model_weights/vaihingen/{}".format(weights_name)
test_weights_name = "unetformer-r18-512-crop-ms-e105_epoch0_best"
log_name = "vaihingen/{}".format(weights_name)
monitor = "val_F1"
monitor_mode = "max"
save_top_k = 1
save_last = True
check_val_every_n_epoch = 1
pretrained_ckpt_path = None
gpus = "auto"
resume_ckpt_path = None
net = UNetFormer(num_classes=num_classes)
loss = UnetFormerLoss(ignore_index=ignore_index)
use_aux_loss = True

3.4 loss function

UNetFormer's loss function consists of SoftCrossEntropyLoss cross entropy loss function and DiceLoss loss function

3.4.1 SoftCrossEntropyLoss

class SoftCrossEntropyLoss(paddle.nn.Layer):
    """
    Drop-in replacement for nn.CrossEntropyLoss with few additions:
    - Support of label smoothing
    """

    __constants__ = ["reduction", "ignore_index", "smooth_factor"]

    def __init__(
        self,
        reduction: str = "mean",
        smooth_factor: float = 0.0,
        ignore_index: Optional[int] = -100,
        dim=1,
    ):
        super().__init__()
        self.smooth_factor = smooth_factor
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.dim = dim

    def forward(self, input: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
        log_prob = paddle.nn.functional.log_softmax(x=input, axis=self.dim)
        return label_smoothed_nll_loss(
            log_prob,
            target,
            epsilon=self.smooth_factor,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            dim=self.dim,
        )

3.4.2 DiceLoss

class DiceLoss(paddle.nn.Layer):
    """
    Implementation of Dice loss for image segmentation task.
    It supports binary, multiclass and multilabel cases
    """

    def __init__(
        self,
        mode: str = "multiclass",
        classes: List[int] = None,
        log_loss=False,
        from_logits=True,
        smooth: float = 0.0,
        ignore_index=None,
        eps=1e-07,
    ):
        """

        :param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
        :param classes: Optional list of classes that contribute in loss computation;
        By default, all channels are included.
        :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
        :param from_logits: If True assumes input is raw logits
        :param smooth:
        :param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
        :param eps: Small epsilon for numerical stability
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(DiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert (
                mode != BINARY_MODE
            ), "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype="int64")
        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.ignore_index = ignore_index
        self.log_loss = log_loss

    def forward(self, y_pred: paddle.Tensor, y_true: paddle.Tensor) -> paddle.Tensor:
        """

        :param y_pred: NxCxHxW
        :param y_true: NxHxW
        :return: scalar
        """
        assert y_true.shape[0] == y_pred.shape[0]
        if self.from_logits:
            if self.mode == MULTICLASS_MODE:
                y_pred = paddle.nn.functional.log_softmax(y_pred, axis=1).exp()
            else:
                y_pred = paddle.nn.functional.log_sigmoid(x=y_pred).exp()
        bs = y_true.shape[0]
        num_classes = y_pred.shape[1]
        dims = 0, 2
        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)
            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * paddle.cast(mask, dtype="float32")
                y_true = y_true * paddle.cast(mask, dtype="float32")
        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)
            if self.ignore_index is not None:
                if self.ignore_index is not None:
                    mask = y_true != self.ignore_index
                    mask = paddle.cast(mask, dtype="float32")
                    y_pred = paddle.cast(
                        y_pred * mask.unsqueeze(axis=1), dtype="float32"
                    )
                    mask_float = paddle.cast(mask, dtype=y_true.dtype)
                    masked_y_true = (y_true * mask_float).astype("int64")
                    y_true = paddle.nn.functional.one_hot(
                        num_classes=num_classes, x=masked_y_true
                    ).astype("int64")
                    mask = paddle.cast(mask, dtype="int64")
                    y_true = y_true.transpose(perm=[0, 2, 1]) * mask.unsqueeze(axis=1)
            else:
                y_true = paddle.nn.functional.one_hot(
                    num_classes=num_classes, x=y_true
                ).astype("int64")
                y_true = y_true.transpose(perm=[0, 2, 1])
        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)
            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * paddle.cast(mask, dtype="float32")
                y_true = y_true * paddle.cast(mask, dtype="float32")
        scores = soft_dice_score(
            y_pred,
            y_true.astype(dtype=y_pred.dtype),
            smooth=self.smooth,
            eps=self.eps,
            dims=dims,
        )
        if self.log_loss:
            loss = -paddle.log(x=scores.clip(min=self.eps))
        else:
            loss = 1.0 - scores
        mask = y_true.sum(axis=dims) > 0
        loss *= mask.astype(loss.dtype)
        if self.classes is not None:
            loss = loss[self.classes]
        return loss.mean()

3.4.3 JointLoss

SoftCrossEntropyLoss and DiceLoss will be combined using JointLoss

class JointLoss(paddle.nn.Layer):
    """
    Wrap two loss functions into one. This class computes a weighted sum of two losses.
    """

    def __init__(
        self,
        first: paddle.nn.Layer,
        second: paddle.nn.Layer,
        first_weight=1.0,
        second_weight=1.0,
    ):
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input)

3.4.4 UNetFormerLoss

class UnetFormerLoss(paddle.nn.Layer):
    def __init__(self, ignore_index=255):
        super().__init__()
        self.main_loss = JointLoss(
            SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index),
            DiceLoss(smooth=0.05, ignore_index=ignore_index),
            1.0,
            1.0,
        )
        self.aux_loss = SoftCrossEntropyLoss(
            smooth_factor=0.05, ignore_index=ignore_index
        )

    def forward(self, logits, labels):
        if self.training and len(logits) == 2:
            logit_main, logit_aux = logits
            loss = self.main_loss(logit_main, labels) + 0.4 * self.aux_loss(
                logit_aux, labels
            )
        else:
            loss = self.main_loss(logits, labels)
        return loss

3.5 Optimizer Construction

UNetFormer uses AdamW optimizer, which can be directly constructed by calling paddle.optimizer.AdamW, code as follows:

layerwise_params = {
    "backbone.*": dict(lr=backbone_lr, weight_decay=backbone_weight_decay)
}
net_params = process_model_params(net, layerwise_params=layerwise_params)
optimizer = paddle.optimizer.AdamW(
    parameters=net_params, learning_rate=lr, weight_decay=weight_decay
)
tmp_lr = paddle.optimizer.lr.CosineAnnealingWarmRestarts(
    T_0=15, T_mult=2, learning_rate=optimizer.get_lr()
)
optimizer.set_lr_scheduler(tmp_lr)
lr_scheduler = tmp_lr

3.6 Model Training

    checkpoint_callback = ModelCheckpoint(
        save_top_k=config.save_top_k,
        monitor=config.monitor,
        save_last=config.save_last,
        mode=config.monitor_mode,
        dirpath=config.weights_path,
        filename=config.weights_name,
    )

    logger = CSVLogger("lightning_logs", name=config.log_name)

    model = Supervision_Train(config)

    if config.pretrained_ckpt_path:
        state_dict = paddle.load(config.pretrained_ckpt_path)
        model.set_state_dict(state_dict)

    paddle.set_device("gpu")

    optimizer, lr_scheduler = model.configure_optimizers()

    train_loader = model.train_dataloader()
    val_loader = model.val_dataloader()

    for epoch in range(config.max_epoch):
        print(f"Epoch {epoch+1}/{config.max_epoch}")
        model.train()
        train_losses = []
        for batch_idx, batch in enumerate(train_loader):
            output = model.training_step(batch, batch_idx)
            loss = output["loss"]
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if batch_idx % 10 == 0:
                print(
                    f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}"
                )

        train_log = model.on_train_epoch_end()
        train_log["loss"] = np.mean(train_losses)
        if (epoch + 1) % config.check_val_every_n_epoch == 0:
            model.eval()
            val_losses = []
            for batch_idx, batch in enumerate(val_loader):
                output = model.validation_step(batch, batch_idx)
                val_losses.append(output["loss_val"].item())
            val_log = model.on_validation_epoch_end()
            val_log["loss_val"] = np.mean(val_losses)
            checkpoint_callback.on_validation_epoch_end(None, model, val_log)
            logger.log_metrics(epoch, train_log, val_log)
        if lr_scheduler:
            lr_scheduler.step()
        if config.resume_ckpt_path and epoch == 0:
            state = paddle.load(config.resume_ckpt_path)
            model.set_state_dict(state["model_state_dict"])
            optimizer.set_state_dict(state["optimizer_state_dict"])
            if lr_scheduler and "lr_scheduler_state_dict" in state:
                lr_scheduler.set_state_dict(state["lr_scheduler_state_dict"])
            print(f"Resumed training from checkpoint: {config.resume_ckpt_path}")


if __name__ == "__main__":
    main()

3.7 Model Testing

def main():
    seed_everything(42)
    args = get_args()
    config = py2cfg(args.config_path)
    args.output_path.mkdir(exist_ok=True, parents=True)
    model = Supervision_Train.load_from_checkpoint(
        os.path.join(config.weights_path, config.test_weights_name + ".pdparams"),
        config=config,
    )
    model.eval()
    evaluator = Evaluator(num_class=config.num_classes)
    evaluator.reset()
    test_dataset = config.test_dataset
    test_loader = paddle.io.DataLoader(
        dataset=test_dataset,
        batch_size=2,
        num_workers=4,
        drop_last=False,
        shuffle=False,
    )

    results = []
    with paddle.no_grad():
        for batch in tqdm(test_loader):
            images = batch["img"]
            images = images.astype("float32")
            raw_predictions = model(images)

            raw_predictions = paddle.nn.functional.softmax(raw_predictions, axis=1)
            predictions = raw_predictions.argmax(axis=1)

            image_ids = batch["img_id"]
            masks_true = batch["gt_semantic_seg"]

            for i in range(len(image_ids)):
                mask = predictions[i].numpy()
                evaluator.add_batch(pre_image=mask, gt_image=masks_true[i].numpy())
                mask_name = image_ids[i]
                results.append((mask, str(args.output_path / mask_name), args.rgb))

    iou_per_class = evaluator.Intersection_over_Union()
    f1_per_class = evaluator.F1()
    OA = evaluator.OA()

    for class_name, class_iou, class_f1 in zip(
        config.classes, iou_per_class, f1_per_class
    ):
        print(f"F1_{class_name}: {class_f1:.4f}, IOU_{class_name}: {class_iou:.4f}")

    print(
        f"F1: {np.nanmean(f1_per_class[:-1]):.4f}, "
        f"mIOU: {np.nanmean(iou_per_class[:-1]):.4f}, "
        f"OA: {OA:.4f}"
    )

    t0 = time.time()
    with mp.Pool(processes=mp.cpu_count()) as pool:
        pool.map(img_writer, results)
    t1 = time.time()
    print(f"Images writing time: {t1 - t0:.2f} seconds")

4. Result Display

Training results using Vaihingen dataset.

F1 mIOU OA
0.9062 0.8318 0.9283

image-vaihingen1

image-vaihingen2

Comparing the two pictures, it can be seen that the model has accurately segmented the contours of buildings, trees, cars and other objects in remote sensing images, and handled overlapping areas well.

6. References