import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
from loguru import logger
from pyproj import CRS
from pyproj.aoi import AreaOfInterest
# pylint: disable=no-name-in-module
from pyproj.database import query_utm_crs_info
from shapely import LineString, MultiLineString, Point, Polygon, from_wkt, line_merge, node
from shapely.geometry.base import BaseGeometry
[docs]
def clip_nx_graph(graph: nx.Graph, polygon: Polygon) -> nx.Graph:
"""
Clip a NetworkX graph by a polygon and return the induced subgraph.
Builds a GeoDataFrame of node points using `graph.graph["crs"]`, clips it by the given
polygon, then returns the node-induced subgraph (keeping original node/edge attributes).
Parameters:
graph (nx.Graph): Graph with node coords stored as `x`, `y` and a CRS in `graph.graph["crs"]`.
polygon (Polygon): Clipping polygon in the same CRS as the graph.
Returns:
(nx.Graph): Subgraph containing only nodes whose points fall inside the polygon.
Notes:
Edges are preserved only if both endpoints remain in the subgraph.
"""
crs = graph.graph["crs"]
points = gpd.GeoDataFrame(
data=[{"id": p_id, "geometry": Point(data["x"], data["y"])} for p_id, data in graph.nodes(data=True)], crs=crs
).clip(polygon, True)
clipped = graph.subgraph(points["id"].tolist())
return clipped
def _fmt_top_sizes(sizes, top_k: int = 5) -> str:
ss = sorted(sizes, reverse=True)
if len(ss) <= top_k:
return "[" + ", ".join(map(str, ss)) + "]"
return "[" + ", ".join(map(str, ss[:top_k])) + ", …]"
[docs]
def keep_largest_strongly_connected_component(graph: nx.DiGraph, *, top_k_wcc_sizes: int = 5) -> nx.DiGraph:
"""
Keep only the largest strongly connected component of a directed graph.
Logs the sizes of weakly connected components (WCC) for visibility, then removes all
nodes outside the largest strongly connected component (SCC) and returns the pruned copy.
Parameters:
graph (nx.DiGraph): Directed graph to prune (a copy is made).
top_k_wcc_sizes (int): How many largest WCC sizes to show in the warning.
Returns:
(nx.DiGraph): Graph restricted to the largest SCC.
Notes:
- Uses `nx.weakly_connected_components` for a quick disconnectedness summary.
- Nodes from all SCCs except the largest are removed.
"""
graph = graph.copy()
weakly_connected_components = list(nx.weakly_connected_components(graph))
if len(weakly_connected_components) > 1:
sizes = [len(c) for c in weakly_connected_components]
logger.warning(
f"Graph contains {len(weakly_connected_components)} weakly connected components. "
f"This means the graph has disconnected groups if edge directions are ignored. "
f"Component sizes:: {_fmt_top_sizes(sizes, top_k=top_k_wcc_sizes)}"
)
all_scc = sorted(nx.strongly_connected_components(graph), key=len)
nodes_to_del = set().union(*all_scc[:-1])
if nodes_to_del:
logger.warning(
f"Removing {len(nodes_to_del)} nodes from {len(all_scc) - 1} smaller strongly connected components. "
f"These are subgraphs where nodes are internally reachable but isolated from the rest. "
f"Retaining only the largest strongly connected component ({len(all_scc[-1])} nodes)."
)
graph.remove_nodes_from(nodes_to_del)
return graph
def estimate_crs_for_bounds(minx, miny, maxx, maxy) -> CRS:
"""
Estimate a local UTM CRS for the given lon/lat bounds.
Uses the bounds center to query a suitable WGS84 UTM CRS via `pyproj.query_utm_crs_info`
and returns it as a `pyproj.CRS`.
Parameters:
minx (float): Western longitude.
miny (float): Southern latitude.
maxx (float): Eastern longitude.
maxy (float): Northern latitude.
Returns:
(pyproj.CRS): UTM CRS suited for metric distance calculations near the bounds center.
"""
x_center = np.mean([minx, maxx])
y_center = np.mean([miny, maxy])
utm_crs_list = query_utm_crs_info(
datum_name="WGS 84",
area_of_interest=AreaOfInterest(
west_lon_degree=x_center,
south_lat_degree=y_center,
east_lon_degree=x_center,
north_lat_degree=y_center,
),
)
crs = CRS.from_epsg(utm_crs_list[0].code)
logger.debug(f"Estimated CRS for territory {crs}")
return crs
def _edges_to_gdf(graph: nx.Graph, crs) -> gpd.GeoDataFrame:
"""
Converts nx graph to gpd.GeoDataFrame as edges.
"""
graph_df = pd.DataFrame(list(graph.edges(data=True)), columns=["u", "v", "data"])
edge_data_expanded = pd.json_normalize(graph_df["data"])
graph_df = pd.concat([graph_df.drop(columns=["data"]), edge_data_expanded], axis=1)
graph_df = gpd.GeoDataFrame(graph_df, geometry="geometry", crs=crs).set_index(["u", "v"])
graph_df["geometry"] = graph_df["geometry"].fillna(LineString())
return graph_df
def _nodes_to_gdf(graph: nx.Graph, crs: int) -> gpd.GeoDataFrame:
"""
Converts nx graph to gpd.GeoDataFrame as nodes.
"""
ind, data = zip(*graph.nodes(data=True))
node_geoms = (Point(d["x"], d["y"]) for d in data)
gdf_nodes = gpd.GeoDataFrame(data, index=ind, crs=crs, geometry=list(node_geoms))
return gdf_nodes
def _restore_edges_geom(nodes_gdf, edges_gdf) -> gpd.GeoDataFrame:
edges_wout_geom = edges_gdf[edges_gdf["geometry"].is_empty].reset_index()
edges_wout_geom["geometry"] = [
LineString((s, e))
for s, e in zip(
nodes_gdf.loc[edges_wout_geom["u"], "geometry"], nodes_gdf.loc[edges_wout_geom["v"], "geometry"]
)
]
edges_wout_geom.set_index(["u", "v"], inplace=True)
edges_gdf.update(edges_wout_geom)
return edges_gdf
[docs]
def graph_to_gdf(
graph: nx.MultiDiGraph, edges: bool = True, nodes: bool = True, restore_edge_geom=False
) -> gpd.GeoDataFrame | None:
"""
Convert a NetworkX graph to GeoDataFrames (edges and/or nodes).
Reads CRS from `graph.graph["crs"]`. Depending on flags, returns only nodes, only edges,
or a concatenation of both. Optionally reconstructs missing edge geometries from node points.
Parameters:
graph (nx.MultiDiGraph): Graph with node coords (`x`, `y`) and optional edge `geometry`.
edges (bool): If True, include edges.
nodes (bool): If True, include nodes.
restore_edge_geom (bool): If True, fill empty edge geometries from node coordinates.
Returns:
(gpd.GeoDataFrame | None): Nodes and/or edges as GeoDataFrame(s).
If both `edges` and `nodes` are False, returns None.
Raises:
ValueError: If `graph.graph["crs"]` is missing.
Notes:
- Edge GeoDataFrame uses MultiIndex `(u, v)`.
- When both are requested, the result is a vertical concat of nodes then edges.
"""
try:
crs = graph.graph["crs"]
except KeyError as exc:
raise ValueError("Graph does not have crs attribute and no crs was provided") from exc
if not edges and not nodes:
logger.debug("Neither edges or nodes were selected, graph_to_gdf returning None")
return None
if nodes and not edges:
nodes_gdf = _nodes_to_gdf(graph, crs)
return nodes_gdf
if not nodes and edges:
edges_gdf = _edges_to_gdf(graph, crs)
if restore_edge_geom:
nodes_gdf = _nodes_to_gdf(graph, crs)
edges_gdf = _restore_edges_geom(nodes_gdf, edges_gdf)
return edges_gdf
nodes_gdf = _nodes_to_gdf(graph, crs)
edges_gdf = _edges_to_gdf(graph, crs)
if restore_edge_geom:
edges_gdf = _restore_edges_geom(nodes_gdf, edges_gdf)
full_gdf = pd.concat([nodes_gdf, edges_gdf])
return full_gdf
[docs]
def gdf_to_graph(
gdf: gpd.GeoDataFrame, project_gdf_attr=True, reproject_to_utm_crs=True, speed=5, check_intersections=True
) -> nx.DiGraph:
"""
Convert a GeoDataFrame of LineStrings into a directed graph (nx.DiGraph).
Explodes multilines, optionally enforces topological intersections, merges collinear segments,
transfers selected attributes back to merged lines via centroid-buffer spatial join,
and constructs a directed graph whose edges correspond to line segments. Lengths are computed
in meters in a local metric CRS; travel time uses a provided speed.
Parameters:
gdf (gpd.GeoDataFrame): Input with LineString geometries (other types are filtered out).
project_gdf_attr (bool): If True, projects original attributes to merged lines via nearest overlay.
reproject_to_utm_crs (bool): If True, lengths computed in UTM and optionally reprojected back.
speed (float): Speed in km/h used to compute `time_min` for each edge.
check_intersections (bool): If True, uses `shapely.node` before `line_merge` to enforce proper splits.
Returns:
(nx.DiGraph): Directed graph with:
- node attributes: `x`, `y`;
- edge attributes: `geometry`, `length_meter`, `time_min`, plus projected attributes;
Graph attribute `graph["crs"]` is set to the (possibly reprojected) CRS.
Raises:
ValueError: If the input contains no valid LineStrings.
Notes:
- Attribute projection aggregates multi-matches via a uniqueness reducer (`unique_list`).
- `speed` is internally converted to meters/minute.
"""
def unique_list(agg_vals):
agg_vals = list(set(agg_vals.dropna()))
if len(agg_vals) == 1:
return agg_vals[0]
return agg_vals
original_crs = gdf.crs
gdf = gdf.to_crs(gdf.estimate_utm_crs())
gdf = gdf.explode(ignore_index=True)
gdf = gdf[gdf.geom_type == "LineString"]
if len(gdf) == 0:
raise ValueError("Provided GeoDataFrame contains no valid LineStrings")
if check_intersections:
lines = line_merge(node(MultiLineString(gdf.geometry.to_list())))
else:
lines = line_merge(MultiLineString(gdf.geometry.to_list()))
if isinstance(lines, LineString):
lines = MultiLineString([lines])
lines = gpd.GeoDataFrame(geometry=list(lines.geoms), crs=gdf.crs)
if len(gdf.columns) > 1 and project_gdf_attr:
lines_centroids = lines.copy()
lines_centroids.geometry = lines_centroids.apply(
lambda row: row.geometry.line_interpolate_point(row.geometry.length / 2), axis=1
).buffer(0.05, resolution=2)
lines_with_attrs = gpd.sjoin(lines_centroids, gdf, how="left", predicate="intersects")
aggregated_attrs = (
lines_with_attrs.drop(columns=["geometry", "index_right"]) # убрать геометрию буфера
.groupby(lines_with_attrs.index)
.agg(unique_list)
)
lines = pd.concat([lines, aggregated_attrs], axis=1)
lines["length_meter"] = np.round(lines.length, 2)
if not reproject_to_utm_crs:
lines = lines.to_crs(original_crs)
coords = lines.geometry.get_coordinates()
coords_grouped_by_index = coords.reset_index(names="old_index").groupby("old_index")
start_coords = coords_grouped_by_index.head(1).apply(lambda a: (a.x, a.y), axis=1).rename("start")
end_coords = coords_grouped_by_index.tail(1).apply(lambda a: (a.x, a.y), axis=1).rename("end")
coords = pd.concat([start_coords.reset_index(), end_coords.reset_index()], axis=1)[["start", "end"]]
lines = pd.concat([lines, coords], axis=1)
unique_coords = pd.concat([coords["start"], coords["end"]], ignore_index=True).unique()
coord_to_index = {coord: idx for idx, coord in enumerate(unique_coords)}
lines["u"] = lines["start"].map(coord_to_index)
lines["v"] = lines["end"].map(coord_to_index)
speed = speed * 1000 / 60
lines["time_min"] = np.round(lines["length_meter"] / speed, 2)
graph = nx.Graph()
for coords, node_id in coord_to_index.items():
x, y = coords
graph.add_node(node_id, x=float(x), y=float(y))
columns_to_attr = lines.columns.difference(["start", "end", "u", "v"])
for _, row in lines.iterrows():
edge_attrs = {}
for col in columns_to_attr:
edge_attrs[col] = row[col]
graph.add_edge(row.u, row.v, **edge_attrs)
graph.graph["crs"] = lines.crs
graph.graph["speed m/min"] = speed
return nx.DiGraph(graph)
[docs]
def write_gml(graph: nx.Graph, gml_path: str) -> nx.Graph:
"""
Write a NetworkX graph to GML, coercing node coordinates to plain floats.
Ensures node attributes `x` and `y` are Python `float`, then writes the graph using
`stringizer=str` so any non-primitive attribute values are serialized as strings.
Parameters:
graph (nx.Graph): Input graph. Not mutated — a sanitized copy is written.
gml_path (str): Output GML file path.
Returns:
(nx.Graph): The sanitized copy of the graph that was written to disk.
"""
graph = graph.copy()
# Nodes: x/y to float; node geometry to WKT if present
for n, data in graph.nodes(data=True):
if "x" in data:
try:
data["x"] = float(data["x"])
except Exception:
raise ValueError(f"Node {n} has non-numeric x={data['x']!r}")
if "y" in data:
try:
data["y"] = float(data["y"])
except Exception:
raise ValueError(f"Node {n} has non-numeric y={data['y']!r}")
nx.write_gml(graph, gml_path, stringizer=lambda v: str(v))
return graph
[docs]
def read_gml(gml_path: str, **nx_kwargs) -> nx.Graph:
"""
Read a GML file into a NetworkX graph and cast edge `geometry` from WKT strings to shapely.
Loads the graph via `networkx.read_gml` and, when an edge attribute `geometry` is a string,
attempts to parse it with `shapely.wkt.from_wkt`. Non-parsable strings are left unchanged.
Parameters:
gml_path (str): Path to the GML file.
**nx_kwargs: Additional keyword arguments forwarded to `networkx.read_gml`.
Returns:
(nx.Graph): The graph with edge `geometry` parsed to shapely objects where possible.
"""
graph = nx.read_gml(gml_path, **nx_kwargs)
if graph.is_multigraph():
for u, v, k, data in graph.edges(keys=True, data=True):
if "geometry" in data and isinstance(data["geometry"], str):
try:
data["geometry"] = from_wkt(data["geometry"])
except Exception:
pass
else:
for u, v, data in graph.edges(data=True):
if "geometry" in data and isinstance(data["geometry"], str):
try:
data["geometry"] = from_wkt(data["geometry"])
except Exception:
pass
return graph
[docs]
def reproject_graph(graph: nx.Graph, target_crs) -> nx.Graph:
"""
Reproject node coordinates (`x`, `y`) and edge geometries to a new CRS (in place).
Builds GeoDataFrames for nodes and for edges that have shapely `geometry`, applies
`GeoDataFrame.to_crs(target_crs)`, writes transformed coordinates/geometries back to the graph,
and updates `graph.graph["crs"]` to the resulting target CRS.
Parameters:
graph (nx.Graph): Graph with current CRS in `graph["crs"]`; nodes carry `x`, `y`
in that CRS, edges may carry shapely `geometry` in the same CRS.
target_crs: Target CRS accepted by GeoPandas (EPSG int, string like `"EPSG:3857"`,
or a `pyproj.CRS`).
Returns:
(nx.Graph): The same graph instance (mutated in place) with updated coordinates/geometries and CRS.
Raises:
ValueError: If `graph.graph["crs"]` is missing.
Notes:
- Only nodes with both `x` and `y` are updated.
- Edges without shapely geometry are left unchanged.
- If an edge `geometry` is stored as a WKT string, it is not reprojected; parse it first.
"""
try:
current_crs = graph.graph["crs"]
except KeyError as exc:
raise ValueError("Graph does not have 'crs' attribute") from exc
nodes_items = [(n, d) for n, d in graph.nodes(data=True) if "x" in d and "y" in d]
if nodes_items:
node_ids = [n for n, _ in nodes_items]
node_points = [Point(float(d["x"]), float(d["y"])) for _, d in nodes_items]
nodes_gdf = gpd.GeoDataFrame(index=node_ids, geometry=node_points, crs=current_crs).to_crs(target_crs)
target_crs = nodes_gdf.crs
for nid, geom in nodes_gdf.geometry.items():
graph.nodes[nid]["x"] = float(geom.x)
graph.nodes[nid]["y"] = float(geom.y)
if graph.is_multigraph():
edge_records = [
(u, v, k, data)
for u, v, k, data in graph.edges(keys=True, data=True)
if isinstance(data.get("geometry"), BaseGeometry)
]
if edge_records:
idx = [(u, v, k) for u, v, k, _ in edge_records]
geoms = [data["geometry"] for _, _, _, data in edge_records]
edges_gdf = gpd.GeoDataFrame(
index=pd.MultiIndex.from_tuples(idx, names=["u", "v", "k"]), geometry=geoms, crs=current_crs
).to_crs(target_crs)
target_crs = edges_gdf.crs
for (u, v, k), geom in edges_gdf.geometry.items():
graph.edges[u, v, k]["geometry"] = geom
else:
edge_records = [
(u, v, data) for u, v, data in graph.edges(data=True) if isinstance(data.get("geometry"), BaseGeometry)
]
if edge_records:
idx = [(u, v) for u, v, _ in edge_records]
geoms = [data["geometry"] for _, _, data in edge_records]
edges_gdf = gpd.GeoDataFrame(
index=pd.MultiIndex.from_tuples(idx, names=["u", "v"]), geometry=geoms, crs=current_crs
).to_crs(target_crs)
target_crs = edges_gdf.crs
for (u, v), geom in edges_gdf.geometry.items():
graph.edges[u, v]["geometry"] = geom
graph.graph["crs"] = target_crs
return graph