compute_fisher_weights
compute_fisher_weights(embeddings, labels, min_count=50)Compute sqrt(Fisher ratio) per dimension, L2-normalized.
Fisher ratio = between-class variance / within-class variance. Dimensions where classes differ most get the highest weight.
Parameters
embeddings : (n, d) float32 Embedding matrix. labels : (n,) str or int array Coarse category labels (one per row). min_count : int Classes with fewer than this many members are excluded.
Returns
weights : (d,) float32 L2-normalized sqrt(Fisher ratio) weights. Uniform if <2 classes survive filtering.