Skip to content

Utilities for working with paths through a hierarchy

batch_filter_empty_paths(predicted_boxes, predicted_paths, predicted_path_scores)

Applies empty path filtering to a batch of predictions.

This function maps the filter_empty_paths function over a batch of predicted boxes, paths, and scores.

Parameters:

Name Type Description Default
predicted_boxes list[Tensor]

A batch of bounding box tensors.

required
predicted_paths list[list[list[int]]]

A batch of predicted path lists.

required
predicted_path_scores list[list[Tensor]]

A batch of predicted path score lists.

required

Returns:

Type Description
list[tuple[Tensor, list[list[int]], list[Tensor]]]

A list of tuples, where each tuple contains the filtered boxes, paths, and scores for an item in the batch.

Examples:

>>> import torch
>>> boxes_batch = [torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]]), torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]])]
>>> paths_batch = [[[4], [4, 6], [4, 5], [], []], [[4], [4, 6], [4, 5], [], []]]
>>> scores_batch = [[torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])], [torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])]]
>>> result = batch_filter_empty_paths(boxes_batch, paths_batch, scores_batch)
>>> len(result)
2
>>> result[0][0] # boxes for first batch item
tensor([[482.2700, 395.7700, 241.9800],
        [  8.1100, 156.8700, 152.9100],
        [610.4200, 429.3800, 307.7000],
        [103.8600, 200.9300, 197.5700]])
>>> result[0][1] # paths for first batch item
[[4], [4, 6], [4, 5]]
Source code in hierarchical_loss/path_utils.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def batch_filter_empty_paths(predicted_boxes: list[torch.Tensor], predicted_paths: list[list[list[int]]], predicted_path_scores: list[list[torch.Tensor]]) -> list[tuple[torch.Tensor, list[list[int]], list[torch.Tensor]]]:
    """Applies empty path filtering to a batch of predictions.

    This function maps the `filter_empty_paths` function over a batch of
    predicted boxes, paths, and scores.

    Parameters
    ----------
    predicted_boxes : list[torch.Tensor]
        A batch of bounding box tensors.
    predicted_paths : list[list[list[int]]]
        A batch of predicted path lists.
    predicted_path_scores : list[list[torch.Tensor]]
        A batch of predicted path score lists.

    Returns
    -------
    list[tuple[torch.Tensor, list[list[int]], list[torch.Tensor]]]
        A list of tuples, where each tuple contains the filtered boxes,
        paths, and scores for an item in the batch.

    Examples
    --------
    >>> import torch
    >>> boxes_batch = [torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]]), torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]])]
    >>> paths_batch = [[[4], [4, 6], [4, 5], [], []], [[4], [4, 6], [4, 5], [], []]]
    >>> scores_batch = [[torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])], [torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])]]
    >>> result = batch_filter_empty_paths(boxes_batch, paths_batch, scores_batch)
    >>> len(result)
    2
    >>> result[0][0] # boxes for first batch item
    tensor([[482.2700, 395.7700, 241.9800],
            [  8.1100, 156.8700, 152.9100],
            [610.4200, 429.3800, 307.7000],
            [103.8600, 200.9300, 197.5700]])
    >>> result[0][1] # paths for first batch item
    [[4], [4, 6], [4, 5]]
    """
    B = len(predicted_paths)
    return list(itertools.starmap(filter_empty_paths, zip(predicted_boxes, predicted_paths, predicted_path_scores)))

batch_truncate_paths_conditionals(predicted_paths, predicted_path_scores, threshold=0.25)

Applies conditional probability truncation to a batch of path lists.

This function maps the truncate_paths_conditionals function over a batch of predicted paths and scores.

Parameters:

Name Type Description Default
predicted_paths list[list[list[int]]]

A batch of path lists. Each item in the outer list corresponds to an item in the batch.

required
predicted_path_scores list[list[Tensor]]

A batch of score lists, corresponding to predicted_paths.

required
threshold float

The probability threshold to use for truncation, by default 0.25.

0.25

Returns:

Type Description
list[tuple[list[list[int]], list[Tensor]]]

A list of tuples, where each tuple contains the truncated paths and scores for an item in the batch.

Examples:

>>> import torch
>>> paths_batch = [[[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]], [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]]
>>> scores_batch = [[torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])], [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]]
>>> batch_truncate_paths_conditionals(paths_batch, scores_batch, 0.589)
[([[4, 2], [4, 6], [4, 5], [], []], [tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]), ([[4, 2], [4, 6], [4, 5], [], []], [tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])])]
Source code in hierarchical_loss/path_utils.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def batch_truncate_paths_conditionals(predicted_paths: list[list[list[int]]], predicted_path_scores: list[list[torch.Tensor]], threshold: float = 0.25) -> list[tuple[list[list[int]], list[torch.Tensor]]]:
    """Applies conditional probability truncation to a batch of path lists.

    This function maps the `truncate_paths_conditionals` function over a
    batch of predicted paths and scores.

    Parameters
    ----------
    predicted_paths : list[list[list[int]]]
        A batch of path lists. Each item in the outer list corresponds to
        an item in the batch.
    predicted_path_scores : list[list[torch.Tensor]]
        A batch of score lists, corresponding to `predicted_paths`.
    threshold : float, optional
        The probability threshold to use for truncation, by default 0.25.

    Returns
    -------
    list[tuple[list[list[int]], list[torch.Tensor]]]
        A list of tuples, where each tuple contains the truncated paths and
        scores for an item in the batch.

    Examples
    --------
    >>> import torch
    >>> paths_batch = [[[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]], [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]]
    >>> scores_batch = [[torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])], [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]]
    >>> batch_truncate_paths_conditionals(paths_batch, scores_batch, 0.589)
    [([[4, 2], [4, 6], [4, 5], [], []], [tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]), ([[4, 2], [4, 6], [4, 5], [], []], [tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])])]
    """
    B = len(predicted_paths)
    return list(itertools.starmap(truncate_paths_conditionals, zip(predicted_paths, predicted_path_scores, itertools.repeat(threshold, B))))

