Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 94 additions & 28 deletions vortex-datafusion/src/convert/exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ pub struct ProcessedProjection {
pub(crate) fn make_vortex_predicate(
expr_convertor: &dyn ExpressionConvertor,
predicate: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we replace the current convert?

) -> DFResult<Option<Expression>> {
let exprs = predicate
.iter()
.map(|e| expr_convertor.convert(e.as_ref()))
.map(|e| expr_convertor.convert_with_schema(e.as_ref(), input_schema))
.collect::<DFResult<Vec<_>>>()?;

Ok(and_collect(exprs))
Expand All @@ -70,6 +71,17 @@ pub trait ExpressionConvertor: Send + Sync {
/// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;

/// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`],
/// using the input schema when DataFusion expression metadata affects semantics.
fn convert_with_schema(
&self,
expr: &dyn PhysicalExpr,
input_schema: &Schema,
) -> DFResult<Expression> {
let _ = input_schema;
self.convert(expr)
}

/// Split a projection into Vortex expressions that can be pushed down and leftover
/// DataFusion projections that need to be evaluated after the scan.
fn split_projection(
Expand Down Expand Up @@ -112,7 +124,11 @@ pub struct DefaultExpressionConvertor {}

impl DefaultExpressionConvertor {
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
fn try_convert_scalar_function(
&self,
scalar_fn: &ScalarFunctionExpr,
input_schema: Option<&Schema>,
) -> DFResult<Expression> {
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
{
// DataFusion's GetFieldFunc flattens nested field access into a single call
Expand All @@ -124,7 +140,7 @@ impl DefaultExpressionConvertor {
.split_first()
.ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;

let mut result = self.convert(source_expr.as_ref())?;
let mut result = self.convert_impl(source_expr.as_ref(), input_schema)?;
for expr in field_names {
let field_name = expr
.downcast_ref::<df_expr::Literal>()
Expand All @@ -147,7 +163,17 @@ impl DefaultExpressionConvertor {
}

/// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
#[cfg(test)]
fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
self.try_convert_case_expr_with_schema(case_expr, None)
}

/// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
fn try_convert_case_expr_with_schema(
&self,
case_expr: &df_expr::CaseExpr,
input_schema: Option<&Schema>,
) -> DFResult<Expression> {
// DataFusion CaseExpr has:
// - expr(): Optional base expression (for "CASE expr WHEN ..." form)
// - when_then_expr(): Vec of (when, then) pairs
Expand All @@ -170,33 +196,31 @@ impl DefaultExpressionConvertor {
// Convert all when/then pairs to (condition, value) tuples
let mut pairs = Vec::with_capacity(when_then_pairs.len());
for (when_expr, then_expr) in when_then_pairs {
let condition = self.convert(when_expr.as_ref())?;
let value = self.convert(then_expr.as_ref())?;
let condition = self.convert_impl(when_expr.as_ref(), input_schema)?;
let value = self.convert_impl(then_expr.as_ref(), input_schema)?;
pairs.push((condition, value));
}

// Convert optional else expression
let else_value = case_expr
.else_expr()
.map(|e| self.convert(e.as_ref()))
.map(|e| self.convert_impl(e.as_ref(), input_schema))
.transpose()?;

// Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs
Ok(nested_case_when(pairs, else_value))
}
}

impl ExpressionConvertor for DefaultExpressionConvertor {
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
can_be_pushed_down_impl(expr, schema)
}

fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
fn convert_impl(
&self,
df: &dyn PhysicalExpr,
input_schema: Option<&Schema>,
) -> DFResult<Expression> {
// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
// for that node, up to any `and` or `or` node.
if let Some(binary_expr) = df.downcast_ref::<df_expr::BinaryExpr>() {
let left = self.convert(binary_expr.left().as_ref())?;
let right = self.convert(binary_expr.right().as_ref())?;
let left = self.convert_impl(binary_expr.left().as_ref(), input_schema)?;
let right = self.convert_impl(binary_expr.right().as_ref(), input_schema)?;
let operator = try_operator_from_df(binary_expr.op())?;

return Ok(Binary.new_expr(operator, [left, right]));
Expand All @@ -207,8 +231,8 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
}

if let Some(like) = df.downcast_ref::<df_expr::LikeExpr>() {
let child = self.convert(like.expr().as_ref())?;
let pattern = self.convert(like.pattern().as_ref())?;
let child = self.convert_impl(like.expr().as_ref(), input_schema)?;
let pattern = self.convert_impl(like.pattern().as_ref(), input_schema)?;
return Ok(Like.new_expr(
LikeOptions {
negated: like.negated(),
Expand All @@ -224,23 +248,26 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
}

if let Some(cast_expr) = df.downcast_ref::<df_expr::CastExpr>() {
let cast_dtype = DType::from_arrow(cast_expr.target_field().as_ref());
let child = self.convert(cast_expr.expr().as_ref())?;
let mut cast_dtype = DType::from_arrow(cast_expr.target_field().as_ref());
if let Some(input_schema) = input_schema {
cast_dtype = cast_dtype.with_nullability(cast_expr.nullable(input_schema)?.into());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside CastExpr::nullable there's this comment:

A cast is nullable if either the child is nullable or the
target field allows nulls. This conservative rule prevents
optimizers from assuming a non-null result when a null input could
still propagate. return_field() continues to expose the exact
target metadata separately.

So I think this is correct? Did you run into it with some real world example or just a random AI thing?

}
let child = self.convert_impl(cast_expr.expr().as_ref(), input_schema)?;
return Ok(cast(child, cast_dtype));
}

if let Some(is_null_expr) = df.downcast_ref::<df_expr::IsNullExpr>() {
let arg = self.convert(is_null_expr.arg().as_ref())?;
let arg = self.convert_impl(is_null_expr.arg().as_ref(), input_schema)?;
return Ok(is_null(arg));
}

if let Some(is_not_null_expr) = df.downcast_ref::<df_expr::IsNotNullExpr>() {
let arg = self.convert(is_not_null_expr.arg().as_ref())?;
let arg = self.convert_impl(is_not_null_expr.arg().as_ref(), input_schema)?;
return Ok(is_not_null(arg));
}

if let Some(in_list) = df.downcast_ref::<df_expr::InListExpr>() {
let value = self.convert(in_list.expr().as_ref())?;
let value = self.convert_impl(in_list.expr().as_ref(), input_schema)?;
let list_elements: Vec<_> = in_list
.list()
.iter()
Expand All @@ -264,17 +291,35 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
}

if let Some(scalar_fn) = df.downcast_ref::<ScalarFunctionExpr>() {
return self.try_convert_scalar_function(scalar_fn);
return self.try_convert_scalar_function(scalar_fn, input_schema);
}

if let Some(case_expr) = df.downcast_ref::<df_expr::CaseExpr>() {
return self.try_convert_case_expr(case_expr);
return self.try_convert_case_expr_with_schema(case_expr, input_schema);
}

Err(exec_datafusion_err!(
"Couldn't convert DataFusion physical {df} expression to a vortex expression"
))
}
}

impl ExpressionConvertor for DefaultExpressionConvertor {
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
can_be_pushed_down_impl(expr, schema)
}

fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
self.convert_impl(df, None)
}

fn convert_with_schema(
&self,
df: &dyn PhysicalExpr,
input_schema: &Schema,
) -> DFResult<Expression> {
self.convert_impl(df, Some(input_schema))
}

fn split_projection(
&self,
Expand Down Expand Up @@ -325,7 +370,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
if matches!(r, TreeNodeRecursion::Continue) {
scan_projection.push((
projection_expr.alias.clone(),
self.convert(projection_expr.expr.as_ref())?,
self.convert_with_schema(projection_expr.expr.as_ref(), input_schema)?,
));
leftover_projection.push(ProjectionExpr {
expr: Arc::new(df_expr::Column::new_with_schema(
Expand Down Expand Up @@ -558,6 +603,7 @@ mod tests {
use datafusion_physical_plan::expressions as df_expr;
use insta::assert_snapshot;
use rstest::rstest;
use vortex::scalar_fn::fns::cast::Cast;

use super::*;
use crate::common_tests::TestSessionContext;
Expand Down Expand Up @@ -585,24 +631,30 @@ mod tests {
#[test]
fn test_make_vortex_predicate_empty() {
let expr_convertor = DefaultExpressionConvertor::default();
let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
let schema = Schema::empty();
let result = make_vortex_predicate(&expr_convertor, &[], &schema).unwrap();
assert!(result.is_none());
}

#[test]
fn test_make_vortex_predicate_single() {
let expr_convertor = DefaultExpressionConvertor::default();
let schema = Schema::new(vec![Field::new("test", DataType::Boolean, false)]);
let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
let result = make_vortex_predicate(&expr_convertor, &[col_expr], &schema).unwrap();
assert!(result.is_some());
}

#[test]
fn test_make_vortex_predicate_multiple() {
let expr_convertor = DefaultExpressionConvertor::default();
let schema = Schema::new(vec![
Field::new("col1", DataType::Boolean, false),
Field::new("col2", DataType::Boolean, false),
]);
let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
let result = make_vortex_predicate(&expr_convertor, &[col1, col2], &schema).unwrap();
assert!(result.is_some());
// Result should be an AND expression combining the two columns
}
Expand Down Expand Up @@ -667,6 +719,20 @@ mod tests {
assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
}

#[test]
fn test_expr_from_df_cast_uses_physical_nullable_result() {
let input_schema = Schema::new(vec![Field::new("text", DataType::Utf8View, true)]);
let child = Arc::new(df_expr::Column::new("text", 0)) as Arc<dyn PhysicalExpr>;
let target = Arc::new(Field::new("text", DataType::Utf8View, false));
let cast_expr = df_expr::CastExpr::new_with_target_field(child, target, None);

let result = DefaultExpressionConvertor::default()
.convert_with_schema(&cast_expr, &input_schema)
.unwrap();

assert_eq!(result.as_::<Cast>(), &DType::Utf8(Nullability::Nullable));
}

#[test]
fn test_expr_from_df_binary() {
let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
Expand Down
3 changes: 2 additions & 1 deletion vortex-datafusion/src/persistent/opener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ impl FileOpener for VortexOpener {
)));
}

make_vortex_predicate(expr_convertor.as_ref(), &pushed).transpose()
make_vortex_predicate(expr_convertor.as_ref(), &pushed, &this_file_schema)
.transpose()
})
.transpose()?;

Expand Down
8 changes: 8 additions & 0 deletions vortex-datafusion/src/persistent/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,14 @@ mod tests {
self.inner.convert(expr)
}

fn convert_with_schema(
&self,
expr: &dyn PhysicalExpr,
input_schema: &Schema,
) -> DFResult<vortex::expr::Expression> {
self.inner.convert_with_schema(expr, input_schema)
}

fn split_projection(
&self,
source_projection: ProjectionExprs,
Expand Down
2 changes: 1 addition & 1 deletion vortex-datafusion/src/v2/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ impl DataSource for VortexDataSource {
.collect();

// Convert to Vortex conjunction.
let vortex_pred = make_vortex_predicate(&convertor, &pushable)?;
let vortex_pred = make_vortex_predicate(&convertor, &pushable, input_schema)?;

// Combine with existing filter.
let new_filter = match (&self.filter, vortex_pred) {
Expand Down
Loading