Skip to content

chroma

ChromaDB implementation for vector database.

Classes:

  • ChromaDB

    ChromaDB implementation for vector database.

ChromaDB

ChromaDB(client: ClientAPI, collection_name: str = 'rago')

Bases: DBBase

ChromaDB implementation for vector database.

Methods:

  • embed

    Embed the documents into the database.

  • search

    Search a query from documents.

Source code in src/rago/augmented/db/chroma.py
15
16
17
18
19
20
21
22
23
def __init__(
    self,
    client: ClientAPI,
    collection_name: str = 'rago',
) -> None:
    """Initialize ChromaDB."""
    self.client = client
    self.collection_name = collection_name
    self._setup()

embed

embed(documents: Any) -> None

Embed the documents into the database.

Source code in src/rago/augmented/db/chroma.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def embed(self, documents: Any) -> None:
    """Embed the documents into the database."""
    if not isinstance(documents, tuple) or len(documents) != 2:
        raise ValueError(
            'documents format must be: (List[str], List[List[float]])'
        )

    documents_list: List[str] = documents[0]
    embeddings_list: List[List[float]] = documents[1]

    # Convert embeddings to numpy array
    embeddings = np.array(embeddings_list, dtype=np.float32)

    self.collection.add(
        documents=documents_list,
        embeddings=embeddings,
        ids=[str(i) for i in range(len(documents_list))],
    )

search

search(
    query_encoded: Any, top_k: int = 2
) -> Tuple[List[float], List[str]]

Search a query from documents.

Source code in src/rago/augmented/db/chroma.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def search(
    self, query_encoded: Any, top_k: int = 2
) -> Tuple[List[float], List[str]]:
    """Search a query from documents."""
    # Convert query_encoded to numpy array
    query_encoded_np = np.array([query_encoded], dtype=np.float32)

    results = self.collection.query(
        query_embeddings=query_encoded_np.tolist(),
        n_results=top_k,
    )

    # Check if keys exist before accessing them
    distances = results.get('distances', [[]])
    ids = results.get('ids', [[]])

    # Ensure distances and ids are not None before indexing
    distances_list: List[float] = distances[0] if distances else []
    ids_list: List[str] = ids[0] if ids else []

    return distances_list, ids_list