Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions examples/link_prediction/graph_store/storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,14 @@
from distutils.util import strtobool
from typing import Literal, Optional

# TODO(kmonte): Remove GLT imports from this file.
import graphlearn_torch as glt
import torch

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.distributed.dataset_factory import build_dataset
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_range_partitioner import DistRangePartitioner
from gigl.distributed.graph_store.storage_utils import register_dataset
from gigl.distributed.dist_server import init_server, wait_and_shutdown_server
from gigl.distributed.utils import get_free_ports_from_master_node, get_graph_store_info
from gigl.distributed.utils.networking import get_free_ports_from_master_node
from gigl.distributed.utils.serialized_graph_metadata_translator import (
Expand All @@ -103,12 +101,11 @@ def _run_storage_process(

This function does the following:

1. "Registers" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it.
2. Initialized the GLT server.
1. Initializes the GiGL DistServer with the dataset.
Under the hood this is synchronized with the clients initializing via gigl.distributed.graph_store.compute.init_compute_process,
and after this call there will be Torch RPC connections between the storage nodes and compute nodes.
3. Initializes the Torch Distributed process group for the storage node.
4. Waits for the server to exit.
2. Initializes the Torch Distributed process group for the storage node.
3. Waits for the server to exit.
Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`)

Args:
Expand All @@ -118,16 +115,13 @@ def _run_storage_process(
torch_process_port (int): The port for the Torch process.
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
"""

# "Register" the dataset so that gigl.distributed.graph_store.remote_dist_dataset.RemoteDistDataset can access it.
register_dataset(dataset)
cluster_master_ip = cluster_info.storage_cluster_master_ip
logger.info(
f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}"
)
# Initialize the GLT server before starting the Torch Distributed process group.
# Otherwise, we saw intermittent hangs when initializing the server.
glt.distributed.init_server(
init_server(
num_servers=cluster_info.num_storage_nodes,
server_rank=storage_rank,
dataset=dataset,
Expand Down Expand Up @@ -156,7 +150,7 @@ def _run_storage_process(
)
# Wait for the server to exit.
# Will wait until clients are also shutdown (with `gigl.distributed.graph_store.compute.shutdown_compute_proccess`)
glt.distributed.wait_and_shutdown_server()
wait_and_shutdown_server()
logger.info(f"Storage node {storage_rank} exited")


Expand Down
4 changes: 2 additions & 2 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.dist_sampling_producer import DistSamplingProducer
from gigl.distributed.dist_sampling_producer import DistAblpSamplingProducer
from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
Expand Down Expand Up @@ -520,7 +520,7 @@ def __init__(
if self.worker_options.pin_memory:
self._channel.pin_memory()

self._mp_producer = DistSamplingProducer(
self._mp_producer = DistAblpSamplingProducer(
self.data,
self.input_data,
self.sampling_config,
Expand Down
2 changes: 1 addition & 1 deletion gigl/distributed/dist_sampling_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _sampling_worker_loop(
shutdown_rpc(graceful=False)


class DistSamplingProducer(DistMpSamplingProducer):
class DistAblpSamplingProducer(DistMpSamplingProducer):
def init(self):
r"""Create the subprocess pool. Init samplers and rpc server."""
if self.sampling_config.seed is not None:
Expand Down
Loading