📝 Blogs
🍪 Chunking 2M+ files a day for Code Search using Syntax Trees
Kevin Lu – July 3rd, 2023
Initializing any vector store requires chunking large documents for effective search.
Why can’t we just embed entire files? Consider our main API endpoint (opens in a new tab)‘s file:
- Imports
- Constants declarations
- Helper functions
- Business logic for each webhook endpoint
If I search “GitHub Action run”, it should match the section corresponding to the switch case block (opens in a new tab) checking for the “check_runs completed” event. However, this is 20 lines of code out of 400+ lines so even a perfect search algorithm would only say that this is a 5% match. But if we chunked the 400 lines into 20 chunks of 20 lines, this would match the correct switch case block.
How do we get 20-line chunks? Naively we could break up the 400-line file evenly into 20-line chunks.
But this would yield terrible results. Semantically similar code would not stay together and context would be lost. For example, function headers would be separated from their implementation details and the docstrings.
Our current code chunking algorithm chunks 2M+ files a day and is open-sourced (opens in a new tab)!
Constraints 🚧
Most chunkers for RAG-based (retrieval augmented generation) agents cap by token count. For simplicity, we decided to use character count, with a max character count of 1500.
This is because the average token to character ratio for code is about 1:5 and embeddings models cap at 512. Further, 1500 characters correspond approximately to 40 lines, which work out to a small to medium sized function or class.
The problem becomes getting as close to 1500 characters as possible, while still ensuring the chunks are semantically similar and all relevant contexts are unseparated.
Out of the Box Solution 📦
The easiest out-of-the-box solution for code chunking is Langchain’s recursive chunker (opens in a new tab). At a high level:
- Break the text using the top-level delimiter (firstly using classes, then function definitions, then function methods etc.)
- Loop through each section and greedily concatenate them until it breaks the character limit. For chunks that are too big, recursively chunk the section starting with the next-level delimiter.
delimiters = ["nclass ", "ndef ", "ntdef ", "nn", "n", " ", ""]
def chunk(text: str, delimiter_index: int = 0, MAX_CHARS: int = 1500) -> list[str]:
delimiter = delimiters[delimiter_index]
new_chunks = []
current_chunk = ""
for section in text.split(delimiter):
if len(section) > MAX_CHARS:
# Section is too big, recursively chunk this section
new_chunks.append(current_chunk)
current_chunk = ""
new_chunks.extend(chunk(section, delimiter_index + 1, MAX_CHARS)
elif len(current_chunk) + len(section) > MAX_CHARS:
# Current chunk is max size
new_chunks.append(current_chunk)
current_chunk = section
else:
# Concatenate section to current_chunk
current_chunk += section
return new_chunks
For each language, we would also use different delimiters.
Examples
For full files of the examples, see https://gist.github.com/kevinlu1248/ded3ea33dcd8a9bd08078f4c64eb9268 (opens in a new tab).
Example #1
Based on our on_check_suite.py
file for handling GitHub Action runs. A bad split separating a string concatenation declaration from it’s contents. ❌
...
def on_check_suite(request: CheckRunCompleted):
logger.info(f"Received check run completed event for {request.repository.full_name}")
g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
if not get_gha_enabled(repo):
logger.info(f"Skipping github action for {request.repository.full_name} because it is not enabled")
return None
pr = repo.get_pull(request.check_run.pull_requests[0].number)
num_pr_commits = len(list(pr.get_commits()))
if num_pr_commits > 20:
logger.info(f"Skipping github action for PR with {num_pr_commits} commits")
return None
logger.info(f"Running github action for PR with {num_pr_commits} commits")
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id
)
if not logs:
return None
logs = clean_logs(logs)
extractor = GHAExtractor()
logger.info(f"Extracting logs from {request.repository.full_name}, logs: {logs}")
problematic_logs = extractor.gha_extract(logs)
if problematic_logs.count("
") > 15:
problematic_logs += "
========================================
There are a lot of errors. This is likely a larger issue with the PR and not a small linting/type-checking issue."
comments = list(pr.get_issue_comments())
if len(comments) >= 2 and problematic_logs == comments[-1].body and comments[-2].body == comments[-1].body:
comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs) + "
I'm getting the same errors 3 times in a row, so I will stop working on fixing this PR.")
logger.warning("Skipping logs because it is duplicated")
raise Exception("Duplicate error logs")
print(problematic_logs)
comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs))
on_comment(
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
comment=problematic_logs,
pr_path=None,
pr_line_position=None,
username=request.sender.login,
installation_id=request.installation.id,
pr_number=request.check_run.pull_requests[0].number,
comment_id=comment.id,
repo=repo,
)
return {"success": True}
Example #2
Based on BaseIndex.ts
file from LlamaIndex declaring the ABC for vector stores. A bad split separates a class method from its header. ❌
...
export class IndexDict extends IndexStruct {
nodesDict: Record<string, BaseNode> = {};
docStore: Record<string, Document> = {}; // FIXME: this should be implemented in storageContext
type: IndexStructType = IndexStructType.SIMPLE_DICT;
========================================
getSummary(): string {
if (this.summary === undefined) {
throw new Error("summary field of the index dict is not set");
}
return this.summary;
}
addNode(node: BaseNode, textId?: string) {
const vectorId = textId ?? node.id_;
this.nodesDict[vectorId] = node;
}
toJson(): Record<string, unknown> {
return {
...super.toJson(),
nodesDict: this.nodesDict,
type: this.type,
};
}
}
...
Problems 🤔
However, it comes with some major problems:
- This chunker decently for Python but breaks curly-bracket-heavy languages like JS and XML-based languages like HTML in unexpected ways.
- Further,
str.split
does not work well for these more complex syntaxes like JS and HTML. - E.g. Even for Python, it broke the problematic logs line by splitting
problematic_logs += "
and the rest of the string
- Further,
- Only 16 languages are currently supported, without support for JSX, Typescript, EJS and C#.
- JSX/TSX makes up the majority of our userbase
- Langchain deletes important delimiters such as “def” and “class”.
Our Solution 🧠
The inherent problem is that iterative str.split
with different delimiters is a primitive method for approximating concrete syntax trees (CST).
To solve this, we decided to just use CST parsers. But how do we get CST parsers for a large number of languages? Thankfully, the library tree-sitter (opens in a new tab) provides a standardized way to access 113 CST-parsers for programming languages and is fast (written in C) and dependency-free.
The new algorithm is fairly similar to the Langchain algorithm and is as follows:
- To chunk a parent node, we iterate through its children and greedily bundle them together. For each child node:
- If the current chunk is too big, we add that to our list of chunks and empty the bundle
- If the next child node is too big, we recursively chunk the child node and add it to the list of chunks
- Otherwise, concatenate the current chunk with the child node
- Post-process the final result by combining single-line chunks with the next chunk.
- This guarantees that there are no chunks that are too small since they yield less meaningful results
from tree_sitter import Node
def chunk_node(node: Node, text: str, MAX_CHARS: int = 1500) -> list[str]:
new_chunks = []
current_chunk = ""
for child in node.children:
if child.end_byte - child.start_byte > MAX_CHARS:
new_chunks.append(current_chunk)
current_chunk = ""
new_chunks.extend(chunk_node(child, text, MAX_CHARS)
elif > MAX_CHARS:
new_chunks.append(current_chunk)
current_chunk = text[node.start_byte:node.end_byte]
else:
current_chunk += text[node.start_byte:node.end_byte]
return new_chunks
Example
Full chunks can be found at https://gist.github.com/kevinlu1248/49a72a1978868775109c5627677dc512 (opens in a new tab)
Example #1
Based on our on_check_suite.py
file for handling GitHub Action runs. Correct splitting, also splitting before an if statement instead of separating the if-statement from the body. ✅
...
def on_check_suite(request: CheckRunCompleted):
logger.info(f"Received check run completed event for {request.repository.full_name}")
g = get_github_client(request.installation.id)
repo = g.get_repo(request.repository.full_name)
if not get_gha_enabled(repo):
logger.info(f"Skipping github action for {request.repository.full_name} because it is not enabled")
return None
pr = repo.get_pull(request.check_run.pull_requests[0].number)
num_pr_commits = len(list(pr.get_commits()))
if num_pr_commits > 20:
logger.info(f"Skipping github action for PR with {num_pr_commits} commits")
return None
logger.info(f"Running github action for PR with {num_pr_commits} commits")
logs = download_logs(
request.repository.full_name,
request.check_run.run_id,
request.installation.id
)
if not logs:
return None
logs = clean_logs(logs)
extractor = GHAExtractor()
logger.info(f"Extracting logs from {request.repository.full_name}, logs: {logs}")
problematic_logs = extractor.gha_extract(logs)
if problematic_logs.count("n") > 15:
problematic_logs += "nnThere are a lot of errors. This is likely a larger issue with the PR and not a small linting/type-checking issue."
comments = list(pr.get_issue_comments())
==========
if len(comments) >= 2 and problematic_logs == comments[-1].body and comments[-2].body == comments[-1].body:
comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs) + "nnI'm getting the same errors 3 times in a row, so I will stop working on fixing this PR.")
logger.warning("Skipping logs because it is duplicated")
raise Exception("Duplicate error logs")
print(problematic_logs)
comment = pr.as_issue().create_comment(log_message.format(error_logs=problematic_logs))
on_comment(
repo_full_name=request.repository.full_name,
repo_description=request.repository.description,
com