diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index c47731ed5125..5e60791d1564 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -738,6 +738,8 @@ fn decimal_op( // Follow the Hive decimal arithmetic rules // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + // And the Calcite rules + // https://github.com/apache/arrow/blob/36ddbb531cac9b9e512dfa3776d1d64db588209f/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java#L46 let array: PrimitiveArray = match op { Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => { // max(s1, s2) @@ -794,9 +796,8 @@ fn decimal_op( } Op::Div => { - // Follow postgres and MySQL adding a fixed scale increment of 4 - // s1 + 4 - let result_scale = s1.saturating_add(4).min(T::MAX_SCALE); + // max(6, s1 + p2 + 1) + let result_scale = 6.max(s1 + *p2 as i8 + 1).min(T::MAX_SCALE); let mul_pow = result_scale - s1 + s2; // p1 - s1 + s2 + result_scale @@ -1126,10 +1127,17 @@ mod tests { ); let result = div(&a, &b).unwrap(); - assert_eq!(result.data_type(), &DataType::Decimal128(17, 7)); + assert_eq!(result.data_type(), &DataType::Decimal128(26, 16)); assert_eq!( result.as_primitive::().values(), - &[27777, 0, 162078, 11133333, -1300000, 402] + &[ + 27777777777777, + 0, + 162078651685393, + 11133333333333333, + -1300000000000000, + 402684563758 + ] ); let result = rem(&a, &b).unwrap();