Skip to content

Semantic Router

mailtag.semantic_router.SemanticRouter

Routes text to categories based on embedding similarity.

Uses pre-computed category embeddings to instantly classify emails based on semantic similarity to category centroids.

Source code in src/mailtag/semantic_router.py
class SemanticRouter:
    """Routes text to categories based on embedding similarity.

    Uses pre-computed category embeddings to instantly classify
    emails based on semantic similarity to category centroids.
    """

    def __init__(
        self,
        embedder: MLXEmbedder,
        score_threshold: float = 0.75,
    ):
        """Initialize the semantic router.

        Args:
            embedder: MLXEmbedder instance for generating embeddings
            score_threshold: Minimum similarity score to accept a route (0.0-1.0)
        """
        self.embedder = embedder
        self.score_threshold = score_threshold
        self.category_embeddings: dict[str, np.ndarray] = {}
        self.categories: list[str] = []
        self._embedding_matrix: np.ndarray | None = None
        logger.info(f"SemanticRouter initialized with threshold: {score_threshold}")

    def load_embeddings(self, path: Path | str) -> bool:
        """Load pre-computed category embeddings from file.

        Args:
            path: Path to the .npz file containing embeddings

        Returns:
            True if loaded successfully, False otherwise
        """
        path = Path(path)
        if not path.exists():
            logger.warning(f"Embeddings file not found: {path}")
            return False

        try:
            data = np.load(path, allow_pickle=True)
            self.category_embeddings = {key: data[key] for key in data.files}
            self.categories = list(self.category_embeddings.keys())
            self._build_embedding_matrix()
            logger.info(f"Loaded embeddings for {len(self.categories)} categories from {path}")
            return True
        except (OSError, json.JSONDecodeError, KeyError, ValueError) as e:
            logger.error(f"Failed to load embeddings from {path}: {e}")
            return False

    def save_embeddings(self, path: Path | str) -> bool:
        """Save category embeddings to file.

        Args:
            path: Path to save the .npz file

        Returns:
            True if saved successfully, False otherwise
        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)

        try:
            np.savez(path, **self.category_embeddings)
            logger.info(f"Saved embeddings for {len(self.categories)} categories to {path}")
            return True
        except (OSError, TypeError, ValueError) as e:
            logger.error(f"Failed to save embeddings to {path}: {e}")
            return False

    def _build_embedding_matrix(self):
        """Build the embedding matrix for efficient similarity computation."""
        if not self.category_embeddings:
            self._embedding_matrix = None
            return

        self._embedding_matrix = np.stack([self.category_embeddings[cat] for cat in self.categories])
        # Normalize for cosine similarity
        norms = np.linalg.norm(self._embedding_matrix, axis=1, keepdims=True)
        self._embedding_matrix = self._embedding_matrix / norms

    def build_from_examples(self, category_examples: dict[str, list[str]]) -> None:
        """Build category embeddings from example texts.

        Args:
            category_examples: Dict mapping category names to lists of example texts
        """
        logger.info(f"Building embeddings for {len(category_examples)} categories...")

        for category, examples in category_examples.items():
            if not examples:
                logger.warning(f"No examples for category '{category}', skipping")
                continue

            # Compute embeddings for all examples
            embeddings = self.embedder.encode_documents(examples)

            # Use centroid (mean) as category embedding
            centroid = embeddings.mean(axis=0)
            self.category_embeddings[category] = centroid

            logger.debug(f"Built embedding for '{category}' from {len(examples)} examples")

        self.categories = list(self.category_embeddings.keys())
        self._build_embedding_matrix()
        logger.info(f"Built embeddings for {len(self.categories)} categories")

    def build_from_validated_db(self, validated_db: dict[str, str], min_examples: int = 1) -> None:
        """Build category embeddings from validated classification database.

        Args:
            validated_db: Dict mapping sender email to category
            min_examples: Minimum examples required per category
        """
        # Group by category
        category_senders: dict[str, list[str]] = {}
        for sender, category in validated_db.items():
            if category not in category_senders:
                category_senders[category] = []
            category_senders[category].append(sender)

        # Build examples from sender addresses
        category_examples: dict[str, list[str]] = {}
        for category, senders in category_senders.items():
            if len(senders) >= min_examples:
                # Use sender domain and category name as examples
                examples = []
                for sender in senders[:10]:  # Limit to 10 examples per category
                    # Create representative text from sender
                    domain = sender.split("@")[-1] if "@" in sender else sender
                    examples.append(f"Email from {domain} categorized as {category}")
                category_examples[category] = examples

        self.build_from_examples(category_examples)

    def route(self, text: str) -> tuple[str, float]:
        """Route text to the most similar category.

        Args:
            text: Input text to classify

        Returns:
            Tuple of (category, similarity_score)
            Returns ("", 0.0) if no category meets threshold
        """
        if not self.categories or self._embedding_matrix is None:
            logger.warning("No category embeddings loaded, cannot route")
            return "", 0.0

        # Compute query embedding
        query_embedding = self.embedder.encode_query(text)
        query_norm = query_embedding / np.linalg.norm(query_embedding)

        # Compute similarities with all categories
        similarities = np.dot(self._embedding_matrix, query_norm)

        # Find best match
        best_idx = np.argmax(similarities)
        best_score = similarities[best_idx]
        best_category = self.categories[best_idx]

        if best_score >= self.score_threshold:
            logger.debug(f"Routed to '{best_category}' with score {best_score:.3f}")
            return best_category, float(best_score)
        else:
            logger.debug(f"No route found (best: '{best_category}' with score {best_score:.3f})")
            return "", float(best_score)

    def route_batch(self, texts: list[str]) -> list[tuple[str, float]]:
        """Route multiple texts to categories in a single batch (more efficient than per-item).

        Args:
            texts: List of input texts to classify

        Returns:
            List of (category, similarity_score) tuples, one per input text.
            Returns ("", 0.0) for texts where no category meets threshold.
        """
        if not self.categories or self._embedding_matrix is None:
            return [("", 0.0)] * len(texts)

        if not texts:
            return []

        # Batch-encode all queries at once
        query_embeddings = self.embedder.encode(texts, prefix="search_query: ")
        # Normalize
        norms = np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / norms

        # Compute similarities for all queries at once: (n_queries, n_categories)
        similarities = np.dot(query_embeddings, self._embedding_matrix.T)

        results = []
        for i in range(len(texts)):
            best_idx = np.argmax(similarities[i])
            best_score = similarities[i][best_idx]
            if best_score >= self.score_threshold:
                results.append((self.categories[best_idx], float(best_score)))
            else:
                results.append(("", float(best_score)))

        return results

    def route_with_alternatives(self, text: str, top_k: int = 3) -> list[tuple[str, float]]:
        """Route text and return top-k alternatives.

        Args:
            text: Input text to classify
            top_k: Number of top alternatives to return

        Returns:
            List of (category, score) tuples sorted by score descending
        """
        if not self.categories or self._embedding_matrix is None:
            return []

        query_embedding = self.embedder.encode_query(text)
        query_norm = query_embedding / np.linalg.norm(query_embedding)
        similarities = np.dot(self._embedding_matrix, query_norm)

        # Get top-k indices
        top_indices = np.argsort(similarities)[-top_k:][::-1]

        results = []
        for idx in top_indices:
            results.append((self.categories[idx], float(similarities[idx])))

        return results

    def add_category(self, category: str, examples: list[str]) -> None:
        """Add a new category with examples.

        Args:
            category: Category name
            examples: List of example texts
        """
        if not examples:
            logger.warning(f"No examples provided for category '{category}'")
            return

        embeddings = self.embedder.encode_documents(examples)
        centroid = embeddings.mean(axis=0)
        self.category_embeddings[category] = centroid

        if category not in self.categories:
            self.categories.append(category)

        self._build_embedding_matrix()
        logger.info(f"Added category '{category}' with {len(examples)} examples")

    def remove_category(self, category: str) -> bool:
        """Remove a category.

        Args:
            category: Category name to remove

        Returns:
            True if removed, False if not found
        """
        if category not in self.category_embeddings:
            return False

        del self.category_embeddings[category]
        self.categories.remove(category)
        self._build_embedding_matrix()
        logger.info(f"Removed category '{category}'")
        return True

    @property
    def num_categories(self) -> int:
        """Return the number of loaded categories."""
        return len(self.categories)

    def get_category_info(self) -> dict[str, dict]:
        """Get information about loaded categories.

        Returns:
            Dict with category names and embedding dimensions
        """
        return {cat: {"embedding_dim": emb.shape[0]} for cat, emb in self.category_embeddings.items()}

num_categories property

Return the number of loaded categories.

__init__(embedder, score_threshold=0.75)

Initialize the semantic router.

Parameters:

Name Type Description Default
embedder MLXEmbedder

MLXEmbedder instance for generating embeddings

required
score_threshold float

Minimum similarity score to accept a route (0.0-1.0)

0.75
Source code in src/mailtag/semantic_router.py
def __init__(
    self,
    embedder: MLXEmbedder,
    score_threshold: float = 0.75,
):
    """Initialize the semantic router.

    Args:
        embedder: MLXEmbedder instance for generating embeddings
        score_threshold: Minimum similarity score to accept a route (0.0-1.0)
    """
    self.embedder = embedder
    self.score_threshold = score_threshold
    self.category_embeddings: dict[str, np.ndarray] = {}
    self.categories: list[str] = []
    self._embedding_matrix: np.ndarray | None = None
    logger.info(f"SemanticRouter initialized with threshold: {score_threshold}")

route(text)

Route text to the most similar category.

Parameters:

Name Type Description Default
text str

Input text to classify

required

Returns:

Type Description
str

Tuple of (category, similarity_score)

float

Returns ("", 0.0) if no category meets threshold

Source code in src/mailtag/semantic_router.py
def route(self, text: str) -> tuple[str, float]:
    """Route text to the most similar category.

    Args:
        text: Input text to classify

    Returns:
        Tuple of (category, similarity_score)
        Returns ("", 0.0) if no category meets threshold
    """
    if not self.categories or self._embedding_matrix is None:
        logger.warning("No category embeddings loaded, cannot route")
        return "", 0.0

    # Compute query embedding
    query_embedding = self.embedder.encode_query(text)
    query_norm = query_embedding / np.linalg.norm(query_embedding)

    # Compute similarities with all categories
    similarities = np.dot(self._embedding_matrix, query_norm)

    # Find best match
    best_idx = np.argmax(similarities)
    best_score = similarities[best_idx]
    best_category = self.categories[best_idx]

    if best_score >= self.score_threshold:
        logger.debug(f"Routed to '{best_category}' with score {best_score:.3f}")
        return best_category, float(best_score)
    else:
        logger.debug(f"No route found (best: '{best_category}' with score {best_score:.3f})")
        return "", float(best_score)

route_batch(texts)

Route multiple texts to categories in a single batch (more efficient than per-item).

Parameters:

Name Type Description Default
texts list[str]

List of input texts to classify

required

Returns:

Type Description
list[tuple[str, float]]

List of (category, similarity_score) tuples, one per input text.

list[tuple[str, float]]

Returns ("", 0.0) for texts where no category meets threshold.

Source code in src/mailtag/semantic_router.py
def route_batch(self, texts: list[str]) -> list[tuple[str, float]]:
    """Route multiple texts to categories in a single batch (more efficient than per-item).

    Args:
        texts: List of input texts to classify

    Returns:
        List of (category, similarity_score) tuples, one per input text.
        Returns ("", 0.0) for texts where no category meets threshold.
    """
    if not self.categories or self._embedding_matrix is None:
        return [("", 0.0)] * len(texts)

    if not texts:
        return []

    # Batch-encode all queries at once
    query_embeddings = self.embedder.encode(texts, prefix="search_query: ")
    # Normalize
    norms = np.linalg.norm(query_embeddings, axis=1, keepdims=True)
    query_embeddings = query_embeddings / norms

    # Compute similarities for all queries at once: (n_queries, n_categories)
    similarities = np.dot(query_embeddings, self._embedding_matrix.T)

    results = []
    for i in range(len(texts)):
        best_idx = np.argmax(similarities[i])
        best_score = similarities[i][best_idx]
        if best_score >= self.score_threshold:
            results.append((self.categories[best_idx], float(best_score)))
        else:
            results.append(("", float(best_score)))

    return results

route_with_alternatives(text, top_k=3)

Route text and return top-k alternatives.

Parameters:

Name Type Description Default
text str

Input text to classify

required
top_k int

Number of top alternatives to return

3

Returns:

Type Description
list[tuple[str, float]]

List of (category, score) tuples sorted by score descending

Source code in src/mailtag/semantic_router.py
def route_with_alternatives(self, text: str, top_k: int = 3) -> list[tuple[str, float]]:
    """Route text and return top-k alternatives.

    Args:
        text: Input text to classify
        top_k: Number of top alternatives to return

    Returns:
        List of (category, score) tuples sorted by score descending
    """
    if not self.categories or self._embedding_matrix is None:
        return []

    query_embedding = self.embedder.encode_query(text)
    query_norm = query_embedding / np.linalg.norm(query_embedding)
    similarities = np.dot(self._embedding_matrix, query_norm)

    # Get top-k indices
    top_indices = np.argsort(similarities)[-top_k:][::-1]

    results = []
    for idx in top_indices:
        results.append((self.categories[idx], float(similarities[idx])))

    return results

build_from_examples(category_examples)

Build category embeddings from example texts.

Parameters:

Name Type Description Default
category_examples dict[str, list[str]]

Dict mapping category names to lists of example texts

required
Source code in src/mailtag/semantic_router.py
def build_from_examples(self, category_examples: dict[str, list[str]]) -> None:
    """Build category embeddings from example texts.

    Args:
        category_examples: Dict mapping category names to lists of example texts
    """
    logger.info(f"Building embeddings for {len(category_examples)} categories...")

    for category, examples in category_examples.items():
        if not examples:
            logger.warning(f"No examples for category '{category}', skipping")
            continue

        # Compute embeddings for all examples
        embeddings = self.embedder.encode_documents(examples)

        # Use centroid (mean) as category embedding
        centroid = embeddings.mean(axis=0)
        self.category_embeddings[category] = centroid

        logger.debug(f"Built embedding for '{category}' from {len(examples)} examples")

    self.categories = list(self.category_embeddings.keys())
    self._build_embedding_matrix()
    logger.info(f"Built embeddings for {len(self.categories)} categories")

load_embeddings(path)

Load pre-computed category embeddings from file.

Parameters:

Name Type Description Default
path Path | str

Path to the .npz file containing embeddings

required

Returns:

Type Description
bool

True if loaded successfully, False otherwise

Source code in src/mailtag/semantic_router.py
def load_embeddings(self, path: Path | str) -> bool:
    """Load pre-computed category embeddings from file.

    Args:
        path: Path to the .npz file containing embeddings

    Returns:
        True if loaded successfully, False otherwise
    """
    path = Path(path)
    if not path.exists():
        logger.warning(f"Embeddings file not found: {path}")
        return False

    try:
        data = np.load(path, allow_pickle=True)
        self.category_embeddings = {key: data[key] for key in data.files}
        self.categories = list(self.category_embeddings.keys())
        self._build_embedding_matrix()
        logger.info(f"Loaded embeddings for {len(self.categories)} categories from {path}")
        return True
    except (OSError, json.JSONDecodeError, KeyError, ValueError) as e:
        logger.error(f"Failed to load embeddings from {path}: {e}")
        return False

save_embeddings(path)

Save category embeddings to file.

Parameters:

Name Type Description Default
path Path | str

Path to save the .npz file

required

Returns:

Type Description
bool

True if saved successfully, False otherwise

Source code in src/mailtag/semantic_router.py
def save_embeddings(self, path: Path | str) -> bool:
    """Save category embeddings to file.

    Args:
        path: Path to save the .npz file

    Returns:
        True if saved successfully, False otherwise
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    try:
        np.savez(path, **self.category_embeddings)
        logger.info(f"Saved embeddings for {len(self.categories)} categories to {path}")
        return True
    except (OSError, TypeError, ValueError) as e:
        logger.error(f"Failed to save embeddings to {path}: {e}")
        return False

