Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions crates/k_means/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,9 @@ use rabitq::bit::block::BlockCode;
use rabitq::packing::{any_pack, padding_pack};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rayon::iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use simd::Floating;

pub fn preprocess<T: Send>(num_threads: usize, x: &mut [T], f: impl Fn(&mut T) + Sync) {
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_scoped(
|thread| thread.run(),
move |pool| {
pool.install(|| {
x.par_iter_mut().for_each(&f);
});
},
)
.expect("failed to build thread pool")
}

pub fn k_means(
num_threads: usize,
mut check: impl FnMut(usize),
Expand All @@ -57,7 +43,7 @@ pub fn k_means(
if n >= 1024 && c >= 1024 {
rabitq_index(n, c, samples, centroids)
} else {
flat_index(dims, n, c, samples, centroids)
flat_index(n, c, samples, centroids)
}
};
let mut lloyd_k_means =
Expand All @@ -75,18 +61,6 @@ pub fn k_means(
}
}

pub fn k_means_lookup(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
assert_ne!(centroids.len(), 0);
let mut result = (f32::INFINITY, 0);
for i in 0..centroids.len() {
let dis = f32::reduce_sum_of_d2(vector, &centroids[i]);
if dis <= result.0 {
result = (dis, i);
}
}
result.1
}

fn quick_centers(
c: usize,
dims: usize,
Expand Down Expand Up @@ -175,13 +149,7 @@ fn rabitq_index(n: usize, c: usize, samples: &[Vec<f32>], centroids: &[Vec<f32>]
.collect::<Vec<_>>()
}

fn flat_index(
_dims: usize,
n: usize,
c: usize,
samples: &[Vec<f32>],
centroids: &[Vec<f32>],
) -> Vec<usize> {
fn flat_index(n: usize, c: usize, samples: &[Vec<f32>], centroids: &[Vec<f32>]) -> Vec<usize> {
(0..n)
.into_par_iter()
.map(|i| {
Expand Down
12 changes: 3 additions & 9 deletions crates/rabitq/src/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,9 @@ pub fn code(vector: &[f32]) -> Code {
CodeMetadata {
dis_u_2: sum_of_x_2,
factor_cnt: {
let cnt_pos = vector
.iter()
.map(|x| x.is_sign_positive() as i32)
.sum::<i32>();
let cnt_neg = vector
.iter()
.map(|x| x.is_sign_negative() as i32)
.sum::<i32>();
(cnt_pos - cnt_neg) as f32
let cnt_pos = vector.iter().filter(|x| x.is_sign_positive()).count();
let cnt_neg = vector.iter().filter(|x| x.is_sign_negative()).count();
cnt_pos as f32 - cnt_neg as f32
},
factor_ip: sum_of_x_2 / sum_of_abs_x,
factor_err: {
Expand Down
21 changes: 11 additions & 10 deletions crates/rabitq/src/packing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ pub fn any_pack<T: Default>(mut x: impl Iterator<Item = T>) -> [T; 32] {
std::array::from_fn(|_| x.next()).map(|x| x.unwrap_or_default())
}

pub fn pack_to_u4(signs: &[bool]) -> Vec<u8> {
fn f(x: [bool; 4]) -> u8 {
x[0] as u8 | (x[1] as u8) << 1 | (x[2] as u8) << 2 | (x[3] as u8) << 3
}
let mut result = Vec::with_capacity(signs.len().div_ceil(4));
for i in 0..signs.len().div_ceil(4) {
let x = std::array::from_fn(|j| signs.get(i * 4 + j).copied().unwrap_or_default());
result.push(f(x));
}
result
pub fn pack_to_u4(input: &[bool]) -> Vec<u8> {
let f = |t: &[bool; 4]| t[0] as u8 | (t[1] as u8) << 1 | (t[2] as u8) << 2 | (t[3] as u8) << 3;
let (arrays, remainder) = input.as_chunks::<4>();
let mut buffer = [false; 4];
let tailing = if !remainder.is_empty() {
buffer[..remainder.len()].copy_from_slice(remainder);
Some(&buffer)
} else {
None
};
arrays.iter().chain(tailing).map(f).collect()
}
2 changes: 1 addition & 1 deletion crates/vchordrq/src/bulkdelete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn bulkdelete<R: RelationRead + RelationWrite, O: Operator>(
for first in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first).inspect(|_| check()),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), _, _, first, _| results.push(first),
);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/vchordrq/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ where
for first in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first).inspect(|guard| trace.push(guard.id())),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), _, _, first, _| {
results.push(first);
},
Expand Down
20 changes: 9 additions & 11 deletions crates/vchordrq/src/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,23 @@ pub fn insert<'b, R: RelationRead + RelationWrite, O: Operator>(
let epsilon = 1.9;

type State = (Reverse<Distance>, AlwaysEqual<f32>, AlwaysEqual<u32>);
let mut state: State = if !is_residual {
let first = meta_tuple.first();
// it's safe to leave it a fake value
(
Reverse(Distance::ZERO),
AlwaysEqual(0.0),
AlwaysEqual(first),
)
} else {
let mut state: State = if is_residual {
let prefetch =
BorrowedIter::from_slice(meta_tuple.centroid_prefetch(), |x| bump.alloc_slice(x));
let head = meta_tuple.centroid_head();
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
let distance = vectors::read_for_h1_tuple::<R, O, _>(
prefetch.map(|id| index.read(id)),
head,
LAccess::new(O::Vector::unpack(vector), O::DistanceAccessor::default()),
);
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))
} else {
// fast path
let distance = Distance::ZERO;
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))
};

Expand Down
2 changes: 1 addition & 1 deletion crates/vchordrq/src/maintain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ where
for first in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first).inspect(|_| check()),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), _, _, first, _| results.push(first),
);
}
Expand Down
6 changes: 3 additions & 3 deletions crates/vchordrq/src/prewarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ where
for first in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first).inspect(|_| counter += 1),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), head, _, first, prefetch| {
vectors::read_for_h1_tuple::<R, O, _>(
prefetch.iter().map(|&id| index.read(id)),
Expand Down Expand Up @@ -97,15 +97,15 @@ where
tape::read_directory_tape::<R>(by_next(index, jump_tuple.directory_first()));
tape::read_frozen_tape::<R, _, _>(
by_directory(&mut prefetch_h0_tuples, directory).inspect(|_| counter += 1),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
id_2(|_, _, _, _| {
results.push(());
}),
);
} else {
tape::read_frozen_tape::<R, _, _>(
by_next(index, jump_tuple.frozen_first()).inspect(|_| counter += 1),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); 32])),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
id_2(|_, _, _, _| {
results.push(());
}),
Expand Down
95 changes: 68 additions & 27 deletions crates/vchordrq/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
//
// Copyright (c) 2025 TensorChord Inc.