batch_truncate_paths_marginals(predicted_paths, predicted_path_scores, threshold=0.25)

Applies marginal probability truncation to a batch of path lists.

This function maps the truncate_paths_marginals function over a batch of predicted paths and scores.

Parameters:

Name Type Description Default
predicted_paths list[list[list[int]]]

A batch of path lists. Each item in the outer list corresponds to an item in the batch.

required
predicted_path_scores list[list[Tensor]]

A batch of score lists, corresponding to predicted_paths.

required
threshold float

The probability threshold to use for truncation, by default 0.25.

0.25

Returns:

Type Description
list[tuple[list[list[int]], list[Tensor]]]

A list of tuples, where each tuple contains the truncated paths and scores for an item in the batch.

Examples:

>>> import torch
>>> paths_batch = [[[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]], [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]]
>>> scores_batch = [[torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])], [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]]
>>> batch_truncate_paths_marginals(paths_batch, scores_batch, 0.589)
[([[4], [4, 6], [4, 5], [], []], [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]), ([[4], [4, 6], [4, 5], [], []], [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])])]
Source code in hierarchical_loss/path_utils.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def batch_truncate_paths_marginals(predicted_paths: list[list[list[int]]], predicted_path_scores: list[list[torch.Tensor]], threshold: float = 0.25) -> list[tuple[list[list[int]], list[torch.Tensor]]]:
    """Applies marginal probability truncation to a batch of path lists.

    This function maps the `truncate_paths_marginals` function over a
    batch of predicted paths and scores.

    Parameters
    ----------
    predicted_paths : list[list[list[int]]]
        A batch of path lists. Each item in the outer list corresponds to
        an item in the batch.
    predicted_path_scores : list[list[torch.Tensor]]
        A batch of score lists, corresponding to `predicted_paths`.
    threshold : float, optional
        The probability threshold to use for truncation, by default 0.25.

    Returns
    -------
    list[tuple[list[list[int]], list[torch.Tensor]]]
        A list of tuples, where each tuple contains the truncated paths and
        scores for an item in the batch.

    Examples
    --------
    >>> import torch
    >>> paths_batch = [[[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]], [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]]
    >>> scores_batch = [[torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])], [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]]
    >>> batch_truncate_paths_marginals(paths_batch, scores_batch, 0.589)
    [([[4], [4, 6], [4, 5], [], []], [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]), ([[4], [4, 6], [4, 5], [], []], [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])])]
    """
    B = len(predicted_paths)
    return list(itertools.starmap(truncate_paths_marginals, zip(predicted_paths, predicted_path_scores, itertools.repeat(threshold, B))))

construct_parent_childset_tree(tree)

Converts a {child: parent} tree into a {parent: set[children]} tree.

This function inverts the standard {child: parent} structure, creating a dictionary for navigating the tree top-down.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
dict

A nested dictionary representing the tree in a top-down format, e.g., {parent: set[children]}.

Examples:

>>> childparent_tree = {0:1, 1:2, 3:2, 4:5}
>>> construct_parent_childset_tree(childparent_tree)
{1: {0}, 2: {1, 3}, 5: {4}}
Source code in hierarchical_loss/path_utils.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def construct_parent_childset_tree(tree: dict[Hashable, Hashable]) -> dict[Hashable, set]:
    """Converts a {child: parent} tree into a {parent: set[children]} tree.

    This function inverts the standard {child: parent} structure, creating
    a dictionary for navigating the tree top-down.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    dict
        A nested dictionary representing the tree in a top-down format,
        e.g., `{parent: set[children]}`.

    Examples
    --------
    >>> childparent_tree = {0:1, 1:2, 3:2, 4:5}
    >>> construct_parent_childset_tree(childparent_tree)
    {1: {0}, 2: {1, 3}, 5: {4}}
    """
    parent_childset_tree = {}
    for child, parent in tree.items():
        if parent not in parent_childset_tree:
            parent_childset_tree[parent] = set()
        parent_childset_tree[parent].add(child)
    return parent_childset_tree

construct_parent_childtensor_tree(tree, device=None)

Converts a {child: parent} tree into a {parent: tensor[children]} tree.

This function inverts the standard {child: parent} structure, creating a dictionary for navigating the tree top-down.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
dict

A nested dictionary representing the tree in a top-down format, e.g., {parent: tensor[children]}.

Examples:

>>> childparent_tree = {0:1, 1:2, 3:2, 4:5}
>>> construct_parent_childtensor_tree(childparent_tree)
{1: tensor([0]), 2: tensor([1, 3]), 5: tensor([4])}
Source code in hierarchical_loss/path_utils.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def construct_parent_childtensor_tree(tree: dict[Hashable, Hashable], device=None) -> dict[Hashable, torch.Tensor]:
    """Converts a {child: parent} tree into a {parent: tensor[children]} tree.

    This function inverts the standard {child: parent} structure, creating
    a dictionary for navigating the tree top-down.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    dict
        A nested dictionary representing the tree in a top-down format,
        e.g., `{parent: tensor[children]}`.

    Examples
    --------
    >>> childparent_tree = {0:1, 1:2, 3:2, 4:5}
    >>> construct_parent_childtensor_tree(childparent_tree)
    {1: tensor([0]), 2: tensor([1, 3]), 5: tensor([4])}
    """
    childset_tree = construct_parent_childset_tree(tree)
    return {k: torch.tensor(list(v), device=device) for k,v  in childset_tree.items()} 

