compute_fisher_weights

compute_fisher_weights(embeddings, labels, min_count=50)

Compute sqrt(Fisher ratio) per dimension, L2-normalized.

Fisher ratio = between-class variance / within-class variance. Dimensions where classes differ most get the highest weight.

Parameters

embeddings : (n, d) float32 Embedding matrix. labels : (n,) str or int array Coarse category labels (one per row). min_count : int Classes with fewer than this many members are excluded.

Returns

weights : (d,) float32 L2-normalized sqrt(Fisher ratio) weights. Uniform if <2 classes survive filtering.