Technology•April 18, 2024
Knowledge Graphs for RAG without a GraphDB
Explore how to build and use knowledge graph extraction and retrieval for question answering
from langchain_experimental.graph_transformers import LLMGraphTransformer from langchain_openai import ChatOpenAI from langchain_core.documents import Document # Prompt used by LLMGraphTransformer is tuned for Gpt4. llm = ChatOpenAI(temperature=0, model_name="gpt-4") llm_transformer = LLMGraphTransformer(llm=llm) text = """ Marie Curie, was a Polish and naturalised-French physicist and chemist who conducted pioneering research on radioactivity. She was the first woman to win a Nobel Prize, the first person to win a Nobel Prize twice, and the only person to win a Nobel Prize in two scientific fields. Her husband, Pierre Curie, was a co-winner of her first Nobel Prize, making them the first-ever married couple to win the Nobel Prize and launching the Curie family legacy of five Nobel Prizes. She was, in 1906, the first woman to become a professor at the University of Paris. """ documents = [Document(page_content=text)] graph_documents = llm_transformer.convert_to_graph_documents(documents) print(f"Nodes:{graph_documents[0].nodes}") print(f"Relationships:{graph_documents[0].relationships}")
QUERY_ENTITY_EXTRACT_PROMPT = ( "A question is provided below. Given the question, extract up to 5 " "entity names and types from the text. Focus on extracting the key entities " "that we can use to best lookup answers to the question. Avoid stopwords.\n" "---------------------\n" "{question}\n" "---------------------\n" "{format_instructions}\n" ) def extract_entities(llm): prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt]) class SimpleNode(BaseModel): """Represents a node in a graph with associated properties.""" id: str = Field(description="Name or human-readable unique identifier.") type: str = optional_enum_field(node_types, description="The type or label of the node.") class SimpleNodeList(BaseModel): """Represents a list of simple nodes.""" nodes: List[SimpleNode] output_parser = JsonOutputParser(pydantic_object=SimpleNodeList) return ( RunnablePassthrough.assign( format_instructions=lambda _: output_parser.get_format_instructions(), ) | ChatPromptTemplate.from_messages([QUERY_ENTITY_EXTRACT_PROMPT]) | llm | output_parser | RunnableLambda( lambda node_list: [(n["id"], n["type"]) for n in node_list["nodes"]]) )
# Example showing extracted entities (nodes) extract_entities(llm).invoke({ "question": "Who is Marie Curie?"}) # Output: [Marie Curie(Person)]
def _combine_relations(relations): return "\n".join(map(repr, relations)) ANSWER_PROMPT = ( "The original question is given below." "This question has been used to retrieve information from a knowledge graph." "The matching triples are shown below." "Use the information in the triples to answer the original question.\n\n" "Original Question: {question}\n\n" "Knowledge Graph Triples:\n{context}\n\n" "Response:" ) chain = ( { "question": RunnablePassthrough() } # extract_entities is provided by the Cassandra knowledge graph library # and extracts entitise as shown above. | RunnablePassthrough.assign(entities = extract_entities(llm)) | RunnablePassthrough.assign( # graph_store.as_runnable() is provided by the CassandraGraphStore # and takes one or more entities and retrieves the relevant sub-graph(s). triples = itemgetter("entities") | graph_store.as_runnable()) | RunnablePassthrough.assign( context = itemgetter("triples") | RunnableLambda(_combine_relations)) | ChatPromptTemplate.from_messages([ANSWER_PROMPT]) | llm )
chain.invoke("Who is Marie Curie?") # Output AIMessage( content="Marie Curie is a Polish and French chemist, physicist, and professor who " "researched radioactivity. She was married to Pierre Curie and has worked at " "the University of Paris. She is also a recipient of the Nobel Prize.", response_metadata={ 'token_usage': {'completion_tokens': 47, 'prompt_tokens': 213, 'total_tokens': 260}, 'model_name': 'gpt-4', ... } )
def fetch_relation(tg: asyncio.TaskGroup, depth: int, source: Node) -> AsyncPagedQuery: paged_query = AsyncPagedQuery( depth, session.execute_async(query, (source.name, source.type)) ) return tg.create_task(paged_query.next()) results = set() async with asyncio.TaskGroup() as tg: if isinstance(start, Node): start = [start] discovered = {t: 0 for t in start} pending = {fetch_relation(tg, 1, source) for source in start} while pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for future in done: depth, relations, more = future.result() for relation in relations: results.add(relation) # Schedule the future for more results from the same query. if more is not None: pending.add(tg.create_task(more.next())) # Schedule futures for the next step. if depth < steps: # We've found a path of length `depth` to each of the targets. # We need to update `discovered` to include the shortest path. # And build `to_visit` to be all of the targets for which this is # the new shortest path. to_visit = set() for r in relations: previous = discovered.get(r.target, steps + 1) if depth < previous: discovered[r.target] = depth to_visit.add(r.target) for source in to_visit: pending.add(fetch_relation(tg, depth + 1, source)) return results