Skip to content

sentence_transformer

Classes for augmentation with hugging face.

Classes:

SentenceTransformerAug

SentenceTransformerAug(
    model_name: Optional[str] = None,
    db: DBBase = FaissDB(),
    top_k: Optional[int] = None,
    api_key: str = '',
    cache: Optional[Cache] = None,
    logs: dict[str, Any] = DEFAULT_LOGS,
)

Bases: AugmentedBase

Class for augmentation with Hugging Face.

Methods:

  • get_embedding

    Retrieve the embedding for a given text using OpenAI API.

  • search

    Search an encoded query into vector database.

Source code in src/rago/augmented/base.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(
    self,
    model_name: Optional[str] = None,
    db: DBBase = FaissDB(),
    top_k: Optional[int] = None,
    api_key: str = '',
    cache: Optional[Cache] = None,
    logs: dict[str, Any] = DEFAULT_LOGS,
) -> None:
    """Initialize AugmentedBase."""
    if logs is DEFAULT_LOGS:
        logs = {}
    super().__init__(api_key=api_key, cache=cache, logs=logs)

    self.db = db

    self.top_k = top_k if top_k is not None else self.default_top_k
    self.model_name = (
        model_name if model_name is not None else self.default_model_name
    )
    self.model = None

    self._validate()
    self._setup()

get_embedding

get_embedding(content: list[str]) -> EmbeddingType

Retrieve the embedding for a given text using OpenAI API.

Source code in src/rago/augmented/sentence_transformer.py
24
25
26
27
def get_embedding(self, content: list[str]) -> EmbeddingType:
    """Retrieve the embedding for a given text using OpenAI API."""
    model = cast(SentenceTransformer, self.model)
    return model.encode(content)

search

search(
    query: str, documents: Any, top_k: int = 0
) -> list[str]

Search an encoded query into vector database.

Source code in src/rago/augmented/sentence_transformer.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def search(self, query: str, documents: Any, top_k: int = 0) -> list[str]:
    """Search an encoded query into vector database."""
    if not self.model:
        raise Exception('The model was not created.')

    document_encoded = self.get_embedding(documents)
    query_encoded = self.get_embedding([query])
    top_k = top_k or self.top_k or self.default_top_k or 1

    self.db.embed(document_encoded)

    scores, indices = self.db.search(query_encoded, top_k=top_k)

    retrieved_docs = [documents[i] for i in indices]

    self.logs['indices'] = indices
    self.logs['scores'] = scores
    self.logs['search_params'] = {
        'query_encoded': query_encoded,
        'top_k': top_k,
    }

    return retrieved_docs