A vector database that uses the local file system for storage.
npm install vectraMEMORY:
SOURCE: src\index.ts
DETAILS: export * from './FileFetcher';
export * from './GPT3Tokenizer';
export * from './ItemSelector';
export * from './LocalIndex';
export * from './LocalDocument';
export * from './LocalDocumentIndex';
export * from './LocalDocumentResult';
export * from './OpenAIEmbeddings';
export * from './TextSplitter';
export * from './types';
export * from './WebFetcher';
MEMORY:
SOURCE: src\LocalIndex.spec.ts
DETAILS: import assert from 'node:assert'
import sinon from 'sinon'
import { LocalIndex } from './LocalIndex'
import { IndexItem } from './types'
import fs from 'fs/promises'
import path from 'path'
describe('LocalIndex', () => {
const testIndexDir = path.join(__dirname, 'test_index');
const basicIndexItems: Partial
{ id: '1', vector: [1, 2, 3] },
{ id: '2', vector: [2, 3, 4] },
{ id: '3', vector: [3, 4, 5] }
];
beforeEach(async () => {
await fs.rm(testIndexDir, { recursive: true, force: true });
});
afterEach(async () => {
await fs.rm(testIndexDir, { recursive: true, force: true });
sinon.restore();
});
it('should create a new index', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
const created = await index.isIndexCreated();
assert.equal(created, true);
assert.equal(index.folderPath, testIndexDir);
});
it('blocks concurrent operations when lock is held', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.beginUpdate(); // grab lock for a big update!
await assert.rejects(async () => {
await index.beginUpdate(); // try to grab lock again. should fail!
}, new Error('Update already in progress'))
})
describe('createIndex', () => {
it('checks for existing index on creation', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex(); // create first index.json
// create without deleteIfExists. Will reject
await assert.rejects(async () => {
await index.createIndex()
}, new Error('Index already exists'))
// create with deleteIfExists. Should remove old data
await index.insertItem({id:'1', vector: [1,2,3]})
const lengthBefore = (await index.listItems()).length
assert.equal(lengthBefore, 1)
await index.createIndex({deleteIfExists: true, version: 2, metadata_config: {}})
const lengthAfter = (await index.listItems()).length
assert.equal(lengthAfter, 0)
})
it('delete index if file creation fails', async () => {
const index = new LocalIndex(testIndexDir);
sinon.stub(fs, 'writeFile').rejects(new Error('fs error'))
await assert.rejects(async () => {
await index.createIndex();
}, new Error('Error creating index'))
await assert.rejects(async () => {
await index.listItems();
})
})
})
describe('deleteItem', () => {
it('does nothing when id not found', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.beginUpdate();
await index.insertItem(basicIndexItems[0])
await index.insertItem(basicIndexItems[1])
await index.insertItem(basicIndexItems[2])
await index.endUpdate();
await assert.doesNotReject(async () => {
await index.deleteItem('dne');
})
assert.equal((await index.listItems()).length, 3)
})
it('leaves existing empty index when last el deleted', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.insertItem(basicIndexItems[0]);
await index.deleteItem(basicIndexItems[0].id ?? '');
assert.equal(await index.isIndexCreated(), true);
assert.equal((await index.listItems()).length, 0);
})
it('removes elements from any position', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.batchInsertItems([
{id: '1', vector: []},
{id: '2', vector: []},
{id: '3', vector: []},
{id: '4', vector: []},
{id: '5', vector: []},
]);
await index.beginUpdate();
await index.deleteItem('1');
await index.deleteItem('3');
await index.deleteItem('5');
await index.endUpdate();
assert.deepStrictEqual(await index.listItems(), [{id: '2', vector: [], metadata: {}, norm: 0}, {id: '4', vector: [], metadata: {}, norm: 0}])
})
})
describe('endUpdate', () => {
it('throws an error if no update has begun', async () => {
const index = new LocalIndex(testIndexDir);
await assert.rejects(async () => {
await index.endUpdate();
}, new Error('No update in progress'));
})
it('throws an error if the index could not be saved', async () => {
const index = new LocalIndex(testIndexDir, 'index.json');
await index.createIndex();
await index.beginUpdate();
sinon.stub(fs, 'writeFile').rejects(new Error('fs error'))
await assert.rejects(async () => {
await index.endUpdate();
}, new Error('Error saving index: Error: fs error'))
})
})
describe('getIndexStats', () => {
it('reports empty index correctly', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
assert.deepStrictEqual(await index.getIndexStats(), {
version: 1,
metadata_config: {},
items: 0
})
})
it('correctly reports non-empty index stats', async () => {
const index = new LocalIndex(testIndexDir)
await index.createIndex({version: 1, metadata_config: {indexed: []}})
await index.batchInsertItems(basicIndexItems);
assert.deepStrictEqual(await index.getIndexStats(), {
version: 1,
metadata_config: {indexed: []},
items: 3
})
})
})
describe('getItem', () => {
it('returns undefined when item not found', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
assert.equal(await index.getItem('1'), undefined)
})
it('returns requested item', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.batchInsertItems(basicIndexItems);
const item2 = await index.getItem('2');
assert.equal(item2?.id, basicIndexItems[1].id)
assert.equal(item2?.vector, basicIndexItems[1].vector)
assert.equal((await index.listItems()).length, 3)
})
})
describe('batchInsertItems', () => {
it('should insert provided items', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
const newItems = await index.batchInsertItems(basicIndexItems);
assert.equal(newItems.length, 3);
const retrievedItems = await index.listItems();
assert.equal(retrievedItems.length, 3);
});
it('on id collision - cancel batch insert & bubble up error', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.insertItem({ id: '2', vector: [9, 9, 9] });
// ensures insert error is bubbled up to batchIndexItems caller
await assert.rejects(
async () => {
await index.batchInsertItems(basicIndexItems);
},
{
name: 'Error',
message: 'Item with id 2 already exists'
}
);
// ensures no partial update is applied
const storedItems = await index.listItems();
assert.equal(storedItems.length, 1);
});
});
describe('listItemsByMetadata', () => {
it('returns items matching metadata filter', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.batchInsertItems([
{id: '1', vector: [], metadata: {category: 'food'}},
{id: '2', vector: [], metadata: {category: 'food'}},
{id: '3', vector: [], metadata: {category: 'electronics'}},
{id: '4', vector: [], metadata: {category: 'drink'}},
{id: '5', vector: [], metadata: {category: 'food'}},
]);
const foodItems = await index.listItemsByMetadata({category: {'$eq': 'food'}})
assert.deepStrictEqual(foodItems.map((item) => item.id), ["1", "2", "5"])
const drinkItems = await index.listItemsByMetadata({category: {'$eq': 'drink'}})
assert.deepStrictEqual(drinkItems.map((item) => item.id), ["4"])
const clothingItems = await index.listItemsByMetadata({category: {'$eq': 'clothes'}})
assert.deepStrictEqual(clothingItems, [])
})
it('returns nothing when no items in index', async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
const items = await index.listItemsByMetadata({});
assert.deepStrictEqual(items, []);
})
});
describe("queryItems", () => {
it("returns empty array on empty index search", async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
const result = await index.queryItems([1, 2, 3], "", 10);
assert.deepStrictEqual(result, []);
});
it("returns bad match when no better match exists", async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.insertItem({ id: "1", vector: [0.9, 0, 0, 0, 0] });
const result = await index.queryItems([0, 0, 0, 0, 0.1], "", 1);
assert.equal(result[0]?.score, 0);
assert.equal(result[0]?.item.id, "1");
});
it("returns all vectors when fewer than topK exist", async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.batchInsertItems(basicIndexItems);
const result = await index.queryItems([0, 0, 1], "", 10);
assert.equal(result.length, 3);
assert.deepStrictEqual(
result.map(({ item }) => item.id),
basicIndexItems.map((item) => item.id),
);
});
it("filters by metadata when filter provided", async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex();
await index.batchInsertItems([
{ id: "1", vector: [1, 0, 0], metadata: { category: "food" } },
{ id: "2", vector: [0, 0, 1], metadata: { category: "drink" } },
]);
const bestGeneralMatch = await index.queryItems([1, 0, 0], "", 1);
const bestDrinkMatch = await index.queryItems([1, 0, 0], "", 1, {
category: { $eq: "drink" },
});
assert.equal(bestGeneralMatch[0].item.id, "1");
assert.equal(bestDrinkMatch[0].item.id, "2");
});
it("reads item metadata file when provided", async () => {
const index = new LocalIndex(testIndexDir);
await index.createIndex({version: 1, metadata_config: {indexed: ['category']}});
await index.batchInsertItems([
{ id: "1", vector: [1, 0, 0] },
{ id: "2", vector: [0, 0, 1], metadata: {category: 'drink'} },
]);
sinon
.stub(fs, "readFile")
.resolves(JSON.stringify({ category: "drink" }));
const bestDrinkMatch = await index.queryItems([1, 0, 0], "", 2, {category: {'$eq': 'drink'}});
assert.notEqual(bestDrinkMatch[0].item.metadataFile, undefined);
assert.equal(bestDrinkMatch[0].item.id, "2");
});
});
});
MEMORY:
SOURCE: src\OpenAIEmbeddings.ts
DETAILS: import axios, { AxiosInstance, AxiosResponse, AxiosRequestConfig } from 'axios';
import { EmbeddingsModel, EmbeddingsResponse } from "./types";
import { CreateEmbeddingRequest, CreateEmbeddingResponse, OpenAICreateEmbeddingRequest } from "./internals";
import { Colorize } from "./internals";
export interface BaseOpenAIEmbeddingsOptions {
/**
* Optional. Number of embedding dimensions to return.
*/
dimensions?: number;
/**
* Optional. Whether to log requests to the console.
* @remarks
* This is useful for debugging prompts and defaults to false.
*/
logRequests?: boolean;
/**
* Optional. Maximum number of tokens that can be sent to the embedding model.
*/
maxTokens?: number;
/**
* Optional. Retry policy to use when calling the OpenAI API.
* @remarks
* The default retry policy is [2000, 5000] which means that the first retry will be after
* 2 seconds and the second retry will be after 5 seconds.
*/
retryPolicy?: number[];
/**
* Optional. Request options to use when calling the OpenAI API.
*/
requestConfig?: AxiosRequestConfig;
}
/**
* Options for configuring an OpenAIEmbeddings to generate embeddings using an OSS hosted model.
*/
export interface OSSEmbeddingsOptions extends BaseOpenAIEmbeddingsOptions {
/**
* Model to use for completion.
*/
ossModel: string;
/**
* Optional. Endpoint to use when calling the OpenAI API.
* @remarks
* For Azure OpenAI this is the deployment endpoint.
*/
ossEndpoint: string;
}
/**
* Options for configuring an OpenAIEmbeddings to generate embeddings using an OpenAI hosted model.
*/
export interface OpenAIEmbeddingsOptions extends BaseOpenAIEmbeddingsOptions {
/**
* API key to use when calling the OpenAI API.
* @remarks
* A new API key can be created at https://platform.openai.com/account/api-keys.
*/
apiKey: string;
/**
* Model to use for completion.
* @remarks
* For Azure OpenAI this is the name of the deployment to use.
*/
model: string;
/**
* Optional. Organization to use when calling the OpenAI API.
*/
organization?: string;
/**
* Optional. Endpoint to use when calling the OpenAI API.
*/
endpoint?: string;
}
/**
* Options for configuring an OpenAIEmbeddings to generate embeddings using an Azure OpenAI hosted model.
*/
export interface AzureOpenAIEmbeddingsOptions extends BaseOpenAIEmbeddingsOptions {
/**
* API key to use when making requests to Azure OpenAI.
*/
azureApiKey: string;
/**
* Deployment endpoint to use.
*/
azureEndpoint: string;
/**
* Name of the Azure OpenAI deployment (model) to use.
*/
azureDeployment: string;
/**
* Optional. Version of the API being called. Defaults to 2023-05-15.
*/
azureApiVersion?: string;
}
/**
* A PromptCompletionModel for calling OpenAI and Azure OpenAI hosted models.
* @remarks
*/
export class OpenAIEmbeddings implements EmbeddingsModel {
private readonly _httpClient: AxiosInstance;
private readonly _clientType: ClientType;
private readonly UserAgent = 'AlphaWave';
public readonly maxTokens;
/**
* Options the client was configured with.
*/
public readonly options: OSSEmbeddingsOptions|OpenAIEmbeddingsOptions|AzureOpenAIEmbeddingsOptions;
/**
* Creates a new OpenAIClient instance.
* @param options Options for configuring an OpenAIClient.
*/
public constructor(options: OSSEmbeddingsOptions|OpenAIEmbeddingsOptions|AzureOpenAIEmbeddingsOptions) {
this.maxTokens = options.maxTokens ?? 500;
// Check for azure config
if ((options as AzureOpenAIEmbeddingsOptions).azureApiKey) {
this._clientType = ClientType.AzureOpenAI;
this.options = Object.assign({
retryPolicy: [2000, 5000],
azureApiVersion: '2023-05-15',
}, options) as AzureOpenAIEmbeddingsOptions;
// Cleanup and validate endpoint
let endpoint = this.options.azureEndpoint.trim();
if (endpoint.endsWith('/')) {
endpoint = endpoint.substring(0, endpoint.length - 1);
}
if (!endpoint.toLowerCase().startsWith('https://')) {
throw new Error(Client created with an invalid endpoint of '${endpoint}'. The endpoint must be a valid HTTPS url.);
}
this.options.azureEndpoint = endpoint;
} else if ((options as OSSEmbeddingsOptions).ossModel) {
this._clientType = ClientType.OSS;
this.options = Object.assign({
retryPolicy: [2000, 5000]
}, options) as OSSEmbeddingsOptions;
} else {
this._clientType = ClientType.OpenAI;
this.options = Object.assign({
retryPolicy: [2000, 5000]
}, options) as OpenAIEmbeddingsOptions;
}
// Create client
this._httpClient = axios.create({
validateStatus: (status) => status < 400 || status == 429
});
}
/**
* Creates embeddings for the given inputs using the OpenAI API.
* @param model Name of the model to use (or deployment for Azure).
* @param inputs Text inputs to create embeddings for.
* @returns A EmbeddingsResponse with a status and the generated embeddings or a message when an error occurs.
*/
public async createEmbeddings(inputs: string | string[]): Promise
if (this.options.logRequests) {
console.log(Colorize.title('EMBEDDINGS REQUEST:'));
console.log(Colorize.output(inputs));
}
const startTime = Date.now();
const response = await this.createEmbeddingRequest({
input: inputs,
});
if (this.options.logRequests) {
console.log(Colorize.title('RESPONSE:'));
console.log(Colorize.value('status', response.status));
console.log(Colorize.value('duration', Date.now() - startTime, 'ms'));
console.log(Colorize.output(response.data));
}
// Process response
if (response.status < 300) {
return { status: 'success', output: response.data.data.sort((a, b) => a.index - b.index).map((item) => item.embedding) };
} else if (response.status == 429) {
return { status: 'rate_limited', message: The embeddings API returned a rate limit error. }
} else {
return { status: 'error', message: The embeddings API returned an error status of ${response.status}: ${response.statusText} };
}
}
/**
* @private
*/
protected createEmbeddingRequest(request: CreateEmbeddingRequest): Promise
if (this.options.dimensions) {
request.dimensions = this.options.dimensions;
}
if (this._clientType == ClientType.AzureOpenAI) {
const options = this.options as AzureOpenAIEmbeddingsOptions;
const url = ${options.azureEndpoint}/openai/deployments/${options.azureDeployment}/embeddings?api-version=${options.azureApiVersion!};
return this.post(url, request);
} else if (this._clientType == ClientType.OSS) {
const options = this.options as OSSEmbeddingsOptions;
const url = ${options.ossEndpoint}/v1/embeddings;
(request as OpenAICreateEmbeddingRequest).model = options.ossModel;
return this.post(url, request);
} else {
const options = this.options as OpenAIEmbeddingsOptions;
const url = ${options.endpoint ?? 'https://api.openai.com'}/v1/embeddings;
(request as OpenAICreateEmbeddingRequest).model = options.model;
return this.post(url, request);
}
}
/**
* @private
*/
protected async post
// Initialize request config
const requestConfig: AxiosRequestConfig = Object.assign({}, this.options.requestConfig);
// Initialize request headers
if (!requestConfig.headers) {
requestConfig.headers = {};
}
if (!requestConfig.headers['Content-Type']) {
requestConfig.headers['Content-Type'] = 'application/json';
}
if (!requestConfig.headers['User-Agent']) {
requestConfig.headers['User-Agent'] = this.UserAgent;
}
if (this._clientType == ClientType.AzureOpenAI) {
const options = this.options as AzureOpenAIEmbeddingsOptions;
requestConfig.headers['api-key'] = options.azureApiKey;
} else if (this._clientType == ClientType.OpenAI) {
const options = this.options as OpenAIEmbeddingsOptions;
requestConfig.headers['Authorization'] = Bearer ${options.apiKey};
if (options.organization) {
requestConfig.headers['OpenAI-Organization'] = options.organization;
}
}
// Send request
const response = await this._httpClient.post(url, body, requestConfig);
// Check for rate limit error
if (response.status == 429 && Array.isArray(this.options.retryPolicy) && retryCount < this.options.retryPolicy.length) {
const delay = this.options.retryPolicy[retryCount];
await new Promise((resolve) => setTimeout(resolve, delay));
return this.post(url, body, retryCount + 1);
} else {
return response;
}
}
}
enum ClientType {
OpenAI,
AzureOpenAI,
OSS
}
MEMORY:
SOURCE: src\ItemSelector.ts
DETAILS: import { MetadataFilter, MetadataTypes } from './types';
export class ItemSelector {
/**
* Returns the similarity between two vectors using the cosine similarity.
* @param vector1 Vector 1
* @param vector2 Vector 2
* @returns Similarity between the two vectors
*/
public static cosineSimilarity(vector1: number[], vector2: number[]) {
// Return the quotient of the dot product and the product of the norms
return this.dotProduct(vector1, vector2) / (this.normalize(vector1) * this.normalize(vector2));
}
/**
* Normalizes a vector.
* @remarks
* The norm of a vector is the square root of the sum of the squares of the elements.
* The LocalIndex pre-normalizes all vectors to improve performance.
* @param vector Vector to normalize
* @returns Normalized vector
*/
public static normalize(vector: number[]) {
// Initialize a variable to store the sum of the squares
let sum = 0;
// Loop through the elements of the array
for (let i = 0; i < vector.length; i++) {
// Square the element and add it to the sum
sum += vector[i] * vector[i];
}
// Return the square root of the sum
return Math.sqrt(sum);
}
/**
* Returns the similarity between two vectors using cosine similarity.
* @remarks
* The LocalIndex pre-normalizes all vectors to improve performance.
* This method uses the pre-calculated norms to improve performance.
* @param vector1 Vector 1
* @param norm1 Norm of vector 1
* @param vector2 Vector 2
* @param norm2 Norm of vector 2
* @returns Similarity between the two vectors
*/
public static normalizedCosineSimilarity(vector1: number[], norm1: number, vector2: number[], norm2: number) {
// Return the quotient of the dot product and the product of the norms
return this.dotProduct(vector1, vector2) / (norm1 * norm2);
}
/**
* Applies a filter to the metadata of an item.
* @param metadata Metadata of the item
* @param filter Filter to apply
* @returns True if the item matches the filter, false otherwise
*/
public static select(metadata: Record
if (filter === undefined || filter === null) {
return true;
}
for (const key in filter) {
switch (key) {
case '$and':
if (!filter[key]!.every((f: MetadataFilter) => this.select(metadata, f))) {
return false;
}
break;
case '$or':
if (!filter[key]!.some((f: MetadataFilter) => this.select(metadata, f))) {
return false;
}
break;
default:
const value = filter[key];
if (value === undefined || value === null) {
return false;
} else if (typeof value == 'object') {
if (!this.metadataFilter(metadata[key], value as MetadataFilter)) {
return false;
}
} else {
if (metadata[key] !== value) {
return false;
}
}
break;
}
}
return true;
}
private static dotProduct(arr1: number[], arr2: number[]) {
// Initialize a variable to store the sum of the products
let sum = 0;
// Loop through the elements of the arrays
for (let i = 0; i < arr1.length; i++) {
// Multiply the corresponding elements and add them to the sum
sum += arr1[i] * arr2[i];
}
// Return the sum
return sum;
}
private static metadataFilter(value: MetadataTypes, filter: MetadataFilter): boolean {
if (value === undefined || value === null) {
return false;
}
for (const key in filter) {
switch (key) {
case '$eq':
if (value !== filter[key]) {
return false;
}
break;
case '$ne':
if (value === filter[key]) {
return false;
}
break;
case '$gt':
if (typeof value != 'number' || value <= filter[key]!) {
return false;
}
break;
case '$gte':
if (typeof value != 'number' || value < filter[key]!) {
return false;
}
break;
case '$lt':
if (typeof value != 'number' || value >= filter[key]!) {
return false;
}
break;
case '$lte':
if (typeof value != 'number' || value > filter[key]!) {
return false;
}
break;
case '$in':
if (typeof value == 'boolean') {
return false;
} else if(typeof value == 'string' && !filter[key]!.includes(value)){
return false
} else if(!filter[key]!.some(val => typeof val == 'string' && val.includes(value as string))){
return false
}
break;
case '$nin':
if (typeof value == 'boolean') {
return false;
}
else if (typeof value == 'string' && filter[key]!.includes(value)) {
return false;
}
else if (filter[key]!.some(val => typeof val == 'string' && val.includes(value as string))) {
return false;
}
break;
default:
return value === filter[key];
}
}
return true;
}
}
MEMORY: ",
SOURCE: src\TextSplitter.ts
DETAILS: import { GPT3Tokenizer } from "./GPT3Tokenizer";
import { TextChunk, Tokenizer } from "./types";
const ALPHANUMERIC_CHARS = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789';
export interface TextSplitterConfig {
separators: string[];
keepSeparators: boolean;
chunkSize: number;
chunkOverlap: number;
tokenizer: Tokenizer;
docType?: string;
export class TextSplitter {
private readonly _config: TextSplitterConfig;
public constructor(config?: Partial
this._config = Object.assign({
keepSeparators: false,
chunkSize: 400,
chunkOverlap: 40,
} as TextSplitterConfig, config);
// Create a default tokenizer if none is provided
if (!this._config.tokenizer) {
this._config.tokenizer = new GPT3Tokenizer();
// Use default separators if none are provided
if (!this._config.separators || this._config.separators.length === 0) {
this._config.separators = this.getSeparators(this._config.docType);
// Validate the config settings
if (this._config.chunkSize < 1) {
throw new Error("chunkSize must be >= 1");
} else if (this._config.chunkOverlap < 0) {
throw new Error("chunkOverlap must be >= 0");
} else if (this._config.chunkOverlap > this._config.chunkSize) {
throw new Error("chunkOverlap must be <= chunkSize");
public split(text: string): TextChunk[] {
// Get basic chunks
const chunks = this.recursiveSplit(text, this._config.separators, 0);
const that = this;
function getOverlapTokens(tokens?: number[]): number[] {
if (tokens != undefined) {
const len = tokens.length > that._config.chunkOverlap ? that._config.chunkOverlap : tokens.length;
return tokens.slice(0, len);
} else {
return [];
// Add overlap tokens and text to the start and end of each chunk
if (this._config.chunkOverlap > 0) {
for (let i = 1; i < chunks.length; i++) {
const previousChunk = chunks[i - 1];
const chunk = chunks[i];
const nextChunk = i < chunks.length - 1 ? chunks[i + 1] : undefined;
chunk.startOverlap = getOverlapTokens(previousChunk.tokens.reverse()).reverse();
chunk.endOverlap = getOverlapTokens(nextChunk?.tokens);
return chunks;
private recursiveSplit(text: string, separators: string[], startPos: number): TextChunk[] {
const chunks: TextChunk[] = [];
if (text.length > 0) {
// Split text into parts
let parts: string[];
let separator = '';
const nextSeparators = separators.length > 1 ? separators.slice(1) : [];
if (separators.length > 0) {
// Split by separator
separator = separators[0];
parts = separator == ' ' ? this.splitBySpaces(text) : text.split(separator);
} else {
// Cut text in half
const half = Math.floor(text.length / 2);
parts = [text.substring(0, half), text.substring(half)];
// Iterate over parts
for (let i = 0; i < parts.length; i++) {
const lastChunk = (i === parts.length - 1);
// Get chunk text and endPos
let chunk = parts[i];
const endPos = (startPos + (chunk.length - 1)) + (lastChunk ? 0 : separator.length);
if (this._config.keepSeparators && !lastChunk) {
chunk += separator;
// Ensure chunk contains text
if (!this.containsAlphanumeric(chunk)) {
continue;
// Optimization to avoid encoding really large chunks
if (chunk.length / 6 > this._config.chunkSize) {
// Break the text into smaller chunks
const subChunks = this.recursiveSplit(chunk, nextSeparators, startPos);
chunks.push(...subChunks);
} else {
// Encode chunk text
const tokens = this._config.tokenizer.encode(chunk);
if (tokens.length > this._config.chunkSize) {
// Break the text into smaller chunks
const subChunks = this.recursiveSplit(chunk, nextSeparators, startPos);
chunks.push(...subChunks);
} else {
// Append chunk to output
chunks.push({
text: chunk,
tokens: tokens,
startPos: startPos,
endPos: endPos,
startOverlap: [],
endOverlap: [],
});
// Update startPos
startPos = endPos + 1;
return this.combineChunks(chunks);
private combineChunks(chunks: TextChunk[]): TextChunk[] {
const combinedChunks: TextChunk[] = [];
let currentChunk: TextChunk|undefined;
let currentLength = 0;
const separator = this._config.keepSeparators ? '' : ' ';
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
if (currentChunk) {
const length = currentChunk.tokens.length + chunk.tokens.length;
if (length > this._config.chunkSize) {
combinedChunks.push(currentChunk);
currentChunk = chunk;
currentLength = chunk.tokens.length;
} else {
currentChunk.text += separator + chunk.text;
currentChunk.endPos = chunk.endPos;
currentChunk.tokens.push(...chunk.tokens);
currentLength += chunk.tokens.length;
}
} else {
currentChunk = chunk;
currentLength = chunk.tokens.length;
if (currentChunk) {
combinedChunks.push(currentChunk);
return combinedChunks;
private containsAlphanumeric(text: string): boolean {
for (let i = 0; i < text.length; i++) {
if (ALPHANUMERIC_CHARS.includes(text[i])) {
return true;
}
}
return false;
}
private splitBySpaces(text: string): string[] {
// Split text by tokens and return parts
const parts: string[] = [];
let tokens = this._config.tokenizer.encode(text);
do {
if (tokens.length <= this._config.chunkSize) {
parts.push(this._config.tokenizer.decode(tokens));
break;
} else {
const span = tokens.splice(0, this._config.chunkSize);
parts.push(this._config.tokenizer.decode(span));
} while (true);
return parts;
}
private getSeparators(docType?: string): string[] {
switch (docType ?? '') {
case "cpp":
return [
// Split along class definitions
"\nclass ",
// Split along function definitions
"\nvoid ",
"\nint ",
"\nfloat ",
"\ndouble ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "go":
return [
// Split along function definitions
"\nfunc ",
"\nvar ",
"\nconst ",
"\ntype ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nswitch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "java":
case "c#":
case "csharp":
case "cs":
case "ts":
case "tsx":
case "typescript":
return [
// split along regions
"// LLM-REGION",
"/* LLM-REGION",
"/** LLM-REGION",
// Split along class definitions
"\nclass ",
// Split along method definitions
"\npublic ",
"\nprotected ",
"\nprivate ",
"\nstatic ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
" "
];
case "js":
case "jsx":
case "javascript":
return [
// split along regions
"// LLM-REGION",
"/* LLM-REGION",
"/** LLM-REGION",
// Split along class definitions
"\nclass ",
// Split along function definitions
"\nfunction ",
"\nconst ",
"\nlet ",
"\nvar ",
"\nclass ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nswitch ",
"\ncase ",
"\ndefault ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "php":
return [
// Split along function definitions
"\nfunction ",
// Split along class definitions
"\nclass ",
// Split along control flow statements
"\nif ",
"\nforeach ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "proto":
return [
// Split along message definitions
"\nmessage ",
// Split along service definitions
"\nservice ",
// Split along enum definitions
"\nenum ",
// Split along option definitions
"\noption ",
// Split along import statements
"\nimport ",
// Split along syntax declarations
"\nsyntax ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "python":
case "py":
return [
// First, try to split along class definitions
"\nclass ",
"\ndef ",
"\n\tdef ",
// Now split by the normal type of lines
"\n\n",
"\n",
];
case "rst":
return [
// Split along section titles
"\n===\n",
"\n---\n",
"\n*\n",
// Split along directive markers
"\n.. ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "ruby":
return [
// Split along method definitions
"\ndef ",
"\nclass ",
// Split along control flow statements
"\nif ",
"\nunless ",
"\nwhile ",
"\nfor ",
"\ndo ",
"\nbegin ",
"\nrescue ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "rust":
return [
// Split along function definitions
"\nfn ",
"\nconst ",
"\nlet ",
// Split along control flow statements
"\nif ",
"\nwhile ",
"\nfor ",
"\nloop ",
"\nmatch ",
"\nconst ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "scala":
return [
// Split along class definitions
"\nclass ",
"\nobject ",
// Split along method definitions
"\ndef ",
"\nval ",
"\nvar ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\nmatch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "swift":
return [
// Split along function definitions
"\nfunc ",
// Split along class definitions
"\nclass ",
"\nstruct ",
"\nenum ",
// Split along control flow statements
"\nif ",
"\nfor ",
"\nwhile ",
"\ndo ",
"\nswitch ",
"\ncase ",
// Split by the normal type of lines
"\n\n",
"\n",
];
case "md":
case "markdown":
return [
// First, try to split along Markdown headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
// Note the alternative syntax for headings (below) is not handled here
// Heading level 2
// End of code block
"``\n\n",
// Horizontal lines
"\n\n*\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
// Note that this splitter doesn't handle horizontal lines defined
// by three or more of *, ---, or ___, but this is not handled
// Github tables
"",
// "",
// "",
// " "\n\n",
"\n",
];
case "latex":
return [
// First, try to split along Latex sections
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
// Now split by environments
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
// Now split by math environments
"\n\\begin{align}",
// Now split by the normal type of lines
"\n\n",
"\n",
];
case "html":
return [
// First, try to split along HTML tags
"",
"",
"
"
",
"
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"
"