How to Scale GraphRAG with Neo4j for Efficient Document Querying

Featured on Hashnode

Note: 9/12/24 - This update replaces the code examples that use EOL (end of life) py2neo package with the official neo4j Python driver.

In this tutorial, I will walk through the example implementation of an architecture for a scalable GraphRAG system, which leverages Neo4j to store and manage graph data extracted from documents. We will process documents, extract entities and relationships using OpenAI's GPT-4o models, and store them in a Neo4j graph, making it easier to handle large datasets and answer queries using graph algorithms like centrality. Centrality measures help identify the most important nodes in a graph based on their connections, which is important for retrieving the most relevant information quickly and accurately. In this example, we emphasize centrality-based retrieval over community-based retrieval to improve the relevance of query responses. You can follow along with the complete source code on GitHub.

This guide will cover:

  • Setting up Neo4j with Docker

  • Using the class-based design for document processing and graph management

  • Using centrality measures to improve query responses

  • Reindexing the graph as new data is added

Prerequisites

Ensure you have the following:

  • Python 3.9+

  • Docker

  • Necessary libraries: openai, neo4j, python-dotenv

You can install these with:

pip install openai neo4j python-dotenv

Additionally, we will use Docker to run a Neo4j instance for managing graph data.

Project Overview

This project is structured using an object-oriented approach, with distinct classes for managing key components. The system processes documents, extracts entities and relationships, and stores them in Neo4j. Using centrality measures, we prioritize the most important entities in the graph, helping improve the accuracy and relevance of query responses.

Project Structure

  • app.py: Entry point that orchestrates the document processing and querying workflow.

  • GraphManager (in graph_manager.py): Manages Neo4j operations such as building graphs, recalculating centrality measures, and managing updates.

  • QueryHandler (in query_handler.py): Handles user queries and utilizes GPT models to provide responses based on graph data and centrality measures.

  • DocumentProcessor (in document_processor.py): Splits documents into chunks, extracts entities and relationships, and summarizes them.

  • GraphDatabase (in graph_database.py): Manages the connection to the Neo4j database.

  • Logger (in logger.py): Provides logging utilities to track the application's progress.

Setting Up Neo4j with Docker

To set up Neo4j locally, run the following commands to build and start the Docker container:

sh build.sh
sh start.sh

This will run a Neo4j instance locally, accessible via http://localhost:7474 and bolt://localhost:7687.

Connecting to Neo4j from Python

We will use the py2neo library to connect to the Neo4j database. The GraphDatabaseConnection class in graph_database.py handles this connection:

from neo4j import GraphDatabase


class GraphDatabaseConnection:
    def __init__(self, uri, user, password):
        if not uri or not user or not password:
            raise ValueError(
                "URI, user, and password must be provided to initialize the DatabaseConnection.")
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def get_session(self):
        return self.driver.session()

    def clear_database(self):
        with self.get_session() as session:
            session.run("MATCH (n) DETACH DELETE n")

Document Processing with DocumentProcessor

The DocumentProcessor class is responsible for processing documents by splitting them into chunks, extracting key entities and relationships, and summarizing them using OpenAI’s GPT models.

Example: Document Processing

from logger import Logger


class DocumentProcessor:
    logger = Logger("DocumentProcessor").get_logger()

    def __init__(self, client, model):
        self.client = client
        self.model = model

    def split_documents(self, documents, chunk_size=600, overlap_size=100):
        chunks = []
        for document in documents:
            for i in range(0, len(document), chunk_size - overlap_size):
                chunk = document[i:i + chunk_size]
                chunks.append(chunk)
        self.logger.debug("Documents split into %d chunks", len(chunks))
        return chunks

    def extract_elements(self, chunks):
        elements = []
        for index, chunk in enumerate(chunks):
            self.logger.debug(
                f"Extracting elements and relationship strength from chunk {index + 1}")
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system",
                        "content": "Extract entities, relationships, and their strength from the following text. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships, and estimate a strength between 0.0 (very weak) and 1.0 (very strong). Format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]. Do not include any other text in your response. Use this exact format: Parsed relationship: Entity1 -> Relationship -> Entity2 [strength: X.X]."},
                    {"role": "user", "content": chunk}
                ]
            )
            entities_and_relations = response.choices[0].message.content
            elements.append(entities_and_relations)
        self.logger.debug("Elements extracted")
        return elements

    def summarize_elements(self, elements):
        summaries = []
        for index, element in enumerate(elements):
            self.logger.debug(f"Summarizing element {index + 1}")
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "Summarize the following entities and relationships in a structured format. Use common terms such as 'related to', 'depends on', 'influences', etc., for relationships. Use '->' to represent relationships after the 'Relationships:' word."},
                    {"role": "user", "content": element}
                ]
            )
            summary = response.choices[0].message.content
            summaries.append(summary)
        self.logger.debug("Summaries created")
        return summaries

