Skip to content

Utilities for working with hierarchies as tensors

accumulate_hierarchy(predictions, hierarchy_index, reduce_op, identity_value)

Performs a reduce operation along a hierarchical structure.

This function applies a reduction operation (e.g., torch.sum, torch.max) along each ancestral path in a hierarchy. The implementation is fully vectorized. It first gathers the initial values for all nodes along each path, replaces padded values with the identity_value, and then applies the reduce_op along the path dimension.

Parameters:

Name Type Description Default
predictions Tensor

A tensor of shape [B, D, N], where B is the batch size, D is the number of detections, and N is the number of classes.

required
hierarchy_index Tensor

An int tensor of shape [N, M] encoding the hierarchy, where N is the number of classes and M is the maximum hierarchy depth. Each row i contains the path from node i to its root. Parent node IDs are to the right of child node IDs. A value of -1 is used for padding.

required
reduce_op Callable[[Tensor, int], Tensor]

A function that performs a reduction operation along a dimension, such as torch.sum or torch.max. It must accept a tensor and a dim argument, and return a tensor.

required
identity_value float | int

The identity value for the reduction operation. For example, 0.0 for torch.sum or -torch.inf for torch.max.

required

Returns:

Type Description
Tensor

A new tensor with the same shape as predictions (but with the last dimension, M, reduced) containing the aggregated values along each path.

Examples:

>>> hierarchy_index = torch.tensor([
...     [ 0,  1,  2],
...     [ 1,  2, -1],
...     [ 2, -1, -1],
...     [ 3,  4, -1],
...     [ 4, -1, -1]
... ], dtype=torch.int64)
>>> # Predictions for 5 classes: [0., 10., 20., 30., 40.]
>>> predictions = torch.arange(0, 50, 10, dtype=torch.float32).view(1, 1, 5)
>>>
>>> # Example 1: Hierarchical Sum
>>> # Path 0: [0, 1, 2] -> 0. + 10. + 20. = 30.
>>> # Path 1: [1, 2]   -> 10. + 20. = 30.
>>> # Path 2: [2]      -> 20. = 20.
>>> # Path 3: [3, 4]   -> 30. + 40. = 70.
>>> # Path 4: [4]      -> 40. = 40.
>>> sum_preds = accumulate_hierarchy(predictions, hierarchy_index, torch.sum, 0.0)
>>> print(sum_preds.squeeze())
tensor([30., 30., 20., 70., 40.])
>>>
>>> # Example 2: Hierarchical Max
>>> # Path 0: [0, 1, 2] -> max(0., 10., 20.) = 20.
>>> # Path 1: [1, 2]   -> max(10., 20.) = 20.
>>> # Path 2: [2]      -> max(20.) = 20.
>>> # Path 3: [3, 4]   -> max(30., 40.) = 40.
>>> # Path 4: [4]      -> max(40.) = 40.
>>> max_op = lambda x, dim: torch.max(x, dim=dim).values
>>> max_preds = accumulate_hierarchy(predictions, hierarchy_index, max_op, -torch.inf)
>>> print(max_preds.squeeze())
tensor([20., 20., 20., 40., 40.])
Source code in hierarchical_loss/hierarchy_tensor_utils.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def accumulate_hierarchy(
    predictions: torch.Tensor,
    hierarchy_index: torch.Tensor,
    reduce_op: Callable[[torch.Tensor, int], torch.Tensor],
    identity_value: float | int,
) -> torch.Tensor:
    """Performs a reduce operation along a hierarchical structure.

    This function applies a reduction operation (e.g., `torch.sum`,
    `torch.max`) along each ancestral path in a hierarchy. The implementation
    is fully vectorized. It first gathers the initial values for all
    nodes along each path, replaces padded values with the `identity_value`,
    and then applies the `reduce_op` along the path dimension.

    Parameters
    ----------
    predictions : torch.Tensor
        A tensor of shape `[B, D, N]`, where `B` is the batch size, `D` is the
        number of detections, and `N` is the number of classes.
    hierarchy_index : torch.Tensor
        An int tensor of shape `[N, M]` encoding the hierarchy, where `N` is the
        number of classes and `M` is the maximum hierarchy depth. Each row `i`
        contains the path from node `i` to its root. Parent node IDs are to
        the right of child node IDs. A value of -1 is used for padding.
    reduce_op : Callable[[torch.Tensor, int], torch.Tensor]
        A function that performs a reduction operation along a dimension,
        such as `torch.sum` or `torch.max`. It must accept a tensor
        and a `dim` argument, and return a tensor.
    identity_value : float | int
        The identity value for the reduction operation. For example,
        `0.0` for `torch.sum` or `-torch.inf` for `torch.max`.

    Returns
    -------
    torch.Tensor
        A new tensor with the same shape as `predictions` (but with the
        last dimension, M, reduced) containing the aggregated values
        along each path.

    Examples
    --------
    >>> hierarchy_index = torch.tensor([
    ...     [ 0,  1,  2],
    ...     [ 1,  2, -1],
    ...     [ 2, -1, -1],
    ...     [ 3,  4, -1],
    ...     [ 4, -1, -1]
    ... ], dtype=torch.int64)
    >>> # Predictions for 5 classes: [0., 10., 20., 30., 40.]
    >>> predictions = torch.arange(0, 50, 10, dtype=torch.float32).view(1, 1, 5)
    >>>
    >>> # Example 1: Hierarchical Sum
    >>> # Path 0: [0, 1, 2] -> 0. + 10. + 20. = 30.
    >>> # Path 1: [1, 2]   -> 10. + 20. = 30.
    >>> # Path 2: [2]      -> 20. = 20.
    >>> # Path 3: [3, 4]   -> 30. + 40. = 70.
    >>> # Path 4: [4]      -> 40. = 40.
    >>> sum_preds = accumulate_hierarchy(predictions, hierarchy_index, torch.sum, 0.0)
    >>> print(sum_preds.squeeze())
    tensor([30., 30., 20., 70., 40.])
    >>>
    >>> # Example 2: Hierarchical Max
    >>> # Path 0: [0, 1, 2] -> max(0., 10., 20.) = 20.
    >>> # Path 1: [1, 2]   -> max(10., 20.) = 20.
    >>> # Path 2: [2]      -> max(20.) = 20.
    >>> # Path 3: [3, 4]   -> max(30., 40.) = 40.
    >>> # Path 4: [4]      -> max(40.) = 40.
    >>> max_op = lambda x, dim: torch.max(x, dim=dim).values
    >>> max_preds = accumulate_hierarchy(predictions, hierarchy_index, max_op, -torch.inf)
    >>> print(max_preds.squeeze())
    tensor([20., 20., 20., 40., 40.])
    """
    B, D, N = predictions.shape
    M = hierarchy_index.shape[1]

    # 1. GATHER: Collect prediction values for each node in each path.
    # Create a mask for valid indices (non -1)
    valid_mask = hierarchy_index != -1

    # Create a "safe" index tensor to prevent out-of-bounds errors from -1.
    # We replace -1 with a valid index (e.g., 0) and will zero out its
    # contribution later using the mask.
    safe_index = hierarchy_index.masked_fill(~valid_mask, 0)

    # Use advanced indexing to gather values. `predictions[:, :, safe_index]`
    # creates a tensor of shape [B, D, N, M].
    path_values = predictions[:, :, safe_index]

    # Replace the invalid, padded values with the appropriate identity value.
    # The valid_mask broadcasts from [N, M] to [B, D, N, M].
    path_values = torch.where(
        valid_mask,
        path_values,
        identity_value
    )

    # 2. ACCUMULATE: Apply the reduction operation along the path dimension.
    final_values = reduce_op(path_values, -1)

    return final_values