filter_empty_paths(predicted_boxes, predicted_paths, predicted_path_scores)

Filters out predictions with empty paths.

After truncation, some paths may become empty. This function removes those empty paths along with their corresponding scores and bounding boxes.

Parameters:

Name Type Description Default
predicted_boxes Tensor

A 2D tensor of bounding box predictions, where columns correspond to individual predictions (e.g., shape [4, N]).

required
predicted_paths list[list[int]]

A list of predicted paths.

required
predicted_path_scores list[Tensor]

A list of predicted path scores.

required

Returns:

Type Description
tuple[Tensor, list[list[int]], list[Tensor]]

A tuple containing the filtered boxes, paths, and scores, with empty path predictions removed.

Examples:

>>> import torch
>>> boxes = torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]])
>>> paths = [[4], [4, 6], [4, 5], [], []]
>>> scores = [torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])]
>>> f_boxes, f_paths, f_scores = filter_empty_paths(boxes, paths, scores)
>>> f_boxes
tensor([[482.2700, 395.7700, 241.9800],
        [  8.1100, 156.8700, 152.9100],
        [610.4200, 429.3800, 307.7000],
        [103.8600, 200.9300, 197.5700]])
>>> f_paths
[[4], [4, 6], [4, 5]]
>>> f_scores
[tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765])]
Source code in hierarchical_loss/path_utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def filter_empty_paths(predicted_boxes: torch.Tensor, predicted_paths: list[list[int]], predicted_path_scores: list[torch.Tensor]) -> tuple[torch.Tensor, list[list[int]], list[torch.Tensor]]:
    """Filters out predictions with empty paths.

    After truncation, some paths may become empty. This function removes
    those empty paths along with their corresponding scores and bounding
    boxes.

    Parameters
    ----------
    predicted_boxes : torch.Tensor
        A 2D tensor of bounding box predictions, where columns correspond
        to individual predictions (e.g., shape [4, N]).
    predicted_paths : list[list[int]]
        A list of predicted paths.
    predicted_path_scores : list[torch.Tensor]
        A list of predicted path scores.

    Returns
    -------
    tuple[torch.Tensor, list[list[int]], list[torch.Tensor]]
        A tuple containing the filtered boxes, paths, and scores,
        with empty path predictions removed.

    Examples
    --------
    >>> import torch
    >>> boxes = torch.tensor([[482.27, 395.77, 241.98, 359.60, 258.38], [8.11, 156.87, 152.91, 335.40, 24.81], [610.42, 429.38, 307.70, 382.68, 413.79], [103.86, 200.93, 197.57, 352.40, 197.61]])
    >>> paths = [[4], [4, 6], [4, 5], [], []]
    >>> scores = [torch.tensor([0.9896]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([]), torch.tensor([])]
    >>> f_boxes, f_paths, f_scores = filter_empty_paths(boxes, paths, scores)
    >>> f_boxes
    tensor([[482.2700, 395.7700, 241.9800],
            [  8.1100, 156.8700, 152.9100],
            [610.4200, 429.3800, 307.7000],
            [103.8600, 200.9300, 197.5700]])
    >>> f_paths
    [[4], [4, 6], [4, 5]]
    >>> f_scores
    [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765])]
    """
    keep_idx = [i for i, path in enumerate(predicted_paths) if len(path) > 0]
    return (
        predicted_boxes[:,keep_idx],
        [predicted_paths[k] for k in keep_idx],
        [predicted_path_scores[k] for k in keep_idx]
    )

find_closest_permitted_parent(node, tree, permitted_nodes)

Finds the first ancestor of a node that is in a permitted set.

This function walks up the ancestral chain of a node (using the {child: parent} tree) and returns the first ancestor it finds that is present in the permitted_nodes set.

If no ancestor (including the node itself) is in the set, or if the node is not in the tree to begin with, it returns None.

Parameters:

Name Type Description Default
node Hashable

The ID of the node to start searching from.

required
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required
permitted_nodes set[Hashable]

A set of node IDs that are considered "permitted".

required

Returns:

Type Description
Hashable | None

The ID of the closest permitted ancestor, or None if none is found.

Examples:

>>> tree = {1: 2, 2: 3, 3: 4, 4: 5}
>>> permitted = {0, 2, 5}
>>> find_closest_permitted_parent(1, tree, permitted) # 1 -> 2 (permitted)
2
>>> find_closest_permitted_parent(0, tree, permitted) # 0 is not in tree keys, returns None
>>> tree[0] = 1 # Add 0 to the tree
>>> find_closest_permitted_parent(0, tree, permitted) # 0 -> 1 -> 2 (permitted)
2
>>> tree = {10: 20, 20: 30, 30: 40}
>>> find_closest_permitted_parent(10, tree, {50, 60}) # No permitted ancestors, returns None
Source code in hierarchical_loss/tree_utils.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def find_closest_permitted_parent(
    node: Hashable,
    tree: dict[Hashable, Hashable],
    permitted_nodes: set[Hashable],
) -> Hashable | None:
    """Finds the first ancestor of a node that is in a permitted set.

    This function walks up the ancestral chain of a node (using the
    {child: parent} tree) and returns the first ancestor it finds
    that is present in the `permitted_nodes` set.

    If no ancestor (including the node itself) is in the set,
    or if the node is not in the tree to begin with, it returns None.

    Parameters
    ----------
    node : Hashable
        The ID of the node to start searching from.
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.
    permitted_nodes : set[Hashable]
        A set of node IDs that are considered "permitted".

    Returns
    -------
    Hashable | None
        The ID of the closest permitted ancestor, or None if none is found.

    Examples
    --------
    >>> tree = {1: 2, 2: 3, 3: 4, 4: 5}
    >>> permitted = {0, 2, 5}
    >>> find_closest_permitted_parent(1, tree, permitted) # 1 -> 2 (permitted)
    2
    >>> find_closest_permitted_parent(0, tree, permitted) # 0 is not in tree keys, returns None
    >>> tree[0] = 1 # Add 0 to the tree
    >>> find_closest_permitted_parent(0, tree, permitted) # 0 -> 1 -> 2 (permitted)
    2
    >>> tree = {10: 20, 20: 30, 30: 40}
    >>> find_closest_permitted_parent(10, tree, {50, 60}) # No permitted ancestors, returns None
    """
    if node not in tree:
        return None
    parent = tree[node]
    while parent not in permitted_nodes:
        if parent in tree:
            parent = tree[parent]
        else:
            return None
    return parent