Graph Management with GraphManager

Once the entities and relationships are extracted, they are stored in the graph using the GraphManager class. This class handles building and reprojecting the graph, calculating centrality measures, and managing updates when new data is added.

Example: Building the Graph

The build_graph method within the GraphManager class is responsible for creating nodes and relationships based on document summaries:

from graph_database import GraphDatabaseConnection
from logger import Logger
import re


class GraphManager:
    logger = Logger('GraphManager').get_logger()

    def __init__(self, db_connection: GraphDatabaseConnection):
        self.db_connection = db_connection
        self.db_connection.clear_database()

    def build_graph(self, summaries):
        if self.db_connection is None:
            self.logger.error("Graph database connection is not available.")
            return

        entities = {}

        with self.db_connection.get_session() as session:
            for summary in summaries:
                lines = summary.split("\n")
                entities_section = False
                relationships_section = False

                for line in lines:
                    if line.startswith("### Entities:") or line.startswith("**Entities:**") or line.startswith("Entities:"):
                        entities_section = True
                        relationships_section = False
                        continue
                    elif line.startswith("### Relationships:") or line.startswith("**Relationships:**") or line.startswith("Relationships:"):
                        entities_section = False
                        relationships_section = True
                        continue

                    if entities_section and line.strip():
                        if line[0].isdigit() and '.' in line:
                            entity_name = line.split(".", 1)[1].strip()
                        else:
                            entity_name = line.strip()
                        entity_name = self.normalize_entity_name(
                            entity_name.replace("**", ""))
                        self.logger.debug(f"Creating node: {entity_name}")
                        session.run(
                            "MERGE (e:Entity {name: $name})", name=entity_name)
                        entities[entity_name] = entity_name

                    elif relationships_section and line.strip():
                        parts = line.split("->")
                        if len(parts) >= 2:
                            source = self.normalize_entity_name(
                                parts[0].strip())
                            target = self.normalize_entity_name(
                                parts[-1].strip())

                            relationship_part = parts[1].strip()
                            relation_name = self.sanitize_relationship_name(
                                relationship_part.split("[")[0].strip())
                            strength = re.search(
                                r"\[strength:\s*(\d\.\d)\]", relationship_part)
                            weight = float(strength.group(
                                1)) if strength else 1.0

                            self.logger.debug(
                                f"Parsed relationship: {source} -> {relation_name} -> {target} [weight: {weight}]")
                            if source in entities and target in entities:
                                if relation_name:
                                    self.logger.debug(
                                        f"Creating relationship: {source} -> {relation_name} -> {target} with weight {weight}")
                                    session.run(
                                        "MATCH (a:Entity {name: $source}), (b:Entity {name: $target}) "
                                        "MERGE (a)-[r:" + relation_name +
                                        " {weight: $weight}]->(b)",
                                        source=source, target=target, weight=weight
                                    )
                                else:
                                    self.logger.debug(
                                        f"Skipping relationship: {source} -> {relation_name} -> {target} (relation name is empty)")
                            else:
                                self.logger.debug(
                                    f"Skipping relationship: {source} -> {relation_name} -> {target} (one or both entities not found)")

# NOTE: More methods in the class, see the full code for details

Centrality Measures for Enhanced Query Responses

The GraphManager also calculates centrality measures such as degree, betweenness, and closeness centrality. These measures help prioritize key entities in the graph, improving the relevance of query responses.

Example: Calculating Centrality

The calculate_centrality_measures method calculates centrality for each node in the graph:

def calculate_centrality_measures(self, graph_name="entityGraph"):
    self.reproject_graph(graph_name)

    with self.db_connection.get_session() as session:
        check_query = f"CALL gds.graph.exists($graph_name) YIELD exists"
        exists_result = session.run(
            check_query, graph_name=graph_name).single()["exists"]

        if not exists_result:
            raise Exception(
                f"Graph projection '{graph_name}' does not exist.")

        degree_centrality_query = f"""
        CALL gds.degree.stream($graph_name)
        YIELD nodeId, score
        RETURN gds.util.asNode(nodeId).name AS entityName, score
        ORDER BY score DESC
        LIMIT 10
        """
        degree_centrality_result = session.run(
            degree_centrality_query, graph_name=graph_name).data()

        betweenness_centrality_query = f"""
        CALL gds.betweenness.stream($graph_name)
        YIELD nodeId, score
        RETURN gds.util.asNode(nodeId).name AS entityName, score
        ORDER BY score DESC
        LIMIT 10
        """
        betweenness_centrality_result = session.run(
            betweenness_centrality_query, graph_name=graph_name).data()

        closeness_centrality_query = f"""
        CALL gds.closeness.stream($graph_name)
        YIELD nodeId, score
        RETURN gds.util.asNode(nodeId).name AS entityName, score
        ORDER BY score DESC
        LIMIT 10
        """
        closeness_centrality_result = session.run(
            closeness_centrality_query, graph_name=graph_name).data()

        centrality_data = {
            "degree": degree_centrality_result,
            "betweenness": betweenness_centrality_result,
            "closeness": closeness_centrality_result
        }

        return centrality_data

Handling Queries with QueryHandler

The QueryHandler class uses the results from the centrality measures to generate more relevant and accurate responses to user queries by leveraging OpenAI’s GPT models.

Example: Handling Queries

from graph_manager import GraphManager
from openai import OpenAI
from logger import Logger


class QueryHandler:
    logger = Logger("QueryHandler").get_logger()

    def __init__(self, graph_manager: GraphManager, client: OpenAI, model: str):
        self.graph_manager = graph_manager
        self.client = client
        self.model = model

    def ask_question(self, query):
        centrality_data = self.graph_manager.calculate_centrality_measures()
        centrality_summary = self.graph_manager.summarize_centrality_measures(
            centrality_data)

        response = self.client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "Use the centrality measures to answer the following query."},
                {"role": "user", "content": f"Query: {query} Centrality Summary: {centrality_summary}"}
            ]
        )
        self.logger.debug("Query answered: %s",
                          response.choices[0].message.content)
        final_answer = response.choices[0].message.content
        return final_answer

By focusing on the most central entities, the system generates better, more context-aware answers.


Reindexing with New Documents

When new documents are added, the graph is reindexed to update entities, relationships, and centrality measures. The reindex_with_new_documents function in the root app.py handles this process:

Example: Reindexing

def reindex_with_new_documents(new_documents, graph_manager: GraphManager):
    chunks = document_processor.split_documents(new_documents)
    elements_file = 'data/new_elements_data.pkl'
    summaries_file = 'data/new_summaries_data.pkl'

    elements = load_or_run(
        elements_file, document_processor.extract_elements, chunks)
    summaries = load_or_run(
        summaries_file, document_processor.summarize_elements, elements)

    graph_manager.build_graph(summaries)
    graph_manager.reproject_graph()

This ensures the graph is up-to-date with new data, and centrality measures are recalculated.


Running the Application

After setting up the environment, run the application:

python app.py

This will:

  1. Index the initial documents.

  2. Process a user query to extract the main themes.

  3. Reindex the graph with new documents.

  4. Answer another query based on the updated graph.


Conclusion

By using Neo4j and taking a class-based approach with clear separation of concerns, we have built a scalable and efficient GraphRAG pipeline. The system can handle larger datasets, leverage graph algorithms to enhance query responses, and continuously update the graph as new data is added.

This design allows you to further extend the system, incorporating additional algorithms or larger datasets, and tailor it to specific business needs.