build_hierarchy_index_tensor(hierarchy, device=None)

Creates a 2D tensor mapping each node to its full ancestor path.

This function translates a {child: parent} dictionary hierarchy into a 2D tensor. The hierarchy MUST BE DENSE, in the sense that the keys and values must run from 0 to C-1 where C is the number of nodes. Each row i of the tensor corresponds to node i. The row contains the full ancestor path starting with the node itself: [node_id, parent_id, grandparent_id, ..., root_id].

The paths are right-padded with -1 to the length of the longest ancestor path in the hierarchy.

This tensor is used as an index for hierarchical accumulation operations.

Parameters:

Name Type Description Default
hierarchy dict[int, int]

A tree in {child: parent} format. Node IDs must be non-negative integers that can be used as tensor indices.

required
device device | str | None

The desired device for the output tensor. If None, uses the default PyTorch device. By default None.

None

Returns:

Type Description
Tensor

A 2D tensor of shape (C, M), where M is the maximum hierarchy depth and C is the number of categories (nodes in the tree). tensor[i] contains the ancestor path for node i, padded with -1.

Examples:

>>> hierarchy = {0: 1, 1: 2, 3: 4}
>>> # Nodes found: {0, 1, 2, 3, 4} -> len=5
>>> # Max depth: 3 (for node 0)
>>> build_hierarchy_index_tensor(hierarchy)
tensor([[ 0,  1,  2],
        [ 1,  2, -1],
        [ 2, -1, -1],
        [ 3,  4, -1],
        [ 4, -1, -1]], dtype=torch.int32)
Source code in hierarchical_loss/hierarchy_tensor_utils.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def build_hierarchy_index_tensor(
    hierarchy: dict[int, int], device: torch.device | str | None = None
) -> torch.Tensor:
    """Creates a 2D tensor mapping each node to its full ancestor path.

    This function translates a {child: parent} dictionary hierarchy into a
    2D tensor. The hierarchy MUST BE DENSE, in the sense that the keys and values
    must run from 0 to C-1 where C is the number of nodes.
    Each row `i` of the tensor corresponds to node `i`. The
    row contains the full ancestor path starting with the node itself:
    `[node_id, parent_id, grandparent_id, ..., root_id]`.

    The paths are right-padded with -1 to the length of the longest
    ancestor path in the hierarchy.

    This tensor is used as an index for hierarchical accumulation operations.

    Parameters
    ----------
    hierarchy : dict[int, int]
        A tree in {child: parent} format. Node IDs must be non-negative
        integers that can be used as tensor indices.
    device : torch.device | str | None, optional
        The desired device for the output tensor. If `None`, uses the
        default PyTorch device. By default `None`.

    Returns
    -------
    torch.Tensor
        A 2D tensor of shape `(C, M)`, where `M` is the maximum hierarchy
        depth and `C` is the number of categories (nodes in the tree).
        `tensor[i]` contains the ancestor path for node `i`, padded with -1.

    Examples
    --------
    >>> hierarchy = {0: 1, 1: 2, 3: 4}
    >>> # Nodes found: {0, 1, 2, 3, 4} -> len=5
    >>> # Max depth: 3 (for node 0)
    >>> build_hierarchy_index_tensor(hierarchy)
    tensor([[ 0,  1,  2],
            [ 1,  2, -1],
            [ 2, -1, -1],
            [ 3,  4, -1],
            [ 4, -1, -1]], dtype=torch.int32)
    """
    lens = get_ancestor_chain_lens(hierarchy)
    index_tensor = torch.full((len(lens), max(lens.values())), -1, dtype=torch.int32, device=device)
    preorder_apply(hierarchy, set_indices, index_tensor)
    return index_tensor

