Source code for cgeom.algorithms._intersection

"""Line segment intersection using Bentley-Ottmann sweep line and brute force."""

import heapq
from enum import IntEnum

import numpy as np

# ---------------------------------------------------------------------------
# Event types — ordering matters for the heap (left < intersection < right)
# ---------------------------------------------------------------------------


class _EventType(IntEnum):
    LEFT = 0
    INTERSECTION = 1
    RIGHT = 2


# ---------------------------------------------------------------------------
# Geometric helpers
# ---------------------------------------------------------------------------

_EPS = 1e-9


def _segment_intersection(s1, s2):
    """Parametric intersection of two segments. Returns [x, y] or None."""
    x1, y1 = s1[0]
    x2, y2 = s1[1]
    x3, y3 = s2[0]
    x4, y4 = s2[1]

    denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)

    if abs(denom) < _EPS:
        # Parallel or collinear — check for overlap
        return _collinear_intersection(s1, s2)

    t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom
    u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom

    if -_EPS <= t <= 1 + _EPS and -_EPS <= u <= 1 + _EPS:
        ix = x1 + t * (x2 - x1)
        iy = y1 + t * (y2 - y1)
        return [ix, iy]
    return None


def _collinear_intersection(s1, s2):
    """Check if two collinear segments overlap, return one endpoint of overlap or None."""
    x1, y1 = s1[0]
    x2, y2 = s1[1]
    x3, y3 = s2[0]
    x4, y4 = s2[1]

    # Check if segments are on the same line
    cross = (x3 - x1) * (y2 - y1) - (y3 - y1) * (x2 - x1)
    if abs(cross) > _EPS:
        return None

    # Project onto the longer axis
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    if dx >= dy:
        a_lo, a_hi = min(x1, x2), max(x1, x2)
        b_lo, b_hi = min(x3, x4), max(x3, x4)
    else:
        a_lo, a_hi = min(y1, y2), max(y1, y2)
        b_lo, b_hi = min(y3, y4), max(y3, y4)

    lo = max(a_lo, b_lo)
    hi = min(a_hi, b_hi)

    if lo > hi + _EPS:
        return None

    # Return the start of overlap as the intersection point
    if dx >= dy:
        t = (lo - x1) / (x2 - x1) if abs(x2 - x1) > _EPS else 0
        return [x1 + t * (x2 - x1), y1 + t * (y2 - y1)]
    else:
        t = (lo - y1) / (y2 - y1) if abs(y2 - y1) > _EPS else 0
        return [x1 + t * (x2 - x1), y1 + t * (y2 - y1)]


def _y_at_x(seg, x):
    """Y-coordinate where segment crosses the sweep line at x."""
    x1, y1 = seg[0]
    x2, y2 = seg[1]
    if abs(x2 - x1) < _EPS:
        return min(y1, y2)
    t = (x - x1) / (x2 - x1)
    return y1 + t * (y2 - y1)


# ---------------------------------------------------------------------------
# Sweep-line status structure
# ---------------------------------------------------------------------------


class _SweepLineStatus:
    """Maintains sorted list of active segments ordered by y at current sweep x."""

    def __init__(self):
        self._segments = []  # list of segment indices
        self._sweep_x = 0.0
        self._seg_data = None  # reference to segments array

    def set_context(self, segments):
        self._seg_data = segments

    def set_sweep_x(self, x):
        self._sweep_x = x

    def _sort_key(self, seg_idx):
        return _y_at_x(self._seg_data[seg_idx], self._sweep_x)

    def insert(self, seg_idx):
        key = self._sort_key(seg_idx)
        # Binary search for insertion point
        lo, hi = 0, len(self._segments)
        while lo < hi:
            mid = (lo + hi) // 2
            if self._sort_key(self._segments[mid]) < key - _EPS:
                lo = mid + 1
            else:
                hi = mid
        self._segments.insert(lo, seg_idx)

    def remove(self, seg_idx):
        try:
            self._segments.remove(seg_idx)
        except ValueError:
            pass

    def _find_pos(self, seg_idx):
        for i, s in enumerate(self._segments):
            if s == seg_idx:
                return i
        return -1

    def swap(self, seg_a, seg_b):
        pos_a = self._find_pos(seg_a)
        pos_b = self._find_pos(seg_b)
        if pos_a >= 0 and pos_b >= 0:
            self._segments[pos_a] = seg_b
            self._segments[pos_b] = seg_a

    def neighbors(self, seg_idx):
        """Return (above, below) neighbor indices or None."""
        pos = self._find_pos(seg_idx)
        if pos < 0:
            return None, None
        above = self._segments[pos - 1] if pos > 0 else None
        below = self._segments[pos + 1] if pos < len(self._segments) - 1 else None
        return above, below


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------