use crate::closure_lifetime_binder::id_2;
use crate::closure_lifetime_binder::{id_0, id_1, id_2};
use crate::linked_vec::LinkedVec;
use crate::operator::*;
use crate::tape::{by_directory, by_next};
use crate::tuples::*;
use crate::{Opaque, Page, tape, vectors};
use algo::accessor::LAccess;
use algo::accessor::{FunctionalAccessor, LAccess};
use algo::prefetcher::{Prefetcher, PrefetcherHeapFamily, PrefetcherSequenceFamily};
use algo::{BorrowedIter, Bump, PackedRefMut4, PackedRefMut8, RelationRead};
use always_equal::AlwaysEqual;
Expand Down Expand Up @@ -51,6 +51,7 @@ where
let dims = meta_tuple.dims();
let is_residual = meta_tuple.is_residual();
let height_of_root = meta_tuple.height_of_root();
let cells = meta_tuple.cells().to_vec();
assert_eq!(dims, vector.dims(), "unmatched dimensions");
if height_of_root as usize != 1 + probes.len() {
panic!(
Expand All @@ -59,27 +60,26 @@ where
probes.len()
);
}
debug_assert_eq!(cells[(height_of_root - 1) as usize], 1);

type State = Vec<(Reverse<Distance>, AlwaysEqual<f32>, AlwaysEqual<u32>)>;
let mut state: State = if !is_residual {
let first = meta_tuple.first();
// it's safe to leave it a fake value
vec![(
Reverse(Distance::ZERO),
AlwaysEqual(0.0),
AlwaysEqual(first),
)]
} else {
let mut state: State = if is_residual {
let prefetch =
BorrowedIter::from_slice(meta_tuple.centroid_prefetch(), |x| bump.alloc_slice(x));
let head = meta_tuple.centroid_head();
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
let distance = vectors::read_for_h1_tuple::<R, O, _>(
prefetch.map(|id| index.read(id)),
head,
LAccess::new(O::Vector::unpack(vector), O::DistanceAccessor::default()),
);
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
vec![(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))]
} else {
// fast path
let distance = Distance::ZERO;
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
vec![(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))]
};

