Source code for iduedu.modules.public_transport_builders

import geopandas as gpd
import networkx as nx
import numpy as np
import pandas as pd
from shapely import LineString, MultiPolygon, Polygon
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map, thread_map

from iduedu import config
from iduedu.enums.pt_enums import PublicTrasport
from iduedu.modules.overpass_parsers import parse_overpass_subway_data, parse_overpass_to_edgenode

from .graph_transformers import clip_nx_graph, estimate_crs_for_bounds
from .overpass_downloaders import (
    get_4326_boundary,
    get_routes_by_poly,
)

logger = config.logger


def _graph_data_to_nx(graph_df, keep_geometry: bool = True, additional_data=None) -> nx.DiGraph:
    """
    Build a directed public-transport graph from a mixed edge/node DataFrame, with optional subway add-ons.

    Input `graph_df` contains both node rows (non-null `node_id`) and edge rows (null `node_id`).
    Nodes are grouped/merged (platforms by rounded coords; other types by (coord, route, type, then by ref_id)),
    then edges are mapped to these merged nodes. If `additional_data` (subway entrances/transfers) is provided,
    edges with `u_ref`/`v_ref` are joined by `ref_id`.

    Parameters:
        graph_df (pd.DataFrame): Mixed table with:
            node rows → columns: `node_id`, `point` (projected meters), `route`, `type`, `ref_id`, `extra_data`;
            edge rows → columns: `u`, `v`, `type`, `extra_data`, `route`, `geometry`.
        keep_geometry (bool): If True, store shapely `geometry` on edges; otherwise drop it.
        additional_data (tuple | None): Optional `(additional_edges, additional_nodes)` produced by
            `parse_overpass_subway_data`. `additional_edges` must have `u_ref`, `v_ref`, `type`, optional
            `geometry`, `route`, `extra_data`; `additional_nodes` must have `ref_id`, `point`, `type`, `extra_data`.

    Returns:
        (nx.DiGraph): Directed graph
    """

    def avg_point(points):
        arr = np.asarray([p for p in points if p is not None], dtype=float)
        return tuple(arr.mean(axis=0)) if len(arr) else None

    def merge_dicts_last(dicts):
        out = {}
        for d in dicts.dropna():
            for k, v in d.items():
                out[k] = v
        return out

    nodes_col = ["node_id", "point", "route", "type", "ref_id", "extra_data"]
    edges_col = ["u", "v", "type", "extra_data", "route", "geometry"]
    graph_nodes = graph_df[~graph_df["node_id"].isna()][nodes_col].copy()
    graph_edges = graph_df[graph_df["node_id"].isna()][edges_col].copy()

    if additional_data is not None:
        additional_edges, additional_nodes = additional_data
        graph_nodes_combined = graph_nodes.merge(additional_nodes, left_on="ref_id", right_on="ref_id", how="outer")
        for column in nodes_col:
            if column in graph_nodes_combined.columns:
                continue
            else:
                graph_nodes_combined[column] = graph_nodes_combined[f"{column}_y"].combine_first(
                    graph_nodes_combined[f"{column}_x"]
                )
        graph_nodes_combined = graph_nodes_combined[nodes_col]

        no_point_refs = graph_nodes_combined[
            (graph_nodes_combined["point"].isna()) & (~graph_nodes_combined["ref_id"].isna())
        ][["ref_id"]].copy()
        if len(no_point_refs) > 0:
            no_point_refs = no_point_refs.merge(
                additional_edges[["v_ref", "u_ref"]], left_on="ref_id", right_on="v_ref"
            )[["v_ref", "u_ref"]].drop_duplicates(subset="u_ref")
            no_point_refs = no_point_refs.merge(
                graph_nodes[["node_id", "type", "ref_id"]], left_on="u_ref", right_on="ref_id"
            ).drop(columns=["ref_id"])
            no_point_refs = no_point_refs.merge(
                graph_edges[["u", "v"]], left_on="node_id", right_on="u", how="left"
            ).drop(columns=["u"])
            no_point_refs = no_point_refs.merge(
                graph_nodes[["node_id", "type", "ref_id"]].add_prefix("potential_"),
                left_on="v",
                right_on="potential_node_id",
            )
            no_point_refs = no_point_refs[no_point_refs["potential_type"] == "platform"]
            remap_refs = no_point_refs.set_index("potential_ref_id")["v_ref"].to_dict()
            s = graph_nodes_combined["ref_id"].copy()
            mapped = s.map(remap_refs)
            mask = mapped.notna()
            graph_nodes_combined.loc[mask, "ref_id"] = mapped[mask]
            graph_nodes_combined.loc[mask, "type"] = "subway_platform"

            graph_nodes = graph_nodes_combined.dropna(subset=["node_id", "point"], how="all").copy()
            graph_nodes["route"] = graph_nodes["route"].fillna("subway_transit")
            # TODO delete duplicated platform node

    platforms = graph_nodes[graph_nodes["type"] == "platform"].copy()
    platforms["point_group"] = platforms["point"].apply(lambda p: (round(p[0]), round(p[1])))
    platforms = platforms.groupby("point_group", as_index=False).agg(
        point=("point", "first"),
        node_id=("node_id", lambda s: tuple(s.dropna())),
        route=("route", lambda s: tuple(s)),
        ref_id=("ref_id", lambda s: tuple(s.dropna())),
        extra_data=("extra_data", merge_dicts_last),
    )
    platforms["type"] = "platform"

    not_platforms = graph_nodes[graph_nodes["type"] != "platform"].copy()
    not_platforms["point_group"] = not_platforms["point"].apply(lambda p: (round(p[0]), round(p[1])))
    not_platforms = not_platforms.groupby(["point_group", "route", "type"], as_index=False).agg(
        point=("point", "first"),
        node_id=("node_id", lambda s: tuple(s.dropna())),
        ref_id=("ref_id", lambda s: tuple(set(s.dropna()))),
        extra_data=("extra_data", merge_dicts_last),
    )
    not_platforms = not_platforms.groupby(["ref_id", "route", "type"], as_index=False).agg(
        point=("point", avg_point),
        node_id=("node_id", "sum"),
        extra_data=("extra_data", merge_dicts_last),
    )

    all_nodes = pd.concat([platforms, not_platforms], ignore_index=True)

    map_nodeid_to_idx = {}
    map_refid_to_idx = {}

    for idx, row in all_nodes.iterrows():
        for nid in row["node_id"] or []:
            map_nodeid_to_idx[nid] = idx
        rids = row["ref_id"]
        if not isinstance(rids, (list, tuple)):
            rids = [rids]
        for rid in rids or []:
            map_refid_to_idx[rid] = idx

    edges_existing = graph_edges[["route", "type", "u", "v", "geometry", "extra_data"]].copy()
    edges_existing["u"] = edges_existing["u"].map(map_nodeid_to_idx)
    edges_existing["v"] = edges_existing["v"].map(map_nodeid_to_idx)
    edges_existing = edges_existing.dropna(subset=["u", "v"]).copy()

    if additional_data is not None:
        additional_edges["u"] = additional_edges["u_ref"].map(map_refid_to_idx)
        additional_edges["v"] = additional_edges["v_ref"].map(map_refid_to_idx)
        additional_edges = additional_edges.dropna(subset=["u", "v"]).copy()

        def _ensure_geom(row):
            if row.get("geometry") is not None:
                return row["geometry"]
            pu = all_nodes.loc[int(row["u"]), "point"]
            pv = all_nodes.loc[int(row["v"]), "point"]
            return LineString([pu, pv])

        additional_edges["geometry"] = additional_edges.apply(_ensure_geom, axis=1)
    else:
        additional_edges = pd.DataFrame()

    edges = pd.concat([edges_existing, additional_edges], ignore_index=True)

    if "length_meter" not in edges.columns:
        edges["length_meter"] = np.nan
    if "time_min" not in edges.columns:
        edges["time_min"] = np.nan

    def calc_len_time(row):
        if row.type == "boarding":
            return 0.0, 0.0
        geom = row.geometry
        if geom is None:
            return 0.0, 0.0
        length = float(round(geom.length, 3))
        speed = PublicTrasport[row.type.upper()].avg_speed
        return length, float(round(length / speed, 3))

    mask_missing = edges["length_meter"].isna() | edges["time_min"].isna()

    vals = edges.loc[mask_missing].apply(calc_len_time, axis=1, result_type="expand")
    vals.columns = ["length_meter", "time_min"]
    edges.loc[mask_missing, ["length_meter", "time_min"]] = vals

    graph = nx.DiGraph()

    for idx, node in all_nodes.iterrows():
        route = list(set(node["route"])) if isinstance(node["route"], tuple) else [node["route"]]
        if len(route) == 1:
            route = route[0]
        graph.add_node(
            idx,
            x=float(node["point"][0]),
            y=float(node["point"][1]),
            type=node["type"],
            route=route,
            ref_id=(node["ref_id"][0] if isinstance(node["ref_id"], tuple) and node["ref_id"] else node["ref_id"]),
            **(node["extra_data"] if isinstance(node["extra_data"], dict) else {}),
        )

    for _, e in edges.iterrows():
        graph.add_edge(
            int(e["u"]),
            int(e["v"]),
            route=e["route"],
            type=e["type"],
            geometry=(e["geometry"] if keep_geometry else None),
            length_meter=e["length_meter"],
            time_min=e["time_min"],
            **(e["extra_data"] if isinstance(e["extra_data"], dict) else {}),
        )

    return graph


