File GpuIndexCagra.h
-
namespace faiss
Implementation of k-means clustering with many variants.
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree.
IDSelector is intended to define a subset of vectors to handle (for removal or as subset to search)
PQ4 SIMD packing and accumulation functions
The basic kernel accumulates nq query vectors with bbs = nb * 2 * 16 vectors and produces an output matrix for that. It is interesting for nq * nb <= 4, otherwise register spilling becomes too large.
The implementation of these functions is spread over 3 cpp files to reduce parallel compile times. Templates are instantiated explicitly.
This file contains callbacks for kernels that compute distances.
Throughout the library, vectors are provided as float * pointers. Most algorithms can be optimized when several vectors are processed (added/searched) together in a batch. In this case, they are passed in as a matrix. When n vectors of size d are provided as float * x, component j of vector i is
x[ i * d + j ]
where 0 <= i < n and 0 <= j < d. In other words, matrices are always compact. When specifying the size of the matrix, we call it an n*d matrix, which implies a row-major storage.
I/O functions can read/write to a filename, a file handle or to an object that abstracts the medium.
The read functions return objects that should be deallocated with delete. All references within these objectes are owned by the object.
Definition of inverted lists + a few common classes that implement the interface.
Since IVF (inverted file) indexes are of so much use for large-scale use cases, we group a few functions related to them in this small library. Most functions work both on IndexIVFs and IndexIVFs embedded within an IndexPreTransform.
In this file are the implementations of extra metrics beyond L2 and inner product
Implements a few neural net layers, mainly to support QINCo
Defines a few objects that apply transformations to a set of vectors Often these are pre-processing steps.
-
namespace gpu
Enums
-
enum class graph_build_algo
Values:
-
enumerator IVF_PQ
Use IVF-PQ to build all-neighbors knn graph.
-
enumerator NN_DESCENT
Use NN-Descent to build all-neighbors knn graph.
-
enumerator IVF_PQ
-
enum class codebook_gen
A type for specifying how PQ codebooks are created.
Values:
-
enumerator PER_SUBSPACE
-
enumerator PER_CLUSTER
-
enumerator PER_SUBSPACE
-
struct IVFPQBuildCagraConfig
Public Members
-
uint32_t n_lists = 1024
The number of inverted lists (clusters)
Hint: the number of vectors per cluster (
n_rows/n_lists
) should be approximately 1,000 to 10,000.
-
uint32_t kmeans_n_iters = 20
The number of iterations searching for kmeans centers (index building).
-
double kmeans_trainset_fraction = 0.5
The fraction of data to use during iterative kmeans building.
-
uint32_t pq_bits = 8
The bit length of the vector element after compression by PQ.
Possible values: [4, 5, 6, 7, 8].
Hint: the smaller the ‘pq_bits’, the smaller the index size and the better the search performance, but the lower the recall.
-
uint32_t pq_dim = 0
The dimensionality of the vector after compression by PQ. When zero, an optimal value is selected using a heuristic.
NB:
pq_dim /// pq_bits
must be a multiple of 8.Hint: a smaller ‘pq_dim’ results in a smaller index size and better search performance, but lower recall. If ‘pq_bits’ is 8, ‘pq_dim’ can be set to any number, but multiple of 8 are desirable for good performance. If ‘pq_bits’ is not 8, ‘pq_dim’ should be a multiple of 8. For good performance, it is desirable that ‘pq_dim’ is a multiple of 32. Ideally, ‘pq_dim’ should be also a divisor of the dataset dim.
-
codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE
How PQ codebooks are created.
-
bool force_random_rotation = false
Apply a random rotation matrix on the input data and queries even if
dim % pq_dim == 0
.Note: if
dim
is not multiple ofpq_dim
, a random rotation is always applied to the input data and queries to transform the working space fromdim
torot_dim
, which may be slightly larger than the original space and and is a multiple ofpq_dim
(rot_dim % pq_dim == 0
). However, this transform is not necessary whendim
is multiple ofpq_dim
(dim == rot_dim
, hence no need in adding “extra” data columns / features).By default, if
dim == rot_dim
, the rotation transform is initialized with the identity matrix. Whenforce_random_rotation == true
, a random orthogonal transform matrix is generated regardless of the values ofdim
andpq_dim
.
-
bool conservative_memory_allocation = false
By default, the algorithm allocates more space than necessary for individual clusters (
list_data
). This allows to amortize the cost of memory allocation and reduce the number of data copies during repeated calls toextend
(extending the database).The alternative is the conservative allocation behavior; when enabled, the algorithm always allocates the minimum amount of memory required to store the given number of records. Set this flag to
true
if you prefer to use as little GPU memory for the database as possible.
-
uint32_t n_lists = 1024
-
struct IVFPQSearchCagraConfig
Public Members
-
uint32_t n_probes = 20
The number of clusters to search.
-
cudaDataType_t lut_dtype = CUDA_R_32F
Data type of look up table to be created dynamically at search time.
Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U]
The use of low-precision types reduces the amount of shared memory required at search time, so fast shared memory kernels can be used even for datasets with large dimansionality. Note that the recall is slightly degraded when low-precision type is selected.
-
cudaDataType_t internal_distance_dtype = CUDA_R_32F
Storage data type for distance/similarity computed at search time.
Possible values: [CUDA_R_16F, CUDA_R_32F]
If the performance limiter at search time is device memory access, selecting FP16 will improve performance slightly.
-
double preferred_shmem_carveout = 1.0
Preferred fraction of SM’s unified memory / L1 cache to be used as shared memory.
Possible values: [0.0 - 1.0] as a fraction of the
sharedMemPerMultiprocessor
.One wants to increase the carveout to make sure a good GPU occupancy for the main search kernel, but not to keep it too high to leave some memory to be used as L1 cache. Note, this value is interpreted only as a hint. Moreover, a GPU usually allows only a fixed set of cache configurations, so the provided value is rounded up to the nearest configuration. Refer to the NVIDIA tuning guide for the target GPU architecture.
Note, this is a low-level tuning parameter that can have drastic negative effects on the search performance if tweaked incorrectly.
-
uint32_t n_probes = 20
-
struct GpuIndexCagraConfig : public faiss::gpu::GpuIndexConfig
Public Members
-
size_t intermediate_graph_degree = 128
Degree of input graph for pruning.
-
size_t graph_degree = 64
Degree of output graph.
-
graph_build_algo build_algo = graph_build_algo::IVF_PQ
ANN algorithm to build knn graph.
-
size_t nn_descent_niter = 20
Number of Iterations to run if building with NN_DESCENT.
-
IVFPQBuildCagraConfig *ivf_pq_params = nullptr
-
IVFPQSearchCagraConfig *ivf_pq_search_params = nullptr
-
float refine_rate = 2.0f
-
bool store_dataset = true
-
size_t intermediate_graph_degree = 128
-
struct SearchParametersCagra : public faiss::SearchParameters
Public Members
-
size_t max_queries = 0
Maximum number of queries to search at the same time (batch size). Auto select when 0.
-
size_t itopk_size = 64
Number of intermediate search results retained during the search.
This is the main knob to adjust trade off between accuracy and search speed. Higher values improve the search accuracy.
-
size_t max_iterations = 0
Upper limit of search iterations. Auto select when 0.
-
search_algo algo = search_algo::AUTO
Which search implementation to use.
-
size_t team_size = 0
Number of threads used to calculate a single distance. 4, 8, 16, or 32.
-
size_t search_width = 1
Number of graph nodes to select as the starting point for the search in each iteration. aka search width?
-
size_t min_iterations = 0
Lower limit of search iterations.
-
size_t thread_block_size = 0
Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
-
size_t hashmap_min_bitlen = 0
Lower limit of hashmap bit length. More than 8.
-
float hashmap_max_fill_rate = 0.5
Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
-
uint32_t num_random_samplings = 1
Number of iterations of initial random seed node selection. 1 or more.
-
uint64_t seed = 0x128394
Bit mask used for initial random seed node selection.
-
size_t max_queries = 0
-
struct GpuIndexCagra : public faiss::gpu::GpuIndex
Public Functions
-
GpuIndexCagra(GpuResourcesProvider *provider, int dims, faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig())
-
void copyFrom(const faiss::IndexHNSWCagra *index)
Initialize ourselves from the given CPU index; will overwrite all data in ourselves
-
void copyTo(faiss::IndexHNSWCagra *index) const
Copy ourselves to the given CPU index; will overwrite all data in the index instance
-
virtual void reset() override
removes all elements from the database.
Protected Functions
-
virtual bool addImplRequiresIDs_() const override
Does addImpl_ require IDs? If so, and no IDs are provided, we will generate them sequentially based on the order in which the IDs are added
-
virtual void addImpl_(idx_t n, const float *x, const idx_t *ids) override
Overridden to actually perform the add All data is guaranteed to be resident on our device
-
virtual void searchImpl_(idx_t n, const float *x, int k, float *distances, idx_t *labels, const SearchParameters *search_params) const override
Called from GpuIndex for search.
Protected Attributes
-
const GpuIndexCagraConfig cagraConfig_
Our configuration options.
-
GpuIndexCagra(GpuResourcesProvider *provider, int dims, faiss::MetricType metric = faiss::METRIC_L2, GpuIndexCagraConfig config = GpuIndexCagraConfig())
-
enum class graph_build_algo
-
namespace gpu