embed_with_diagnostics

embed_with_diagnostics(
    embeddings,
    text_col,
    label_columns,
    embed_fn,
    lift_threshold=3.0,
    prefix='',
)

Two-pass embedding with axis diagnostics.

  1. Run diagnose_axes on the provided embeddings.
  2. If any axis has lift < lift_threshold, rebuild text with explicit labeled fields for under-served axes.
  3. Re-embed with the structured text via embed_fn.
  4. Re-diagnose to confirm improvement.

Parameters

embeddings : (n, d) float32 Baseline embeddings (already computed by the caller). text_col : list[str] Original text strings used for embedding (one per row). label_columns : dict[str, array-like] Mapping of axis name → per-item labels (length n). embed_fn : callable fn(texts: list[str]) -> np.ndarray — re-embeds a list of texts. lift_threshold : float Axes with lift below this are promoted to explicit text fields. prefix : str Prefix prepended to each structured text (e.g. "search_document: ").

Returns

embeddings : (n, d) float32 Final embeddings (original if no axes promoted, re-embedded otherwise). before_diags : list[AxisDiagnostic] Diagnostics from the first pass. after_diags : list[AxisDiagnostic] Diagnostics from the second pass (same as before if no re-embedding). texts : list[str] Final text strings (original or structured).