//Knowledgedump.org - Function for modular exponentiation by squaring, to reduce integer overflows.

#ifndef MODULAR_EXP_H	//Include guard
#define MODULAR_EXP_H

#include <iostream>


//Forward declaration
//Function for checking input arguments.
long long modular_exp(long long base, long long exp, long long mod);
//Recursive function for calculation.
long long mod_exp(long long base, long long exp, long long mod);



//Function only allows exponent bigger or equal 0, since integer types are used.
//Checks input, then calls the actually calculating function.
long long modular_exp(long long base, long long exp, long long mod) {
	if (base == 0) { return 0; }
	if (base == 1) { return 1; }
	if (exp == 0) { return 1; }
	if (exp < 0) {
		std::cout << "Invalid exponent input." << std::endl;
		return 0;
	}
	if ((mod == 0) || (mod == 1) || (mod == -1)) {
		std::cout << "Invalid modulo input." << std::endl;
		return 0;
	}

	return mod_exp(base, exp, mod);
}


//Since function is called recursively, the actual calculation process is done in a separate function,
//to avoid redundant base, exp and mod checks.
long long mod_exp(long long base, long long exp, long long mod) {
	if (exp == 1) { return base % mod; }

	if (exp > 0) {
		if (exp % 2 == 0) {
			long long mid = mod_exp(base, exp / 2, mod);
			return  (mid * mid) % mod;
		}
		else {
			return (base * mod_exp(base, exp - 1, mod)) % mod;
		}
	}
}


#endif		//Include guard