Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pychunkedgraph/app/segmentation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,8 @@ def handle_subgraph(table_id, root_id, only_internal_edges=True):
supervoxels = np.concatenate(
[agg.supervoxels for agg in l2id_agglomeration_d.values()]
)
mask0 = np.in1d(edges.node_ids1, supervoxels)
mask1 = np.in1d(edges.node_ids2, supervoxels)
mask0 = np.isin(edges.node_ids1, supervoxels)
mask1 = np.isin(edges.node_ids2, supervoxels)
edges = edges[mask0 & mask1]

return edges
Expand Down
7 changes: 4 additions & 3 deletions pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel, unsupported-binary-operation

import time
import typing
import datetime
Expand Down Expand Up @@ -734,8 +733,8 @@ def get_l2_agglomerations(
else:
all_chunk_edges = all_chunk_edges.get_pairs()
supervoxels = self.get_children(level2_ids, flatten=True)
mask0 = np.in1d(all_chunk_edges[:, 0], supervoxels)
mask1 = np.in1d(all_chunk_edges[:, 1], supervoxels)
mask0 = np.isin(all_chunk_edges[:, 0], supervoxels)
mask1 = np.isin(all_chunk_edges[:, 1], supervoxels)
return all_chunk_edges[mask0 & mask1]

l2id_children_d = self.get_children(level2_ids)
Expand Down Expand Up @@ -807,6 +806,7 @@ def add_edges(
source_coords: typing.Sequence[int] = None,
sink_coords: typing.Sequence[int] = None,
allow_same_segment_merge: typing.Optional[bool] = False,
stitch_mode: typing.Optional[bool] = False,
) -> operation.GraphEditOperation.Result:
"""
Adds an edge to the chunkedgraph
Expand All @@ -823,6 +823,7 @@ def add_edges(
source_coords=source_coords,
sink_coords=sink_coords,
allow_same_segment_merge=allow_same_segment_merge,
stitch_mode=stitch_mode,
).execute()

def remove_edges(
Expand Down
4 changes: 2 additions & 2 deletions pychunkedgraph/graph/chunks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_chunk_coordinates_multiple(meta, ids: np.ndarray) -> np.ndarray:
y_offset = x_offset - bits_per_dim
z_offset = y_offset - bits_per_dim

ids = np.array(ids, dtype=int, copy=False)
ids = np.asarray(ids, dtype=int)
X = ids >> x_offset & 2**bits_per_dim - 1
Y = ids >> y_offset & 2**bits_per_dim - 1
Z = ids >> z_offset & 2**bits_per_dim - 1
Expand Down Expand Up @@ -154,7 +154,7 @@ def get_chunk_ids_from_node_ids(meta, ids: Iterable[np.uint64]) -> np.ndarray:
bits_per_dims = np.array([meta.bitmasks[l] for l in get_chunk_layers(meta, ids)])
offsets = 64 - meta.graph_config.LAYER_ID_BITS - 3 * bits_per_dims

ids = np.array(ids, dtype=int, copy=False)
ids = np.asarray(ids, dtype=int)
cids1 = np.array((ids >> offsets) << offsets, dtype=np.uint64)
# cids2 = np.vectorize(get_chunk_id)(meta, ids)
# assert np.all(cids1 == cids2)
Expand Down
48 changes: 24 additions & 24 deletions pychunkedgraph/graph/cutting.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def merge_cross_chunk_edges_graph_tool(
if len(mapping) > 0:
mapping = np.concatenate(mapping)
u_nodes = np.unique(edges)
u_unmapped_nodes = u_nodes[~np.in1d(u_nodes, mapping)]
u_unmapped_nodes = u_nodes[~np.isin(u_nodes, mapping)]
unmapped_mapping = np.concatenate(
[u_unmapped_nodes.reshape(-1, 1), u_unmapped_nodes.reshape(-1, 1)], axis=1
)
Expand Down Expand Up @@ -189,9 +189,9 @@ def _build_gt_graph(self, edges, affs):
) = flatgraph.build_gt_graph(comb_edges, comb_affs, make_directed=True)

self.source_graph_ids = np.where(
np.in1d(self.unique_supervoxel_ids, self.sources)
np.isin(self.unique_supervoxel_ids, self.sources)
)[0]
self.sink_graph_ids = np.where(np.in1d(self.unique_supervoxel_ids, self.sinks))[
self.sink_graph_ids = np.where(np.isin(self.unique_supervoxel_ids, self.sinks))[
0
]

Expand Down Expand Up @@ -395,10 +395,10 @@ def _remap_cut_edge_set(self, cut_edge_set):

remapped_cutset = np.array(remapped_cutset, dtype=np.uint64)

remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8")
edges_flattened_view = self.cg_edges.view(dtype="u8,u8")
remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8").ravel()
edges_flattened_view = self.cg_edges.view(dtype="u8,u8").ravel()

cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view)
cutset_mask = np.isin(remapped_cutset_flattened_view, edges_flattened_view)

return remapped_cutset[cutset_mask]

Expand Down Expand Up @@ -432,8 +432,8 @@ def _get_split_preview_connected_components(self, cut_edge_set):
max_sinks = 0
i = 0
for cc in ccs_test_post_cut:
num_sources = np.count_nonzero(np.in1d(self.source_graph_ids, cc))
num_sinks = np.count_nonzero(np.in1d(self.sink_graph_ids, cc))
num_sources = np.count_nonzero(np.isin(self.source_graph_ids, cc))
num_sinks = np.count_nonzero(np.isin(self.sink_graph_ids, cc))
if num_sources > max_sources:
max_sources = num_sources
max_source_index = i
Expand Down Expand Up @@ -486,8 +486,8 @@ def _filter_graph_connected_components(self):
# If connected component contains no sources or no sinks,
# remove its nodes from the mincut computation
if not (
np.any(np.in1d(self.source_graph_ids, cc))
and np.any(np.in1d(self.sink_graph_ids, cc))
np.any(np.isin(self.source_graph_ids, cc))
and np.any(np.isin(self.sink_graph_ids, cc))
):
for node_id in cc:
removed[node_id] = True
Expand Down Expand Up @@ -525,13 +525,13 @@ def _gt_mincut_sanity_check(self, partition):
np.array(np.where(partition.a == i_cc)[0], dtype=int)
]

if np.any(np.in1d(self.sources, cc_list)):
assert np.all(np.in1d(self.sources, cc_list))
assert ~np.any(np.in1d(self.sinks, cc_list))
if np.any(np.isin(self.sources, cc_list)):
assert np.all(np.isin(self.sources, cc_list))
assert ~np.any(np.isin(self.sinks, cc_list))

if np.any(np.in1d(self.sinks, cc_list)):
assert np.all(np.in1d(self.sinks, cc_list))
assert ~np.any(np.in1d(self.sources, cc_list))
if np.any(np.isin(self.sinks, cc_list)):
assert np.all(np.isin(self.sinks, cc_list))
assert ~np.any(np.isin(self.sources, cc_list))

def _sink_and_source_connectivity_sanity_check(self, cut_edge_set):
"""
Expand All @@ -555,19 +555,19 @@ def _sink_and_source_connectivity_sanity_check(self, cut_edge_set):
illegal_split = False
try:
for cc in ccs_test_post_cut:
if np.any(np.in1d(self.source_graph_ids, cc)):
assert np.all(np.in1d(self.source_graph_ids, cc))
assert ~np.any(np.in1d(self.sink_graph_ids, cc))
if np.any(np.isin(self.source_graph_ids, cc)):
assert np.all(np.isin(self.source_graph_ids, cc))
assert ~np.any(np.isin(self.sink_graph_ids, cc))
if (
len(self.source_path_vertices) == len(cc)
and self.disallow_isolating_cut
):
if not self.partition_edges_within_label(cc):
raise IsolatingCutException("Source")

if np.any(np.in1d(self.sink_graph_ids, cc)):
assert np.all(np.in1d(self.sink_graph_ids, cc))
assert ~np.any(np.in1d(self.source_graph_ids, cc))
if np.any(np.isin(self.sink_graph_ids, cc)):
assert np.all(np.isin(self.sink_graph_ids, cc))
assert ~np.any(np.isin(self.source_graph_ids, cc))
if (
len(self.sink_path_vertices) == len(cc)
and self.disallow_isolating_cut
Expand Down Expand Up @@ -664,8 +664,8 @@ def run_split_preview(
supervoxels = np.concatenate(
[agg.supervoxels for agg in l2id_agglomeration_d.values()]
)
mask0 = np.in1d(edges.node_ids1, supervoxels)
mask1 = np.in1d(edges.node_ids2, supervoxels)
mask0 = np.isin(edges.node_ids1, supervoxels)
mask1 = np.isin(edges.node_ids2, supervoxels)
edges = edges[mask0 & mask1]
edges_to_remove, illegal_split = run_multicut(
edges,
Expand Down
4 changes: 3 additions & 1 deletion pychunkedgraph/graph/edges/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def merge_cross_edge_dicts(x_edges_d1: Dict, x_edges_d2: Dict) -> Dict:
Combines two cross chunk dictionaries of form
{node_id: {layer id : edge list}}.
"""
node_ids = np.unique(list(x_edges_d1.keys()) + list(x_edges_d2.keys()))
node_ids = np.unique(
np.array(list(x_edges_d1.keys()) + list(x_edges_d2.keys()), dtype=basetypes.NODE_ID)
)
result_d = {}
for node_id in node_ids:
cross_edge_ds = [x_edges_d1.get(node_id, {}), x_edges_d2.get(node_id, {})]
Expand Down
6 changes: 3 additions & 3 deletions pychunkedgraph/graph/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def __init__(
self.privileged_mode = privileged_mode

def __enter__(self):
if not self.operation_id:
self.operation_id = self.cg.id_client.create_operation_id()

if self.privileged_mode:
assert self.operation_id is not None, "Please provide operation ID."
warn("Warning: Privileged mode without acquiring lock.")
return self
if not self.operation_id:
self.operation_id = self.cg.id_client.create_operation_id()

nodes_ts = self.cg.get_node_timestamps(self.root_ids, return_numpy=0)
min_ts = min(nodes_ts)
Expand Down
4 changes: 2 additions & 2 deletions pychunkedgraph/graph/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_contact_sites(
)

# Build area lookup dictionary
cs_svs = edges[~np.in1d(edges, sv_ids).reshape(-1, 2)]
cs_svs = edges[~np.isin(edges, sv_ids).reshape(-1, 2)]
area_dict = collections.defaultdict(int)

for area, sv_id in zip(areas, cs_svs):
Expand All @@ -165,7 +165,7 @@ def get_contact_sites(
cs_dict = collections.defaultdict(list)
for cc in ccs:
cc_sv_ids = unique_ids[cc]
cc_sv_ids = cc_sv_ids[np.in1d(cc_sv_ids, u_cs_svs)]
cc_sv_ids = cc_sv_ids[np.isin(cc_sv_ids, u_cs_svs)]
cs_areas = area_dict_vec(cc_sv_ids)
partner_root_id = (
int(cg.get_root(cc_sv_ids[0], time_stamp=time_stamp))
Expand Down
91 changes: 58 additions & 33 deletions pychunkedgraph/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import attributes
from .edges import Edges
from .edges.utils import get_edges_status
from .edits import get_profiler
from .utils import basetypes
from .utils import serializers
from .cache import CacheService
Expand Down Expand Up @@ -419,6 +420,7 @@ def execute(
op_type = "merge" if is_merge else "split"
self.parent_ts = parent_ts
root_ids = self._update_root_ids()
self.privileged_mode = self.privileged_mode or (is_merge and self.stitch_mode)
with locks.RootLock(
self.cg,
root_ids,
Expand Down Expand Up @@ -571,6 +573,7 @@ class MergeOperation(GraphEditOperation):
"affinities",
"bbox_offset",
"allow_same_segment_merge",
"stitch_mode",
]

def __init__(
Expand All @@ -584,13 +587,15 @@ def __init__(
bbox_offset: Tuple[int, int, int] = (240, 240, 24),
affinities: Optional[Sequence[np.float32]] = None,
allow_same_segment_merge: Optional[bool] = False,
stitch_mode: bool = False,
) -> None:
super().__init__(
cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords
)
self.added_edges = np.atleast_2d(added_edges).astype(basetypes.NODE_ID)
self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES)
self.allow_same_segment_merge = allow_same_segment_merge
self.stitch_mode = stitch_mode

self.affinities = None
if affinities is not None:
Expand All @@ -615,40 +620,55 @@ def _update_root_ids(self) -> np.ndarray:
def _apply(
self, *, operation_id, timestamp
) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]:
root_ids = set(
self.cg.get_roots(
self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts
profiler = get_profiler()

with profiler.profile("merge_apply_get_roots"):
root_ids = set(
self.cg.get_roots(
self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts
)
)
)
if len(root_ids) < 2 and not self.allow_same_segment_merge:
raise PreconditionError("Supervoxels must belong to different objects.")
bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset)
with TimeIt("subgraph", self.cg.graph_id, operation_id):
edges = self.cg.get_subgraph(
root_ids,
bbox=bbox,
bbox_is_coordinate=True,
edges_only=True,
raise PreconditionError(
"Supervoxels must belong to different objects."
f" Tried to merge {self.added_edges.ravel()},"
f" which all belong to {tuple(root_ids)[0]}."
)

if self.allow_same_segment_merge:
inactive_edges = types.empty_2d
else:
with TimeIt("preprocess", self.cg.graph_id, operation_id):
inactive_edges = edits.merge_preprocess(
atomic_edges = self.added_edges
fake_edge_rows = []
if not self.stitch_mode:
bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset)
with profiler.profile("get_subgraph"):
with TimeIt("subgraph", self.cg.graph_id, operation_id):
edges = self.cg.get_subgraph(
root_ids,
bbox=bbox,
bbox_is_coordinate=True,
edges_only=True,
)

if self.allow_same_segment_merge:
inactive_edges = types.empty_2d
else:
with profiler.profile("merge_preprocess"):
with TimeIt("preprocess", self.cg.graph_id, operation_id):
inactive_edges = edits.merge_preprocess(
self.cg,
subgraph_edges=edges,
supervoxels=self.added_edges.ravel(),
parent_ts=self.parent_ts,
)

with profiler.profile("check_fake_edges"):
atomic_edges, fake_edge_rows = edits.check_fake_edges(
self.cg,
subgraph_edges=edges,
supervoxels=self.added_edges.ravel(),
atomic_edges=self.added_edges,
inactive_edges=inactive_edges,
time_stamp=timestamp,
parent_ts=self.parent_ts,
)

atomic_edges, fake_edge_rows = edits.check_fake_edges(
self.cg,
atomic_edges=self.added_edges,
inactive_edges=inactive_edges,
time_stamp=timestamp,
parent_ts=self.parent_ts,
)
with TimeIt("add_edges", self.cg.graph_id, operation_id):
new_roots, new_l2_ids, new_entries = edits.add_edges(
self.cg,
Expand All @@ -657,6 +677,7 @@ def _apply(
time_stamp=timestamp,
parent_ts=self.parent_ts,
allow_same_segment_merge=self.allow_same_segment_merge,
stitch_mode=self.stitch_mode,
)
return new_roots, new_l2_ids, fake_edge_rows + new_entries

Expand Down Expand Up @@ -867,18 +888,20 @@ def __init__(
self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES)
self.path_augment = path_augment
self.disallow_isolating_cut = disallow_isolating_cut
if np.any(np.in1d(self.sink_ids, self.source_ids)):
if np.any(np.isin(self.sink_ids, self.source_ids)):
raise PreconditionError(
"Supervoxels exist in both sink and source, "
"try placing the points further apart."
)

ids = np.concatenate([self.source_ids, self.sink_ids])
ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID)
layers = self.cg.get_chunk_layers(ids)
assert np.sum(layers) == layers.size, "IDs must be supervoxels."

def _update_root_ids(self) -> np.ndarray:
sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids))
sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)).astype(
basetypes.NODE_ID
)
root_ids = np.unique(
self.cg.get_roots(
sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts
Expand All @@ -894,7 +917,9 @@ def _apply(
# Verify that sink and source are from the same root object
root_ids = set(
self.cg.get_roots(
np.concatenate([self.source_ids, self.sink_ids]),
np.concatenate([self.source_ids, self.sink_ids]).astype(
basetypes.NODE_ID
),
assert_roots=True,
time_stamp=self.parent_ts,
)
Expand All @@ -915,9 +940,9 @@ def _apply(
edges = reduce(lambda x, y: x + y, edges_tuple, Edges([], []))
supervoxels = np.concatenate(
[agg.supervoxels for agg in l2id_agglomeration_d.values()]
)
mask0 = np.in1d(edges.node_ids1, supervoxels)
mask1 = np.in1d(edges.node_ids2, supervoxels)
).astype(basetypes.NODE_ID)
mask0 = np.isin(edges.node_ids1, supervoxels)
mask1 = np.isin(edges.node_ids2, supervoxels)
edges = edges[mask0 & mask1]
if len(edges) == 0:
raise PreconditionError("No local edges found.")
Expand Down
Loading