diff --git a/clients/ts-sdk/src/__tests__/constants.ts b/clients/ts-sdk/src/__tests__/constants.ts index 19c0e0c272..4097f086b4 100644 --- a/clients/ts-sdk/src/__tests__/constants.ts +++ b/clients/ts-sdk/src/__tests__/constants.ts @@ -8,11 +8,11 @@ export const TRIEVE = new TrieveSDK({ organizationId: "de73679c-707f-4fc2-853e-994c910d944c", }); -// export const TRIEVELOCAL = new TrieveSDK({ +// export const TRIEVE = new TrieveSDK({ // baseUrl: "http://localhost:8090", -// organizationId: "f8bd8fc0-0f66-48d3-92be-a2e5f74c7952", -// datasetId: "506effe8-c460-4ec0-8576-6dd6ef1a57f0", -// apiKey: "tr-x5I18LghGCL22qXYQS6BrlJbqAOPI7Tz", +// organizationId: "967d4740-d8f0-4f3a-8a62-3c1297e5f6c4", +// datasetId: "88fb2a53-17bd-4311-9763-051dc5c9c476", +// apiKey: "tr-5OiU6tPsjgcMz0AeujPbKlBJFqeXVJ9G", // }); export const EXAMPLE_TOPIC_ID = "f85984e1-7818-4971-b300-2f462fe1a5a2"; diff --git a/clients/ts-sdk/src/functions/organization/organization.test.ts b/clients/ts-sdk/src/functions/organization/organization.test.ts index e5d69ab5c2..7a7bf0fba5 100644 --- a/clients/ts-sdk/src/functions/organization/organization.test.ts +++ b/clients/ts-sdk/src/functions/organization/organization.test.ts @@ -3,132 +3,141 @@ import { TRIEVE } from "../../__tests__/constants"; import { TrieveSDK } from "../../sdk"; import { CreateApiKeyResponse, ReturnQueuedChunk } from "../../types.gen"; -describe("Organization Tests", async () => { - let trieve: TrieveSDK; - beforeAll(() => { - trieve = TRIEVE; - }); - - test("create an api key and verify it works", async () => { - const apiKeyResponse = await trieve.createOrganizationApiKey({ - role: 1, - name: "test suite key", +describe( + "Organization Tests", + async () => { + let trieve: TrieveSDK; + beforeAll(() => { + trieve = TRIEVE; }); - expectTypeOf(apiKeyResponse).toEqualTypeOf(); + test("create an api key and verify it works", async () => { + const apiKeyResponse = await trieve.createOrganizationApiKey({ + role: 1, + name: "test suite key", + }); - const newTrieve = new TrieveSDK({ - apiKey: apiKeyResponse.api_key, - datasetId: trieve.datasetId, - }); + expectTypeOf(apiKeyResponse).toEqualTypeOf(); - const queuedChunk = await newTrieve.createChunk({ - chunk_html: "testing hello world", - tracking_id: "1234", - tag_set: ["test"], - }); + const newTrieve = new TrieveSDK({ + apiKey: apiKeyResponse.api_key, + datasetId: trieve.datasetId, + baseUrl: trieve.trieve.baseUrl, + }); - expectTypeOf(queuedChunk).toEqualTypeOf(); + const queuedChunk = await newTrieve.createChunk({ + chunk_html: "testing hello world", + tracking_id: "1234", + tag_set: ["test"], + }); - newTrieve.deleteChunkByTrackingId({ - trackingId: "1234", - }); - }); - - test("create an expired api key and verify it does not work", async () => { - const apiKeyResponse = await trieve.createOrganizationApiKey({ - expires_at: new Date(new Date().setDate(new Date().getDate() - 1)) - .toISOString() - .slice(0, 19) - .replace("T", " "), - role: 1, - name: "test suite key", + expectTypeOf(queuedChunk).toEqualTypeOf(); + + newTrieve.deleteChunkByTrackingId({ + trackingId: "1234", + }); }); - expectTypeOf(apiKeyResponse).toEqualTypeOf(); + test("create an expired api key and verify it does not work", async () => { + const apiKeyResponse = await trieve.createOrganizationApiKey({ + expires_at: new Date(new Date().setDate(new Date().getDate() - 1)) + .toISOString() + .slice(0, 19) + .replace("T", " "), + role: 1, + name: "test suite key", + }); + + expectTypeOf(apiKeyResponse).toEqualTypeOf(); + console.log(apiKeyResponse); + + let errorOccurred = false; - let errorOccurred = false; + const newTrieve = new TrieveSDK({ + apiKey: apiKeyResponse.api_key, + datasetId: trieve.datasetId, + baseUrl: trieve.trieve.baseUrl, + }); - const newTrieve = new TrieveSDK({ - apiKey: apiKeyResponse.api_key, - datasetId: trieve.datasetId, + try { + await newTrieve.createChunk({ + chunk_html: "testing hello world", + tracking_id: "should_never_work", + tag_set: ["test"], + }); + + newTrieve.deleteChunkByTrackingId({ + trackingId: "should_never_work", + }); + console.log("should not have worked"); + } catch (e) { + errorOccurred = true; + } + + expect(errorOccurred).toBe(true); }); - try { - await newTrieve.createChunk({ - chunk_html: "testing hello world", - tracking_id: "should_never_work", - tag_set: ["test"], + test("create an api key with a filter for test and verify it excludes chunks without the tag", async () => { + const apiKeyResponse = await trieve.createOrganizationApiKey({ + role: 1, + name: "test suite key", + default_params: { + filters: { + must: [ + { + field: "tag_set", + match_all: ["test"], + }, + ], + }, + }, }); - newTrieve.deleteChunkByTrackingId({ - trackingId: "should_never_work", + expectTypeOf(apiKeyResponse).toEqualTypeOf(); + + const newTrieve = new TrieveSDK({ + apiKey: apiKeyResponse.api_key, + datasetId: trieve.datasetId, + baseUrl: trieve.trieve.baseUrl, }); - } catch (e) { - errorOccurred = true; - } - - expect(errorOccurred).toBe(true); - }); - - test("create an api key with a filter for test and verify it excludes chunks without the tag", async () => { - const apiKeyResponse = await trieve.createOrganizationApiKey({ - role: 1, - name: "test suite key", - default_params: { + + const queuedChunks = await newTrieve.createChunk([ + { + chunk_html: "testing hello world", + tracking_id: "not_test", + tag_set: ["not_test"], + }, + { + chunk_html: "testing hello world", + tracking_id: "test", + tag_set: ["test"], + }, + ]); + + expectTypeOf(queuedChunks).toEqualTypeOf(); + + await new Promise((r) => setTimeout(r, 10000)); + + const chunksResp = await newTrieve.scroll({ + page_size: 100, filters: { must: [ { field: "tag_set", - match_all: ["test"], + match_all: ["not_test"], }, ], }, - }, - }); - - expectTypeOf(apiKeyResponse).toEqualTypeOf(); - - const newTrieve = new TrieveSDK({ - apiKey: apiKeyResponse.api_key, - datasetId: trieve.datasetId, - }); - - const queuedChunks = await newTrieve.createChunk([ - { - chunk_html: "testing hello world", - tracking_id: "not_test", - tag_set: ["not_test"], - }, - { - chunk_html: "testing hello world", - tracking_id: "test", - tag_set: ["test"], - }, - ]); - - expectTypeOf(queuedChunks).toEqualTypeOf(); - - await new Promise((r) => setTimeout(r, 10000)); - - const chunksResp = await newTrieve.scroll({ - page_size: 100, - filters: { - must: [ - { - field: "tag_set", - match_all: ["not_test"], - }, - ], - }, - }); + }); - for (const chunk of chunksResp.chunks) { - expect(chunk.tag_set).toContain("test"); - } + for (const chunk of chunksResp.chunks) { + expect(chunk.tag_set).toContain("test"); + } - newTrieve.deleteChunkByTrackingId({ - trackingId: "1234", + newTrieve.deleteChunkByTrackingId({ + trackingId: "1234", + }); }); - }); -}); + }, + { timeout: 100000 }, +); diff --git a/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/down.sql b/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/down.sql new file mode 100644 index 0000000000..fed8379b96 --- /dev/null +++ b/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/down.sql @@ -0,0 +1,5 @@ +-- This file should undo anything in `up.sql` +ALTER TABLE organization_api_key + DROP CONSTRAINT organization_api_key_organization_id_fkey, + ADD CONSTRAINT organization_api_key_organization_id_fkey + FOREIGN KEY (organization_id) REFERENCES organizations(id); \ No newline at end of file diff --git a/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/up.sql b/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/up.sql new file mode 100644 index 0000000000..bc828ee3fd --- /dev/null +++ b/server/migrations/2024-11-28-041452_add_on_delete_cascade_for_org_api_key/up.sql @@ -0,0 +1,5 @@ +-- Your SQL goes here +ALTER TABLE organization_api_key + DROP CONSTRAINT organization_api_key_organization_id_fkey, + ADD CONSTRAINT organization_api_key_organization_id_fkey + FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; \ No newline at end of file diff --git a/server/src/data/models.rs b/server/src/data/models.rs index fc832edaf2..79ffdb8a84 100644 --- a/server/src/data/models.rs +++ b/server/src/data/models.rs @@ -6661,6 +6661,8 @@ pub struct SortOptions { pub sort_by: Option, /// Location lets you rank your results by distance from a location. If not specified, this has no effect. Bias allows you to determine how much of an effect the location of chunks will have on the search results. If not specified, this defaults to 0.0. We recommend setting this to 1.0 for a gentle reranking of the results, >3.0 for a strong reranking of the results. pub location_bias: Option, + /// Recency Bias lets you determine how much of an effect the recency of chunks will have on the search results. If not specified, this defaults to 0.0. We recommend setting this to 1.0 for a gentle reranking of the results, >3.0 for a strong reranking of the results. + pub recency_bias: Option, /// Set use_weights to true to use the weights of the chunks in the result set in order to sort them. If not specified, this defaults to true. pub use_weights: Option, /// Tag weights is a JSON object which can be used to boost the ranking of chunks with certain tags. This is useful for when you want to be able to bias towards chunks with a certain tag on the fly. The keys are the tag names and the values are the weights. diff --git a/server/src/operators/organization_operator.rs b/server/src/operators/organization_operator.rs index 68ecf8739b..59a74787c2 100644 --- a/server/src/operators/organization_operator.rs +++ b/server/src/operators/organization_operator.rs @@ -777,6 +777,11 @@ pub async fn get_assumed_user_by_organization_api_key( let api_key: OrganizationApiKey = organization_api_key_columns::organization_api_key .filter(organization_api_key_columns::api_key_hash.eq(hash_function(api_key))) + .filter( + organization_api_key_columns::expires_at + .is_null() + .or(organization_api_key_columns::expires_at.ge(diesel::dsl::now.nullable())), + ) .select(OrganizationApiKey::as_select()) .first::(&mut conn) .await diff --git a/server/src/operators/search_operator.rs b/server/src/operators/search_operator.rs index 30c36525d1..3a27a3dd40 100644 --- a/server/src/operators/search_operator.rs +++ b/server/src/operators/search_operator.rs @@ -17,9 +17,9 @@ use super::typo_operator::correct_query; use crate::data::models::{ convert_to_date_time, ChunkGroup, ChunkGroupAndFileId, ChunkMetadata, ChunkMetadataStringTagSet, ChunkMetadataTypes, ConditionType, ContentChunkMetadata, Dataset, - DatasetConfiguration, GeoInfoWithBias, HasIDCondition, QdrantChunkMetadata, QdrantSortBy, - QueryTypes, ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, - SlimChunkMetadata, SortByField, SortBySearchType, UnifiedId, + DatasetConfiguration, HasIDCondition, QdrantChunkMetadata, QdrantSortBy, QueryTypes, + ReRankOptions, RedisPool, ScoreChunk, ScoreChunkDTO, SearchMethod, SlimChunkMetadata, + SortByField, SortBySearchType, SortOptions, UnifiedId, }; use crate::handlers::chunk_handler::{ AutocompleteReqPayload, ChunkFilter, CountChunkQueryResponseBody, CountChunksReqPayload, @@ -1494,12 +1494,16 @@ pub async fn retrieve_chunks_from_point_ids( pub fn rerank_chunks( chunks: Vec, - tag_weights: Option>, - use_weights: Option, - query_location: Option, + sort_options: Option, ) -> Vec { let mut reranked_chunks = Vec::new(); - if use_weights.unwrap_or(true) { + if sort_options.is_none() { + return chunks; + } + + let sort_options = sort_options.unwrap(); + + if sort_options.use_weights.unwrap_or(true) { chunks.into_iter().for_each(|mut chunk| { if chunk.metadata[0].metadata().weight == 0.0 { chunk.score *= 1.0; @@ -1512,8 +1516,54 @@ pub fn rerank_chunks( reranked_chunks = chunks; } - if query_location.is_some() && query_location.unwrap().bias > 0.0 { - let info_with_bias = query_location.unwrap(); + if sort_options.recency_bias.is_some() && sort_options.recency_bias.unwrap() > 0.0 { + let recency_weight = sort_options.recency_bias.unwrap(); + let min_timestamp = reranked_chunks + .iter() + .filter_map(|chunk| chunk.metadata[0].metadata().time_stamp) + .min(); + let max_timestamp = reranked_chunks + .iter() + .filter_map(|chunk| chunk.metadata[0].metadata().time_stamp) + .max(); + let max_score = reranked_chunks + .iter() + .map(|chunk| chunk.score) + .max_by(|a, b| a.partial_cmp(b).unwrap()); + let min_score = reranked_chunks + .iter() + .map(|chunk| chunk.score) + .min_by(|a, b| a.partial_cmp(b).unwrap()); + + if let (Some(min), Some(max)) = (min_timestamp, max_timestamp) { + let min_duration = chrono::Utc::now().signed_duration_since(min.and_utc()); + let max_duration = chrono::Utc::now().signed_duration_since(max.and_utc()); + + reranked_chunks = reranked_chunks + .iter_mut() + .map(|chunk| { + if let Some(time_stamp) = chunk.metadata[0].metadata().time_stamp { + let duration = + chrono::Utc::now().signed_duration_since(time_stamp.and_utc()); + let normalized_recency_score = (duration.num_seconds() as f32 + - min_duration.num_seconds() as f32) + / (max_duration.num_seconds() as f32 + - min_duration.num_seconds() as f32); + + let normalized_chunk_score = (chunk.score - min_score.unwrap_or(0.0)) + / (max_score.unwrap_or(1.0) - min_score.unwrap_or(0.0)); + + chunk.score = (normalized_chunk_score * (1.0 / recency_weight) as f64) + + (recency_weight * normalized_recency_score) as f64 + } + chunk.clone() + }) + .collect::>(); + } + } + + if sort_options.location_bias.is_some() && sort_options.location_bias.unwrap().bias > 0.0 { + let info_with_bias = sort_options.location_bias.unwrap(); let query_location = info_with_bias.location; let location_bias = info_with_bias.bias; let distances = reranked_chunks @@ -1550,7 +1600,7 @@ pub fn rerank_chunks( .collect::>(); } - if let Some(tag_weights) = tag_weights { + if let Some(tag_weights) = sort_options.tag_weights { reranked_chunks = reranked_chunks .iter_mut() .map(|chunk| { @@ -1580,12 +1630,15 @@ pub fn rerank_chunks( pub fn rerank_groups( groups: Vec, - tag_weights: Option>, - use_weights: Option, - query_location: Option, + sort_options: Option, ) -> Vec { let mut reranked_groups = Vec::new(); - if use_weights.unwrap_or(true) { + if sort_options.is_none() { + return groups; + } + + let sort_options = sort_options.unwrap(); + if sort_options.use_weights.unwrap_or(true) { groups.into_iter().for_each(|mut group| { let first_chunk = group.metadata.get_mut(0).unwrap(); if first_chunk.metadata[0].metadata().weight == 0.0 { @@ -1599,8 +1652,55 @@ pub fn rerank_groups( reranked_groups = groups; } - if query_location.is_some() && query_location.unwrap().bias > 0.0 { - let info_with_bias = query_location.unwrap(); + if sort_options.recency_bias.is_some() && sort_options.recency_bias.unwrap() > 0.0 { + let recency_weight = sort_options.recency_bias.unwrap(); + let min_timestamp = reranked_groups + .iter() + .filter_map(|group| group.metadata[0].metadata[0].metadata().time_stamp) + .min(); + let max_timestamp = reranked_groups + .iter() + .filter_map(|group| group.metadata[0].metadata[0].metadata().time_stamp) + .max(); + let max_score = reranked_groups + .iter() + .map(|group| group.metadata[0].score) + .max_by(|a, b| a.partial_cmp(b).unwrap()); + let min_score = reranked_groups + .iter() + .map(|group| group.metadata[0].score) + .min_by(|a, b| a.partial_cmp(b).unwrap()); + + if let (Some(min), Some(max)) = (min_timestamp, max_timestamp) { + let min_duration = chrono::Utc::now().signed_duration_since(min.and_utc()); + let max_duration = chrono::Utc::now().signed_duration_since(max.and_utc()); + + reranked_groups = reranked_groups + .iter_mut() + .map(|group| { + let first_chunk = group.metadata.get_mut(0).unwrap(); + if let Some(time_stamp) = first_chunk.metadata[0].metadata().time_stamp { + let duration = + chrono::Utc::now().signed_duration_since(time_stamp.and_utc()); + let normalized_recency_score = (duration.num_seconds() as f32 + - min_duration.num_seconds() as f32) + / (max_duration.num_seconds() as f32 + - min_duration.num_seconds() as f32); + + let normalized_chunk_score = (first_chunk.score - min_score.unwrap_or(0.0)) + / (max_score.unwrap_or(1.0) - min_score.unwrap_or(0.0)); + + first_chunk.score = (normalized_chunk_score * (1.0 / recency_weight) as f64) + + (recency_weight * normalized_recency_score) as f64 + } + group.clone() + }) + .collect::>(); + } + } + + if sort_options.location_bias.is_some() && sort_options.location_bias.unwrap().bias > 0.0 { + let info_with_bias = sort_options.location_bias.unwrap(); let query_location = info_with_bias.location; let location_bias = info_with_bias.bias; let distances = reranked_groups @@ -1638,7 +1738,7 @@ pub fn rerank_groups( .collect::>(); } - if let Some(tag_weights) = tag_weights { + if let Some(tag_weights) = sort_options.tag_weights { reranked_groups = reranked_groups .iter_mut() .map(|group| { @@ -1917,21 +2017,7 @@ pub async fn search_chunks_query( result_chunks.score_chunks }; - result_chunks.score_chunks = rerank_chunks( - rerank_chunks_input, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); + result_chunks.score_chunks = rerank_chunks(rerank_chunks_input, data.sort_options); timer.add("reranking"); @@ -2075,21 +2161,7 @@ pub async fn search_hybrid_chunks( cross_encoder_results.retain(|chunk| chunk.score >= score_threshold.into()); } - rerank_chunks( - cross_encoder_results, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ) + rerank_chunks(cross_encoder_results, data.sort_options) }; reranked_chunks.truncate(data.page_size.unwrap_or(10) as usize); @@ -2241,21 +2313,7 @@ pub async fn search_groups_query( result_chunks.score_chunks }; - result_chunks.score_chunks = rerank_chunks( - rerank_chunks_input, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); + result_chunks.score_chunks = rerank_chunks(rerank_chunks_input, data.sort_options); Ok(SearchWithinGroupResults { bookmarks: result_chunks.score_chunks, @@ -2396,21 +2454,7 @@ pub async fn search_hybrid_groups( config, ) .await?; - let score_chunks = rerank_chunks( - cross_encoder_results, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); + let score_chunks = rerank_chunks(cross_encoder_results, data.sort_options); score_chunks .iter() @@ -2426,21 +2470,7 @@ pub async fn search_hybrid_groups( ) .await?; - rerank_chunks( - cross_encoder_results, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ) + rerank_chunks(cross_encoder_results, data.sort_options) }; if let Some(score_threshold) = data.score_threshold { @@ -2565,21 +2595,7 @@ pub async fn search_over_groups_query( timer.add("fetched from postgres"); - result_chunks.group_chunks = rerank_groups( - result_chunks.group_chunks, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); + result_chunks.group_chunks = rerank_groups(result_chunks.group_chunks, data.sort_options); result_chunks.corrected_query = corrected_query.map(|c| c.query); @@ -2791,21 +2807,7 @@ pub async fn hybrid_search_over_groups( }); } - reranked_chunks = rerank_groups( - reranked_chunks, - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); + reranked_chunks = rerank_groups(reranked_chunks, data.sort_options); let result_chunks = DeprecatedSearchOverGroupsResponseBody { group_chunks: reranked_chunks, @@ -2940,36 +2942,8 @@ pub async fn autocomplete_chunks_query( (result_chunks.score_chunks.as_slice(), empty_vec) }; - let mut reranked_chunks = rerank_chunks( - before_increase.to_vec(), - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - ); - reranked_chunks.extend(rerank_chunks( - after_increase.to_vec(), - data.sort_options - .as_ref() - .map(|d| d.tag_weights.clone()) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.use_weights) - .unwrap_or_default(), - data.sort_options - .as_ref() - .map(|d| d.location_bias) - .unwrap_or_default(), - )); + let mut reranked_chunks = rerank_chunks(before_increase.to_vec(), data.sort_options.clone()); + reranked_chunks.extend(rerank_chunks(after_increase.to_vec(), data.sort_options)); result_chunks.score_chunks = reranked_chunks;