How to Scale GraphRAG with Neo4j for Efficient Document Querying
Note: 9/12/24 - This update replaces the code examples that use EOL (end of life) py2neo
package with the official neo4j
Python driver.
In this tutorial, I will walk through the example implementation of an architecture for a scalable GraphRAG system, which leverages Neo4j to store and manage graph data extracted from documents. We will process documents, extract entities and relationships using OpenAI's GPT-4o models, and store them in a Neo4j graph, making it easier to handle large datasets and answer queries using graph algorithms like centrality. Centrality measures help identify the most important nodes in a graph based on their connections, which is important for retrieving the most relevant information quickly and accurately. In this example, we emphasize centrality-based retrieval over community-based retrieval to improve the relevance of query responses. You can follow along with the complete source code on GitHub.
This guide will cover:
Setting up Neo4j with Docker
Using the class-based design for document processing and graph management
Using centrality measures to improve query responses
Reindexing the graph as new data is added
Prerequisites
Ensure you have the following:
Python 3.9+
Docker
Necessary libraries:
openai
,neo4j
,python-dotenv
You can install these with:
pip install openai neo4j python-dotenv
Additionally, we will use Docker to run a Neo4j instance for managing graph data.
Project Overview
This project is structured using an object-oriented approach, with distinct classes for managing key components. The system processes documents, extracts entities and relationships, and stores them in Neo4j. Using centrality measures, we prioritize the most important entities in the graph, helping improve the accuracy and relevance of query responses.
Project Structure
app.py
: Entry point that orchestrates the document processing and querying workflow.GraphManager
(ingraph_
manager.py
): Manages Neo4j operations such as building graphs, recalculating centrality measures, and managing updates.QueryHandler
(inquery_
handler.py
): Handles user queries and utilizes GPT models to provide responses based on graph data and centrality measures.DocumentProcessor
(indocument_
processor.py
): Splits documents into chunks, extracts entities and relationships, and summarizes them.GraphDatabase
(ingraph_
database.py
): Manages the connection to the Neo4j database.Logger
(inlogger.py
): Provides logging utilities to track the application's progress.
Setting Up Neo4j with Docker
To set up Neo4j locally, run the following commands to build and start the Docker container:
sh build.sh
sh start.sh
This will run a Neo4j instance locally, accessible via http://localhost:7474
and bolt://
localhost:7687
.
Connecting to Neo4j from Python
We will use the py2neo
library to connect to the Neo4j database. The GraphDatabaseConnection
class in graph_
database.py
handles this connection:
from neo4j import GraphDatabase
class GraphDatabaseConnection:
def __init__(self, uri, user, password):
if not uri or not user or not password:
raise ValueError(
"URI, user, and password must be provided to initialize the DatabaseConnection.")
self.driver = GraphDatabase.driver(uri, auth=(user, password))
def close(self):
self.driver.close()
def get_session(self):
return self.driver.session()
def clear_database(self):
with self.get_session() as session:
session.run("MATCH (n) DETACH DELETE n")
Document Processing with DocumentProcessor
The DocumentProcessor
class is responsible for processing documents by splitting them into chunks, extracting key entities and relationships, and summarizing them using OpenAI’s GPT models.
Example: Document Processing
from logger import Logger
class DocumentProcessor:
logger = Logger("DocumentProcessor").get_logger()
def __init__(self, client, model):
self.client = client
self.model = model
def split_documents(self, documents, chunk_size=600, overlap_size=100):
chunks = []
for document in documents:
for i in range(0, len(document), chunk_size - overlap_size):
chunk = document[i:i + chunk_size]
chunks.append(chunk)
self.logger.debug("Documents split into %d chunks", len(chunks))
return chunks
def extract_elements(self, chunks):
elements = []
for index, chunk in enumerate(chunks):
self.logger.debug(
f"Extracting elements and relationship strength from chunk {index + 1}")
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system",
"content": "Extract entities, relationships, and their strength from the following text. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships, and estimate a strength between 0.0 (very weak) and 1.0 (very strong). Format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]. Do not include any other text in your response. Use this exact format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]."},
{"role": "user", "content": chunk}
]
)
entities_and_relations = response.choices[0].message.content
elements.append(entities_and_relations)
self.logger.debug("Elements extracted")
return elements
def summarize_elements(self, elements):
summaries = []
for index, element in enumerate(elements):
self.logger.debug(f"Summarizing element {index + 1}")
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "Summarize the following entities and relationships in a structured format. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships. Use '->' to represent relationships after the 'Relationships:' word."},
{"role": "user", "content": element}
]
)
summary = response.choices[0].message.content
summaries.append(summary)
self.logger.debug("Summaries created")
return summaries
Graph Management with GraphManager
Once the entities and relationships are extracted, they are stored in the graph using the GraphManager
class. This class handles building and reprojecting the graph, calculating centrality measures, and managing updates when new data is added.
Example: Building the Graph
The build_graph
method within the GraphManager
class is responsible for creating nodes and relationships based on document summaries:
from graph_database import GraphDatabaseConnection
from logger import Logger
import re
class GraphManager:
logger = Logger('GraphManager').get_logger()
def __init__(self, db_connection: GraphDatabaseConnection):
self.db_connection = db_connection
self.db_connection.clear_database()
def build_graph(self, summaries):
if self.db_connection is None:
self.logger.error("Graph database connection is not available.")
return
entities = {}
with self.db_connection.get_session() as session:
for summary in summaries:
lines = summary.split("\n")
entities_section = False
relationships_section = False
for line in lines:
if line.startswith("### Entities:") or line.startswith("**Entities:**") or line.startswith("Entities:"):
entities_section = True
relationships_section = False
continue
elif line.startswith("### Relationships:") or line.startswith("**Relationships:**") or line.startswith("Relationships:"):
entities_section = False
relationships_section = True
continue
if entities_section and line.strip():
if line[0].isdigit() and '.' in line:
entity_name = line.split(".", 1)[1].strip()
else:
entity_name = line.strip()
entity_name = self.normalize_entity_name(
entity_name.replace("**", ""))
self.logger.debug(f"Creating node: {entity_name}")
session.run(
"MERGE (e:Entity {name: $name})", name=entity_name)
entities[entity_name] = entity_name
elif relationships_section and line.strip():
parts = line.split("->")
if len(parts) >= 2:
source = self.normalize_entity_name(
parts[0].strip())
target = self.normalize_entity_name(
parts[-1].strip())
relationship_part = parts[1].strip()
relation_name = self.sanitize_relationship_name(
relationship_part.split("[")[0].strip())
strength = re.search(
r"\[strength:\s*(\d\.\d)\]", relationship_part)
weight = float(strength.group(
1)) if strength else 1.0
self.logger.debug(
f"Parsed relationship: {source} -> {relation_name} -> {target} [weight: {weight}]")
if source in entities and target in entities:
if relation_name:
self.logger.debug(
f"Creating relationship: {source} -> {relation_name} -> {target} with weight {weight}")
session.run(
"MATCH (a:Entity {name: $source}), (b:Entity {name: $target}) "
"MERGE (a)-[r:" + relation_name +
" {weight: $weight}]->(b)",
source=source, target=target, weight=weight
)
else:
self.logger.debug(
f"Skipping relationship: {source} -> {relation_name} -> {target} (relation name is empty)")
else:
self.logger.debug(
f"Skipping relationship: {source} -> {relation_name} -> {target} (one or both entities not found)")
# NOTE: More methods in the class, see the full code for details
Centrality Measures for Enhanced Query Responses
The GraphManager
also calculates centrality measures such as degree, betweenness, and closeness centrality. These measures help prioritize key entities in the graph, improving the relevance of query responses.
Example: Calculating Centrality
The calculate_centrality_measures
method calculates centrality for each node in the graph:
def calculate_centrality_measures(self, graph_name="entityGraph"):
self.reproject_graph(graph_name)
with self.db_connection.get_session() as session:
check_query = f"CALL gds.graph.exists($graph_name) YIELD exists"
exists_result = session.run(
check_query, graph_name=graph_name).single()["exists"]
if not exists_result:
raise Exception(
f"Graph projection '{graph_name}' does not exist.")
degree_centrality_query = f"""
CALL gds.degree.stream($graph_name)
YIELD nodeId, score
RETURN gds.util.asNode(nodeId).name AS entityName, score
ORDER BY score DESC
LIMIT 10
"""
degree_centrality_result = session.run(
degree_centrality_query, graph_name=graph_name).data()
betweenness_centrality_query = f"""
CALL gds.betweenness.stream($graph_name)
YIELD nodeId, score
RETURN gds.util.asNode(nodeId).name AS entityName, score
ORDER BY score DESC
LIMIT 10
"""
betweenness_centrality_result = session.run(
betweenness_centrality_query, graph_name=graph_name).data()
closeness_centrality_query = f"""
CALL gds.closeness.stream($graph_name)
YIELD nodeId, score
RETURN gds.util.asNode(nodeId).name AS entityName, score
ORDER BY score DESC
LIMIT 10
"""
closeness_centrality_result = session.run(
closeness_centrality_query, graph_name=graph_name).data()
centrality_data = {
"degree": degree_centrality_result,
"betweenness": betweenness_centrality_result,
"closeness": closeness_centrality_result
}
return centrality_data
Handling Queries with QueryHandler
The QueryHandler
class uses the results from the centrality measures to generate more relevant and accurate responses to user queries by leveraging OpenAI’s GPT models.
Example: Handling Queries
from graph_manager import GraphManager
from openai import OpenAI
from logger import Logger
class QueryHandler:
logger = Logger("QueryHandler").get_logger()
def __init__(self, graph_manager: GraphManager, client: OpenAI, model: str):
self.graph_manager = graph_manager
self.client = client
self.model = model
def ask_question(self, query):
centrality_data = self.graph_manager.calculate_centrality_measures()
centrality_summary = self.graph_manager.summarize_centrality_measures(
centrality_data)
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "Use the centrality measures to answer the following query."},
{"role": "user", "content": f"Query: {query} Centrality Summary: {centrality_summary}"}
]
)
self.logger.debug("Query answered: %s",
response.choices[0].message.content)
final_answer = response.choices[0].message.content
return final_answer
By focusing on the most central entities, the system generates better, more context-aware answers.
Reindexing with New Documents
When new documents are added, the graph is reindexed to update entities, relationships, and centrality measures. The reindex_with_new_documents
function in the root app.py
handles this process:
Example: Reindexing
def reindex_with_new_documents(new_documents, graph_manager: GraphManager):
chunks = document_processor.split_documents(new_documents)
elements_file = 'data/new_elements_data.pkl'
summaries_file = 'data/new_summaries_data.pkl'
elements = load_or_run(
elements_file, document_processor.extract_elements, chunks)
summaries = load_or_run(
summaries_file, document_processor.summarize_elements, elements)
graph_manager.build_graph(summaries)
graph_manager.reproject_graph()
This ensures the graph is up-to-date with new data, and centrality measures are recalculated.
Running the Application
After setting up the environment, run the application:
python app.py
This will:
Index the initial documents.
Process a user query to extract the main themes.
Reindex the graph with new documents.
Answer another query based on the updated graph.
Conclusion
By using Neo4j and taking a class-based approach with clear separation of concerns, we have built a scalable and efficient GraphRAG pipeline. The system can handle larger datasets, leverage graph algorithms to enhance query responses, and continuously update the graph as new data is added.
This design allows you to further extend the system, incorporating additional algorithms or larger datasets, and tailor it to specific business needs.