Saturday, September 14, 2013

Modular Exponentiation

Objective: To calculate (a^n)%MOD in O(logn) time, where a,n,MOD are constants.

Solution:
a^n = (a^(n/2) )^2 if n is even
           (a^(n-1))*a is n is odd
 Base cases:
a^1 = a
a^0 = 1

So we can solve this problem by recursion.

Code:

int mod_exp(int a, int n, int MOD) {
    if(n == 0)
        return 1%MOD;
    else if(n == 1)
        return a%MOD;
    else if(n&1)
       return (a*mod_exp(a,n-1,MOD))%MOD;
   else
   {
       int ret = mod_exp(a,n/2,MOD);
       return (ret*ret)%MOD;
   }
}
Note: To take care of overflow issues, it is recommended to use long long instead of int.

Iterative Version:

LL modexp(LL a, int n, LL MOD) {
    if(a == 0) {
        if(n == 0) return 1;
        else return 0;
    }
    LL ret = 1;
    while(n > 0) {
        if(n&1) {
            ret *= a;
            if(ret >= MOD) ret %= MOD;
        }
        n >>= 1;
        a *= a;
        if(a >= MOD) a %= MOD;
    }
    return ret;
}

2 comments: