Code Search with Vector Embeddings: A Transformer's Approach

In today's fast-paced development world, navigating through large codebases can be a daunting task. Wouldn't it be great if you could search through your codebase using natural language queries? In this blog post, we'll walk you through a basic Python script that does just that, albeit in a simple manner, leveraging the power of transformer models.

Laying the Groundwork

The core challenge we're addressing is transforming raw code snippets from our codebase into meaningful vector representations, known as embeddings. These embeddings capture the essence of the code in a format that can be compared for similarity. By doing so, when we pose a natural language query, the system can sift through these embeddings, identify the most similar code snippets, and present them as relevant "answers" to our query.

Setup

Before you can run the script, you'll need to set up your environment. Here's a step-by-step guide:

  1. Python Environment: Ensure you have Python 3.7 or newer installed. You can check your Python version with python --version.

  2. Install Required Libraries: You can install all the necessary libraries using pip:

     pip install numpy torch transformers
    
  3. Clone the Repository: The entire code, along with some sample codebases to test on, is available on GitHub. Clone the repository to your local machine:

     git clone git@github.com:stephenc222/example-vectorize-codebase.git
    
  4. Run the Script: Navigate to the directory containing the script and run:

     python app.py
    

For a more detailed walkthrough, including potential customizations and optimizations, check out the companion GitHub repository.

The Building Blocks

Our script uses the following libraries:

  • os and numpy: Basic Python libraries for file handling and numerical operations.

  • torch and torch.nn.functional: PyTorch libraries for tensor operations.

  • transformers: The Hugging Face library, which provides pre-trained transformer models.

import os
import numpy as np
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

Loading the Codebase

The load_codebase function recursively navigates through the specified directory, filtering out unwanted files and directories. It then reads the content of the allowed files and appends them to a list of code snippets.

CODEBASE_DIR = "./example-codebase"
IGNORED_DIRECTORIES = ["node_modules", "public/build"]
IGNORED_FILES = ["package-lock.json", "yarn.lock"]
ALLOWED_EXTENSIONS = [".ts", ".tsx"]

IMAGE_EXTENSIONS = [
    ".png",
    ".jpg",
    ".jpeg",
    ".gif",
    ".bmp",
    ".svg",
    ".ico",
]


def load_codebase(directory):
    snippets = []
    for filename in os.listdir(directory):
        # Skip hidden files and directories
        if filename.startswith('.'):
            continue

        filepath = os.path.join(directory, filename)

        if os.path.isdir(filepath):
            # If it's a directory, recursively load its contents
            snippets.extend(load_codebase(filepath))
        else:
            if any(ignored in filepath for ignored in IGNORED_DIRECTORIES):
                continue
            if filename in IGNORED_FILES:
                continue
            if not any(filepath.endswith(ext) for ext in ALLOWED_EXTENSIONS):
                continue

            with open(filepath, 'r') as file:
                content = file.read().strip()
                if content:  # Check if content is not empty
                    snippets.append(content)
    return snippets

Generating Embeddings with Transformers

The heart of our script is the generate_embeddings function. Here it is:


def generate_embeddings(snippets):
    prefix = "query: "  # Assuming all code snippets are queries
    input_texts = [prefix + snippet for snippet in snippets]

    tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
    model = AutoModel.from_pretrained('thenlper/gte-base')

    batch_dict = tokenizer(input_texts, max_length=512,
                           padding=True, truncation=True, return_tensors='pt')
    outputs = model(**batch_dict)
    embeddings = average_pool(
        outputs.last_hidden_state, batch_dict['attention_mask'])

The average_pool function it uses:

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

And here is how it all works:

  • Tokenization: We prefix each code snippet with "query: " and tokenize it using the AutoTokenizer from the Hugging Face library. This prepares our text for the transformer model.

  • Model Inference: We use a pre-trained transformer model (AutoModel) to generate embeddings for our tokenized code snippets. The model returns the last hidden states for each token.

  • Pooling: Since we want a single vector representation for each code snippet, we use the average_pool function to average out the token embeddings. This gives us a fixed-size vector for each code snippet.

  • Normalization: Finally, we normalize the embeddings to ensure they have a magnitude of 1. This is crucial for calculating cosine similarities later on.

Finding the Nearest Neighbors

The find_k_nearest_neighbors function calculates the cosine similarity between the query embedding and all code snippet embeddings. Since our embeddings are normalized, a simple dot product gives us the cosine similarity. The function then returns the indices of the top-k most similar code snippets.

def find_k_nearest_neighbors(query_embedding, embeddings, k=5):
    # Using cosine similarity as embeddings are normalized
    similarities = np.dot(embeddings, query_embedding.T)
    sorted_indices = similarities.argsort(axis=0)[-k:][::-1]
    return sorted_indices.squeeze()

Bringing It All Together

In the __main__ block, we:

  • Load our codebase using load_codebase.

  • Generate embeddings for all code snippets using generate_embeddings.

  • Generate an embedding for our query.

  • Find the nearest neighbors using find_k_nearest_neighbors.

Finally, we print out the top matches to see the most relevant pieces of code for our query.

if __name__ == "__main__":
    snippets = load_codebase(CODEBASE_DIR)
    embeddings = generate_embeddings(snippets)

    # example query
    query = "Where are the rules of sudoku defined?"
    query_embedding = generate_embeddings([query])
    nearest_neighbors = find_k_nearest_neighbors(query_embedding, embeddings)
    top_matches = nearest_neighbors[:2]
    print("Query:", query)
    print("Top Matches:")
    for index in top_matches:
        # print the first 500 characters to illustrate the found match
        print(f"- Matched Code:\n{snippets[index][:500]}...\n")

Next Steps

While our example provides a foundational understanding of code search with vector embeddings, it's essential to recognize that we've worked with a relatively small codebase. In real-world scenarios, especially with extensive and complex codebases, there are additional considerations and optimizations to be made.

Finetuning Embedding Models

While pre-trained models offer a great starting point, they might not always capture the nuances of specific domains or applications. By finetuning these models on domain-specific data, we can achieve better performance. For instance, if you have a codebase primarily in a specific programming language or related to a particular domain (like web development or data science), finetuning your model on similar code snippets can enhance its understanding and, consequently, its search accuracy.

Vector Databases

As the size of the codebase grows, storing and searching through embeddings in memory becomes inefficient. This is where vector databases come into play. Tools like Milvus, Faiss, and others are designed to handle large-scale vector data and provide efficient similarity search capabilities. I've wrtten about how to also use sqlite to store vector embeddings. By integrating a vector database, you can scale your code search tool to handle much larger codebases without compromising on search speed.

Chunking the Codebase

Another challenge with large codebases is memory consumption during the embedding generation phase. One way to address this is by "chunking" the codebase. Instead of processing the entire codebase at once, you can divide it into smaller chunks and process each chunk separately. This ensures that only a part of the codebase is in memory at any given time, reducing memory overhead and making the embedding generation process more manageable.

By implementing chunking, you can sequentially generate embeddings for each chunk and then store them in the vector database. This approach not only conserves memory but also allows for parallel processing, where multiple chunks can be processed simultaneously on different cores or machines.

Conclusion

Through the ability of transformer models, we've built a simple tool to demonstrate search through a codebase using natural language queries. This approach can be a game-changer for large projects, making it easier for developers to find relevant code snippets and understand the codebase faster. The next time you're lost in a sea of code, remember that transformers might just be the compass you need!

The entire codebase for this blog post is available on my GitHub repository. Feel free to fork, star, or open issues!