diagnose_axes

diagnose_axes(embeddings, label_columns, k=15, sample_n=5000, seed=42)

Detect which categorical axes are under-served by the embedding.

For each axis, computes k-NN purity (do nearby points share the same label?) and compares to the Herfindahl random baseline. Low lift = the embedding doesn’t actively separate this axis = candidate for promotion to an explicit text field before re-embedding.

Parameters

embeddings : (n, d) float32 Embedding matrix. label_columns : dict[str, array-like] Mapping of axis name → per-item labels (length n). k : int Number of neighbors for k-NN purity (default 15). sample_n : int Subsample to this many points for speed (0 = no subsampling). Default 5000 — brute-force k-NN on 5k points takes ~1s. seed : int Random seed for subsampling.

Returns

diagnostics : list[AxisDiagnostic] Sorted by lift ascending (worst-performing axis first).

Examples

diags = diagnose_axes(embeddings, {“gmdn”: gmdn_labels, “polarity”: pol_labels}) for d in diags: … print(f”{d.name}: lift={d.lift:.1f}x purity={d.knn_purity:.3f}“) polarity: lift=2.2x purity=0.978 gmdn: lift=23.0x purity=0.929