def _multi_get_routes_by_poly(args):
    return get_routes_by_poly(*args)


def _multi_parse_overpass_to_edgenode(args):
    return parse_overpass_to_edgenode(*args)


def _get_public_transport_graph(
    osm_id: int,
    territory: Polygon | MultiPolygon | gpd.GeoDataFrame | None,
    transport_types: list[str],
    osm_edge_tags: list[str],
    clip_by_territory: bool = False,
    keep_edge_geometry: bool = True,
):
    """
    Build a directed public-transport graph for one or several transport types inside a territory.

    The function:
    1) resolves a boundary polygon (EPSG:4326);
    2) downloads routes via Overpass in parallel;
    3) (for subway) additionally parses stop areas/groups and stations, producing entrances/transfers;
    4) parses each route into edge/node rows (parallel for large inputs);
    5) assembles a single `nx.DiGraph` via `_graph_data_to_nx`, computing missing edge length/time;
    6) optionally clips the graph by the territory.

    Parameters:
        osm_id (int): OSM relation/area id of the territory; used if `territory` is not provided.
        territory (Polygon | MultiPolygon | gpd.GeoDataFrame | None): Boundary geometry in EPSG:4326.
        transport_types (list[str]): Transport types to include (e.g., `["bus", "tram", "subway"]`).
        osm_edge_tags (list[str]): Which route/member tags to retain on edges/nodes (overrides defaults).
        clip_by_territory (bool): If True, clip the final graph to the (projected) boundary.
        keep_edge_geometry (bool): If True, store shapely `geometry` on edges.

    Returns:
        (nx.DiGraph): Directed PT graph. Graph attributes set by this function:
            - `graph["crs"]` (int/EPSG of the local projected CRS),
            - `graph["type"]` = "public_trasport" (sic).
    """

    polygon = get_4326_boundary(osm_id=osm_id, territory=territory)

    args_list = [(polygon, transport) for transport in transport_types]

    # Если парсим метро - ожидаем в ответе информацию о станциях
    platform_stop_data_use = False
    if "subway" in transport_types:
        platform_stop_data_use = True

    if not config.enable_tqdm_bar:
        logger.debug("Downloading pt routes")
    overpass_data = pd.concat(
        thread_map(
            _multi_get_routes_by_poly,
            args_list,
            desc="Downloading public transport routes from OSM",
            disable=not config.enable_tqdm_bar,
        ),
        ignore_index=True,
    ).reset_index(drop=True)

    if overpass_data.shape[0] == 0:
        logger.warning("No routes found for public transport.")
        return nx.Graph()

    local_crs = estimate_crs_for_bounds(*polygon.bounds).to_epsg()

    # необходимые osm теги из relation маршрута
    if osm_edge_tags is None:
        needed_tags = set(config.transport_useful_edges_attr)
    else:
        needed_tags = set(osm_edge_tags)

    # Отделяем станции от маршрутов при необходимости
    if platform_stop_data_use:
        for column in ["is_stop_area", "is_stop_area_group", "is_station"]:
            overpass_data[column] = overpass_data[column].astype("boolean").fillna(False)
        routes_data = overpass_data[
            ~((overpass_data["is_stop_area"]) | (overpass_data["is_stop_area_group"]) | (overpass_data["is_station"]))
        ].copy()
        stop_areas = overpass_data[overpass_data["is_stop_area"]].copy()
        stop_areas_group = overpass_data[overpass_data["is_stop_area_group"]].copy()
        stations_data = overpass_data[overpass_data["is_station"]].copy()
        add_data = parse_overpass_subway_data(stop_areas, stop_areas_group, stations_data, local_crs)
    else:
        routes_data = overpass_data.copy()
        add_data = None

    if not config.enable_tqdm_bar:
        logger.debug("Parsing public transport routes")
    if overpass_data.shape[0] > 100:
        # Если много маршрутов - обрабатываем в параллели
        edgenode_for_routes = process_map(
            _multi_parse_overpass_to_edgenode,
            [(row, local_crs, needed_tags) for _, row in routes_data.iterrows()],
            desc="Parsing public transport routes",
            chunksize=1,
            disable=not config.enable_tqdm_bar,
        )
    else:
        tqdm.pandas(desc="Parsing public transport routes", disable=not config.enable_tqdm_bar)
        edgenode_for_routes = [
            data
            for data in routes_data.progress_apply(
                lambda x: parse_overpass_to_edgenode(x, local_crs, needed_tags), axis=1
            ).tolist()
            if data is not None
        ]

    if len(edgenode_for_routes) == 0:
        logger.warning("No routes were parsed for public transport.")
        return nx.DiGraph()
    graph_df = pd.concat(edgenode_for_routes, ignore_index=True)
    to_return = _graph_data_to_nx(graph_df, keep_geometry=keep_edge_geometry, additional_data=add_data)
    to_return.graph["crs"] = local_crs
    to_return.graph["type"] = "public_trasport"

    if clip_by_territory:
        polygon = gpd.GeoSeries([polygon], crs=4326).to_crs(local_crs).union_all()
        to_return = clip_nx_graph(to_return, polygon)

    logger.debug("Done!")
    return to_return


