TechnologyApril 18, 2024

Knowledge Graphs for RAG without a GraphDB

Explore how to build and use knowledge graph extraction and retrieval for question answering
Ben Chambers
Ben ChambersML CTO, DataStax
Knowledge Graphs for RAG without a GraphDB
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document

# Prompt used by LLMGraphTransformer is tuned for Gpt4.
llm = ChatOpenAI(temperature=0, model_name="gpt-4")

llm_transformer = LLMGraphTransformer(llm=llm)

text = """
Marie Curie, was a Polish and naturalised-French physicist and chemist who 
conducted pioneering research on radioactivity.
She was the first woman to win a Nobel Prize, the first person to win a 
Nobel Prize twice, and the only person to win a Nobel Prize in two 
scientific fields.
Her husband, Pierre Curie, was a co-winner of her first Nobel Prize, making 
them the first-ever married couple to win the Nobel Prize and launching 
the Curie family legacy of five Nobel Prizes.
She was, in 1906, the first woman to become a professor at the University of 
Paris.
"""
documents = [Document(page_content=text)]
graph_documents = llm_transformer.convert_to_graph_documents(documents)
print(f"Nodes:{graph_documents[0].nodes}")
print(f"Relationships:{graph_documents[0].relationships}")
QUERY_ENTITY_EXTRACT_PROMPT = (
    "A question is provided below. Given the question, extract up to 5 "
    "entity names and types from the text. Focus on extracting the key 
entities "
    "that we can use to best lookup answers to the question. Avoid 
stopwords.\n"
    "---------------------\n"
    "{question}\n"
    "---------------------\n"
    "{format_instructions}\n"
)

def extract_entities(llm):
    prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt])

    class SimpleNode(BaseModel):
        """Represents a node in a graph with associated properties."""

        id: str = Field(description="Name or human-readable unique identifier.")
        type: str = optional_enum_field(node_types,
                                        description="The type or label of the node.")

    class SimpleNodeList(BaseModel):
        """Represents a list of simple nodes."""

        nodes: List[SimpleNode]

    output_parser = JsonOutputParser(pydantic_object=SimpleNodeList)
    return (
        RunnablePassthrough.assign(
            format_instructions=lambda _: output_parser.get_format_instructions(),
        )
        | ChatPromptTemplate.from_messages([QUERY_ENTITY_EXTRACT_PROMPT])
        | llm
        | output_parser
        | RunnableLambda(
            lambda node_list: [(n["id"], n["type"]) for n in node_list["nodes"]])
    )
# Example showing extracted entities (nodes)
extract_entities(llm).invoke({ "question": "Who is Marie Curie?"})

# Output:
[Marie Curie(Person)]
def _combine_relations(relations):
    return "\n".join(map(repr, relations))

ANSWER_PROMPT = (
    "The original question is given below."
    "This question has been used to retrieve information from a knowledge 
graph."
    "The matching triples are shown below."
    "Use the information in the triples to answer the original 
question.\n\n"
    "Original Question: {question}\n\n"
    "Knowledge Graph Triples:\n{context}\n\n"
    "Response:"
)

chain = (
    { "question": RunnablePassthrough() }
       # extract_entities is provided by the Cassandra knowledge graph 
library
       # and extracts entitise as shown above.
    | RunnablePassthrough.assign(entities = extract_entities(llm))
    | RunnablePassthrough.assign(
        # graph_store.as_runnable() is provided by the CassandraGraphStore
        # and takes one or more entities and retrieves the relevant 
sub-graph(s).
        triples = itemgetter("entities") | graph_store.as_runnable())
    | RunnablePassthrough.assign(
        context = itemgetter("triples") | RunnableLambda(_combine_relations))
    | ChatPromptTemplate.from_messages([ANSWER_PROMPT])
    | llm
)
chain.invoke("Who is Marie Curie?")

# Output
AIMessage(
  content="Marie Curie is a Polish and French chemist, physicist, and 
professor who "
          "researched radioactivity. She was married to Pierre Curie and has 
worked at "
          "the University of Paris. She is also a recipient of the Nobel 
Prize.",
  response_metadata={
    'token_usage': {'completion_tokens': 47, 'prompt_tokens': 213, 
'total_tokens': 260},
    'model_name': 'gpt-4',
    ...
  }
)
def fetch_relation(tg: asyncio.TaskGroup, depth: int, source: Node) -> 
AsyncPagedQuery:
    paged_query = AsyncPagedQuery(
        depth, session.execute_async(query, (source.name, source.type))
    )
    return tg.create_task(paged_query.next())

results = set()
async with asyncio.TaskGroup() as tg:
    if isinstance(start, Node):
        start = [start]

    discovered = {t: 0 for t in start}
    pending = {fetch_relation(tg, 1, source) for source in start}

    while pending:
        done, pending = await asyncio.wait(pending, 
return_when=asyncio.FIRST_COMPLETED)
        for future in done:
            depth, relations, more = future.result()
            for relation in relations:
                   results.add(relation)

            # Schedule the future for more results from the same query.
            if more is not None:
                pending.add(tg.create_task(more.next()))

            # Schedule futures for the next step.
            if depth < steps:
                # We've found a path of length `depth` to each of the targets.
                # We need to update `discovered` to include the shortest path.
                # And build `to_visit` to be all of the targets for which this is
                # the new shortest path.

                to_visit = set()
                for r in relations:
                    previous = discovered.get(r.target, steps + 1)
                    if depth < previous:
                        discovered[r.target] = depth
                        to_visit.add(r.target)

                for source in to_visit:
                    pending.add(fetch_relation(tg, depth + 1, source))

return results
Discover more
DataStax Astra DBRetrieval-augmented generation
Share

One-stop Data API for Production GenAI

Astra DB gives JavaScript developers a complete data API and out-of-the-box integrations that make it easier to build production RAG apps with high relevancy and low latency.