multi_level_fisher_weights

multi_level_fisher_weights(embeddings, graph, item_labels, min_count=50)

Compute Fisher weights averaged across multiple DAG depths.

At each depth of the graph, item labels are resolved via graph.items_at_depth(), then compute_fisher_weights is called. The per-depth weight vectors are combined via weighted average (weighted by number of surviving classes at that depth) and re-normalized.

Falls back to single-level Fisher if the graph has depth <= 1.

Parameters

embeddings : (n, d) float32 graph : CategoryGraph item_labels : (n,) str array Per-item leaf labels (finest level). min_count : int Passed through to compute_fisher_weights.

Returns

weights : (d,) float32 L2-normalized combined Fisher weights.