跳转至

Visualize(可视化) 模块

ppsci.visualize

Visualizer

Base class for visualizer.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py.

required
num_timestamps int

Number of timestamps.

required
prefix str

Prefix for output file.

required
Source code in ppsci/visualize/base.py
class Visualizer:
    """Base class for visualizer.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int): Batch size of data when computing result in visu.py.
        num_timestamps (int): Number of timestamps.
        prefix (str): Prefix for output file.
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int,
        num_timestamps: int,
        prefix: str,
    ):
        self.input_dict = input_dict
        self.input_keys = tuple(input_dict.keys())
        self.output_expr = output_expr
        self.output_keys = tuple(output_expr.keys())
        self.batch_size = batch_size
        self.num_timestamps = num_timestamps
        self.prefix = prefix

    @abc.abstractmethod
    def save(self, data_dict):
        """Visualize result from data_dict and save as files"""

    def __str__(self):
        return ", ".join(
            [
                f"input_keys: {self.input_keys}",
                f"output_keys: {self.output_keys}",
                f"output_expr: {self.output_expr}",
                f"batch_size: {self.batch_size}",
                f"num_timestamps: {self.num_timestamps}",
                f"output file prefix: {self.prefix}",
            ]
        )
save(data_dict) abstractmethod

Visualize result from data_dict and save as files

Source code in ppsci/visualize/base.py
@abc.abstractmethod
def save(self, data_dict):
    """Visualize result from data_dict and save as files"""

VisualizerScatter1D

Bases: Visualizer

Visualizer for 1d scatter data.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
coord_keys Tuple[str, ...]

Coordinate keys, such as ("x", "y").

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps. Defaults to 1.

1
prefix str

Prefix for output file. Defaults to "plot".

'plot'

Examples:

>>> import ppsci
>>> visu_mat = {"t_f": np.random.randn(16, 1), "eta": np.random.randn(16, 1)}
>>> visualizer_eta = ppsci.visualize.VisualizerScatter1D(
...     visu_mat,
...     ("t_f",),
...     {"eta": lambda d: d["eta"]},
...     num_timestamps=1,
...     prefix="viv_pred",
... )
Source code in ppsci/visualize/visualizer.py
class VisualizerScatter1D(base.Visualizer):
    """Visualizer for 1d scatter data.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        coord_keys (Tuple[str, ...]): Coordinate keys, such as ("x", "y").
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps. Defaults to 1.
        prefix (str, optional): Prefix for output file. Defaults to "plot".

    Examples:
        >>> import ppsci
        >>> visu_mat = {"t_f": np.random.randn(16, 1), "eta": np.random.randn(16, 1)}
        >>> visualizer_eta = ppsci.visualize.VisualizerScatter1D(
        ...     visu_mat,
        ...     ("t_f",),
        ...     {"eta": lambda d: d["eta"]},
        ...     num_timestamps=1,
        ...     prefix="viv_pred",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        coord_keys: Tuple[str, ...],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        prefix: str = "plot",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
        self.coord_keys = coord_keys

    def save(self, filename, data_dict):
        plot.save_plot_from_1d_dict(
            filename, data_dict, self.coord_keys, self.output_keys, self.num_timestamps
        )

VisualizerScatter3D

Bases: Visualizer

Visualizer for 3d scatter data.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps. Defaults to 1.

1
prefix str

Prefix for output file. Defaults to "plot3d_scatter".

'plot3d_scatter'

Examples:

>>> import ppsci
>>> vis_data = {"states": np.random.randn(16, 1)}
>>> visualizer = ppsci.visualize.VisualizerScatter3D(
...     vis_data,
...     {"states": lambda d: d["states"]},
...     num_timestamps=1,
...     prefix="result_states",
... )
Source code in ppsci/visualize/visualizer.py
class VisualizerScatter3D(base.Visualizer):
    """Visualizer for 3d scatter data.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps. Defaults to 1.
        prefix (str, optional): Prefix for output file. Defaults to "plot3d_scatter".

    Examples:
        >>> import ppsci
        >>> vis_data = {"states": np.random.randn(16, 1)}
        >>> visualizer = ppsci.visualize.VisualizerScatter3D(
        ...     vis_data,
        ...     {"states": lambda d: d["states"]},
        ...     num_timestamps=1,
        ...     prefix="result_states",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        prefix: str = "plot3d_scatter",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)

    def save(self, filename, data_dict):
        data_dict = {
            key: value for key, value in data_dict.items() if key in self.output_keys
        }
        value = data_dict[self.output_keys[0]]
        dim = len(value.shape)
        if dim == 3:
            # value.shape=(B, T, 3)
            for i in range(value.shape[0]):
                cur_data_dict = {key: value[i] for key, value in data_dict.items()}
                plot.save_plot_from_3d_dict(
                    filename + str(i),
                    cur_data_dict,
                    self.output_keys,
                    self.num_timestamps,
                )
        else:
            # value.shape=(T, 3)
            plot.save_plot_from_3d_dict(
                filename, data_dict, self.output_keys, self.num_timestamps
            )

VisualizerVtu

Bases: Visualizer

Visualizer for 2D points data.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps

1
prefix str

Prefix for output file.

'vtu'

Examples:

>>> import ppsci
>>> vis_points = {
...     "x": np.random.randn(128, 1),
...     "y": np.random.randn(128, 1),
...     "u": np.random.randn(128, 1),
...     "v": np.random.randn(128, 1),
... }
>>> visualizer_u_v =  ppsci.visualize.VisualizerVtu(
...     vis_points,
...     {"u": lambda d: d["u"], "v": lambda d: d["v"]},
...     num_timestamps=1,
...     prefix="result_u_v",
... )
Source code in ppsci/visualize/visualizer.py
class VisualizerVtu(base.Visualizer):
    """Visualizer for 2D points data.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps
        prefix (str, optional): Prefix for output file.

    Examples:
        >>> import ppsci
        >>> vis_points = {
        ...     "x": np.random.randn(128, 1),
        ...     "y": np.random.randn(128, 1),
        ...     "u": np.random.randn(128, 1),
        ...     "v": np.random.randn(128, 1),
        ... }
        >>> visualizer_u_v =  ppsci.visualize.VisualizerVtu(
        ...     vis_points,
        ...     {"u": lambda d: d["u"], "v": lambda d: d["v"]},
        ...     num_timestamps=1,
        ...     prefix="result_u_v",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        prefix: str = "vtu",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)

    def save(self, filename, data_dict):
        vtu.save_vtu_from_dict(
            filename, data_dict, self.input_keys, self.output_keys, self.num_timestamps
        )

Visualizer2D

Bases: Visualizer

Visualizer for 2D data.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps. Defaults to 1.

1
prefix str

Prefix for output file. Defaults to "plot2d".

'plot2d'

Examples:

>>> import ppsci
>>> vis_points = {
...     "x": np.random.randn(128, 1),
...     "y": np.random.randn(128, 1),
...     "u": np.random.randn(128, 1),
...     "v": np.random.randn(128, 1),
... }
>>> visualizer_u_v = ppsci.visualize.Visualizer2D(
...     vis_points,
...     {"u": lambda d: d["u"], "v": lambda d: d["v"]},
...     num_timestamps=1,
...     prefix="result_u_v",
... )
Source code in ppsci/visualize/visualizer.py
class Visualizer2D(base.Visualizer):
    """Visualizer for 2D data.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps. Defaults to 1.
        prefix (str, optional): Prefix for output file. Defaults to "plot2d".

    Examples:
        >>> import ppsci
        >>> vis_points = {
        ...     "x": np.random.randn(128, 1),
        ...     "y": np.random.randn(128, 1),
        ...     "u": np.random.randn(128, 1),
        ...     "v": np.random.randn(128, 1),
        ... }
        >>> visualizer_u_v = ppsci.visualize.Visualizer2D(
        ...     vis_points,
        ...     {"u": lambda d: d["u"], "v": lambda d: d["v"]},
        ...     num_timestamps=1,
        ...     prefix="result_u_v",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        prefix: str = "plot2d",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)

Visualizer2DPlot

Bases: Visualizer2D

Visualizer for 2D data use matplotlib.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps.

1
stride int

The time stride of visualization. Defaults to 1.

1
xticks Optional[Tuple[float, ...]]

The list of xtick locations. Defaults to None.

None
yticks Optional[Tuple[float, ...]]

The list of ytick locations. Defaults to None.

None
prefix str

Prefix for output file. Defaults to "plot2d".

'plot2d'

Examples:

>>> import ppsci
>>> vis_data = {
...     "target_ux": np.random.randn(128, 20, 1),
...     "pred_ux": np.random.randn(128, 20, 1),
... }
>>> visualizer_states = ppsci.visualize.Visualizer2DPlot(
...     vis_data,
...     {
...         "target_ux": lambda d: d["states"][:, :, 0],
...         "pred_ux": lambda d: output_transform(d)[:, :, 0],
...     },
...     batch_size=1,
...     num_timestamps=10,
...     stride=20,
...     xticks=np.linspace(-2, 14, 9),
...     yticks=np.linspace(-4, 4, 5),
...     prefix="result_states",
... )
Source code in ppsci/visualize/visualizer.py
class Visualizer2DPlot(Visualizer2D):
    """Visualizer for 2D data use matplotlib.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps.
        stride (int, optional): The time stride of visualization. Defaults to 1.
        xticks (Optional[Tuple[float,...]]): The list of xtick locations. Defaults to None.
        yticks (Optional[Tuple[float,...]]): The list of ytick locations. Defaults to None.
        prefix (str, optional): Prefix for output file. Defaults to "plot2d".

    Examples:
        >>> import ppsci
        >>> vis_data = {
        ...     "target_ux": np.random.randn(128, 20, 1),
        ...     "pred_ux": np.random.randn(128, 20, 1),
        ... }
        >>> visualizer_states = ppsci.visualize.Visualizer2DPlot(
        ...     vis_data,
        ...     {
        ...         "target_ux": lambda d: d["states"][:, :, 0],
        ...         "pred_ux": lambda d: output_transform(d)[:, :, 0],
        ...     },
        ...     batch_size=1,
        ...     num_timestamps=10,
        ...     stride=20,
        ...     xticks=np.linspace(-2, 14, 9),
        ...     yticks=np.linspace(-4, 4, 5),
        ...     prefix="result_states",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        stride: int = 1,
        xticks: Optional[Tuple[float, ...]] = None,
        yticks: Optional[Tuple[float, ...]] = None,
        prefix: str = "plot2d",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
        self.stride = stride
        self.xticks = xticks
        self.yticks = yticks

    def save(self, filename, data_dict):
        data_dict = {
            key: value for key, value in data_dict.items() if key in self.output_keys
        }
        value = data_dict[self.output_keys[0]]
        dim = len(value.shape)
        if dim == 4:
            # value.shape=(B, T, H, W)
            for i in range(value.shape[0]):
                cur_data_dict = {key: value[i] for key, value in data_dict.items()}
                plot.save_plot_from_2d_dict(
                    filename + str(i),
                    cur_data_dict,
                    self.output_keys,
                    self.num_timestamps,
                    self.stride,
                    self.xticks,
                    self.yticks,
                )
        else:
            # value.shape=(T, H, W)
            plot.save_plot_from_2d_dict(
                filename,
                data_dict,
                self.output_keys,
                self.num_timestamps,
                self.stride,
                self.xticks,
                self.yticks,
            )

Visualizer3D

Bases: Visualizer

Visualizer for 3D plot data.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
label_dict Dict[str, ndarray]

Label dict.

None
time_list Optional[Tuple[float, ...]]

Time list.

None
prefix str

Prefix for output file.

'vtu'
Source code in ppsci/visualize/visualizer.py
class Visualizer3D(base.Visualizer):
    """Visualizer for 3D plot data.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        label_dict (Dict[str, np.ndarray]): Label dict.
        time_list (Optional[Tuple[float, ...]]): Time list.
        prefix (str, optional): Prefix for output file.
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        label_dict: Optional[Dict[str, np.ndarray]] = None,
        time_list: Optional[Tuple[float, ...]] = None,
        prefix: str = "vtu",
    ):
        self.label = label_dict
        self.time_list = time_list
        super().__init__(input_dict, output_expr, batch_size, len(time_list), prefix)

    def save(self, filename: str, data_dict: Dict[str, np.ndarray]):
        n = int((next(iter(data_dict.values()))).shape[0] / self.num_timestamps)
        coord_keys = [x for x in self.input_dict if x != "t"]
        for i in range(len(self.time_list)):
            vtu.save_vtu_to_mesh(
                osp.join(filename, f"predict_{i+1}.vtu"),
                {key: (data_dict[key][i * n : (i + 1) * n]) for key in data_dict},
                coord_keys,
                self.output_keys,
            )

VisualizerWeather

Bases: Visualizer

Visualizer for weather data use matplotlib.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
xticks Tuple[float, ...]

The list of xtick locations.

required
xticklabels Tuple[str, ...]

The x-axis' tick labels.

required
yticks Tuple[float, ...]

The list of ytick locations.

required
yticklabels Tuple[str, ...]

The y-axis' tick labels.

required
vmin float

Minimum value that the colormap covers.

required
vmax float

Maximal value that the colormap covers.

required
colorbar_label str

The color-bar label. Defaults to "".

''
log_norm bool

Whether use log norm. Defaults to False.

False
batch_size int

: Batch size of data when computing result in visu.py. Defaults to 1.

1
num_timestamps int

Number of timestamps. Defaults to 1.

1
prefix str

Prefix for output file. Defaults to "plot_weather".

'plot_weather'

Examples:

>>> import ppsci
>>> import numpy as np
>>> vis_data = {
...     "output_6h": np.random.randn(1, 720, 1440),
...     "target_6h": np.random.randn(1, 720, 1440),
... }
>>> visualizer_weather = ppsci.visualize.VisualizerWeather(
...     vis_data,
...     {
...         "output_6h": lambda d: d["output_6h"],
...         "target_6h": lambda d: d["target_6h"],
...     },
...     xticks=np.linspace(0, 1439, 13),
...     xticklabels=[str(i) for i in range(360, -1, -30)],
...     yticks=np.linspace(0, 719, 7),
...     yticklabels=[str(i) for i in range(90, -91, -30)],
...     vmin=0,
...     vmax=25,
...     prefix="result_states",
... )
Source code in ppsci/visualize/visualizer.py
class VisualizerWeather(base.Visualizer):
    """Visualizer for weather data use matplotlib.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        xticks (Tuple[float, ...]): The list of xtick locations.
        xticklabels (Tuple[str, ...]): The x-axis' tick labels.
        yticks (Tuple[float, ...]): The list of ytick locations.
        yticklabels (Tuple[str, ...]): The y-axis' tick labels.
        vmin (float): Minimum value that the colormap covers.
        vmax (float): Maximal value that the colormap covers.
        colorbar_label (str, optional): The color-bar label. Defaults to "".
        log_norm (bool, optional): Whether use log norm. Defaults to False.
        batch_size (int, optional): : Batch size of data when computing result in visu.py. Defaults to 1.
        num_timestamps (int, optional): Number of timestamps. Defaults to 1.
        prefix (str, optional): Prefix for output file. Defaults to "plot_weather".

    Examples:
        >>> import ppsci
        >>> import numpy as np
        >>> vis_data = {
        ...     "output_6h": np.random.randn(1, 720, 1440),
        ...     "target_6h": np.random.randn(1, 720, 1440),
        ... }
        >>> visualizer_weather = ppsci.visualize.VisualizerWeather(
        ...     vis_data,
        ...     {
        ...         "output_6h": lambda d: d["output_6h"],
        ...         "target_6h": lambda d: d["target_6h"],
        ...     },
        ...     xticks=np.linspace(0, 1439, 13),
        ...     xticklabels=[str(i) for i in range(360, -1, -30)],
        ...     yticks=np.linspace(0, 719, 7),
        ...     yticklabels=[str(i) for i in range(90, -91, -30)],
        ...     vmin=0,
        ...     vmax=25,
        ...     prefix="result_states",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        xticks: Tuple[float, ...],
        xticklabels: Tuple[str, ...],
        yticks: Tuple[float, ...],
        yticklabels: Tuple[str, ...],
        vmin: float,
        vmax: float,
        colorbar_label: str = "",
        log_norm: bool = False,
        batch_size: int = 1,
        num_timestamps: int = 1,
        prefix: str = "plot_weather",
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
        self.xticks = xticks
        self.xticklabels = xticklabels
        self.yticks = yticks
        self.yticklabels = yticklabels
        self.vmin = vmin
        self.vmax = vmax
        self.colorbar_label = colorbar_label
        self.log_norm = log_norm

    def save(self, filename, data_dict):
        data_dict = {key: data_dict[key] for key in self.output_keys}
        value = data_dict[self.output_keys[0]]
        # value.shape=(B, H, W)
        for i in range(value.shape[0]):
            cur_data_dict = {key: value[i] for key, value in data_dict.items()}
            plot.save_plot_weather_from_dict(
                filename + str(i),
                cur_data_dict,
                self.output_keys,
                self.xticks,
                self.xticklabels,
                self.yticks,
                self.yticklabels,
                self.vmin,
                self.vmax,
                self.colorbar_label,
                self.log_norm,
                self.num_timestamps,
            )

VisualizerRadar

Bases: Visualizer

Visualizer for NowcastNet Radar Dataset.

Parameters:

Name Type Description Default
input_dict Dict[str, ndarray]

Input dict.

required
output_expr Dict[str, Callable]

Output expression.

required
batch_size int

Batch size of data when computing result in visu.py. Defaults to 64.

64
num_timestamps int

Number of timestamps

1
prefix str

Prefix for output file.

'vtu'
case_type str

Case type.

'normal'
total_length str

Total length.

29

Examples:

>>> import ppsci
>>> frames_tensor = paddle.randn([1, 29, 512, 512, 2])
>>> visualizer =  ppsci.visualize.VisualizerRadar(
...     {"input": frames_tensor},
...     {"output": lambda out: out["output"]},
...     num_timestamps=1,
...     prefix="v_nowcastnet",
... )
Source code in ppsci/visualize/radar.py
class VisualizerRadar(base.Visualizer):
    """Visualizer for NowcastNet Radar Dataset.

    Args:
        input_dict (Dict[str, np.ndarray]): Input dict.
        output_expr (Dict[str, Callable]): Output expression.
        batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
        num_timestamps (int, optional): Number of timestamps
        prefix (str, optional): Prefix for output file.
        case_type (str, optional): Case type.
        total_length (str, optional): Total length.

    Examples:
        >>> import ppsci
        >>> frames_tensor = paddle.randn([1, 29, 512, 512, 2])
        >>> visualizer =  ppsci.visualize.VisualizerRadar(
        ...     {"input": frames_tensor},
        ...     {"output": lambda out: out["output"]},
        ...     num_timestamps=1,
        ...     prefix="v_nowcastnet",
        ... )
    """

    def __init__(
        self,
        input_dict: Dict[str, np.ndarray],
        output_expr: Dict[str, Callable],
        batch_size: int = 64,
        num_timestamps: int = 1,
        prefix: str = "vtu",
        case_type: str = "normal",
        total_length: int = 29,
    ):
        super().__init__(input_dict, output_expr, batch_size, num_timestamps, prefix)
        self.case_type = case_type
        self.total_length = total_length
        self.input_dict = input_dict

    def save(self, path, data_dict):
        if not os.path.exists(path):
            os.makedirs(path)
        test_ims = self.input_dict[list(self.input_dict.keys())[0]]
        # keys: {"input", "output"}
        img_gen = data_dict[list(data_dict.keys())[1]]
        vis_info = {"vmin": 1, "vmax": 40}
        if self.case_type == "normal":
            test_ims_plot = test_ims[0][
                :-2, 256 - 192 : 256 + 192, 256 - 192 : 256 + 192
            ]
            img_gen_plot = img_gen[0][:-2, 256 - 192 : 256 + 192, 256 - 192 : 256 + 192]
        else:
            test_ims_plot = test_ims[0][:-2]
            img_gen_plot = img_gen[0][:-2]
        save_plots(
            test_ims_plot,
            labels=[f"gt{i + 1}" for i in range(self.total_length)],
            res_path=path,
            vmin=vis_info["vmin"],
            vmax=vis_info["vmax"],
        )
        save_plots(
            img_gen_plot,
            labels=[f"pd{i + 1}" for i in range(9, self.total_length)],
            res_path=path,
            vmin=vis_info["vmin"],
            vmax=vis_info["vmax"],
        )

save_vtu_from_dict(filename, data_dict, coord_keys, value_keys, num_timestamps=1)

Save dict data to '*.vtu' file.

Parameters:

Name Type Description Default
filename str

Output filename.

required
data_dict Dict[str, ndarray]

Data in dict.

required
coord_keys Tuple[str, ...]

Tuple of coord key. such as ("x", "y").

required
value_keys Tuple[str, ...]

Tuple of value key. such as ("u", "v").

required
num_timestamps int

Number of timestamp in data_dict. Defaults to 1.

1
Source code in ppsci/visualize/vtu.py
def save_vtu_from_dict(
    filename: str,
    data_dict: Dict[str, np.ndarray],
    coord_keys: Tuple[str, ...],
    value_keys: Tuple[str, ...],
    num_timestamps: int = 1,
):
    """Save dict data to '*.vtu' file.

    Args:
        filename (str): Output filename.
        data_dict (Dict[str, np.ndarray]): Data in dict.
        coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
        value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
        num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
    """
    if len(coord_keys) not in [2, 3, 4]:
        raise ValueError(f"ndim of coord ({len(coord_keys)}) should be 2, 3 or 4")

    coord = [data_dict[k] for k in coord_keys if k not in ("t", "sdf")]
    value = [data_dict[k] for k in value_keys] if value_keys else None

    coord = np.concatenate(coord, axis=1)

    if value is not None:
        value = np.concatenate(value, axis=1)

    _save_vtu_from_array(filename, coord, value, value_keys, num_timestamps)

save_vtu_to_mesh(filename, data_dict, coord_keys, value_keys)

Save data into .vtu format by meshio.

Parameters:

Name Type Description Default
filename str

File name.

required
data_dict Dict[str, ndarray]

Data in dict.

required
coord_keys Tuple[str, ...]

Tuple of coord key. such as ("x", "y").

required
value_keys Tuple[str, ...]

Tuple of value key. such as ("u", "v").

required
Source code in ppsci/visualize/vtu.py
def save_vtu_to_mesh(
    filename: str,
    data_dict: Dict[str, np.ndarray],
    coord_keys: Tuple[str, ...],
    value_keys: Tuple[str, ...],
):
    """Save data into .vtu format by meshio.

    Args:
        filename (str): File name.
        data_dict (Dict[str, np.ndarray]): Data in dict.
        coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
        value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
    """
    npoint = len(next(iter(data_dict.values())))
    coord_ndim = len(coord_keys)

    # get the list variable transposed
    points = np.stack((data_dict[key] for key in coord_keys)).reshape(
        coord_ndim, npoint
    )
    mesh = meshio.Mesh(
        points=points.T, cells=[("vertex", np.arange(npoint).reshape(npoint, 1))]
    )
    mesh.point_data = {key: data_dict[key] for key in value_keys}
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    mesh.write(filename)

save_plot_from_1d_dict(filename, data_dict, coord_keys, value_keys, num_timestamps=1)

Plot dict data as file.

Parameters:

Name Type Description Default
filename str

Output filename.

required
data_dict Dict[str, Union[ndarray, Tensor]]

Data in dict.

required
coord_keys Tuple[str, ...]

Tuple of coord key. such as ("x", "y").

required
value_keys Tuple[str, ...]

Tuple of value key. such as ("u", "v").

required
num_timestamps int

Number of timestamp in data_dict. Defaults to 1.

1
Source code in ppsci/visualize/plot.py
def save_plot_from_1d_dict(
    filename, data_dict, coord_keys, value_keys, num_timestamps=1
):
    """Plot dict data as file.

    Args:
        filename (str): Output filename.
        data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
        coord_keys (Tuple[str, ...]): Tuple of coord key. such as ("x", "y").
        value_keys (Tuple[str, ...]): Tuple of value key. such as ("u", "v").
        num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
    """
    space_ndim = len(coord_keys) - int("t" in coord_keys)
    if space_ndim not in [1, 2, 3]:
        raise ValueError(f"ndim of space coord ({space_ndim}) should be 1, 2 or 3")

    coord = [data_dict[k] for k in coord_keys if k != "t"]
    value = [data_dict[k] for k in value_keys] if value_keys else None

    if isinstance(coord[0], paddle.Tensor):
        coord = [x.numpy() for x in coord]
    else:
        coord = [x for x in coord]
    coord = np.concatenate(coord, axis=1)

    if value is not None:
        if isinstance(value[0], paddle.Tensor):
            value = [x.numpy() for x in value]
        else:
            value = [x for x in value]
        value = np.concatenate(value, axis=1)

    _save_plot_from_1d_array(filename, coord, value, value_keys, num_timestamps)

save_plot_from_3d_dict(filename, data_dict, visu_keys, num_timestamps=1)

Plot dict data as file.

Parameters:

Name Type Description Default
filename str

Output filename.

required
data_dict Dict[str, Union[ndarray, Tensor]]

Data in dict.

required
visu_keys Tuple[str, ...]

Keys for visualizing data. such as ("u", "v").

required
num_timestamps int

Number of timestamp in data_dict. Defaults to 1.

1
Source code in ppsci/visualize/plot.py
def save_plot_from_3d_dict(
    filename: str,
    data_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
    visu_keys: Tuple[str, ...],
    num_timestamps: int = 1,
):
    """Plot dict data as file.

    Args:
        filename (str): Output filename.
        data_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]): Data in dict.
        visu_keys (Tuple[str, ...]): Keys for visualizing data. such as ("u", "v").
        num_timestamps (int, optional): Number of timestamp in data_dict. Defaults to 1.
    """

    visu_data = [data_dict[k] for k in visu_keys]
    if isinstance(visu_data[0], paddle.Tensor):
        visu_data = [x.numpy() for x in visu_data]

    _save_plot_from_3d_array(filename, visu_data, visu_keys, num_timestamps)

最后更新: November 17, 2023
创建日期: November 6, 2023