embed_with_diagnostics
embed_with_diagnostics(
embeddings,
text_col,
label_columns,
embed_fn,
lift_threshold=3.0,
prefix='',
)Two-pass embedding with axis diagnostics.
- Run
diagnose_axeson the provided embeddings. - If any axis has lift <
lift_threshold, rebuild text with explicit labeled fields for under-served axes. - Re-embed with the structured text via
embed_fn. - Re-diagnose to confirm improvement.
Parameters
embeddings : (n, d) float32 Baseline embeddings (already computed by the caller). text_col : list[str] Original text strings used for embedding (one per row). label_columns : dict[str, array-like] Mapping of axis name → per-item labels (length n). embed_fn : callable fn(texts: list[str]) -> np.ndarray — re-embeds a list of texts. lift_threshold : float Axes with lift below this are promoted to explicit text fields. prefix : str Prefix prepended to each structured text (e.g. "search_document: ").
Returns
embeddings : (n, d) float32 Final embeddings (original if no axes promoted, re-embedded otherwise). before_diags : list[AxisDiagnostic] Diagnostics from the first pass. after_diags : list[AxisDiagnostic] Diagnostics from the second pass (same as before if no re-embedding). texts : list[str] Final text strings (original or structured).