From 03fb44733f901598072bb8ca2f6b1b68e9483d72 Mon Sep 17 00:00:00 2001 From: ms_yan Date: Fri, 22 Jul 2022 14:36:31 +0800 Subject: [PATCH] fix graph api example --- .../mindspore/dataset/engine/graphdata.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/mindspore/python/mindspore/dataset/engine/graphdata.py b/mindspore/python/mindspore/dataset/engine/graphdata.py index 50960e84be90..190690c39337 100644 --- a/mindspore/python/mindspore/dataset/engine/graphdata.py +++ b/mindspore/python/mindspore/dataset/engine/graphdata.py @@ -550,9 +550,9 @@ class Graph(GraphData): format should be dict, key is feature type, which is represented with string, value should be numpy.array, its shape is not restricted. node_type(Union[list, numpy.ndarray], optional): type of nodes, each element should be string which represent - type of corresponding node. If not provided, default type for each node is '0'. + type of corresponding node. If not provided, default type for each node is "0". edge_type(Union[list, numpy.ndarray], optional): type of edges, each element should be string which represent - type of corresponding edge. If not provided, default type for each edge is '0'. + type of corresponding edge. If not provided, default type for each edge is "0". num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None). working_mode (str, optional): Set working mode, now supports 'local'/'client'/'server' (default='local'). @@ -592,10 +592,13 @@ class Graph(GraphData): ValueError: If `num_client` is not in range [1, 255]. Examples: + >>> import numpy as np + >>> from mindspore.dataset import Graph + >>> >>> # 1) Only provide edges for creating graph, as this is the only required input parameter >>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32) >>> graph = Graph(edges) - >>> graph_info = g.graph_info() + >>> graph_info = graph.graph_info() >>> >>> # 2) Setting node_feat and edge_feat for corresponding node and edge >>> # first dimension of feature shape should be corresponding node num or edge num. @@ -623,9 +626,9 @@ class Graph(GraphData): if node_feat != dict(): num_nodes = node_feat.get(list(node_feat.keys())[0]).shape[0] - node_type = replace_none(node_type, np.array(['0'] * num_nodes)) + node_type = replace_none(node_type, np.array(["0"] * num_nodes)) node_type = np.array(node_type) - edge_type = replace_none(edge_type, np.array(['0'] * edges.shape[1])) + edge_type = replace_none(edge_type, np.array(["0"] * edges.shape[1])) edge_type = np.array(edge_type) self._working_mode = working_mode @@ -693,14 +696,14 @@ class Graph(GraphData): Get all edges in the graph. Args: - edge_type (str): Specify the type of edge, default edge_type is '0' when init graph without specify + edge_type (str): Specify the type of edge, default edge_type is "0" when init graph without specify edge_type. Returns: numpy.ndarray, array of edges. Examples: - >>> edges = graph.get_all_edges(edge_type='0') + >>> edges = graph.get_all_edges(edge_type="0") Raises: TypeError: If `edge_type` is not string. @@ -819,11 +822,11 @@ class Graph(GraphData): Examples: >>> from mindspore.dataset.engine import OutputFormat - >>> nodes = graph.get_all_nodes(node_type=1) - >>> neighbors = graph.get_all_neighbors(node_list=nodes, neighbor_type='0') - >>> neighbors_coo = graph.get_all_neighbors(node_list=nodes, neighbor_type='0', + >>> nodes = graph.get_all_nodes(node_type="0") + >>> neighbors = graph.get_all_neighbors(node_list=nodes, neighbor_type="0") + >>> neighbors_coo = graph.get_all_neighbors(node_list=nodes, neighbor_type="0", ... output_format=OutputFormat.COO) - >>> offset_table, neighbors_csr = graph.get_all_neighbors(node_list=nodes, neighbor_type='0', + >>> offset_table, neighbors_csr = graph.get_all_neighbors(node_list=nodes, neighbor_type="0", ... output_format=OutputFormat.CSR) Raises: @@ -869,9 +872,9 @@ class Graph(GraphData): numpy.ndarray, array of neighbors. Examples: - >>> nodes = graph.get_all_nodes(node_type=1) + >>> nodes = graph.get_all_nodes(node_type="0") >>> neighbors = graph.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2], - ... neighbor_types=[2, 1]) + ... neighbor_types=["0", "0"]) Raises: TypeError: If `node_list` is not list or ndarray. @@ -906,9 +909,9 @@ class Graph(GraphData): numpy.ndarray, array of neighbors. Examples: - >>> nodes = graph.get_all_nodes(node_type=1) - >>> neg_neighbors = graph.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5, - ... neg_neighbor_type='0') + >>> nodes = graph.get_all_nodes(node_type="0") + >>> neg_neighbors = graph.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=3, + ... neg_neighbor_type="0") Raises: TypeError: If `node_list` is not list or ndarray. @@ -937,8 +940,8 @@ class Graph(GraphData): numpy.ndarray, array of features. Examples: - >>> nodes = graph.get_all_nodes(node_type='0') - >>> features = graph.get_node_feature(node_list=nodes, feature_types=["feature_1", "feature_2"]) + >>> nodes = graph.get_all_nodes(node_type="0") + >>> features = graph.get_node_feature(node_list=nodes, feature_types=["node_feature_1"]) Raises: TypeError: If `node_list` is not list or ndarray. @@ -973,8 +976,8 @@ class Graph(GraphData): numpy.ndarray, array of features. Examples: - >>> edges = graph.get_all_edges(edge_type='0') - >>> features = graph.get_edge_feature(edge_list=edges, feature_types=["feature_1"]) + >>> edges = graph.get_all_edges(edge_type="0") + >>> features = graph.get_edge_feature(edge_list=edges, feature_types=["edge_feature_1"]) Raises: TypeError: If `edge_list` is not list or ndarray. @@ -1008,7 +1011,7 @@ class Graph(GraphData): numpy.ndarray, array of features. Examples: - >>> features = graph.get_graph_feature(feature_types=['feature_1', 'feature_2']) + >>> features = graph.get_graph_feature(feature_types=['graph_feature_1']) Raises: TypeError: If `feature_types` is not list or ndarray. -- Gitee