add_category(category, examples)

Add a new category with examples.

Parameters:

Name Type Description Default
category str

Category name

required
examples list[str]

List of example texts

required
Source code in src/mailtag/semantic_router.py
def add_category(self, category: str, examples: list[str]) -> None:
    """Add a new category with examples.

    Args:
        category: Category name
        examples: List of example texts
    """
    if not examples:
        logger.warning(f"No examples provided for category '{category}'")
        return

    embeddings = self.embedder.encode_documents(examples)
    centroid = embeddings.mean(axis=0)
    self.category_embeddings[category] = centroid

    if category not in self.categories:
        self.categories.append(category)

    self._build_embedding_matrix()
    logger.info(f"Added category '{category}' with {len(examples)} examples")

remove_category(category)

Remove a category.

Parameters:

Name Type Description Default
category str

Category name to remove

required

Returns:

Type Description
bool

True if removed, False if not found

Source code in src/mailtag/semantic_router.py
def remove_category(self, category: str) -> bool:
    """Remove a category.

    Args:
        category: Category name to remove

    Returns:
        True if removed, False if not found
    """
    if category not in self.category_embeddings:
        return False

    del self.category_embeddings[category]
    self.categories.remove(category)
    self._build_embedding_matrix()
    logger.info(f"Removed category '{category}'")
    return True