[docs] def get_single_public_transport_graph( public_transport_type: str | PublicTrasport, *, osm_id: int | None = None, territory: Polygon | MultiPolygon | gpd.GeoDataFrame | None = None, clip_by_territory: bool = False, keep_edge_geometry: bool = True, osm_edge_tags: list[str] | None = None, # overrides default tags ): """ Build a directed graph for a single public-transport mode within a given territory. The function resolves a boundary (by `osm_id` or `territory`), downloads OpenStreetMap routes inside that boundary, converts them into a projected `nx.DiGraph`, and computes per-edge length (meters) and travel time (minutes). When the mode is **subway**, additional station context is incorporated: entrances/exits and inter-station transfers are added; station metadata (e.g., name, depth) is attached to nodes where available. Parameters: public_transport_type (str | PublicTrasport): One mode, e.g. `"bus"`, `"tram"`, `"trolleybus"`, `"subway"`. You may pass a `PublicTrasport` enum or its string value. osm_id (int | None): OSM relation/area id of the territory. Provide this or `territory`. territory (Polygon | MultiPolygon | gpd.GeoDataFrame | None): Boundary geometry in EPSG:4326 (or a GeoDataFrame). Used when `osm_id` is not given. clip_by_territory (bool): If True, the resulting graph is clipped to the boundary (in the local CRS). keep_edge_geometry (bool): If True, edge shapes (`shapely` geometries in local CRS) are stored on edges. osm_edge_tags (list[str] | None): Subset of OSM tags to retain on edges/nodes. If None, a sensible default is used from configuration; only requested keys are joined from OSM element tags. Returns: (nx.DiGraph): Directed PT graph with: - node attrs: `x`, `y` (floats, local CRS), `type`, `route`, `ref_id`, plus merged station `extra_data` (if any); - edge attrs: `type`, `route`, `length_meter`, `time_min`, optional `geometry`, and selected OSM tags. Graph attrs: `graph["crs"]` (EPSG int of the local projected CRS), `graph["type"]` = `"public_trasport"`. Notes: - Lengths and times are computed in a **local projected CRS** estimated from the boundary; per-edge speeds are taken from mode-specific defaults (and, for subway connectors, from connector-type defaults). """ public_transport_type = ( public_transport_type.value() if isinstance(public_transport_type, PublicTrasport) else public_transport_type ) return _get_public_transport_graph( osm_id=osm_id, territory=territory, transport_types=[public_transport_type], osm_edge_tags=osm_edge_tags, clip_by_territory=clip_by_territory, keep_edge_geometry=keep_edge_geometry, )
[docs] def get_all_public_transport_graph( *, osm_id: int | None = None, territory: Polygon | MultiPolygon | gpd.GeoDataFrame | None = None, clip_by_territory: bool = False, keep_edge_geometry: bool = True, transport_types: list[PublicTrasport] = None, osm_edge_tags: list[str] | None = None, # overrides default tags ) -> nx.Graph: """ Build a combined directed graph for multiple public-transport modes within a territory. The function collects routes for the requested modes (by default: **tram**, **bus**, **trolleybus**, **subway**), converts them into a single projected graph, and computes per-edge length and time. Edges from different modes coexist in the same `nx.DiGraph`; node ids are unified across modes. For the **subway** mode, station context is added (entrances/exits and inter-station transfers), and available station metadata is merged into node attrs. Parameters: osm_id (int | None): OSM relation/area id of the territory. Provide this or `territory`. territory (Polygon | MultiPolygon | gpd.GeoDataFrame | None): Boundary geometry in EPSG:4326 (or a GeoDataFrame). clip_by_territory (bool): If True, clip the final graph to the boundary (in the local CRS). keep_edge_geometry (bool): If True, retain `shapely` geometries on edges. transport_types (list[PublicTrasport] | None): List of modes to include. Defaults to `[PublicTrasport.TRAM, PublicTrasport.BUS, PublicTrasport.TROLLEYBUS, PublicTrasport.SUBWAY]`. All items must be `PublicTrasport` enums. osm_edge_tags (list[str] | None): Which OSM tags to keep on edges/nodes. If None, a default subset from configuration is used; only these keys are joined from OSM. Returns: (nx.DiGraph): Combined directed PT graph. Typical attributes: - node attrs: `x`, `y` (local CRS), `type`, `route`, `ref_id`, station `extra_data` where applicable; - edge attrs: `type`, `route`, `length_meter`, `time_min`, optional `geometry`, plus selected OSM tags. Graph attrs: `graph["crs"]` (EPSG int), `graph["type"]` = `"public_trasport"`. Notes: Each mode’s ways are downloaded inside the boundary and transformed into directed edges; per-edge speeds are taken from mode-specific defaults (and, for subway connectors, from connector-type defaults). """ if transport_types is None: transport_types = [PublicTrasport.TRAM, PublicTrasport.BUS, PublicTrasport.TROLLEYBUS, PublicTrasport.SUBWAY] else: for transport_type in transport_types: if not isinstance(transport_type, PublicTrasport): raise ValueError(f"transport_type {transport_type} is not a valid transport type.") transports = [transport.value for transport in transport_types if isinstance(transport, PublicTrasport)] return _get_public_transport_graph( osm_id=osm_id, territory=territory, transport_types=transports, osm_edge_tags=osm_edge_tags, clip_by_territory=clip_by_territory, keep_edge_geometry=keep_edge_geometry, )