build_hierarchy_sibling_mask(parent_tensor, device=None)

Creates a boolean mask identifying sibling groups from a parent tensor.

This function is used to prepare a mask for utils.logsumexp_over_siblings. It takes the 1D parent tensor (where parent_tensor[i] = parent_id) and creates a 2D mask.

Each column g in the mask represents a unique sibling group (i.e., a unique parent, including -1 for the root group). A node i will have True in column g if its parent is the parent corresponding to sibling group g.

Parameters:

Name Type Description Default
parent_tensor Tensor

A 1D tensor of shape (C,), where C is the number of classes. parent_tensor[i] contains the integer ID of the parent of node i, or -1 for root nodes. See the build_parent_tensor function.

required
device device | str | None

The desired device for the output tensor. If None, uses the default PyTorch device. By default None.

None

Returns:

Type Description
Tensor

A 2D boolean tensor of shape (C, G), where G is the number of unique parent groups (including roots). mask[i, g] is True if node i belongs to sibling group g.

Examples:

>>> # Node parents: 0->1, 1->2, 2->-1, 3->2, 4->-1, 5->6, 6->-1
>>> parent_tensor = torch.tensor([ 1,  2, -1,  2, -1,  6, -1])
>>> # Unique parents (groups): -1, 1, 2, 6
>>> build_hierarchy_sibling_mask(parent_tensor)
tensor([[False,  True, False, False],
        [False, False,  True, False],
        [ True, False, False, False],
        [False, False,  True, False],
        [ True, False, False, False],
        [False, False, False,  True],
        [ True, False, False, False]])
Source code in hierarchical_loss/hierarchy_tensor_utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def build_hierarchy_sibling_mask(
    parent_tensor: torch.Tensor, device: torch.device | str | None = None
) -> torch.Tensor:
    """Creates a boolean mask identifying sibling groups from a parent tensor.

    This function is used to prepare a mask for `utils.logsumexp_over_siblings`.
    It takes the 1D parent tensor (where `parent_tensor[i] = parent_id`)
    and creates a 2D mask.

    Each column `g` in the mask represents a unique sibling group (i.e., a
    unique parent, including -1 for the root group). A node `i` will have
    `True` in column `g` if its parent is the parent corresponding to
    sibling group `g`.

    Parameters
    ----------
    parent_tensor : torch.Tensor
        A 1D tensor of shape `(C,)`, where `C` is the number of classes.
        `parent_tensor[i]` contains the integer ID of the parent of node `i`,
        or -1 for root nodes.  See the `build_parent_tensor` function.
    device : torch.device | str | None, optional
        The desired device for the output tensor. If `None`, uses the
        default PyTorch device. By default `None`.

    Returns
    -------
    torch.Tensor
        A 2D boolean tensor of shape `(C, G)`, where `G` is the number of
        unique parent groups (including roots). `mask[i, g]` is `True`
        if node `i` belongs to sibling group `g`.

    Examples
    --------
    >>> # Node parents: 0->1, 1->2, 2->-1, 3->2, 4->-1, 5->6, 6->-1
    >>> parent_tensor = torch.tensor([ 1,  2, -1,  2, -1,  6, -1])
    >>> # Unique parents (groups): -1, 1, 2, 6
    >>> build_hierarchy_sibling_mask(parent_tensor)
    tensor([[False,  True, False, False],
            [False, False,  True, False],
            [ True, False, False, False],
            [False, False,  True, False],
            [ True, False, False, False],
            [False, False, False,  True],
            [ True, False, False, False]])
    """
    C = parent_tensor.shape[0]

    # Identify all unique parents (which uniquely define child groups), including -1 for roots
    unique_parents, inverse_indices = torch.unique(parent_tensor, return_inverse=True)
    G = len(unique_parents)

    sibling_mask = torch.zeros(C, G, dtype=torch.bool, device=device)

    # Assign each node to the column of its parent group
    sibling_mask[torch.arange(C), inverse_indices] = True

    return sibling_mask

build_parent_tensor(tree, device=None)

Converts a {child: parent} dictionary tree into a 1D parent tensor.

This function creates a 1D tensor where the value at each index i is the ID of that node's parent. Nodes that are not children (i.e., roots) will have a value of -1.

The size of the tensor is determined by the maximum node ID present in the tree (in either keys or values).

Parameters:

Name Type Description Default
tree dict[int, int]

A tree in {child: parent} format. Node IDs are assumed to be non-negative integers.

required
device device | str | None

The desired device for the output tensor. If None, uses the default PyTorch device. By default None.

None

Returns:

Type Description
Tensor

A 1D tensor of shape (C,), where C is max(all_node_ids) + 1. parent_tensor[i] contains the ID of the parent of node i, or -1 if i is a root or not in the hierarchy.

Examples:

>>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
>>> # Max node ID is 6, so tensor size is 7
>>> build_parent_tensor(tree)
tensor([ 1,  2, -1,  2, -1,  6, -1])
Source code in hierarchical_loss/hierarchy_tensor_utils.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def build_parent_tensor(
    tree: dict[int, int], device: torch.device | str | None = None
) -> torch.Tensor:
    """Converts a {child: parent} dictionary tree into a 1D parent tensor.

    This function creates a 1D tensor where the value at each index `i`
    is the ID of that node's parent. Nodes that are not children (i.e.,
    roots) will have a value of -1.

    The size of the tensor is determined by the maximum node ID present
    in the tree (in either keys or values).

    Parameters
    ----------
    tree : dict[int, int]
        A tree in {child: parent} format. Node IDs are assumed to be
        non-negative integers.
    device : torch.device | str | None, optional
        The desired device for the output tensor. If `None`, uses the
        default PyTorch device. By default `None`.

    Returns
    -------
    torch.Tensor
        A 1D tensor of shape `(C,)`, where `C` is `max(all_node_ids) + 1`.
        `parent_tensor[i]` contains the ID of the parent of node `i`,
        or -1 if `i` is a root or not in the hierarchy.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
    >>> # Max node ID is 6, so tensor size is 7
    >>> build_parent_tensor(tree)
    tensor([ 1,  2, -1,  2, -1,  6, -1])
    """
    nodes = set(tree.keys()) | set(tree.values())
    C = max(nodes) + 1

    parent_tensor = torch.full((C,), -1, dtype=torch.long, device=device)

    for child, parent in tree.items():
        parent_tensor[child] = parent

    return parent_tensor

expand_target_hierarchy(target, hierarchy_index)

Expands a one-hot target tensor up the hierarchy.

This function takes a target tensor that is "one-hot" along the class dimension (i.e., contains a single non-zero value) and propagates that value to all ancestors of the target class. The implementation is fully vectorized.

Parameters:

Name Type Description Default
target Tensor

A tensor of shape [B, D, N], where B is the batch size, D is the number of detections, and N is the number of classes. It is assumed to be one-hot along the last dimension.

required
hierarchy_index Tensor

An int tensor of shape [N, M] encoding the hierarchy, where N is the number of classes and M is the maximum hierarchy depth. Each row i contains the path from node i to its root.

required

Returns:

Type Description
Tensor

A new tensor with the same shape as target where the target value has been propagated up the hierarchy.

Examples:

>>> import torch
>>> hierarchy_index = torch.tensor([
...     [ 0,  1,  2],
...     [ 1,  2, -1],
...     [ 2, -1, -1],
...     [ 3,  4, -1],
...     [ 4, -1, -1]
... ], dtype=torch.int64)
>>> # Target is one-hot at index 0
>>> target = torch.tensor([0.4, 0., 0., 0., 0.]).view(1, 1, 5)
>>> expanded_target = expand_target_hierarchy(target, hierarchy_index)
>>> print(expanded_target.squeeze())
tensor([0.4000, 0.4000, 0.4000, 0.0000, 0.0000])
>>> target = torch.tensor([0., 0., 0., 0.3, 0.]).view(1, 1, 5)
>>> expanded_target = expand_target_hierarchy(target, hierarchy_index)
>>> print(expanded_target.squeeze())
tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.3000])
Source code in hierarchical_loss/hierarchy_tensor_utils.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def expand_target_hierarchy(
    target: torch.Tensor, hierarchy_index: torch.Tensor
) -> torch.Tensor:
    """Expands a one-hot target tensor up the hierarchy.

    This function takes a target tensor that is "one-hot" along the class
    dimension (i.e., contains a single non-zero value) and propagates that
    value to all ancestors of the target class. The implementation is fully
    vectorized.

    Parameters
    ----------
    target : torch.Tensor
        A tensor of shape `[B, D, N]`, where `B` is the batch size, `D` is the
        number of detections, and `N` is the number of classes. It is assumed
        to be one-hot along the last dimension.
    hierarchy_index : torch.Tensor
        An int tensor of shape `[N, M]` encoding the hierarchy, where `N` is the
        number of classes and `M` is the maximum hierarchy depth. Each row `i`
        contains the path from node `i` to its root.

    Returns
    -------
    torch.Tensor
        A new tensor with the same shape as `target` where the target value has
        been propagated up the hierarchy.

    Examples
    --------
    >>> import torch
    >>> hierarchy_index = torch.tensor([
    ...     [ 0,  1,  2],
    ...     [ 1,  2, -1],
    ...     [ 2, -1, -1],
    ...     [ 3,  4, -1],
    ...     [ 4, -1, -1]
    ... ], dtype=torch.int64)
    >>> # Target is one-hot at index 0
    >>> target = torch.tensor([0.4, 0., 0., 0., 0.]).view(1, 1, 5)
    >>> expanded_target = expand_target_hierarchy(target, hierarchy_index)
    >>> print(expanded_target.squeeze())
    tensor([0.4000, 0.4000, 0.4000, 0.0000, 0.0000])
    >>> target = torch.tensor([0., 0., 0., 0.3, 0.]).view(1, 1, 5)
    >>> expanded_target = expand_target_hierarchy(target, hierarchy_index)
    >>> print(expanded_target.squeeze())
    tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.3000])
    """
    M = hierarchy_index.shape[1]

    # Find the single non-zero value and its index in the target tensor.
    hot_value, hot_index = torch.max(target, dim=-1)

    # Gather the ancestral paths corresponding to the hot indices.
    # The shape will be [B, D, M].
    paths = hierarchy_index[hot_index]

    # Create a mask for valid indices (non -1) to handle padded paths.
    valid_mask = paths != -1

    # Create a "safe" index tensor to prevent out-of-bounds errors from -1.
    # We replace -1 with a valid index (e.g., 0) and will zero out its
    # contribution later using a masked source.
    safe_paths = paths.masked_fill(~valid_mask, 0)
    safe_paths_ints = safe_paths.to(torch.int64)

    # Prepare the source tensor for the scatter operation.
    # It should have the same value (`hot_value`) for all valid path members.
    src_values = hot_value.unsqueeze(-1).expand(-1, -1, M)
    masked_src = src_values * valid_mask.to(src_values.dtype)

    # Create an output tensor and scatter the hot value into all ancestral positions.
    expanded_target = torch.zeros_like(target)
    expanded_target.scatter_(dim=-1, index=safe_paths_ints, src=masked_src)

    return expanded_target