get_ancestor_chain_lens(tree)

Get lengths of ancestor chains in a { child: parent } dictionary tree

Examples:

>>> get_ancestor_chain_lens({ 0:1, 1:2, 2:3, 4:5, 5:6, 7:8 })
{3: 1, 2: 2, 1: 3, 0: 4, 6: 1, 5: 2, 4: 3, 8: 1, 7: 2}

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in { child: parent } format.

required

Returns:

Name Type Description
lengths dict[Hashable, int]

The lengths of the path to the root from each node { node: length }

Source code in hierarchical_loss/tree_utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def get_ancestor_chain_lens(tree: dict[Hashable, Hashable]) -> dict[Hashable, int]:
    '''
    Get lengths of ancestor chains in a { child: parent } dictionary tree

    Examples
    --------
    >>> get_ancestor_chain_lens({ 0:1, 1:2, 2:3, 4:5, 5:6, 7:8 })
    {3: 1, 2: 2, 1: 3, 0: 4, 6: 1, 5: 2, 4: 3, 8: 1, 7: 2}

    Parameters
    ----------
    tree: dict[Hashable, Hashable]
        A tree in { child: parent } format.

    Returns
    -------
    lengths: dict[Hashable, int]
        The lengths of the path to the root from each node { node: length }

    '''
    return preorder_apply(tree, _increment_chain_len)

get_roots(tree)

Finds all root nodes in a {child: parent} tree.

A root node is defined as any node that is not a child of another node in the tree (i.e., its ancestor chain length is 1).

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
list[Hashable]

A list of all root nodes.

Examples:

>>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
>>> get_roots(tree) # Roots are 2 and 6
[2, 6]
Source code in hierarchical_loss/tree_utils.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def get_roots(tree: dict[Hashable, Hashable]) -> list[Hashable]:
    """Finds all root nodes in a {child: parent} tree.

    A root node is defined as any node that is not a child of another
    node in the tree (i.e., its ancestor chain length is 1).

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    list[Hashable]
        A list of all root nodes.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
    >>> get_roots(tree) # Roots are 2 and 6
    [2, 6]
    """
    ancestor_chain_lens = get_ancestor_chain_lens(tree)
    return [node for node in ancestor_chain_lens if ancestor_chain_lens[node] == 1]

invert_childparent_tree(tree)

Converts a {child: parent} tree into a nested {parent: {child: ...}} tree.

This function inverts the standard {child: parent} structure, creating a nested dictionary that starts from the root(s). It uses preorder_apply to traverse the tree top-down and build the nested structure.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
dict

A nested dictionary representing the tree in a top-down format, e.g., {root: {child: {grandchild: {}}}}.

Examples:

>>> tree = {0: 1, 1: 2, 3: 2, 5: 6} # 0->1->2, 3->2, 5->6
>>> invert_childparent_tree(tree)
{2: {1: {0: {}}, 3: {}}, 6: {5: {}}}
Source code in hierarchical_loss/tree_utils.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def invert_childparent_tree(tree: dict[Hashable, Hashable]) -> dict:
    """Converts a {child: parent} tree into a nested {parent: {child: ...}} tree.

    This function inverts the standard {child: parent} structure, creating
    a nested dictionary that starts from the root(s). It uses
    `preorder_apply` to traverse the tree top-down and build the
    nested structure.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    dict
        A nested dictionary representing the tree in a top-down format,
        e.g., `{root: {child: {grandchild: {}}}}`.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 2, 5: 6} # 0->1->2, 3->2, 5->6
    >>> invert_childparent_tree(tree)
    {2: {1: {0: {}}, 3: {}}, 6: {5: {}}}
    """
    parentchild_tree = {}
    preorder_apply(tree, _append_to_parentchild_tree, parentchild_tree)
    return parentchild_tree

optimal_hierarchical_path(class_scores, inverted_tree, roots)

Finds optimal paths and extracts their corresponding scores.

This function wraps get_optimal_ancestral_chain to find the single best greedy path for each detection, and then gathers the raw scores associated with each node in those paths.

Parameters:

Name Type Description Default
class_scores list[Tensor]

A list of confidence tensors, one per batch item. Each tensor should have shape (C, N), where C is the number of classes and N is the number of detections.

required
inverted_tree dict[int, Tensor]

The class hierarchy in {parent_id: tensor([child1, child2, ...])} format.

required
roots Tensor

A 1D tensor containing the integer IDs of the root nodes.

required

Returns:

Type Description
tuple[list[list[list[int]]], list[list[Tensor]]]

A tuple containing two items: 1. optimal_paths: The nested list of paths, as returned by get_optimal_ancestral_chain. 2. optimal_path_scores: A nested list of the same structure, but containing 1D tensors of the scores for each path.

Examples:

