10

I'm trying to implement a fast primality test for Rust's u32 and u64 datatypes. As part of it, I need to compute (n*n)%d where n and d are u32 (or u64, respectively).

While the result can easily fit in the datatype, I'm at a loss for how to compute this. As far as I know there is no processor primitive for this.

For u32 we can fake it -- cast up to u64, so that the product won't overflow, then take the modulus, then cast back down to u32, knowing this won't overflow. However since I don't have a u128 datatype (as far as I know) this trick won't work for u64.

So for u64, the most obvious way I can think of to accomplish this is to somehow compute x*y to get a pair (carry, product) of u64, so we capture the amount of overflow instead of just losing it (or panicking, or whatever).

Is there a way to do this? Or another standard way to solve the problem?

7

4 Answers 4

10

Richard Rast pointed out that Wikipedia version works only with 63-bit integers. I extended the code provided by Boiethios to work with full range of 64-bit unsigned integers.

fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
    let msb = 0x8000_0000_0000_0000;
    let mut d = 0;
    let mp2 = m >> 1;
    x %= m;
    y %= m;

    if m & msb == 0 {
        for _ in 0..64 {
            d = if d > mp2 {
                (d << 1) - m
            } else {
                d << 1
            };
            if x & msb != 0 {
                d += y;
            }
            if d >= m {
                d -= m;
            }
            x <<= 1;
        }
        d
    } else {
        for _ in 0..64 {
            d = if d > mp2 {
                d.wrapping_shl(1).wrapping_sub(m)
            } else {
                // the case d == m && x == 0 is taken care of 
                // after the end of the loop
                d << 1
            };
            if x & msb != 0 {
                let (mut d1, overflow) = d.overflowing_add(y);
                if overflow {
                    d1 = d1.wrapping_sub(m);
                }
                d = if d1 >= m { d1 - m } else { d1 };
            }
            x <<= 1;
        }
        if d >= m { d - m } else { d }
    }
}

#[test]
fn test_mul_mod64() {
    let half = 1 << 16;
    let max = std::u64::MAX;

    assert_eq!(mul_mod64(0, 0, 2), 0);
    assert_eq!(mul_mod64(1, 0, 2), 0);
    assert_eq!(mul_mod64(0, 1, 2), 0);
    assert_eq!(mul_mod64(1, 1, 2), 1);
    assert_eq!(mul_mod64(42, 1, 2), 0);
    assert_eq!(mul_mod64(1, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 2), 0);
    assert_eq!(mul_mod64(42, 42, 42), 0);
    assert_eq!(mul_mod64(42, 42, 41), 1);
    assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);

    assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
    assert_eq!(mul_mod64(half, half, half), 0);
    assert_eq!(mul_mod64(half+1, half+1, half), 1);

    assert_eq!(mul_mod64(max, max, max), 0);
    assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
    assert_eq!(mul_mod64(1239876, max, max), 0);
    assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
    assert_eq!(mul_mod64(max, 2948635, max), 0);
    assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
    assert_eq!(mul_mod64(max-1, max-1, max), 1);
    assert_eq!(mul_mod64(2, max/2, max-1), 0);
}
Sign up to request clarification or add additional context in comments.

4 Comments

@mcarton what is half half of?
@Shepmaster now that I think about it, nothing :)
I regret to say that I have not gone through the work of understanding this code :( but I have tested it thoroughly and it works very well. Thank you!
"generic" version using a macro to implement for all unsigned integer types: play.rust-lang.org/…
3

Here's an alternative approach (there's now a u128 datatype):

fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
    let (a, b, m) = (a as u128, b as u128, m as u128);
    ((a * b) % m) as u64
}

This approach just leans on LLVM's 128 bit integer arithmetic.

The thing I like about this version is that it's really easy to convince yourself that the solution is correct for the entire domain. Since a and b are u64s the product is guaranteed to fit in a u128, and since m is a u64 the downcast at the end is guaranteed to be safe.

I don't know how performance compares to other approaches, but I would be pretty surprised if it were dramatically slower. If you really care about performance you're going to want to run some benchmarks and try a few alternatives in any case.

Comments

2

Use simple mathematics:

(n*n)%d = (n%d)*(n%d)%d

To see that this is indeed true, set n = k*d+r:

n*n%d = k**2*d**2+2*k*d*r+r**2 %d = r**2%d = (n%d)*(n%d)%d

3 Comments

strictly speaking n * n % d = (n % d) * (n % d) % d
This can still overflow if d is large.
@interjay Specifically, it can overflow if d > 2^16
1

red75prime added a useful comment. Here is the Rust code to calculate a modulo of two multiplied numbers, taken from Wikipedia:

fn mul_mod(mut x: u64, mut y: u64, m: u64) -> u64 {
    let mut d = 0_u64;
    let mp2 = m >> 1;
    x %= m;
    y %= m;

    for _ in 0..64 {
        d = if d > mp2 {
            (d << 1) - m
        } else {
            d << 1
        };
        if x & 0x8000_0000_0000_0000_u64 != 0 {
            d += y;
        }
        if d > m {
            d -= m;
        }
        x <<= 1;
    }
    d
}

4 Comments

@red75prime If you want to post your own answer, I delete mine.
No, it's fine. The problem is this algorithm isn't correct. I found one bug if d > m ... should be if d >= m .... Another one causes subtract with overflow in ` (d << 1) - m`. I didn't find why yet.
Also it seems to give incorrect answers. 11552001120680995*15777587326414455 (mod 18442563521290148565) should be 844062957336182220, algorithm gives 13054753449364403936
Note that the wikipedia section indicates that all arguments (x, y and m) must be at most 63 bits (that is, $<2^63$). @red75prime your given arguments (here, the modulus) uses all 64 bits.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.