find_closest_permitted_parent(node, tree, permitted_nodes)

Finds the first ancestor of a node that is in a permitted set.

This function walks up the ancestral chain of a node (using the {child: parent} tree) and returns the first ancestor it finds that is present in the permitted_nodes set.

If no ancestor (including the node itself) is in the set, or if the node is not in the tree to begin with, it returns None.

Parameters:

Name Type Description Default
node Hashable

The ID of the node to start searching from.

required
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required
permitted_nodes set[Hashable]

A set of node IDs that are considered "permitted".

required

Returns:

Type Description
Hashable | None

The ID of the closest permitted ancestor, or None if none is found.

Examples:

>>> tree = {1: 2, 2: 3, 3: 4, 4: 5}
>>> permitted = {0, 2, 5}
>>> find_closest_permitted_parent(1, tree, permitted) # 1 -> 2 (permitted)
2
>>> find_closest_permitted_parent(0, tree, permitted) # 0 is not in tree keys, returns None
>>> tree[0] = 1 # Add 0 to the tree
>>> find_closest_permitted_parent(0, tree, permitted) # 0 -> 1 -> 2 (permitted)
2
>>> tree = {10: 20, 20: 30, 30: 40}
>>> find_closest_permitted_parent(10, tree, {50, 60}) # No permitted ancestors, returns None
Source code in hierarchical_loss/tree_utils.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def find_closest_permitted_parent(
    node: Hashable,
    tree: dict[Hashable, Hashable],
    permitted_nodes: set[Hashable],
) -> Hashable | None:
    """Finds the first ancestor of a node that is in a permitted set.

    This function walks up the ancestral chain of a node (using the
    {child: parent} tree) and returns the first ancestor it finds
    that is present in the `permitted_nodes` set.

    If no ancestor (including the node itself) is in the set,
    or if the node is not in the tree to begin with, it returns None.

    Parameters
    ----------
    node : Hashable
        The ID of the node to start searching from.
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.
    permitted_nodes : set[Hashable]
        A set of node IDs that are considered "permitted".

    Returns
    -------
    Hashable | None
        The ID of the closest permitted ancestor, or None if none is found.

    Examples
    --------
    >>> tree = {1: 2, 2: 3, 3: 4, 4: 5}
    >>> permitted = {0, 2, 5}
    >>> find_closest_permitted_parent(1, tree, permitted) # 1 -> 2 (permitted)
    2
    >>> find_closest_permitted_parent(0, tree, permitted) # 0 is not in tree keys, returns None
    >>> tree[0] = 1 # Add 0 to the tree
    >>> find_closest_permitted_parent(0, tree, permitted) # 0 -> 1 -> 2 (permitted)
    2
    >>> tree = {10: 20, 20: 30, 30: 40}
    >>> find_closest_permitted_parent(10, tree, {50, 60}) # No permitted ancestors, returns None
    """
    if node not in tree:
        return None
    parent = tree[node]
    while parent not in permitted_nodes:
        if parent in tree:
            parent = tree[parent]
        else:
            return None
    return parent

get_ancestor_chain_lens(tree)

Get lengths of ancestor chains in a { child: parent } dictionary tree

Examples:

>>> get_ancestor_chain_lens({ 0:1, 1:2, 2:3, 4:5, 5:6, 7:8 })
{3: 1, 2: 2, 1: 3, 0: 4, 6: 1, 5: 2, 4: 3, 8: 1, 7: 2}

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in { child: parent } format.

required

Returns:

Name Type Description
lengths dict[Hashable, int]

The lengths of the path to the root from each node { node: length }

Source code in hierarchical_loss/tree_utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def get_ancestor_chain_lens(tree: dict[Hashable, Hashable]) -> dict[Hashable, int]:
    '''
    Get lengths of ancestor chains in a { child: parent } dictionary tree

    Examples
    --------
    >>> get_ancestor_chain_lens({ 0:1, 1:2, 2:3, 4:5, 5:6, 7:8 })
    {3: 1, 2: 2, 1: 3, 0: 4, 6: 1, 5: 2, 4: 3, 8: 1, 7: 2}

    Parameters
    ----------
    tree: dict[Hashable, Hashable]
        A tree in { child: parent } format.

    Returns
    -------
    lengths: dict[Hashable, int]
        The lengths of the path to the root from each node { node: length }

    '''
    return preorder_apply(tree, _increment_chain_len)

get_roots(tree)

Finds all root nodes in a {child: parent} tree.

A root node is defined as any node that is not a child of another node in the tree (i.e., its ancestor chain length is 1).

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
list[Hashable]

A list of all root nodes.

Examples:

>>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
>>> get_roots(tree) # Roots are 2 and 6
[2, 6]
Source code in hierarchical_loss/tree_utils.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def get_roots(tree: dict[Hashable, Hashable]) -> list[Hashable]:
    """Finds all root nodes in a {child: parent} tree.

    A root node is defined as any node that is not a child of another
    node in the tree (i.e., its ancestor chain length is 1).

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    list[Hashable]
        A list of all root nodes.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 2, 5: 6}
    >>> get_roots(tree) # Roots are 2 and 6
    [2, 6]
    """
    ancestor_chain_lens = get_ancestor_chain_lens(tree)
    return [node for node in ancestor_chain_lens if ancestor_chain_lens[node] == 1]

hierarchically_index_flat_scores(flat_scores, target_indices, hierarchy_index_tensor, hierarchy_mask, device=None)

Gathers scores from a flat score tensor along specified hierarchical paths.