>>> hierarchy = {1: 0, 2: 0, 3: 1, 4: 1, 5: 2, 6: 2}
>>> # C=7 classes, N=2 detections, B=1 batch item
>>> # Scores are shaped (C, N)
>>> scores = torch.tensor([
...     [10., 10.],  # 0 (Root)
...     [ 5.,  1.],  # 1 (Child of 0)
...     [ 1.,  5.],  # 2 (Child of 0)
...     [ 2.,  0.],  # 3 (Child of 1)
...     [ 8.,  0.],  # 4 (Child of 1)
...     [ 0.,  8.],  # 5 (Child of 2)
...     [ 0.,  2.]   # 6 (Child of 2)
... ], dtype=torch.float32)
>>> class_scores = [scores]
>>> inverted_tree = construct_parent_childtensor_tree(hierarchy, device=class_scores[0].device)
>>> roots = torch.tensor(get_roots(hierarchy), device=class_scores[0].device)
>>> paths, path_scores = optimal_hierarchical_path(class_scores, inverted_tree, roots)
>>> paths
[[[0, 1, 4], [0, 2, 5]]]
>>> path_scores
[[tensor([10.,  5.,  8.]), tensor([10.,  5.,  8.])]]
Source code in hierarchical_loss/path_utils.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def optimal_hierarchical_path(class_scores: list[torch.Tensor], inverted_tree: dict[int, torch.Tensor], roots: torch.Tensor) -> tuple[list[list[list[int]]], list[list[torch.Tensor]]]:
    """Finds optimal paths and extracts their corresponding scores.

    This function wraps `get_optimal_ancestral_chain` to find the
    single best greedy path for each detection, and then gathers
    the raw scores associated with each node in those paths.

    Parameters
    ----------
    class_scores : list[torch.Tensor]
        A list of confidence tensors, one per batch item. Each tensor
        should have shape (C, N), where C is the number of classes
        and N is the number of detections.
    inverted_tree : dict[int, torch.Tensor]
        The class hierarchy in `{parent_id: tensor([child1, child2, ...])}`
        format.
    roots : torch.Tensor
        A 1D tensor containing the integer IDs of the root nodes.

    Returns
    -------
    tuple[list[list[list[int]]], list[list[torch.Tensor]]]
        A tuple containing two items:
        1. `optimal_paths`: The nested list of paths, as returned
           by `get_optimal_ancestral_chain`.
        2. `optimal_path_scores`: A nested list of the same structure,
           but containing 1D tensors of the scores for each path.

    Examples
    --------
    >>> hierarchy = {1: 0, 2: 0, 3: 1, 4: 1, 5: 2, 6: 2}
    >>> # C=7 classes, N=2 detections, B=1 batch item
    >>> # Scores are shaped (C, N)
    >>> scores = torch.tensor([
    ...     [10., 10.],  # 0 (Root)
    ...     [ 5.,  1.],  # 1 (Child of 0)
    ...     [ 1.,  5.],  # 2 (Child of 0)
    ...     [ 2.,  0.],  # 3 (Child of 1)
    ...     [ 8.,  0.],  # 4 (Child of 1)
    ...     [ 0.,  8.],  # 5 (Child of 2)
    ...     [ 0.,  2.]   # 6 (Child of 2)
    ... ], dtype=torch.float32)
    >>> class_scores = [scores]
    >>> inverted_tree = construct_parent_childtensor_tree(hierarchy, device=class_scores[0].device)
    >>> roots = torch.tensor(get_roots(hierarchy), device=class_scores[0].device)
    >>> paths, path_scores = optimal_hierarchical_path(class_scores, inverted_tree, roots)
    >>> paths
    [[[0, 1, 4], [0, 2, 5]]]
    >>> path_scores
    [[tensor([10.,  5.,  8.]), tensor([10.,  5.,  8.])]]
    """
    bpaths = []
    bscores = []
    for b, confidence in enumerate(class_scores):
        paths = []
        scores = []
        for i in range(confidence.shape[1]):
            confidence_row = confidence[..., i]
            path = []
            path_score = []
            siblings = roots
            while siblings is not None:
                best = confidence_row.index_select(0, siblings).argmax()
                best_node_id = int(siblings[best])
                path.append(best_node_id)
                path_score.append(confidence_row[best_node_id])
                siblings = inverted_tree[best_node_id] if best_node_id in inverted_tree else None
            paths.append(path)
            scores.append(torch.stack(path_score))
        bpaths.append(paths)
        bscores.append(scores)
    return bpaths, bscores

optimal_hierarchical_paths(class_scores, hierarchy)

.. deprecated:: 0.X.X This function is deprecated as it re-computes the hierarchy on every call, causing a performance bottleneck. Use a Hierarchy object to pre-compute the inverted_tree and roots, and then call optimal_hierarchical_path directly.

Source code in hierarchical_loss/path_utils.py
329
330
331
332
333
334
335
336
337
338
339
def optimal_hierarchical_paths(class_scores: list[torch.Tensor], hierarchy: dict[int, int]) -> tuple[list[list[list[int]]], list[list[torch.Tensor]]]:
    """
    .. deprecated:: 0.X.X
       This function is deprecated as it re-computes the hierarchy
       on every call, causing a performance bottleneck.
       Use a `Hierarchy` object to pre-compute the `inverted_tree`
       and `roots`, and then call `optimal_hierarchical_path` directly.
    """
    inverted_tree = construct_parent_childtensor_tree(hierarchy, device=class_scores[0].device)
    roots = torch.tensor(get_roots(hierarchy), device=class_scores[0].device)
    return optimal_hierarchical_path(class_scores, inverted_tree, roots)

preorder_apply(tree, f, *args)

Applies a function to all nodes in a tree in a pre-order (top-down) fashion.