[docs] class SegmentIntersection: """Find intersection points among a set of 2D line segments. Provides both a Bentley-Ottmann sweep line method and a brute-force O(n^2) pairwise check. Attributes: segments (np.ndarray): Array of segments with shape (n, 2, 2). """
[docs] def __init__(self, segments): """Initialize with a set of 2D line segments. Args: segments: Collection of segments as a list, tuple, or numpy array. Each segment is [[x1, y1], [x2, y2]]. Raises: pydantic.ValidationError: If fewer than 2 segments, zero-length segment, non-numeric, or wrong shape. """ from cgeom.elements.models import SegmentIntersectionInput validated = SegmentIntersectionInput(segments=segments) self.segments = np.array(validated.segments)
[docs] def find_intersections(self): """Find all intersection points using Bentley-Ottmann sweep line. Returns: list[list[float]]: Unique intersection points as [[x, y], ...]. """ segs = self._normalize_segments() n = len(segs) if n < 2: return [] # Build event queue events = [] for i in range(n): lx, ly = segs[i][0] rx, ry = segs[i][1] heapq.heappush(events, (lx, _EventType.LEFT, ly, i, -1)) heapq.heappush(events, (rx, _EventType.RIGHT, ry, i, -1)) status = _SweepLineStatus() status.set_context(segs) found_points = {} # (round_x, round_y) -> [x, y] found_pairs = set() # frozenset(i, j) to avoid duplicate events def _check_and_add(seg_a, seg_b): if seg_a is None or seg_b is None: return pair = frozenset((seg_a, seg_b)) if pair in found_pairs: return pt = _segment_intersection(segs[seg_a], segs[seg_b]) if pt is not None: rkey = (round(pt[0], 9), round(pt[1], 9)) if rkey not in found_points: found_points[rkey] = pt found_pairs.add(pair) heapq.heappush( events, (pt[0], _EventType.INTERSECTION, pt[1], seg_a, seg_b) ) while events: x, etype, y, s1, s2 = heapq.heappop(events) status.set_sweep_x(x) if etype == _EventType.LEFT: status.insert(s1) above, below = status.neighbors(s1) _check_and_add(s1, above) _check_and_add(s1, below) elif etype == _EventType.RIGHT: above, below = status.neighbors(s1) _check_and_add(above, below) status.remove(s1) elif etype == _EventType.INTERSECTION: status.swap(s1, s2) # After swap, check new neighbors above_s1, below_s1 = status.neighbors(s1) above_s2, below_s2 = status.neighbors(s2) _check_and_add(s1, above_s1) _check_and_add(s1, below_s1) _check_and_add(s2, above_s2) _check_and_add(s2, below_s2) return list(found_points.values())
[docs] def find_intersections_brute_force(self): """Find all intersection points using O(n^2) pairwise check. Returns: list[list[float]]: Unique intersection points as [[x, y], ...]. """ segs = self.segments.tolist() n = len(segs) found = {} for i in range(n): for j in range(i + 1, n): pt = _segment_intersection(segs[i], segs[j]) if pt is not None: rkey = (round(pt[0], 9), round(pt[1], 9)) if rkey not in found: found[rkey] = pt return list(found.values())
[docs] def get_intersection_pairs(self): """Return intersection details with segment indices. Returns: list[tuple[int, int, list[float]]]: Each entry is (seg_i, seg_j, [x, y]). """ segs = self.segments.tolist() n = len(segs) pairs = [] for i in range(n): for j in range(i + 1, n): pt = _segment_intersection(segs[i], segs[j]) if pt is not None: pairs.append((i, j, pt)) return pairs
[docs] def get_segments(self): """Return segments as a plain list. Returns: list[list[list[float]]]: Segments as [[[x1,y1],[x2,y2]], ...]. """ return self.segments.tolist()
[docs] def plot(self, title="Segment Intersections"): """Deprecated: use cgeom.visualization.plot_intersections() instead.""" import warnings warnings.warn( "SegmentIntersection.plot() is deprecated. " "Use cgeom.visualization.plot_intersections(si_obj, title) instead.", DeprecationWarning, stacklevel=2, ) from cgeom.visualization import plot_intersections plot_intersections(self, title)
def _normalize_segments(self): """Return segments normalized left-to-right (or bottom-to-top for vertical).""" segs = self.segments.tolist() normalized = [] for seg in segs: (x1, y1), (x2, y2) = seg if x1 > x2 + _EPS or (abs(x1 - x2) < _EPS and y1 > y2): normalized.append([[x2, y2], [x1, y1]]) else: normalized.append([[x1, y1], [x2, y2]]) return normalized