py: Implement bit-shift and not operations for mpz.

Implement not, shl and shr in mpz library.  Add function to create mpzs
on the stack, used for memory efficiency when rhs is a small int.
Factor out code to parse base-prefix of number into a dedicated function.
This commit is contained in:
Damien George
2014-03-01 19:50:50 +00:00
parent 793838a919
commit 06201ff3d6
9 changed files with 273 additions and 135 deletions

203
py/mpz.c
View File

@@ -10,19 +10,27 @@
#if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
#define DIG_SIZE (15)
#define DIG_SIZE (MPZ_DIG_SIZE)
#define DIG_MASK ((1 << DIG_SIZE) - 1)
/*
definition of normalise:
?
mpz is an arbitrary precision integer type with a public API.
mpn functions act on non-negative integers represented by an array of generalised
digits (eg a word per digit). You also need to specify separately the length of the
array. There is no public API for mpn. Rather, the functions are used by mpz to
implement its features.
Integer values are stored little endian (first digit is first in memory).
Definition of normalise: ?
*/
/* compares i with j
returns sign(i - j)
assumes i, j are normalised
*/
int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) {
STATIC int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen) {
if (ilen < jlen) { return -1; }
if (ilen > jlen) { return 1; }
@@ -37,39 +45,46 @@ int mpn_cmp(const mpz_dig_t *idig, uint ilen, const mpz_dig_t *jdig, uint jlen)
/* computes i = j << n
returns number of digits in i
assumes enough memory in i; assumes normalised j
assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory
*/
/* unfinished
uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
uint n_whole = n / DIG_SIZE;
STATIC uint mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
uint n_whole = (n + DIG_SIZE - 1) / DIG_SIZE;
uint n_part = n % DIG_SIZE;
idig += jlen + n_whole + 1;
// start from the high end of the digit arrays
idig += jlen + n_whole - 1;
jdig += jlen - 1;
for (uint i = jlen; i > 0; --i, ++idig, ++jdig) {
mpz_dbl_dig_t d = *jdig;
if (i > 1) {
d |= jdig[1] << DIG_SIZE;
}
d <<= n_part;
*idig = d & DIG_MASK;
// shift the digits
mpz_dbl_dig_t d = 0;
for (uint i = jlen; i > 0; i--, idig--, jdig--) {
d |= *jdig;
*idig = d >> (DIG_SIZE - n_part);
d <<= DIG_SIZE;
}
if (idig[-1] == 0) {
--jlen;
// store remaining bits
*idig = d >> (DIG_SIZE - n_part);
idig -= n_whole - 1;
memset(idig, 0, n_whole - 1);
// work out length of result
jlen += n_whole;
if (idig[jlen - 1] == 0) {
jlen--;
}
// return length of result
return jlen;
}
*/
/* computes i = j >> n
returns number of digits in i
assumes enough memory in i; assumes normalised j
assumes enough memory in i; assumes normalised j; assumes n > 0
can have i, j pointing to same memory
*/
uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
STATIC uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
uint n_whole = n / DIG_SIZE;
uint n_part = n % DIG_SIZE;
@@ -80,7 +95,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
jdig += n_whole;
jlen -= n_whole;
for (uint i = jlen; i > 0; --i, ++idig, ++jdig) {
for (uint i = jlen; i > 0; i--, idig++, jdig++) {
mpz_dbl_dig_t d = *jdig;
if (i > 1) {
d |= jdig[1] << DIG_SIZE;
@@ -90,7 +105,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
}
if (idig[-1] == 0) {
--jlen;
jlen--;
}
return jlen;
@@ -101,7 +116,7 @@ uint mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, uint n) {
assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
can have i, j, k pointing to same memory
*/
uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
STATIC uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = 0;
@@ -131,7 +146,7 @@ uint mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t
assumes enough memory in i; assumes normalised j, k; assumes j >= k
can have i, j, k pointing to same memory
*/
uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
STATIC uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_signed_t borrow = 0;
@@ -159,7 +174,7 @@ uint mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, uint jlen, const mpz_dig_t
returns number of digits in i
assumes enough memory in i; assumes normalised i; assumes dmul != 0
*/
uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
STATIC uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
mpz_dig_t *oidig = idig;
mpz_dbl_dig_t carry = dadd;
@@ -181,7 +196,7 @@ uint mpn_mul_dig_add_dig(mpz_dig_t *idig, uint ilen, mpz_dig_t dmul, mpz_dig_t d
assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
can have j, k point to same memory
*/
uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) {
STATIC uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint klen) {
mpz_dig_t *oidig = idig;
uint ilen = 0;
@@ -214,7 +229,7 @@ uint mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, uint jlen, mpz_dig_t *kdig, uint
modifies den_dig memory, but restors it to original state at end
*/
void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) {
STATIC void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, machine_uint_t den_len, mpz_dig_t *quo_dig, machine_uint_t *quo_len) {
mpz_dig_t *orig_num_dig = num_dig;
mpz_dig_t *orig_quo_dig = quo_dig;
mpz_dig_t norm_shift = 0;
@@ -343,9 +358,7 @@ void mpn_div(mpz_dig_t *num_dig, machine_uint_t *num_len, mpz_dig_t *den_dig, ma
}
}
#define MIN_ALLOC (4)
#define ALIGN_ALLOC (2)
#define NUM_DIG_FOR_INT (sizeof(machine_int_t) * 8 / DIG_SIZE + 1)
#define MIN_ALLOC (2)
static const uint log_base2_floor[] = {
0,
@@ -359,13 +372,10 @@ static const uint log_base2_floor[] = {
4, 4, 4, 5
};
bool mpz_int_is_sml_int(machine_int_t i) {
return -(1 << DIG_SIZE) < i && i < (1 << DIG_SIZE);
}
void mpz_init_zero(mpz_t *z) {
z->alloc = 0;
z->neg = 0;
z->fixed_dig = 0;
z->alloc = 0;
z->len = 0;
z->dig = NULL;
}
@@ -375,8 +385,17 @@ void mpz_init_from_int(mpz_t *z, machine_int_t val) {
mpz_set_from_int(z, val);
}
void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, uint alloc, machine_int_t val) {
z->neg = 0;
z->fixed_dig = 1;
z->alloc = alloc;
z->len = 0;
z->dig = dig;
mpz_set_from_int(z, val);
}
void mpz_deinit(mpz_t *z) {
if (z != NULL) {
if (z != NULL && !z->fixed_dig) {
m_del(mpz_dig_t, z->dig, z->alloc);
}
}
@@ -407,23 +426,26 @@ void mpz_free(mpz_t *z) {
}
STATIC void mpz_need_dig(mpz_t *z, uint need) {
uint alloc;
if (need < MIN_ALLOC) {
alloc = MIN_ALLOC;
} else {
alloc = (need + ALIGN_ALLOC) & (~(ALIGN_ALLOC - 1));
need = MIN_ALLOC;
}
if (z->dig == NULL || z->alloc < alloc) {
z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, alloc);
z->alloc = alloc;
if (z->dig == NULL || z->alloc < need) {
if (z->fixed_dig) {
// cannot reallocate fixed buffers
assert(0);
return;
}
z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, need);
z->alloc = need;
}
}
mpz_t *mpz_clone(const mpz_t *src) {
mpz_t *z = m_new_obj(mpz_t);
z->alloc = src->alloc;
z->neg = src->neg;
z->fixed_dig = 0;
z->alloc = src->alloc;
z->len = src->len;
if (src->dig == NULL) {
z->dig = NULL;
@@ -434,6 +456,9 @@ mpz_t *mpz_clone(const mpz_t *src) {
return z;
}
/* sets dest = src
can have dest, src the same
*/
void mpz_set(mpz_t *dest, const mpz_t *src) {
mpz_need_dig(dest, src->len);
dest->neg = src->neg;
@@ -442,7 +467,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) {
}
void mpz_set_from_int(mpz_t *z, machine_int_t val) {
mpz_need_dig(z, NUM_DIG_FOR_INT);
mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT);
if (val < 0) {
z->neg = 1;
@@ -527,6 +552,9 @@ int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
return cmp;
}
#if 0
// obsolete
// compares mpz with an integer that fits within DIG_SIZE bits
int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) {
int cmp;
if (z->neg == 0) {
@@ -554,6 +582,7 @@ int mpz_cmp_sml_int(const mpz_t *z, machine_int_t sml_int) {
if (cmp > 0) return 1;
return 0;
}
#endif
#if 0
these functions are unused
@@ -631,50 +660,71 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
dest->neg = 1 - dest->neg;
}
#if 0
not finished
/* computes dest = ~z (= -z - 1)
can have dest, z the same
*/
void mpz_not_inpl(mpz_t *dest, const mpz_t *z) {
if (dest != z) {
mpz_set(dest, z);
}
if (dest->neg) {
dest->neg = 0;
mpz_dig_t k = 1;
dest->len = mpn_sub(dest->dig, dest->dig, dest->len, &k, 1);
} else {
mpz_dig_t k = 1;
dest->len = mpn_add(dest->dig, dest->dig, dest->len, &k, 1);
dest->neg = 1;
}
}
/* computes dest = lhs << rhs
can have dest, lhs the same
*/
void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) {
if (dest != lhs) {
if (lhs->len == 0 || rhs == 0) {
mpz_set(dest, lhs);
}
if (dest.len == 0 || rhs == 0) {
return dest;
}
if (rhs < 0) {
dest->len = mpn_shr(dest->len, dest->dig, -rhs);
} else if (rhs < 0) {
mpz_shr_inpl(dest, lhs, -rhs);
} else {
dest->len = mpn_shl(dest->len, dest->dig, rhs);
mpz_need_dig(dest, lhs->len + (rhs + DIG_SIZE - 1) / DIG_SIZE);
dest->len = mpn_shl(dest->dig, lhs->dig, lhs->len, rhs);
dest->neg = lhs->neg;
}
return dest;
}
/* computes dest = lhs >> rhs
can have dest, lhs the same
*/
void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, machine_int_t rhs) {
if (dest != lhs) {
if (lhs->len == 0 || rhs == 0) {
mpz_set(dest, lhs);
}
if (dest.len == 0 || rhs == 0) {
return dest;
}
if (rhs < 0) {
dest->len = mpn_shl(dest->len, dest->dig, -rhs);
} else if (rhs < 0) {
mpz_shl_inpl(dest, lhs, -rhs);
} else {
dest->len = mpn_shr(dest->len, dest->dig, rhs);
mpz_need_dig(dest, lhs->len);
dest->len = mpn_shr(dest->dig, lhs->dig, lhs->len, rhs);
dest->neg = lhs->neg;
if (dest->neg) {
// arithmetic shift right, rounding to negative infinity
uint n_whole = rhs / DIG_SIZE;
uint n_part = rhs % DIG_SIZE;
mpz_dig_t round_up = 0;
for (uint i = 0; i < lhs->len && i < n_whole; i++) {
if (lhs->dig[i] != 0) {
round_up = 1;
break;
}
}
if (n_whole < lhs->len && (lhs->dig[n_whole] & ((1 << n_part) - 1)) != 0) {
round_up = 1;
}
if (round_up) {
dest->len = mpn_add(dest->dig, dest->dig, dest->len, &round_up, 1);
}
}
}
return dest;
}
#endif
/* computes dest = lhs + rhs
can have dest, lhs, rhs the same
@@ -931,12 +981,11 @@ machine_int_t mpz_as_int(const mpz_t *i) {
machine_int_t val = 0;
mpz_dig_t *d = i->dig + i->len;
while (--d >= i->dig)
{
while (--d >= i->dig) {
machine_int_t oldval = val;
val = (val << DIG_SIZE) | *d;
if (val < oldval)
{
if (val < oldval) {
// TODO need better handling of conversion overflow
if (i->neg == 0) {
return 0x7fffffff;
} else {