This function works by first finding an ancestor path (from leaf to root). It then applies the function f to the root (or highest unvisited node) and iterates down the path, applying f to each child and passing in the result from its parent. This top-down application is a pre-order traversal.

It uses memoization (the visited dict) to ensure that f is applied to each node only once, even in multi-branch trees.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

The hierarchy tree, in {child: parent} format.

required
f Callable

The function to apply to each node. Its signature must be f(node: Hashable, parent_result: Any, *args: Any) -> Any.

required
*args Any

Additional positional arguments to be passed to every call of f.

()

Returns:

Type Description
dict[Hashable, Any]

A dictionary mapping each node ID to the result of f(node, ...).

Examples:

>>> # Example: Calculate node depth (pre-order calculation)
>>> tree = {0: 1, 1: 2, 3: 2} # 0->1->2, 3->2
>>> def f(node, parent_depth):
...     # parent_depth is the result from the parent node
...     return 1 if parent_depth is None else parent_depth + 1
...
>>> preorder_apply(tree, f)
{2: 1, 1: 2, 0: 3, 3: 2}
Source code in hierarchical_loss/tree_utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def preorder_apply(tree: dict[Hashable, Hashable], f: Callable, *args: Any) -> dict[Hashable, Any]:
    """Applies a function to all nodes in a tree in a pre-order (top-down) fashion.

    This function works by first finding an ancestor path (from leaf to root).
    It then applies the function `f` to the root (or highest unvisited node)
    and iterates *down* the path, applying `f` to each child and passing in
    the result from its parent. This top-down application is a pre-order
    traversal.

    It uses memoization (the `visited` dict) to ensure that `f` is
    applied to each node only once, even in multi-branch trees.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        The hierarchy tree, in {child: parent} format.
    f : Callable
        The function to apply to each node. Its signature must be
        `f(node: Hashable, parent_result: Any, *args: Any) -> Any`.
    *args: Any
        Additional positional arguments to be passed to every call of `f`.

    Returns
    -------
    dict[Hashable, Any]
        A dictionary mapping each node ID to the result of `f(node, ...)`.

    Examples
    --------
    >>> # Example: Calculate node depth (pre-order calculation)
    >>> tree = {0: 1, 1: 2, 3: 2} # 0->1->2, 3->2
    >>> def f(node, parent_depth):
    ...     # parent_depth is the result from the parent node
    ...     return 1 if parent_depth is None else parent_depth + 1
    ...
    >>> preorder_apply(tree, f)
    {2: 1, 1: 2, 0: 3, 3: 2}
    """
    visited = {}
    for node in tree:
        path = [node]
        while (node in tree) and (node not in visited):
            node = tree[node]
            path.append(node)
        if node not in visited:
            visited[node] = f(node, None, *args)
        for i in range(-2, -len(path) - 1, -1):
            visited[path[i]] = f(path[i], visited[path[i+1]], *args)
    return visited

tree_walk(tree, node)

Walks up the ancestor chain from a starting node.

This generator yields the starting node first, then its parent, its grandparent, and so on, until a root (a node not present as a key in the tree) is reached.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

The hierarchy tree, in {child: parent} format.

required
node Hashable

The node to start the walk from.

required

Yields:

Type Description
Iterator[Hashable]

An iterator of node IDs in the ancestor chain, starting with the given node.

Examples:

>>> tree = {0: 1, 1: 2, 3: 4, 4: 2}
>>> list(tree_walk(tree, 0))
[0, 1, 2]
>>> list(tree_walk(tree, 3))
[3, 4, 2]
>>> list(tree_walk(tree, 2))
[2]
Source code in hierarchical_loss/tree_utils.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def tree_walk(tree: dict[Hashable, Hashable], node: Hashable) -> Iterator[Hashable]:
    """Walks up the ancestor chain from a starting node.

    This generator yields the starting node first, then its parent,
    its grandparent, and so on, until a root (a node not
    present as a key in the tree) is reached.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        The hierarchy tree, in {child: parent} format.
    node : Hashable
        The node to start the walk from.

    Yields
    ------
    Iterator[Hashable]
        An iterator of node IDs in the ancestor chain, starting
        with the given node.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 4, 4: 2}
    >>> list(tree_walk(tree, 0))
    [0, 1, 2]
    >>> list(tree_walk(tree, 3))
    [3, 4, 2]
    >>> list(tree_walk(tree, 2))
    [2]
    """
    yield node
    while node in tree:
        node = tree[node]
        yield node

trim_childparent_tree(tree, permitted_nodes)

Trims a {child: parent} tree to only include permitted nodes.

This function first remaps every node in the tree to its closest permitted ancestor. It then filters this map, keeping only the entries where the node (the key) is also in the permitted_nodes set.

The result is a new {child: parent} tree containing only permitted nodes, mapped to their closest permitted ancestor (which will be another permitted node or None).

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required
permitted_nodes set[Hashable]

A set of node IDs to keep.

required

Returns:

Type Description
dict[Hashable, Hashable | None]

A new {child: parent} tree containing only permitted nodes, each re-mapped to its closest permitted ancestor.

Examples:

>>> tree = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5} # 0->1->2->3->4->5
>>> permitted = {0, 2, 5} # 0, 2, and 5 are permitted
>>> trim_childparent_tree(tree, permitted)
{0: 2, 2: 5}
Source code in hierarchical_loss/tree_utils.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def trim_childparent_tree(
    tree: dict[Hashable, Hashable], permitted_nodes: set[Hashable]
) -> dict[Hashable, Hashable | None]:
    """Trims a {child: parent} tree to only include permitted nodes.

    This function first remaps every node in the tree to its closest
    permitted ancestor. It then filters this map, keeping only the
    entries where the node (the key) is *also* in the `permitted_nodes`
    set.

    The result is a new {child: parent} tree containing *only*
    permitted nodes, mapped to their closest permitted ancestor
    (which will be another permitted node or None).

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.
    permitted_nodes : set[Hashable]
        A set of node IDs to keep.

    Returns
    -------
    dict[Hashable, Hashable | None]
        A new {child: parent} tree containing only permitted nodes,
        each re-mapped to its closest permitted ancestor.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5} # 0->1->2->3->4->5
    >>> permitted = {0, 2, 5} # 0, 2, and 5 are permitted
    >>> trim_childparent_tree(tree, permitted)
    {0: 2, 2: 5}
    """
    new_tree = {}
    for node in tree:
        closest_permitted_parent = find_closest_permitted_parent(node, tree, permitted_nodes)
        new_tree[node] = closest_permitted_parent
    for node in list(new_tree.keys()):
        if new_tree[node] is None or (node not in permitted_nodes):
            del new_tree[node]
    return new_tree

truncate_path_conditionals(path, score, threshold=0.25)

Truncates a path based on a conditional probability threshold.

This function iterates through a path and its corresponding conditional probabilities, stopping at the first element where the probability is below the given threshold.

Parameters:

Name Type Description Default
path list[int]

A list of category indices representing the path.

required
score Tensor

A 1D tensor where each element is the conditional probability of the corresponding category in the path.

required
threshold float

The probability threshold below which to truncate, by default 0.25.

0.25

Returns:

Type Description
tuple[list[int], Tensor]

A tuple containing the truncated path and its corresponding scores.

Examples:

>>> import torch
>>> path = [4, 7]
>>> score = torch.tensor([0.5412, 0.4371])
>>> truncate_path_conditionals(path, score, threshold=0.589)
([], tensor([]))
>>> path = [4, 2]
>>> score = torch.tensor([0.9896, 0.5891])
>>> truncate_path_conditionals(path, score, threshold=0.589)
([4, 2], tensor([0.9896, 0.5891]))
Source code in hierarchical_loss/path_utils.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def truncate_path_conditionals(path: list[int], score: torch.Tensor, threshold: float = 0.25) -> tuple[list[int], torch.Tensor]:
    """Truncates a path based on a conditional probability threshold.

    This function iterates through a path and its corresponding conditional
    probabilities, stopping at the first element where the probability
    is below the given threshold.

    Parameters
    ----------
    path : list[int]
        A list of category indices representing the path.
    score : torch.Tensor
        A 1D tensor where each element is the conditional probability
        of the corresponding category in the path.
    threshold : float, optional
        The probability threshold below which to truncate, by default 0.25.

    Returns
    -------
    tuple[list[int], torch.Tensor]
        A tuple containing the truncated path and its corresponding scores.

    Examples
    --------
    >>> import torch
    >>> path = [4, 7]
    >>> score = torch.tensor([0.5412, 0.4371])
    >>> truncate_path_conditionals(path, score, threshold=0.589)
    ([], tensor([]))
    >>> path = [4, 2]
    >>> score = torch.tensor([0.9896, 0.5891])
    >>> truncate_path_conditionals(path, score, threshold=0.589)
    ([4, 2], tensor([0.9896, 0.5891]))
    """
    truncated_path, truncated_score = [], []
    for category, p in zip(path, score):
        if p < threshold:
            break
        truncated_path.append(category)
    return truncated_path, score[:len(truncated_path)]

truncate_path_marginals(path, score, threshold=0.25)

Truncates a path based on a marginal probability threshold.

This function iterates through a path, calculating the cumulative product (marginal probability) of the scores. It stops at the first element where this cumulative product falls below the given threshold.

Parameters:

Name Type Description Default
path list[int]

A list of category indices representing the path.

required
score Tensor

A 1D tensor where each element is the conditional probability of the corresponding category in the path.

required
threshold float

The probability threshold below which to truncate, by default 0.25.

0.25

Returns:

Type Description
tuple[list[int], Tensor]

A tuple containing the truncated path and its corresponding scores.

Examples:

>>> import torch
>>> path = [4, 2]
>>> score = torch.tensor([0.9896, 0.5891])
>>> truncate_path_marginals(path, score, threshold=0.589)
([4], tensor([0.9896]))
>>> path = [4, 6]
>>> score = torch.tensor([0.9246, 0.7684])
>>> truncate_path_marginals(path, score, threshold=0.589)
([4, 6], tensor([0.9246, 0.7684]))
Source code in hierarchical_loss/path_utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def truncate_path_marginals(path: list[int], score: torch.Tensor, threshold: float = 0.25) -> tuple[list[int], torch.Tensor]:
    """Truncates a path based on a marginal probability threshold.

    This function iterates through a path, calculating the cumulative
    product (marginal probability) of the scores. It stops at the first
    element where this cumulative product falls below the given threshold.

    Parameters
    ----------
    path : list[int]
        A list of category indices representing the path.
    score : torch.Tensor
        A 1D tensor where each element is the conditional probability
        of the corresponding category in the path.
    threshold : float, optional
        The probability threshold below which to truncate, by default 0.25.

    Returns
    -------
    tuple[list[int], torch.Tensor]
        A tuple containing the truncated path and its corresponding scores.

    Examples
    --------
    >>> import torch
    >>> path = [4, 2]
    >>> score = torch.tensor([0.9896, 0.5891])
    >>> truncate_path_marginals(path, score, threshold=0.589)
    ([4], tensor([0.9896]))
    >>> path = [4, 6]
    >>> score = torch.tensor([0.9246, 0.7684])
    >>> truncate_path_marginals(path, score, threshold=0.589)
    ([4, 6], tensor([0.9246, 0.7684]))
    """
    truncated_path, truncated_score = [], []
    marginal_p = 1
    for category, p in zip(path, score):
        marginal_p *= p
        if marginal_p < threshold:
            break
        truncated_path.append(category)
    return truncated_path, score[:len(truncated_path)]