Expand Down Expand Up @@ -124,7 +124,27 @@ where
};

for i in 1..height_of_root {
state = step(state).take(probes[i as usize - 1] as _).collect();
let partial_scan = probes[i as usize - 1] < cells[(height_of_root - 1 - i) as usize];
if partial_scan || is_residual {
state = step(state).take(probes[i as usize - 1] as _).collect();
} else {
// fast path
let mut results = LinkedVec::new();
for (Reverse(_), AlwaysEqual(_), AlwaysEqual(first)) in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), _, norm, first, _| {
results.push((
Reverse(Distance::ZERO),
AlwaysEqual(norm),
AlwaysEqual(first),
));
},
);
}
state = results.into_vec();
}
}

let mut results = LinkedVec::<(_, AlwaysEqual<_>)>::new();
Expand Down Expand Up @@ -192,6 +212,7 @@ where
let dims = meta_tuple.dims();
let is_residual = meta_tuple.is_residual();
let height_of_root = meta_tuple.height_of_root();
let cells = meta_tuple.cells().to_vec();
assert_eq!(dims, vector.dims(), "unmatched dimensions");
if height_of_root as usize != 1 + probes.len() {
panic!(
Expand All @@ -200,27 +221,26 @@ where
probes.len()
);
}
debug_assert_eq!(cells[(height_of_root - 1) as usize], 1);

type State = Vec<(Reverse<Distance>, AlwaysEqual<f32>, AlwaysEqual<u32>)>;
let mut state: State = if !is_residual {
let first = meta_tuple.first();
// it's safe to leave it a fake value
vec![(
Reverse(Distance::ZERO),
AlwaysEqual(0.0),
AlwaysEqual(first),
)]
} else {
let mut state: State = if is_residual {
let prefetch =
BorrowedIter::from_slice(meta_tuple.centroid_prefetch(), |x| bump.alloc_slice(x));
let head = meta_tuple.centroid_head();
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
let distance = vectors::read_for_h1_tuple::<R, O, _>(
prefetch.map(|id| index.read(id)),
head,
LAccess::new(O::Vector::unpack(vector), O::DistanceAccessor::default()),
);
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
vec![(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))]
} else {
// fast path
let distance = Distance::ZERO;
let norm = meta_tuple.centroid_norm();
let first = meta_tuple.first();
vec![(Reverse(distance), AlwaysEqual(norm), AlwaysEqual(first))]
};

Expand Down Expand Up @@ -266,8 +286,29 @@ where

let mut it = None;
for i in 1..height_of_root {
let it = it.insert(step(state));
state = it.take(probes[i as usize - 1] as _).collect();
let partial_scan = probes[i as usize - 1] < cells[(height_of_root - 1 - i) as usize];
let needs_sort = i + 1 == height_of_root && threshold != 0;
if partial_scan || is_residual || needs_sort {
let it = it.insert(step(state));
state = it.take(probes[i as usize - 1] as _).collect();
} else {
// fast path
let mut results = LinkedVec::new();
for (Reverse(_), AlwaysEqual(_), AlwaysEqual(first)) in state {
tape::read_h1_tape::<R, _, _>(
by_next(index, first),
|| FunctionalAccessor::new((), id_0(|_, _| ()), id_1(|_, _| [(); _])),
|(), _, norm, first, _| {
results.push((
Reverse(Distance::ZERO),
AlwaysEqual(norm),
AlwaysEqual(first),
));
},
);
}
state = results.into_vec();
}
}

let mut results = LinkedVec::<(_, AlwaysEqual<_>)>::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/vchordrq/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl VectorKind {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
#[validate(schema(function = "Self::validate_self"))]
pub struct VectorOptions {
Expand Down
Loading
Loading