Skip to content

Hierarchy class

Hierarchy

Centralizes all hierarchy logic, mapping, and tensor creation.

Source code in hierarchical_loss/hierarchy.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Hierarchy:
    """
    Centralizes all hierarchy logic, mapping, and tensor creation.
    """
    def __init__(self, 
                 raw_tree: dict[Hashable, Hashable], 
                 node_to_idx_map: dict[Hashable, int] | None = None,
                 device: torch.device | str | None = None):

        all_nodes = set(raw_tree.keys()) | set(raw_tree.values())

        # 1. Build the translation maps
        if node_to_idx_map:
            # Use the provided map
            self.node_to_idx = node_to_idx_map
            # Verify all nodes are accounted for
            for node in all_nodes:
                if node not in self.node_to_idx:
                    raise ValueError(f"Node '{node}' from raw_tree is missing from the provided node_to_idx_map.")
            # Verify node indices are sequential and dense
            idx_vals = node_to_idx_map.values()
            min_idx, max_idx, n_idx = min(idx_vals), max(idx_vals), len(idx_vals)
            if min_idx !=0 or (max_idx != n_idx-1):
                raise ValueError(f"node_to_idx_map must have contiguous sequential indices")
        else:
            # Auto-generate a dense map
            self.node_to_idx = {node: i for i, node in enumerate(all_nodes)}

        self.idx_to_node = {i: n for n, i in self.node_to_idx.items()}
        self.num_classes = len(self.node_to_idx)

        # 2. Create the core index-based tree
        self.index_tree = dict_keyvalue_replace(raw_tree, self.node_to_idx)

        # 3. Pre-compute all tensor and dict representations
        # These are now cached for the lifetime of the object.
        self.parent_tensor = build_parent_tensor(self.index_tree, device=device)
        self.index_tensor = build_hierarchy_index_tensor(self.index_tree, device=device)
        self.hierarchy_mask = self.index_tensor == -1
        self.sibling_mask = build_hierarchy_sibling_mask(self.parent_tensor, device=device)
        self.roots = torch.tensor(get_roots(self.index_tree), device=device)
        self.parent_child_tensor_tree = construct_parent_childtensor_tree(self.index_tree, device=device)

    def to(self, device: torch.device | str):
        """Moves all computed tensors to the specified device."""
        self.parent_tensor = self.parent_tensor.to(device)
        self.index_tensor = self.index_tensor.to(device)
        self.hierarchy_mask = self.hierarchy_mask.to(device)
        self.sibling_mask = self.sibling_mask.to(device)
        self.roots = self.roots.to(device) 
        self.parent_child_tensor_tree = {k: v.to(device) for k, v in self.parent_child_tensor_tree.items()}
        return self

to(device)

Moves all computed tensors to the specified device.

Source code in hierarchical_loss/hierarchy.py
55
56
57
58
59
60
61
62
63
def to(self, device: torch.device | str):
    """Moves all computed tensors to the specified device."""
    self.parent_tensor = self.parent_tensor.to(device)
    self.index_tensor = self.index_tensor.to(device)
    self.hierarchy_mask = self.hierarchy_mask.to(device)
    self.sibling_mask = self.sibling_mask.to(device)
    self.roots = self.roots.to(device) 
    self.parent_child_tensor_tree = {k: v.to(device) for k, v in self.parent_child_tensor_tree.items()}
    return self