This function takes a batch of flat scores (B, P, C) and a batch of target category indices (B, P). For each target index, it looks up the full ancestral path in hierarchy_index_tensor (C, H) and gathers the corresponding scores from flat_scores.

It then applies the hierarchy_mask to the gathered scores, zeroing out entries where the mask is True.

Parameters:

Name Type Description Default
flat_scores Tensor

A tensor of flat scores with shape (B, P, C), where B is batch size, P is number of proposals, and C is number of categories.

required
target_indices Tensor

A long tensor of shape (B, P) containing the leaf category index for each proposal.

required
hierarchy_index_tensor Tensor

A long tensor of shape (C, H) mapping each category c to its ancestral path. H is the max hierarchy depth.

required
hierarchy_mask Tensor

A boolean invalidity mask of shape (C, H). True indicates an invalid entry (e.g., padding) that should be zeroed out. This mask is indexed by target_indices and applied to the gathered scores.

required
device device | str | None

The desired device for torch.arange. If None, uses the default PyTorch device. By default None.

None

Returns:

Type Description
Tensor

A tensor of shape (B, P, H) containing the gathered scores along each target's ancestral path, after masking.

Examples:

>>> import torch
>>> # B=1, P=1, C=2
>>> flat_scores = torch.tensor([[[10., 20.]]])
>>> # Target index is 0
>>> target_indices = torch.tensor([[0]])
>>> # C=2, H=3. Path 0 is [0, 1, -1]
>>> hierarchy_index_tensor = torch.tensor([[0, 1, -1], [1, -1, -1]], dtype=torch.long)
>>> # Create an invalidity mask (True where path is -1)
>>> invalidity_mask = (hierarchy_index_tensor == -1)
>>> print(invalidity_mask)
tensor([[False, False,  True],
        [False,  True,  True]])
>>>
>>> # The function will gather scores for path [0, 1, -1] -> [10., 20., 10.]
>>> # (Note: -1 safely indexes 0)
>>> # It will apply the mask for index 0: [False, False, True]
>>> # Result: [10., 20., 0.]
>>> hierarchically_index_flat_scores(
...     flat_scores, target_indices, hierarchy_index_tensor, invalidity_mask
... )
tensor([[[10., 20.,  0.]]])
Source code in hierarchical_loss/hierarchy_tensor_utils.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def hierarchically_index_flat_scores(
    flat_scores: torch.Tensor,
    target_indices: torch.Tensor,
    hierarchy_index_tensor: torch.Tensor,
    hierarchy_mask: torch.Tensor,
    device: torch.device | str | None = None,
) -> torch.Tensor:
    """Gathers scores from a flat score tensor along specified hierarchical paths.

    This function takes a batch of flat scores (`B, P, C`) and a batch of
    target category indices (`B, P`). For each target index, it looks up
    the full ancestral path in `hierarchy_index_tensor` (`C, H`) and
    gathers the corresponding scores from `flat_scores`.

    It then applies the `hierarchy_mask` to the gathered scores, zeroing
    out entries where the mask is `True`.

    Parameters
    ----------
    flat_scores : torch.Tensor
        A tensor of flat scores with shape `(B, P, C)`, where `B` is
        batch size, `P` is number of proposals, and `C` is number
        of categories.
    target_indices : torch.Tensor
        A long tensor of shape `(B, P)` containing the leaf category
        index for each proposal.
    hierarchy_index_tensor : torch.Tensor
        A long tensor of shape `(C, H)` mapping each category `c` to its
        ancestral path. `H` is the max hierarchy depth.
    hierarchy_mask : torch.Tensor
        A boolean **invalidity** mask of shape `(C, H)`. `True` indicates
        an invalid entry (e.g., padding) that should be zeroed out.
        This mask is indexed by `target_indices` and applied to the
        gathered scores.
    device : torch.device | str | None, optional
        The desired device for `torch.arange`. If `None`, uses the
        default PyTorch device. By default `None`.

    Returns
    -------
    torch.Tensor
        A tensor of shape `(B, P, H)` containing the gathered scores
        along each target's ancestral path, after masking.

    Examples
    --------
    >>> import torch
    >>> # B=1, P=1, C=2
    >>> flat_scores = torch.tensor([[[10., 20.]]])
    >>> # Target index is 0
    >>> target_indices = torch.tensor([[0]])
    >>> # C=2, H=3. Path 0 is [0, 1, -1]
    >>> hierarchy_index_tensor = torch.tensor([[0, 1, -1], [1, -1, -1]], dtype=torch.long)
    >>> # Create an invalidity mask (True where path is -1)
    >>> invalidity_mask = (hierarchy_index_tensor == -1)
    >>> print(invalidity_mask)
    tensor([[False, False,  True],
            [False,  True,  True]])
    >>>
    >>> # The function will gather scores for path [0, 1, -1] -> [10., 20., 10.]
    >>> # (Note: -1 safely indexes 0)
    >>> # It will apply the mask for index 0: [False, False, True]
    >>> # Result: [10., 20., 0.]
    >>> hierarchically_index_flat_scores(
    ...     flat_scores, target_indices, hierarchy_index_tensor, invalidity_mask
    ... )
    tensor([[[10., 20.,  0.]]])
    """
    batch_size, n_proposals, n_categories = flat_scores.shape
    hierarchy_size = hierarchy_index_tensor.shape[1]

    hierarchy_indices = hierarchy_index_tensor[target_indices]
    flat_mask = hierarchy_mask[target_indices]

    # Construct batch indices
    batch_indices = torch.arange(batch_size, device=device).view(batch_size, 1, 1).expand(batch_size, n_proposals, hierarchy_size) # (B, N, H)
    proposal_indices = torch.arange(n_proposals, device=device).view(1, n_proposals, 1).expand(batch_size, n_proposals, hierarchy_size) # (B, N, H)

    gathered_scores = flat_scores[batch_indices, proposal_indices, hierarchy_indices]  # (B, N, H)

    # Mask out invalid entries
    masked_scores = gathered_scores.masked_fill(flat_mask, 0.)

    return masked_scores

