diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index bb3247ca3c3c..770aca1f7cd9 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -221,12 +221,34 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), ) => true, (Struct(from_fields), Struct(to_fields)) => { - from_fields.len() == to_fields.len() - && from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + if from_fields.len() != to_fields.len() { + return false; + } + + // fast path, all field names are in the same order and same number of fields + if from_fields + .iter() + .zip(to_fields.iter()) + .all(|(f1, f2)| f1.name() == f2.name()) + { + return from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { // Assume that nullability between two structs are compatible, if not, // cast kernel will return error. can_cast_types(f1.data_type(), f2.data_type()) - }) + }); + } + + // slow path, we match the fields by name + to_fields.iter().all(|to_field| { + from_fields + .iter() + .find(|from_field| from_field.name() == to_field.name()) + .is_some_and(|from_field| { + // Assume that nullability between two structs are compatible, if not, + // cast kernel will return error. + can_cast_types(from_field.data_type(), to_field.data_type()) + }) + }) } (Struct(_), _) => false, (_, Struct(_)) => false, @@ -1169,14 +1191,46 @@ pub fn cast_with_options( cast_options, ) } - (Struct(_), Struct(to_fields)) => { + (Struct(from_fields), Struct(to_fields)) => { let array = array.as_struct(); - let fields = array - .columns() - .iter() - .zip(to_fields.iter()) - .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) - .collect::, ArrowError>>()?; + + // Fast path: if field names are in the same order, we can just zip and cast + let fields_match_order = from_fields.len() == to_fields.len() + && from_fields + .iter() + .zip(to_fields.iter()) + .all(|(f1, f2)| f1.name() == f2.name()); + + let fields = if fields_match_order { + // Fast path: cast columns in order + array + .columns() + .iter() + .zip(to_fields.iter()) + .map(|(column, field)| { + cast_with_options(column, field.data_type(), cast_options) + }) + .collect::, ArrowError>>()? + } else { + // Slow path: match fields by name and reorder + to_fields + .iter() + .map(|to_field| { + let from_field_idx = from_fields + .iter() + .position(|from_field| from_field.name() == to_field.name()) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Field '{}' not found in source struct", + to_field.name() + )) + })?; + let column = array.column(from_field_idx); + cast_with_options(column, to_field.data_type(), cast_options) + }) + .collect::, ArrowError>>()? + }; + let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; Ok(Arc::new(array) as ArrayRef) } @@ -10836,11 +10890,11 @@ mod tests { let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); let struct_array = StructArray::from(vec![ ( - Arc::new(Field::new("b", DataType::Boolean, false)), + Arc::new(Field::new("a", DataType::Boolean, false)), boolean.clone() as ArrayRef, ), ( - Arc::new(Field::new("c", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), int.clone() as ArrayRef, ), ]); @@ -10884,11 +10938,11 @@ mod tests { let int = Arc::new(Int32Array::from(vec![Some(42), None, Some(19), None])); let struct_array = StructArray::from(vec![ ( - Arc::new(Field::new("b", DataType::Boolean, false)), + Arc::new(Field::new("a", DataType::Boolean, false)), boolean.clone() as ArrayRef, ), ( - Arc::new(Field::new("c", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), int.clone() as ArrayRef, ), ]); @@ -10918,11 +10972,11 @@ mod tests { let int = Arc::new(Int32Array::from(vec![i32::MAX, 25, 1, 100])); let struct_array = StructArray::from(vec![ ( - Arc::new(Field::new("b", DataType::Boolean, false)), + Arc::new(Field::new("a", DataType::Boolean, false)), boolean.clone() as ArrayRef, ), ( - Arc::new(Field::new("c", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), int.clone() as ArrayRef, ), ]); @@ -10977,6 +11031,165 @@ mod tests { ); } + #[test] + fn test_cast_struct_with_different_field_order() { + // Test slow path: fields are in different order + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + string.clone() as ArrayRef, + ), + ]); + + // Target has fields in different order: c, a, b instead of a, b, c + let to_type = DataType::Struct( + vec![ + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), // Boolean to Utf8 + Field::new("b", DataType::Utf8, false), // Int32 to Utf8 + ] + .into(), + ); + + let result = cast(&struct_array, &to_type).unwrap(); + let result_struct = result.as_struct(); + + assert_eq!(result_struct.data_type(), &to_type); + assert_eq!(result_struct.num_columns(), 3); + + // Verify field "c" (originally position 2, now position 0) remains Utf8 + let c_column = result_struct.column(0).as_string::(); + assert_eq!( + c_column.into_iter().flatten().collect::>(), + vec!["foo", "bar", "baz", "qux"] + ); + + // Verify field "a" (originally position 0, now position 1) was cast from Boolean to Utf8 + let a_column = result_struct.column(1).as_string::(); + assert_eq!( + a_column.into_iter().flatten().collect::>(), + vec!["false", "false", "true", "true"] + ); + + // Verify field "b" (originally position 1, now position 2) was cast from Int32 to Utf8 + let b_column = result_struct.column(2).as_string::(); + assert_eq!( + b_column.into_iter().flatten().collect::>(), + vec!["42", "28", "19", "31"] + ); + } + + #[test] + fn test_cast_struct_with_missing_field() { + // Test that casting fails when target has a field not present in source + let boolean = Arc::new(BooleanArray::from(vec![false, true])); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + )]); + + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), // Field "b" doesn't exist in source + ] + .into(), + ); + + let result = cast(&struct_array, &to_type); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Cast error: Field 'b' not found in source struct" + ); + } + + #[test] + fn test_cast_struct_with_subset_of_fields() { + // Test casting to a struct with fewer fields (selecting a subset) + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + string.clone() as ArrayRef, + ), + ]); + + // Target has only fields "c" and "a", omitting "b" + let to_type = DataType::Struct( + vec![ + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), + ] + .into(), + ); + + let result = cast(&struct_array, &to_type).unwrap(); + let result_struct = result.as_struct(); + + assert_eq!(result_struct.data_type(), &to_type); + assert_eq!(result_struct.num_columns(), 2); + + // Verify field "c" remains Utf8 + let c_column = result_struct.column(0).as_string::(); + assert_eq!( + c_column.into_iter().flatten().collect::>(), + vec!["foo", "bar", "baz", "qux"] + ); + + // Verify field "a" was cast from Boolean to Utf8 + let a_column = result_struct.column(1).as_string::(); + assert_eq!( + a_column.into_iter().flatten().collect::>(), + vec!["false", "false", "true", "true"] + ); + } + + #[test] + fn test_can_cast_struct_with_missing_field() { + // Test that can_cast_types returns false when target has a field not in source + let from_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int64, false), + Field::new("c", DataType::Boolean, false), // Field "c" not in source + ] + .into(), + ); + + assert!(!can_cast_types(&from_type, &to_type)); + } + fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) { run_decimal_cast_test_case::(t.clone()); run_decimal_cast_test_case::(t.clone());