From d39cde35da45e05bce82562af1bd3a2e37d9b425 Mon Sep 17 00:00:00 2001 From: skeptrune Date: Tue, 3 Dec 2024 16:47:59 -0800 Subject: [PATCH] feature: MMR takes 2x page size such that final results may differ --- .../src/components/Layouts/MainLayout.tsx | 42 +++++ server/src/data/models.rs | 2 +- server/src/handlers/message_handler.rs | 4 +- server/src/operators/qdrant_operator.rs | 12 +- server/src/operators/search_operator.rs | 147 +++++++++++------- 5 files changed, 148 insertions(+), 59 deletions(-) diff --git a/frontends/chat/src/components/Layouts/MainLayout.tsx b/frontends/chat/src/components/Layouts/MainLayout.tsx index 9ffae64493..02f8e2c7f8 100644 --- a/frontends/chat/src/components/Layouts/MainLayout.tsx +++ b/frontends/chat/src/components/Layouts/MainLayout.tsx @@ -79,6 +79,8 @@ const MainLayout = (props: LayoutProps) => { >(null); const [useImages, setUseImages] = createSignal(null); + const [useMmr, setUseMmr] = createSignal(false); + const [mmrLambda, setMmrLambda] = createSignal(0.5); const [useGroupSearch, setUseGroupSearch] = createSignal( null, ); @@ -232,6 +234,12 @@ const MainLayout = (props: LayoutProps) => { use_images: useImages(), }, }, + sort_options: { + mmr: { + use_mmr: useMmr(), + mmr_lambda: mmrLambda(), + }, + }, no_result_message: noResultMessage(), use_group_search: useGroupSearch(), search_type: searchType(), @@ -344,6 +352,13 @@ const MainLayout = (props: LayoutProps) => { filters: getFiltersFromStorage(dataset.dataset.id), concat_user_messages_query: concatUserMessagesQuery(), page_size: pageSize(), + sort_options: { + mmr: { + use_mmr: useMmr(), + mmr_lambda: mmrLambda(), + }, + }, + use_group_search: useGroupSearch(), search_query: searchQuery() != "" ? searchQuery() : undefined, score_threshold: minScore(), @@ -514,6 +529,33 @@ const MainLayout = (props: LayoutProps) => { }} /> +
+
+ +
+ { + setUseMmr(e.target.checked); + }} + /> +
+
+ + { + setMmrLambda(parseFloat(e.target.value)); + }} + /> +
, /// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks. pub sort_options: Option, - /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. - pub filters: Option, /// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0. pub score_threshold: Option, + /// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. + pub filters: Option, /// LLM options to use for the completion. If not specified, this defaults to the dataset's LLM options. pub llm_options: Option, /// Context options to use for the completion. If not specified, all options will default to false. diff --git a/server/src/operators/qdrant_operator.rs b/server/src/operators/qdrant_operator.rs index 281230c39a..7fad25f900 100644 --- a/server/src/operators/qdrant_operator.rs +++ b/server/src/operators/qdrant_operator.rs @@ -873,7 +873,11 @@ pub async fn search_over_groups_qdrant_query( QueryPointGroups { collection_name: qdrant_collection.to_string(), - limit: Some(query.limit * page), + limit: if use_mmr && query.limit < 20 { + Some(query.limit * 2) + } else { + Some(query.limit * page) + }, prefetch, using: vector_name, query: Some(qdrant_query), @@ -1101,7 +1105,11 @@ pub async fn search_qdrant_query( QueryPoints { collection_name: qdrant_collection.to_string(), - limit: Some(query.limit), + limit: if use_mmr && query.limit < 20 { + Some(query.limit * 2) + } else { + Some(query.limit * page) + }, offset: Some(offset), prefetch, using: vector_name, diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 59f8e973e5..44ae57ae34 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -45,6 +45,7 @@ use qdrant_client::qdrant::Filter; use qdrant_client::qdrant::{Condition, HasIdCondition, PointId}; use serde::{Deserialize, Serialize}; use simple_server_timing_header::Timer; +use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use utoipa::ToSchema; @@ -485,8 +486,13 @@ pub fn apply_mmr( let (first_idx_pos, &first_idx) = remaining_indices .iter() .enumerate() - .max_by(|(_, &a), (_, &b)| docs[a].score().partial_cmp(&docs[b].score()).unwrap()) - .unwrap(); + .max_by(|(_, &a), (_, &b)| { + docs[a] + .score() + .partial_cmp(&docs[b].score()) + .unwrap_or(Ordering::Equal) + }) + .unwrap_or((0, &0)); selected_indices.push(first_idx); remaining_indices.remove(first_idx_pos); @@ -502,10 +508,15 @@ pub fn apply_mmr( let max_similarity = selected_indices .iter() .map(|&sel_idx| { - cosine_similarity( - docs[idx].embedding().as_ref().unwrap().as_slice(), - docs[sel_idx].embedding().as_ref().unwrap().as_slice(), - ) + let idx_embedding = match docs[idx].embedding() { + Some(embedding) => embedding, + None => return 0.0, + }; + let sel_idx_embedding = match docs[sel_idx].embedding() { + Some(embedding) => embedding, + None => return 0.0, + }; + cosine_similarity(idx_embedding.as_slice(), sel_idx_embedding.as_slice()) }) .fold(f32::NEG_INFINITY, |a, b| a.max(b)); @@ -523,7 +534,7 @@ pub fn apply_mmr( selected_indices.push(remaining_indices[best_idx_pos]); remaining_indices.remove(best_idx_pos); } - log::info!("Selected indices: {:?}", selected_indices); + // Return document IDs in selection order selected_indices .iter() @@ -540,7 +551,7 @@ pub async fn retrieve_qdrant_points_query( ) -> Result { let page = if page == 0 { 1 } else { page }; - let use_mmr = mmr_options.is_some() && mmr_options.as_ref().unwrap().use_mmr; + let use_mmr = mmr_options.is_some_and(|mmr| mmr.use_mmr && mmr.mmr_lambda.unwrap_or(0.5) > 0.0); let (point_ids, count, batch_lengths) = search_qdrant_query( page, @@ -1090,7 +1101,7 @@ pub async fn retrieve_group_qdrant_points_query( config: &DatasetConfiguration, ) -> Result { let page = if page == 0 { 1 } else { page }; - let use_mmr = mmr_options.is_some() && mmr_options.as_ref().unwrap().use_mmr; + let use_mmr = mmr_options.is_some_and(|mmr| mmr.use_mmr && mmr.mmr_lambda.unwrap_or(0.5) > 0.0); let (point_ids, count) = search_over_groups_qdrant_query( page, qdrant_searches.clone(), @@ -1332,7 +1343,7 @@ pub async fn retrieve_chunks_for_groups( score: search_result.score.into(), }) }) - .sorted_by(|a, b| b.score.partial_cmp(&a.score).unwrap()) + .sorted_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)) .collect_vec(); let group_data = groups.iter().find(|group| group.id == group_search_result.group_id); @@ -1587,11 +1598,11 @@ pub fn rerank_chunks( sort_options: Option, ) -> Vec { let mut reranked_chunks = Vec::new(); - if sort_options.is_none() { - return chunks; - } - let sort_options = sort_options.unwrap(); + let sort_options = match sort_options { + Some(options) => options, + None => return chunks, + }; if sort_options.use_weights.unwrap_or(true) { chunks.into_iter().for_each(|mut chunk| { @@ -1606,8 +1617,8 @@ pub fn rerank_chunks( reranked_chunks = chunks; } - if sort_options.recency_bias.is_some() && sort_options.recency_bias.unwrap() > 0.0 { - let recency_weight = sort_options.recency_bias.unwrap(); + if sort_options.recency_bias.is_some_and(|r| r > 0.0) { + let recency_weight = sort_options.recency_bias.unwrap_or(0.0); let min_timestamp = reranked_chunks .iter() .filter_map(|chunk| chunk.metadata[0].metadata().time_stamp) @@ -1619,11 +1630,11 @@ pub fn rerank_chunks( let max_score = reranked_chunks .iter() .map(|chunk| chunk.score) - .max_by(|a, b| a.partial_cmp(b).unwrap()); + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let min_score = reranked_chunks .iter() .map(|chunk| chunk.score) - .min_by(|a, b| a.partial_cmp(b).unwrap()); + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); if let (Some(min), Some(max)) = (min_timestamp, max_timestamp) { let min_duration = chrono::Utc::now().signed_duration_since(min.and_utc()); @@ -1652,7 +1663,7 @@ pub fn rerank_chunks( } } - if sort_options.location_bias.is_some() && sort_options.location_bias.unwrap().bias > 0.0 { + if sort_options.location_bias.is_some_and(|g| g.bias > 0.0) { let info_with_bias = sort_options.location_bias.unwrap(); let query_location = info_with_bias.location; let location_bias = info_with_bias.bias; @@ -1660,16 +1671,20 @@ pub fn rerank_chunks( .iter() .filter_map(|chunk| chunk.metadata[0].metadata().location) .map(|location| query_location.haversine_distance_to(&location)); - let max_distance = distances.clone().max_by(|a, b| a.partial_cmp(b).unwrap()); - let min_distance = distances.clone().min_by(|a, b| a.partial_cmp(b).unwrap()); + let max_distance = distances + .clone() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + let min_distance = distances + .clone() + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let max_score = reranked_chunks .iter() .map(|chunk| chunk.score) - .max_by(|a, b| a.partial_cmp(b).unwrap()); + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let min_score = reranked_chunks .iter() .map(|chunk| chunk.score) - .min_by(|a, b| a.partial_cmp(b).unwrap()); + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); reranked_chunks = reranked_chunks .iter_mut() @@ -1710,24 +1725,31 @@ pub fn rerank_chunks( .collect::>(); } - if sort_options.mmr.is_some() - && sort_options - .mmr - .as_ref() - .map(|m| m.use_mmr) - .unwrap_or(false) + if sort_options + .mmr + .as_ref() + .is_some_and(|mmr| mmr.use_mmr && mmr.mmr_lambda.unwrap_or(0.5) > 0.0) { - let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.3); + let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.5); let max_result = search_results.len(); let reranked_results = apply_mmr(search_results, lambda, max_result); reranked_chunks = reranked_chunks .iter_mut() .map(|chunk| { - let search_result = reranked_results + let search_result = match reranked_results .iter() .find(|result| result.point_id == chunk.metadata[0].qdrant_point_id()) - .unwrap(); + { + Some(result) => result, + None => { + log::error!( + "Failed to find search result for qdrant_point_id for rerank_chunks mmr: {:?}", + chunk.metadata[0].qdrant_point_id() + ); + return chunk.clone(); + } + }; chunk.score = search_result.score.into(); chunk.clone() }) @@ -1749,11 +1771,12 @@ pub fn rerank_groups( sort_options: Option, ) -> Vec { let mut reranked_groups = Vec::new(); - if sort_options.is_none() { - return groups; - } - let sort_options = sort_options.unwrap(); + let sort_options = match sort_options { + Some(options) => options, + None => return groups, + }; + if sort_options.use_weights.unwrap_or(true) { groups.into_iter().for_each(|mut group| { let first_chunk = group.metadata.get_mut(0).unwrap(); @@ -1768,7 +1791,7 @@ pub fn rerank_groups( reranked_groups = groups; } - if sort_options.recency_bias.is_some() && sort_options.recency_bias.unwrap() > 0.0 { + if sort_options.recency_bias.is_some_and(|r| r > 0.0) { let recency_weight = sort_options.recency_bias.unwrap(); let min_timestamp = reranked_groups .iter() @@ -1781,11 +1804,11 @@ pub fn rerank_groups( let max_score = reranked_groups .iter() .map(|group| group.metadata[0].score) - .max_by(|a, b| a.partial_cmp(b).unwrap()); + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let min_score = reranked_groups .iter() .map(|group| group.metadata[0].score) - .min_by(|a, b| a.partial_cmp(b).unwrap()); + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); if let (Some(min), Some(max)) = (min_timestamp, max_timestamp) { let min_duration = chrono::Utc::now().signed_duration_since(min.and_utc()); @@ -1815,7 +1838,7 @@ pub fn rerank_groups( } } - if sort_options.location_bias.is_some() && sort_options.location_bias.unwrap().bias > 0.0 { + if sort_options.location_bias.is_some_and(|g| g.bias > 0.0) { let info_with_bias = sort_options.location_bias.unwrap(); let query_location = info_with_bias.location; let location_bias = info_with_bias.bias; @@ -1823,16 +1846,20 @@ pub fn rerank_groups( .iter() .filter_map(|group| group.metadata[0].metadata[0].metadata().location) .map(|location| query_location.haversine_distance_to(&location)); - let max_distance = distances.clone().max_by(|a, b| a.partial_cmp(b).unwrap()); - let min_distance = distances.clone().min_by(|a, b| a.partial_cmp(b).unwrap()); + let max_distance = distances + .clone() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + let min_distance = distances + .clone() + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let max_score = reranked_groups .iter() .map(|group| group.metadata[0].score) - .max_by(|a, b| a.partial_cmp(b).unwrap()); + .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); let min_score = reranked_groups .iter() .map(|group| group.metadata[0].score) - .min_by(|a, b| a.partial_cmp(b).unwrap()); + .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); reranked_groups = reranked_groups .iter_mut() @@ -1875,14 +1902,12 @@ pub fn rerank_groups( .collect::>(); } - if sort_options.mmr.is_some() - && sort_options - .mmr - .as_ref() - .map(|m| m.use_mmr) - .unwrap_or(false) + if sort_options + .mmr + .as_ref() + .is_some_and(|mmr| mmr.use_mmr && mmr.mmr_lambda.unwrap_or(0.5) > 0.0) { - let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.3); + let lambda = sort_options.mmr.unwrap().mmr_lambda.unwrap_or(0.5); let max_result = search_results.len(); let reranked_results = apply_mmr(search_results, lambda, max_result); @@ -2166,6 +2191,9 @@ pub async fn search_chunks_query( search_chunk_query_results.search_results, data.sort_options, ); + result_chunks + .score_chunks + .truncate(data.page_size.unwrap_or(10) as usize); timer.add("reranking"); @@ -2472,6 +2500,9 @@ pub async fn search_groups_query( search_semantic_chunk_query_results.search_results, data.sort_options, ); + result_chunks + .score_chunks + .truncate(data.page_size.unwrap_or(10) as usize); Ok(SearchWithinGroupResults { bookmarks: result_chunks.score_chunks, @@ -2613,11 +2644,12 @@ pub async fn search_hybrid_groups( config, ) .await?; - let score_chunks = rerank_chunks( + let mut score_chunks: Vec = rerank_chunks( cross_encoder_results, qdrant_results.search_results, data.sort_options, ); + score_chunks.truncate(data.page_size.unwrap_or(10) as usize); score_chunks .iter() @@ -2633,11 +2665,13 @@ pub async fn search_hybrid_groups( ) .await?; - rerank_chunks( + let mut score_chunks: Vec = rerank_chunks( cross_encoder_results, qdrant_results.search_results, data.sort_options, - ) + ); + score_chunks.truncate(data.page_size.unwrap_or(10) as usize); + score_chunks }; if let Some(score_threshold) = data.score_threshold { @@ -2768,6 +2802,9 @@ pub async fn search_over_groups_query( search_over_groups_qdrant_result.search_results, data.sort_options, ); + result_chunks + .group_chunks + .truncate(data.page_size.unwrap_or(10) as usize); result_chunks.corrected_query = corrected_query.map(|c| c.query); @@ -2985,6 +3022,7 @@ pub async fn hybrid_search_over_groups( qdrant_results.search_results, data.sort_options, ); + reranked_chunks.truncate(data.page_size.unwrap_or(10) as usize); let result_chunks = DeprecatedSearchOverGroupsResponseBody { group_chunks: reranked_chunks, @@ -3135,6 +3173,7 @@ pub async fn autocomplete_chunks_query( search_chunk_query_results.search_results, data.sort_options, )); + reranked_chunks.truncate(data.page_size.unwrap_or(10) as usize); result_chunks.score_chunks = reranked_chunks;