invert_childparent_tree(tree)

Converts a {child: parent} tree into a nested {parent: {child: ...}} tree.

This function inverts the standard {child: parent} structure, creating a nested dictionary that starts from the root(s). It uses preorder_apply to traverse the tree top-down and build the nested structure.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required

Returns:

Type Description
dict

A nested dictionary representing the tree in a top-down format, e.g., {root: {child: {grandchild: {}}}}.

Examples:

>>> tree = {0: 1, 1: 2, 3: 2, 5: 6} # 0->1->2, 3->2, 5->6
>>> invert_childparent_tree(tree)
{2: {1: {0: {}}, 3: {}}, 6: {5: {}}}
Source code in hierarchical_loss/tree_utils.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def invert_childparent_tree(tree: dict[Hashable, Hashable]) -> dict:
    """Converts a {child: parent} tree into a nested {parent: {child: ...}} tree.

    This function inverts the standard {child: parent} structure, creating
    a nested dictionary that starts from the root(s). It uses
    `preorder_apply` to traverse the tree top-down and build the
    nested structure.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.

    Returns
    -------
    dict
        A nested dictionary representing the tree in a top-down format,
        e.g., `{root: {child: {grandchild: {}}}}`.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 2, 5: 6} # 0->1->2, 3->2, 5->6
    >>> invert_childparent_tree(tree)
    {2: {1: {0: {}}, 3: {}}, 6: {5: {}}}
    """
    parentchild_tree = {}
    preorder_apply(tree, _append_to_parentchild_tree, parentchild_tree)
    return parentchild_tree

preorder_apply(tree, f, *args)

Applies a function to all nodes in a tree in a pre-order (top-down) fashion.

This function works by first finding an ancestor path (from leaf to root). It then applies the function f to the root (or highest unvisited node) and iterates down the path, applying f to each child and passing in the result from its parent. This top-down application is a pre-order traversal.

It uses memoization (the visited dict) to ensure that f is applied to each node only once, even in multi-branch trees.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

The hierarchy tree, in {child: parent} format.

required
f Callable

The function to apply to each node. Its signature must be f(node: Hashable, parent_result: Any, *args: Any) -> Any.

required
*args Any

Additional positional arguments to be passed to every call of f.

()

Returns:

Type Description
dict[Hashable, Any]

A dictionary mapping each node ID to the result of f(node, ...).

Examples:

>>> # Example: Calculate node depth (pre-order calculation)
>>> tree = {0: 1, 1: 2, 3: 2} # 0->1->2, 3->2
>>> def f(node, parent_depth):
...     # parent_depth is the result from the parent node
...     return 1 if parent_depth is None else parent_depth + 1
...
>>> preorder_apply(tree, f)
{2: 1, 1: 2, 0: 3, 3: 2}
Source code in hierarchical_loss/tree_utils.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def preorder_apply(tree: dict[Hashable, Hashable], f: Callable, *args: Any) -> dict[Hashable, Any]:
    """Applies a function to all nodes in a tree in a pre-order (top-down) fashion.

    This function works by first finding an ancestor path (from leaf to root).
    It then applies the function `f` to the root (or highest unvisited node)
    and iterates *down* the path, applying `f` to each child and passing in
    the result from its parent. This top-down application is a pre-order
    traversal.

    It uses memoization (the `visited` dict) to ensure that `f` is
    applied to each node only once, even in multi-branch trees.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        The hierarchy tree, in {child: parent} format.
    f : Callable
        The function to apply to each node. Its signature must be
        `f(node: Hashable, parent_result: Any, *args: Any) -> Any`.
    *args: Any
        Additional positional arguments to be passed to every call of `f`.

    Returns
    -------
    dict[Hashable, Any]
        A dictionary mapping each node ID to the result of `f(node, ...)`.

    Examples
    --------
    >>> # Example: Calculate node depth (pre-order calculation)
    >>> tree = {0: 1, 1: 2, 3: 2} # 0->1->2, 3->2
    >>> def f(node, parent_depth):
    ...     # parent_depth is the result from the parent node
    ...     return 1 if parent_depth is None else parent_depth + 1
    ...
    >>> preorder_apply(tree, f)
    {2: 1, 1: 2, 0: 3, 3: 2}
    """
    visited = {}
    for node in tree:
        path = [node]
        while (node in tree) and (node not in visited):
            node = tree[node]
            path.append(node)
        if node not in visited:
            visited[node] = f(node, None, *args)
        for i in range(-2, -len(path) - 1, -1):
            visited[path[i]] = f(path[i], visited[path[i+1]], *args)
    return visited

set_indices(index, parent_index, tensor)

A helper function for preorder_apply to build an ancestor path tensor.

This function populates a single row of the tensor (the row specified by index). It sets the first element of the row to index itself. If a parent_index is provided, it copies the parent's ancestor path (its row, excluding the last element) into the current node's row (starting from the second element).

This creates the desired row format: [node_id, parent_id, grand_parent_id, ...].

It is designed to be used with tree_utils.preorder_apply, where: - index is the node - parent_index is the parent_result (the return value from the parent's call, which is the parent's index) - tensor is the *args

Parameters:

Name Type Description Default
index int

The node ID, which corresponds to the row index in the tensor.