truncate_paths_conditionals(predicted_paths, predicted_path_scores, threshold=0.25)

Applies conditional probability truncation to a list of paths.

This function iterates through lists of paths and scores, applying the truncate_path_conditionals function to each path-score pair.

Parameters:

Name Type Description Default
predicted_paths list[list[int]]

A list of paths, where each path is a list of category indices.

required
predicted_path_scores list[Tensor]

A list of 1D tensors, each corresponding to a path in predicted_paths.

required
threshold float

The probability threshold to pass to the truncation function, by default 0.25.

0.25

Returns:

Type Description
tuple[list[list[int]], list[Tensor]]

A tuple containing the list of truncated paths and the list of their corresponding truncated scores.

Examples:

>>> import torch
>>> paths = [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]
>>> scores = [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]
>>> tpaths, tscores = truncate_paths_conditionals(paths, scores, threshold=0.589)
>>> tpaths
[[4, 2], [4, 6], [4, 5], [], []]
>>> tscores
[tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]
Source code in hierarchical_loss/path_utils.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def truncate_paths_conditionals(predicted_paths: list[list[int]], predicted_path_scores: list[torch.Tensor], threshold: float = 0.25) -> tuple[list[list[int]], list[torch.Tensor]]:
    """Applies conditional probability truncation to a list of paths.

    This function iterates through lists of paths and scores, applying
    the `truncate_path_conditionals` function to each path-score pair.

    Parameters
    ----------
    predicted_paths : list[list[int]]
        A list of paths, where each path is a list of category indices.
    predicted_path_scores : list[torch.Tensor]
        A list of 1D tensors, each corresponding to a path in `predicted_paths`.
    threshold : float, optional
        The probability threshold to pass to the truncation function,
        by default 0.25.

    Returns
    -------
    tuple[list[list[int]], list[torch.Tensor]]
        A tuple containing the list of truncated paths and the list of
        their corresponding truncated scores.

    Examples
    --------
    >>> import torch
    >>> paths = [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]
    >>> scores = [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]
    >>> tpaths, tscores = truncate_paths_conditionals(paths, scores, threshold=0.589)
    >>> tpaths
    [[4, 2], [4, 6], [4, 5], [], []]
    >>> tscores
    [tensor([0.9896, 0.5891]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]
    """
    tpaths, tscores = [], []
    for paths, scores in zip(predicted_paths, predicted_path_scores):
        tpath, tscore = truncate_path_conditionals(paths, scores, threshold=threshold)
        tpaths.append(tpath), tscores.append(tscore)
    return tpaths, tscores

truncate_paths_marginals(predicted_paths, predicted_path_scores, threshold=0.25)

Applies marginal probability truncation to a list of paths.

This function iterates through lists of paths and scores, applying the truncate_path_marginals function to each path-score pair.

Parameters:

Name Type Description Default
predicted_paths list[list[int]]

A list of paths, where each path is a list of category indices.

required
predicted_path_scores list[Tensor]

A list of 1D tensors, each corresponding to a path in predicted_paths.

required
threshold float

The probability threshold to pass to the truncation function, by default 0.25.

0.25

Returns:

Type Description
tuple[list[list[int]], list[Tensor]]

A tuple containing the list of truncated paths and the list of their corresponding truncated scores.

Examples:

>>> import torch
>>> paths = [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]
>>> scores = [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]
>>> tpaths, tscores = truncate_paths_marginals(paths, scores, threshold=0.589)
>>> tpaths
[[4], [4, 6], [4, 5], [], []]
>>> tscores
[tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]
Source code in hierarchical_loss/path_utils.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def truncate_paths_marginals(predicted_paths: list[list[int]], predicted_path_scores: list[torch.Tensor], threshold: float = 0.25) -> tuple[list[list[int]], list[torch.Tensor]]:
    """Applies marginal probability truncation to a list of paths.

    This function iterates through lists of paths and scores, applying
    the `truncate_path_marginals` function to each path-score pair.

    Parameters
    ----------
    predicted_paths : list[list[int]]
        A list of paths, where each path is a list of category indices.
    predicted_path_scores : list[torch.Tensor]
        A list of 1D tensors, each corresponding to a path in `predicted_paths`.
    threshold : float, optional
        The probability threshold to pass to the truncation function,
        by default 0.25.

    Returns
    -------
    tuple[list[list[int]], list[torch.Tensor]]
        A tuple containing the list of truncated paths and the list of
        their corresponding truncated scores.

    Examples
    --------
    >>> import torch
    >>> paths = [[4, 2], [4, 6], [4, 5], [4, 7], [4, 2]]
    >>> scores = [torch.tensor([0.9896, 0.5891]), torch.tensor([0.9246, 0.7684]), torch.tensor([0.8949, 0.8765]), torch.tensor([0.5412, 0.4371]), torch.tensor([0.5001, 0.0830])]
    >>> tpaths, tscores = truncate_paths_marginals(paths, scores, threshold=0.589)
    >>> tpaths
    [[4], [4, 6], [4, 5], [], []]
    >>> tscores
    [tensor([0.9896]), tensor([0.9246, 0.7684]), tensor([0.8949, 0.8765]), tensor([]), tensor([])]
    """
    tpaths, tscores = [], []
    for paths, scores in zip(predicted_paths, predicted_path_scores):
        tpath, tscore = truncate_path_marginals(paths, scores, threshold=threshold)
        tpaths.append(tpath), tscores.append(tscore)
    return tpaths, tscores