required
parent_index int or None

The ID of the parent node, or None if the node is a root.

required
tensor Tensor

The 2D tensor being populated with ancestor paths.

required

Returns:

Type Description
int

The index of the current node, to be passed as the parent_index to its children during the pre-order traversal.

Source code in hierarchical_loss/hierarchy_tensor_utils.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def set_indices(index: int, parent_index: int | None, tensor: torch.Tensor) -> int:
    """A helper function for `preorder_apply` to build an ancestor path tensor.

    This function populates a single row of the `tensor` (the row specified
    by `index`). It sets the first element of the row to `index` itself.
    If a `parent_index` is provided, it copies the parent's ancestor path
    (its row, excluding the last element) into the current node's row
    (starting from the second element).

    This creates the desired row format: [node_id, parent_id, grand_parent_id, ...].

    It is designed to be used with `tree_utils.preorder_apply`, where:
    - `index` is the `node`
    - `parent_index` is the `parent_result` (the return value from the
      parent's call, which is the parent's index)
    - `tensor` is the `*args`

    Parameters
    ----------
    index : int
        The node ID, which corresponds to the row index in the tensor.
    parent_index : int or None
        The ID of the parent node, or None if the node is a root.
    tensor : torch.Tensor
        The 2D tensor being populated with ancestor paths.

    Returns
    -------
    int
        The `index` of the current node, to be passed as the
        `parent_index` to its children during the pre-order traversal.
    """
    tensor[index, 0] = index
    if parent_index is not None:
        tensor[index, 1:] = tensor[parent_index, :-1]
    return index

tree_walk(tree, node)

Walks up the ancestor chain from a starting node.

This generator yields the starting node first, then its parent, its grandparent, and so on, until a root (a node not present as a key in the tree) is reached.

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

The hierarchy tree, in {child: parent} format.

required
node Hashable

The node to start the walk from.

required

Yields:

Type Description
Iterator[Hashable]

An iterator of node IDs in the ancestor chain, starting with the given node.

Examples:

>>> tree = {0: 1, 1: 2, 3: 4, 4: 2}
>>> list(tree_walk(tree, 0))
[0, 1, 2]
>>> list(tree_walk(tree, 3))
[3, 4, 2]
>>> list(tree_walk(tree, 2))
[2]
Source code in hierarchical_loss/tree_utils.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def tree_walk(tree: dict[Hashable, Hashable], node: Hashable) -> Iterator[Hashable]:
    """Walks up the ancestor chain from a starting node.

    This generator yields the starting node first, then its parent,
    its grandparent, and so on, until a root (a node not
    present as a key in the tree) is reached.

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        The hierarchy tree, in {child: parent} format.
    node : Hashable
        The node to start the walk from.

    Yields
    ------
    Iterator[Hashable]
        An iterator of node IDs in the ancestor chain, starting
        with the given node.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 3: 4, 4: 2}
    >>> list(tree_walk(tree, 0))
    [0, 1, 2]
    >>> list(tree_walk(tree, 3))
    [3, 4, 2]
    >>> list(tree_walk(tree, 2))
    [2]
    """
    yield node
    while node in tree:
        node = tree[node]
        yield node

trim_childparent_tree(tree, permitted_nodes)

Trims a {child: parent} tree to only include permitted nodes.

This function first remaps every node in the tree to its closest permitted ancestor. It then filters this map, keeping only the entries where the node (the key) is also in the permitted_nodes set.

The result is a new {child: parent} tree containing only permitted nodes, mapped to their closest permitted ancestor (which will be another permitted node or None).

Parameters:

Name Type Description Default
tree dict[Hashable, Hashable]

A tree in {child: parent} format.

required
permitted_nodes set[Hashable]

A set of node IDs to keep.

required

Returns:

Type Description
dict[Hashable, Hashable | None]

A new {child: parent} tree containing only permitted nodes, each re-mapped to its closest permitted ancestor.

Examples:

>>> tree = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5} # 0->1->2->3->4->5
>>> permitted = {0, 2, 5} # 0, 2, and 5 are permitted
>>> trim_childparent_tree(tree, permitted)
{0: 2, 2: 5}
Source code in hierarchical_loss/tree_utils.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def trim_childparent_tree(
    tree: dict[Hashable, Hashable], permitted_nodes: set[Hashable]
) -> dict[Hashable, Hashable | None]:
    """Trims a {child: parent} tree to only include permitted nodes.

    This function first remaps every node in the tree to its closest
    permitted ancestor. It then filters this map, keeping only the
    entries where the node (the key) is *also* in the `permitted_nodes`
    set.

    The result is a new {child: parent} tree containing *only*
    permitted nodes, mapped to their closest permitted ancestor
    (which will be another permitted node or None).

    Parameters
    ----------
    tree : dict[Hashable, Hashable]
        A tree in {child: parent} format.
    permitted_nodes : set[Hashable]
        A set of node IDs to keep.

    Returns
    -------
    dict[Hashable, Hashable | None]
        A new {child: parent} tree containing only permitted nodes,
        each re-mapped to its closest permitted ancestor.

    Examples
    --------
    >>> tree = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5} # 0->1->2->3->4->5
    >>> permitted = {0, 2, 5} # 0, 2, and 5 are permitted
    >>> trim_childparent_tree(tree, permitted)
    {0: 2, 2: 5}
    """
    new_tree = {}
    for node in tree:
        closest_permitted_parent = find_closest_permitted_parent(node, tree, permitted_nodes)
        new_tree[node] = closest_permitted_parent
    for node in list(new_tree.keys()):
        if new_tree[node] is None or (node not in permitted_nodes):
            del new_tree[node]
    return new_tree