Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

integer.cpp

00001 // integer.cpp - written and placed in the public domain by Wei Dai 00002 // contains public domain code contributed by Alister Lee and Leonard Janke 00003 00004 #include "pch.h" 00005 #include "integer.h" 00006 #include "modarith.h" 00007 #include "nbtheory.h" 00008 #include "asn.h" 00009 #include "oids.h" 00010 #include "words.h" 00011 #include "algparam.h" 00012 #include "pubkey.h" // for P1363_KDF2 00013 #include "sha.h" 00014 00015 #include <iostream> 00016 00017 #ifdef SSE2_INTRINSICS_AVAILABLE 00018 #include <emmintrin.h> 00019 #endif 00020 00021 #include "algebra.cpp" 00022 #include "eprecomp.cpp" 00023 00024 NAMESPACE_BEGIN(CryptoPP) 00025 00026 bool FunctionAssignIntToInteger(const std::type_info &valueType, void *pInteger, const void *pInt) 00027 { 00028 if (valueType != typeid(Integer)) 00029 return false; 00030 *reinterpret_cast<Integer *>(pInteger) = *reinterpret_cast<const int *>(pInt); 00031 return true; 00032 } 00033 00034 static int DummyAssignIntToInteger = (AssignIntToInteger = FunctionAssignIntToInteger, 0); 00035 00036 #ifdef SSE2_INTRINSICS_AVAILABLE 00037 template <class T> 00038 AllocatorBase<T>::pointer AlignedAllocator<T>::allocate(size_type n, const void *) 00039 { 00040 if (n < 4) 00041 return new T[n]; 00042 else 00043 return (T *)_mm_malloc(sizeof(T)*n, 16); 00044 00045 } 00046 00047 template <class T> 00048 void AlignedAllocator<T>::deallocate(void *p, size_type n) 00049 { 00050 memset(p, 0, n*sizeof(T)); 00051 if (n < 4) 00052 delete [] p; 00053 else 00054 _mm_free(p); 00055 } 00056 00057 template class AlignedAllocator<word>; 00058 #endif 00059 00060 #define MAKE_DWORD(lowWord, highWord) ((dword(highWord)<<WORD_BITS) | (lowWord)) 00061 00062 static int Compare(const word *A, const word *B, unsigned int N) 00063 { 00064 while (N--) 00065 if (A[N] > B[N]) 00066 return 1; 00067 else if (A[N] < B[N]) 00068 return -1; 00069 00070 return 0; 00071 } 00072 00073 static word Increment(word *A, unsigned int N, word B=1) 00074 { 00075 assert(N); 00076 word t = A[0]; 00077 A[0] = t+B; 00078 if (A[0] >= t) 00079 return 0; 00080 for (unsigned i=1; i<N; i++) 00081 if (++A[i]) 00082 return 0; 00083 return 1; 00084 } 00085 00086 static word Decrement(word *A, unsigned int N, word B=1) 00087 { 00088 assert(N); 00089 word t = A[0]; 00090 A[0] = t-B; 00091 if (A[0] <= t) 00092 return 0; 00093 for (unsigned i=1; i<N; i++) 00094 if (A[i]--) 00095 return 0; 00096 return 1; 00097 } 00098 00099 static void TwosComplement(word *A, unsigned int N) 00100 { 00101 Decrement(A, N); 00102 for (unsigned i=0; i<N; i++) 00103 A[i] = ~A[i]; 00104 } 00105 00106 static word LinearMultiply(word *C, const word *A, word B, unsigned int N) 00107 { 00108 word carry=0; 00109 for(unsigned i=0; i<N; i++) 00110 { 00111 dword p = (dword)A[i] * B + carry; 00112 C[i] = LOW_WORD(p); 00113 carry = HIGH_WORD(p); 00114 } 00115 return carry; 00116 } 00117 00118 static void AtomicInverseModPower2(word *C, word A0, word A1) 00119 { 00120 assert(A0%2==1); 00121 00122 dword A=MAKE_DWORD(A0, A1), R=A0%8; 00123 00124 for (unsigned i=3; i<2*WORD_BITS; i*=2) 00125 R = R*(2-R*A); 00126 00127 assert(R*A==1); 00128 00129 C[0] = LOW_WORD(R); 00130 C[1] = HIGH_WORD(R); 00131 } 00132 00133 // ******************************************************** 00134 00135 class Portable 00136 { 00137 public: 00138 static word Add(word *C, const word *A, const word *B, unsigned int N); 00139 static word Subtract(word *C, const word *A, const word *B, unsigned int N); 00140 00141 static inline void Multiply2(word *C, const word *A, const word *B); 00142 static inline word Multiply2Add(word *C, const word *A, const word *B); 00143 static void Multiply4(word *C, const word *A, const word *B); 00144 static void Multiply8(word *C, const word *A, const word *B); 00145 static inline unsigned int MultiplyRecursionLimit() {return 8;} 00146 00147 static inline void Multiply2Bottom(word *C, const word *A, const word *B); 00148 static void Multiply4Bottom(word *C, const word *A, const word *B); 00149 static void Multiply8Bottom(word *C, const word *A, const word *B); 00150 static inline unsigned int MultiplyBottomRecursionLimit() {return 8;} 00151 00152 static void Square2(word *R, const word *A); 00153 static void Square4(word *R, const word *A); 00154 static void Square8(word *R, const word *A) {assert(false);} 00155 static inline unsigned int SquareRecursionLimit() {return 4;} 00156 }; 00157 00158 word Portable::Add(word *C, const word *A, const word *B, unsigned int N) 00159 { 00160 assert (N%2 == 0); 00161 00162 #ifdef IS_LITTLE_ENDIAN 00163 if (sizeof(dword) == sizeof(size_t)) // dword is only register size 00164 { 00165 dword carry = 0; 00166 N >>= 1; 00167 for (unsigned int i = 0; i < N; i++) 00168 { 00169 dword a = ((const dword *)A)[i] + carry; 00170 dword c = a + ((const dword *)B)[i]; 00171 ((dword *)C)[i] = c; 00172 carry = (a < carry) | (c < a); 00173 } 00174 return (word)carry; 00175 } 00176 else 00177 #endif 00178 { 00179 word carry = 0; 00180 for (unsigned int i = 0; i < N; i+=2) 00181 { 00182 dword u = (dword) carry + A[i] + B[i]; 00183 C[i] = LOW_WORD(u); 00184 u = (dword) HIGH_WORD(u) + A[i+1] + B[i+1]; 00185 C[i+1] = LOW_WORD(u); 00186 carry = HIGH_WORD(u); 00187 } 00188 return carry; 00189 } 00190 } 00191 00192 word Portable::Subtract(word *C, const word *A, const word *B, unsigned int N) 00193 { 00194 assert (N%2 == 0); 00195 00196 #ifdef IS_LITTLE_ENDIAN 00197 if (sizeof(dword) == sizeof(size_t)) // dword is only register size 00198 { 00199 dword borrow = 0; 00200 N >>= 1; 00201 for (unsigned int i = 0; i < N; i++) 00202 { 00203 dword a = ((const dword *)A)[i]; 00204 dword b = a - borrow; 00205 dword c = b - ((const dword *)B)[i]; 00206 ((dword *)C)[i] = c; 00207 borrow = (b > a) | (c > b); 00208 } 00209 return (word)borrow; 00210 } 00211 else 00212 #endif 00213 { 00214 word borrow=0; 00215 for (unsigned i = 0; i < N; i+=2) 00216 { 00217 dword u = (dword) A[i] - B[i] - borrow; 00218 C[i] = LOW_WORD(u); 00219 u = (dword) A[i+1] - B[i+1] - (word)(0-HIGH_WORD(u)); 00220 C[i+1] = LOW_WORD(u); 00221 borrow = 0-HIGH_WORD(u); 00222 } 00223 return borrow; 00224 } 00225 } 00226 00227 void Portable::Multiply2(word *C, const word *A, const word *B) 00228 { 00229 /* 00230 word s; 00231 dword d; 00232 00233 if (A1 >= A0) 00234 if (B0 >= B1) 00235 { 00236 s = 0; 00237 d = (dword)(A1-A0)*(B0-B1); 00238 } 00239 else 00240 { 00241 s = (A1-A0); 00242 d = (dword)s*(word)(B0-B1); 00243 } 00244 else 00245 if (B0 > B1) 00246 { 00247 s = (B0-B1); 00248 d = (word)(A1-A0)*(dword)s; 00249 } 00250 else 00251 { 00252 s = 0; 00253 d = (dword)(A0-A1)*(B1-B0); 00254 } 00255 */ 00256 // this segment is the branchless equivalent of above 00257 word D[4] = {A[1]-A[0], A[0]-A[1], B[0]-B[1], B[1]-B[0]}; 00258 unsigned int ai = A[1] < A[0]; 00259 unsigned int bi = B[0] < B[1]; 00260 unsigned int di = ai & bi; 00261 dword d = (dword)D[di]*D[di+2]; 00262 D[1] = D[3] = 0; 00263 unsigned int si = ai + !bi; 00264 word s = D[si]; 00265 00266 dword A0B0 = (dword)A[0]*B[0]; 00267 C[0] = LOW_WORD(A0B0); 00268 00269 dword A1B1 = (dword)A[1]*B[1]; 00270 dword t = (dword) HIGH_WORD(A0B0) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1); 00271 C[1] = LOW_WORD(t); 00272 00273 t = A1B1 + HIGH_WORD(t) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s; 00274 C[2] = LOW_WORD(t); 00275 C[3] = HIGH_WORD(t); 00276 } 00277 00278 inline void Portable::Multiply2Bottom(word *C, const word *A, const word *B) 00279 { 00280 #ifdef IS_LITTLE_ENDIAN 00281 if (sizeof(dword) == sizeof(size_t)) 00282 { 00283 dword a = *(const dword *)A, b = *(const dword *)B; 00284 ((dword *)C)[0] = a*b; 00285 } 00286 else 00287 #endif 00288 { 00289 dword t = (dword)A[0]*B[0]; 00290 C[0] = LOW_WORD(t); 00291 C[1] = HIGH_WORD(t) + A[0]*B[1] + A[1]*B[0]; 00292 } 00293 } 00294 00295 word Portable::Multiply2Add(word *C, const word *A, const word *B) 00296 { 00297 word D[4] = {A[1]-A[0], A[0]-A[1], B[0]-B[1], B[1]-B[0]}; 00298 unsigned int ai = A[1] < A[0]; 00299 unsigned int bi = B[0] < B[1]; 00300 unsigned int di = ai & bi; 00301 dword d = (dword)D[di]*D[di+2]; 00302 D[1] = D[3] = 0; 00303 unsigned int si = ai + !bi; 00304 word s = D[si]; 00305 00306 dword A0B0 = (dword)A[0]*B[0]; 00307 dword t = A0B0 + C[0]; 00308 C[0] = LOW_WORD(t); 00309 00310 dword A1B1 = (dword)A[1]*B[1]; 00311 t = (dword) HIGH_WORD(t) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1) + C[1]; 00312 C[1] = LOW_WORD(t); 00313 00314 t = (dword) HIGH_WORD(t) + LOW_WORD(A1B1) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s + C[2]; 00315 C[2] = LOW_WORD(t); 00316 00317 t = (dword) HIGH_WORD(t) + HIGH_WORD(A1B1) + C[3]; 00318 C[3] = LOW_WORD(t); 00319 return HIGH_WORD(t); 00320 } 00321 00322 #define MulAcc(x, y) \ 00323 p = (dword)A[x] * B[y] + c; \ 00324 c = LOW_WORD(p); \ 00325 p = (dword)d + HIGH_WORD(p); \ 00326 d = LOW_WORD(p); \ 00327 e += HIGH_WORD(p); 00328 00329 #define SaveMulAcc(s, x, y) \ 00330 R[s] = c; \ 00331 p = (dword)A[x] * B[y] + d; \ 00332 c = LOW_WORD(p); \ 00333 p = (dword)e + HIGH_WORD(p); \ 00334 d = LOW_WORD(p); \ 00335 e = HIGH_WORD(p); 00336 00337 #define SquAcc(x, y) \ 00338 q = (dword)A[x] * A[y]; \ 00339 p = q + c; \ 00340 c = LOW_WORD(p); \ 00341 p = (dword)d + HIGH_WORD(p); \ 00342 d = LOW_WORD(p); \ 00343 e += HIGH_WORD(p); \ 00344 p = q + c; \ 00345 c = LOW_WORD(p); \ 00346 p = (dword)d + HIGH_WORD(p); \ 00347 d = LOW_WORD(p); \ 00348 e += HIGH_WORD(p); 00349 00350 #define SaveSquAcc(s, x, y) \ 00351 R[s] = c; \ 00352 q = (dword)A[x] * A[y]; \ 00353 p = q + d; \ 00354 c = LOW_WORD(p); \ 00355 p = (dword)e + HIGH_WORD(p); \ 00356 d = LOW_WORD(p); \ 00357 e = HIGH_WORD(p); \ 00358 p = q + c; \ 00359 c = LOW_WORD(p); \ 00360 p = (dword)d + HIGH_WORD(p); \ 00361 d = LOW_WORD(p); \ 00362 e += HIGH_WORD(p); 00363 00364 void Portable::Multiply4(word *R, const word *A, const word *B) 00365 { 00366 dword p; 00367 word c, d, e; 00368 00369 p = (dword)A[0] * B[0]; 00370 R[0] = LOW_WORD(p); 00371 c = HIGH_WORD(p); 00372 d = e = 0; 00373 00374 MulAcc(0, 1); 00375 MulAcc(1, 0); 00376 00377 SaveMulAcc(1, 2, 0); 00378 MulAcc(1, 1); 00379 MulAcc(0, 2); 00380 00381 SaveMulAcc(2, 0, 3); 00382 MulAcc(1, 2); 00383 MulAcc(2, 1); 00384 MulAcc(3, 0); 00385 00386 SaveMulAcc(3, 3, 1); 00387 MulAcc(2, 2); 00388 MulAcc(1, 3); 00389 00390 SaveMulAcc(4, 2, 3); 00391 MulAcc(3, 2); 00392 00393 R[5] = c; 00394 p = (dword)A[3] * B[3] + d; 00395 R[6] = LOW_WORD(p); 00396 R[7] = e + HIGH_WORD(p); 00397 } 00398 00399 void Portable::Square2(word *R, const word *A) 00400 { 00401 dword p, q; 00402 word c, d, e; 00403 00404 p = (dword)A[0] * A[0]; 00405 R[0] = LOW_WORD(p); 00406 c = HIGH_WORD(p); 00407 d = e = 0; 00408 00409 SquAcc(0, 1); 00410 00411 R[1] = c; 00412 p = (dword)A[1] * A[1] + d; 00413 R[2] = LOW_WORD(p); 00414 R[3] = e + HIGH_WORD(p); 00415 } 00416 00417 void Portable::Square4(word *R, const word *A) 00418 { 00419 const word *B = A; 00420 dword p, q; 00421 word c, d, e; 00422 00423 p = (dword)A[0] * A[0]; 00424 R[0] = LOW_WORD(p); 00425 c = HIGH_WORD(p); 00426 d = e = 0; 00427 00428 SquAcc(0, 1); 00429 00430 SaveSquAcc(1, 2, 0); 00431 MulAcc(1, 1); 00432 00433 SaveSquAcc(2, 0, 3); 00434 SquAcc(1, 2); 00435 00436 SaveSquAcc(3, 3, 1); 00437 MulAcc(2, 2); 00438 00439 SaveSquAcc(4, 2, 3); 00440 00441 R[5] = c; 00442 p = (dword)A[3] * A[3] + d; 00443 R[6] = LOW_WORD(p); 00444 R[7] = e + HIGH_WORD(p); 00445 } 00446 00447 void Portable::Multiply8(word *R, const word *A, const word *B) 00448 { 00449 dword p; 00450 word c, d, e; 00451 00452 p = (dword)A[0] * B[0]; 00453 R[0] = LOW_WORD(p); 00454 c = HIGH_WORD(p); 00455 d = e = 0; 00456 00457 MulAcc(0, 1); 00458 MulAcc(1, 0); 00459 00460 SaveMulAcc(1, 2, 0); 00461 MulAcc(1, 1); 00462 MulAcc(0, 2); 00463 00464 SaveMulAcc(2, 0, 3); 00465 MulAcc(1, 2); 00466 MulAcc(2, 1); 00467 MulAcc(3, 0); 00468 00469 SaveMulAcc(3, 0, 4); 00470 MulAcc(1, 3); 00471 MulAcc(2, 2); 00472 MulAcc(3, 1); 00473 MulAcc(4, 0); 00474 00475 SaveMulAcc(4, 0, 5); 00476 MulAcc(1, 4); 00477 MulAcc(2, 3); 00478 MulAcc(3, 2); 00479 MulAcc(4, 1); 00480 MulAcc(5, 0); 00481 00482 SaveMulAcc(5, 0, 6); 00483 MulAcc(1, 5); 00484 MulAcc(2, 4); 00485 MulAcc(3, 3); 00486 MulAcc(4, 2); 00487 MulAcc(5, 1); 00488 MulAcc(6, 0); 00489 00490 SaveMulAcc(6, 0, 7); 00491 MulAcc(1, 6); 00492 MulAcc(2, 5); 00493 MulAcc(3, 4); 00494 MulAcc(4, 3); 00495 MulAcc(5, 2); 00496 MulAcc(6, 1); 00497 MulAcc(7, 0); 00498 00499 SaveMulAcc(7, 1, 7); 00500 MulAcc(2, 6); 00501 MulAcc(3, 5); 00502 MulAcc(4, 4); 00503 MulAcc(5, 3); 00504 MulAcc(6, 2); 00505 MulAcc(7, 1); 00506 00507 SaveMulAcc(8, 2, 7); 00508 MulAcc(3, 6); 00509 MulAcc(4, 5); 00510 MulAcc(5, 4); 00511 MulAcc(6, 3); 00512 MulAcc(7, 2); 00513 00514 SaveMulAcc(9, 3, 7); 00515 MulAcc(4, 6); 00516 MulAcc(5, 5); 00517 MulAcc(6, 4); 00518 MulAcc(7, 3); 00519 00520 SaveMulAcc(10, 4, 7); 00521 MulAcc(5, 6); 00522 MulAcc(6, 5); 00523 MulAcc(7, 4); 00524 00525 SaveMulAcc(11, 5, 7); 00526 MulAcc(6, 6); 00527 MulAcc(7, 5); 00528 00529 SaveMulAcc(12, 6, 7); 00530 MulAcc(7, 6); 00531 00532 R[13] = c; 00533 p = (dword)A[7] * B[7] + d; 00534 R[14] = LOW_WORD(p); 00535 R[15] = e + HIGH_WORD(p); 00536 } 00537 00538 void Portable::Multiply4Bottom(word *R, const word *A, const word *B) 00539 { 00540 dword p; 00541 word c, d, e; 00542 00543 p = (dword)A[0] * B[0]; 00544 R[0] = LOW_WORD(p); 00545 c = HIGH_WORD(p); 00546 d = e = 0; 00547 00548 MulAcc(0, 1); 00549 MulAcc(1, 0); 00550 00551 SaveMulAcc(1, 2, 0); 00552 MulAcc(1, 1); 00553 MulAcc(0, 2); 00554 00555 R[2] = c; 00556 R[3] = d + A[0] * B[3] + A[1] * B[2] + A[2] * B[1] + A[3] * B[0]; 00557 } 00558 00559 void Portable::Multiply8Bottom(word *R, const word *A, const word *B) 00560 { 00561 dword p; 00562 word c, d, e; 00563 00564 p = (dword)A[0] * B[0]; 00565 R[0] = LOW_WORD(p); 00566 c = HIGH_WORD(p); 00567 d = e = 0; 00568 00569 MulAcc(0, 1); 00570 MulAcc(1, 0); 00571 00572 SaveMulAcc(1, 2, 0); 00573 MulAcc(1, 1); 00574 MulAcc(0, 2); 00575 00576 SaveMulAcc(2, 0, 3); 00577 MulAcc(1, 2); 00578 MulAcc(2, 1); 00579 MulAcc(3, 0); 00580 00581 SaveMulAcc(3, 0, 4); 00582 MulAcc(1, 3); 00583 MulAcc(2, 2); 00584 MulAcc(3, 1); 00585 MulAcc(4, 0); 00586 00587 SaveMulAcc(4, 0, 5); 00588 MulAcc(1, 4); 00589 MulAcc(2, 3); 00590 MulAcc(3, 2); 00591 MulAcc(4, 1); 00592 MulAcc(5, 0); 00593 00594 SaveMulAcc(5, 0, 6); 00595 MulAcc(1, 5); 00596 MulAcc(2, 4); 00597 MulAcc(3, 3); 00598 MulAcc(4, 2); 00599 MulAcc(5, 1); 00600 MulAcc(6, 0); 00601 00602 R[6] = c; 00603 R[7] = d + A[0] * B[7] + A[1] * B[6] + A[2] * B[5] + A[3] * B[4] + 00604 A[4] * B[3] + A[5] * B[2] + A[6] * B[1] + A[7] * B[0]; 00605 } 00606 00607 #undef MulAcc 00608 #undef SaveMulAcc 00609 #undef SquAcc 00610 #undef SaveSquAcc 00611 00612 // CodeWarrior defines _MSC_VER 00613 #if defined(_MSC_VER) && !defined(__MWERKS__) && defined(_M_IX86) && (_M_IX86<=700) 00614 00615 class PentiumOptimized : public Portable 00616 { 00617 public: 00618 static word __fastcall Add(word *C, const word *A, const word *B, unsigned int N); 00619 static word __fastcall Subtract(word *C, const word *A, const word *B, unsigned int N); 00620 static inline void Square4(word *R, const word *A) 00621 { 00622 // VC60 workaround: MSVC 6.0 has an optimization bug that makes 00623 // (dword)A*B where either A or B has been cast to a dword before 00624 // very expensive. Revisit this function when this 00625 // bug is fixed. 00626 Multiply4(R, A, A); 00627 } 00628 }; 00629 00630 typedef PentiumOptimized LowLevel; 00631 00632 __declspec(naked) word __fastcall PentiumOptimized::Add(word *C, const word *A, const word *B, unsigned int N) 00633 { 00634 __asm 00635 { 00636 push ebp 00637 push ebx 00638 push esi 00639 push edi 00640 00641 mov esi, [esp+24] ; N 00642 mov ebx, [esp+20] ; B 00643 00644 // now: ebx = B, ecx = C, edx = A, esi = N 00645 00646 sub ecx, edx // hold the distance between C & A so we can add this to A to get C 00647 xor eax, eax // clear eax 00648 00649 sub eax, esi // eax is a negative index from end of B 00650 lea ebx, [ebx+4*esi] // ebx is end of B 00651 00652 sar eax, 1 // unit of eax is now dwords; this also clears the carry flag 00653 jz loopend // if no dwords then nothing to do 00654 00655 loopstart: 00656 mov esi,[edx] // load lower word of A 00657 mov ebp,[edx+4] // load higher word of A 00658 00659 mov edi,[ebx+8*eax] // load lower word of B 00660 lea edx,[edx+8] // advance A and C 00661 00662 adc esi,edi // add lower words 00663 mov edi,[ebx+8*eax+4] // load higher word of B 00664 00665 adc ebp,edi // add higher words 00666 inc eax // advance B 00667 00668 mov [edx+ecx-8],esi // store lower word result 00669 mov [edx+ecx-4],ebp // store higher word result 00670 00671 jnz loopstart // loop until eax overflows and becomes zero 00672 00673 loopend: 00674 adc eax, 0 // store carry into eax (return result register) 00675 pop edi 00676 pop esi 00677 pop ebx 00678 pop ebp 00679 ret 8 00680 } 00681 } 00682 00683 __declspec(naked) word __fastcall PentiumOptimized::Subtract(word *C, const word *A, const word *B, unsigned int N) 00684 { 00685 __asm 00686 { 00687 push ebp 00688 push ebx 00689 push esi 00690 push edi 00691 00692 mov esi, [esp+24] ; N 00693 mov ebx, [esp+20] ; B 00694 00695 sub ecx, edx 00696 xor eax, eax 00697 00698 sub eax, esi 00699 lea ebx, [ebx+4*esi] 00700 00701 sar eax, 1 00702 jz loopend 00703 00704 loopstart: 00705 mov esi,[edx] 00706 mov ebp,[edx+4] 00707 00708 mov edi,[ebx+8*eax] 00709 lea edx,[edx+8] 00710 00711 sbb esi,edi 00712 mov edi,[ebx+8*eax+4] 00713 00714 sbb ebp,edi 00715 inc eax 00716 00717 mov [edx+ecx-8],esi 00718 mov [edx+ecx-4],ebp 00719 00720 jnz loopstart 00721 00722 loopend: 00723 adc eax, 0 00724 pop edi 00725 pop esi 00726 pop ebx 00727 pop ebp 00728 ret 8 00729 } 00730 } 00731 00732 #ifdef SSE2_INTRINSICS_AVAILABLE 00733 00734 static bool GetSSE2Capability() 00735 { 00736 word32 b; 00737 00738 __asm 00739 { 00740 mov eax, 1 00741 cpuid 00742 mov b, edx 00743 } 00744 00745 return (b & (1 << 26)) != 0; 00746 } 00747 00748 bool g_sse2DetectionDone = false, g_sse2Detected, g_sse2Enabled = true; 00749 00750 static inline bool HasSSE2() 00751 { 00752 if (g_sse2Enabled && !g_sse2DetectionDone) 00753 { 00754 g_sse2Detected = GetSSE2Capability(); 00755 g_sse2DetectionDone = true; 00756 } 00757 return g_sse2Enabled && g_sse2Detected; 00758 } 00759 00760 class P4Optimized : public PentiumOptimized 00761 { 00762 public: 00763 static word __fastcall Add(word *C, const word *A, const word *B, unsigned int N); 00764 static word __fastcall Subtract(word *C, const word *A, const word *B, unsigned int N); 00765 static void Multiply4(word *C, const word *A, const word *B); 00766 static void Multiply8(word *C, const word *A, const word *B); 00767 static inline void Square4(word *R, const word *A) 00768 { 00769 Multiply4(R, A, A); 00770 } 00771 static void Multiply8Bottom(word *C, const word *A, const word *B); 00772 }; 00773 00774 static void __fastcall P4_Mul(__m128i *C, const __m128i *A, const __m128i *B) 00775 { 00776 __m128i a3210 = _mm_load_si128(A); 00777 __m128i b3210 = _mm_load_si128(B); 00778 00779 __m128i sum; 00780 00781 __m128i z = _mm_setzero_si128(); 00782 __m128i a2b2_a0b0 = _mm_mul_epu32(a3210, b3210); 00783 C[0] = a2b2_a0b0; 00784 00785 __m128i a3120 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(3, 1, 2, 0)); 00786 __m128i b3021 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(3, 0, 2, 1)); 00787 __m128i a1b0_a0b1 = _mm_mul_epu32(a3120, b3021); 00788 __m128i a1b0 = _mm_unpackhi_epi32(a1b0_a0b1, z); 00789 __m128i a0b1 = _mm_unpacklo_epi32(a1b0_a0b1, z); 00790 C[1] = _mm_add_epi64(a1b0, a0b1); 00791 00792 __m128i a31 = _mm_srli_epi64(a3210, 32); 00793 __m128i b31 = _mm_srli_epi64(b3210, 32); 00794 __m128i a3b3_a1b1 = _mm_mul_epu32(a31, b31); 00795 C[6] = a3b3_a1b1; 00796 00797 __m128i a1b1 = _mm_unpacklo_epi32(a3b3_a1b1, z); 00798 __m128i b3012 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(3, 0, 1, 2)); 00799 __m128i a2b0_a0b2 = _mm_mul_epu32(a3210, b3012); 00800 __m128i a0b2 = _mm_unpacklo_epi32(a2b0_a0b2, z); 00801 __m128i a2b0 = _mm_unpackhi_epi32(a2b0_a0b2, z); 00802 sum = _mm_add_epi64(a1b1, a0b2); 00803 C[2] = _mm_add_epi64(sum, a2b0); 00804 00805 __m128i a2301 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(2, 3, 0, 1)); 00806 __m128i b2103 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(2, 1, 0, 3)); 00807 __m128i a3b0_a1b2 = _mm_mul_epu32(a2301, b3012); 00808 __m128i a2b1_a0b3 = _mm_mul_epu32(a3210, b2103); 00809 __m128i a3b0 = _mm_unpackhi_epi32(a3b0_a1b2, z); 00810 __m128i a1b2 = _mm_unpacklo_epi32(a3b0_a1b2, z); 00811 __m128i a2b1 = _mm_unpackhi_epi32(a2b1_a0b3, z); 00812 __m128i a0b3 = _mm_unpacklo_epi32(a2b1_a0b3, z); 00813 __m128i sum1 = _mm_add_epi64(a3b0, a1b2); 00814 sum = _mm_add_epi64(a2b1, a0b3); 00815 C[3] = _mm_add_epi64(sum, sum1); 00816 00817 __m128i a3b1_a1b3 = _mm_mul_epu32(a2301, b2103); 00818 __m128i a2b2 = _mm_unpackhi_epi32(a2b2_a0b0, z); 00819 __m128i a3b1 = _mm_unpackhi_epi32(a3b1_a1b3, z); 00820 __m128i a1b3 = _mm_unpacklo_epi32(a3b1_a1b3, z); 00821 sum = _mm_add_epi64(a2b2, a3b1); 00822 C[4] = _mm_add_epi64(sum, a1b3); 00823 00824 __m128i a1302 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(1, 3, 0, 2)); 00825 __m128i b1203 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(1, 2, 0, 3)); 00826 __m128i a3b2_a2b3 = _mm_mul_epu32(a1302, b1203); 00827 __m128i a3b2 = _mm_unpackhi_epi32(a3b2_a2b3, z); 00828 __m128i a2b3 = _mm_unpacklo_epi32(a3b2_a2b3, z); 00829 C[5] = _mm_add_epi64(a3b2, a2b3); 00830 } 00831 00832 void P4Optimized::Multiply4(word *C, const word *A, const word *B) 00833 { 00834 __m128i temp[7]; 00835 const word *w = (word *)temp; 00836 const __m64 *mw = (__m64 *)w; 00837 00838 P4_Mul(temp, (__m128i *)A, (__m128i *)B); 00839 00840 C[0] = w[0]; 00841 00842 __m64 s1, s2; 00843 00844 __m64 w1 = _m_from_int(w[1]); 00845 __m64 w4 = mw[2]; 00846 __m64 w6 = mw[3]; 00847 __m64 w8 = mw[4]; 00848 __m64 w10 = mw[5]; 00849 __m64 w12 = mw[6]; 00850 __m64 w14 = mw[7]; 00851 __m64 w16 = mw[8]; 00852 __m64 w18 = mw[9]; 00853 __m64 w20 = mw[10]; 00854 __m64 w22 = mw[11]; 00855 __m64 w26 = _m_from_int(w[26]); 00856 00857 s1 = _mm_add_si64(w1, w4); 00858 C[1] = _m_to_int(s1); 00859 s1 = _m_psrlqi(s1, 32); 00860 00861 s2 = _mm_add_si64(w6, w8); 00862 s1 = _mm_add_si64(s1, s2); 00863 C[2] = _m_to_int(s1); 00864 s1 = _m_psrlqi(s1, 32); 00865 00866 s2 = _mm_add_si64(w10, w12); 00867 s1 = _mm_add_si64(s1, s2); 00868 C[3] = _m_to_int(s1); 00869 s1 = _m_psrlqi(s1, 32); 00870 00871 s2 = _mm_add_si64(w14, w16); 00872 s1 = _mm_add_si64(s1, s2); 00873 C[4] = _m_to_int(s1); 00874 s1 = _m_psrlqi(s1, 32); 00875 00876 s2 = _mm_add_si64(w18, w20); 00877 s1 = _mm_add_si64(s1, s2); 00878 C[5] = _m_to_int(s1); 00879 s1 = _m_psrlqi(s1, 32); 00880 00881 s2 = _mm_add_si64(w22, w26); 00882 s1 = _mm_add_si64(s1, s2); 00883 C[6] = _m_to_int(s1); 00884 s1 = _m_psrlqi(s1, 32); 00885 00886 C[7] = _m_to_int(s1) + w[27]; 00887 _mm_empty(); 00888 } 00889 00890 void P4Optimized::Multiply8(word *C, const word *A, const word *B) 00891 { 00892 __m128i temp[28]; 00893 const word *w = (word *)temp; 00894 const __m64 *mw = (__m64 *)w; 00895 const word *x = (word *)temp+7*4; 00896 const __m64 *mx = (__m64 *)x; 00897 const word *y = (word *)temp+7*4*2; 00898 const __m64 *my = (__m64 *)y; 00899 const word *z = (word *)temp+7*4*3; 00900 const __m64 *mz = (__m64 *)z; 00901 00902 P4_Mul(temp, (__m128i *)A, (__m128i *)B); 00903 00904 P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B); 00905 00906 P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1); 00907 00908 P4_Mul(temp+21, (__m128i *)A+1, (__m128i *)B+1); 00909 00910 C[0] = w[0]; 00911 00912 __m64 s1, s2, s3, s4; 00913 00914 __m64 w1 = _m_from_int(w[1]); 00915 __m64 w4 = mw[2]; 00916 __m64 w6 = mw[3]; 00917 __m64 w8 = mw[4]; 00918 __m64 w10 = mw[5]; 00919 __m64 w12 = mw[6]; 00920 __m64 w14 = mw[7]; 00921 __m64 w16 = mw[8]; 00922 __m64 w18 = mw[9]; 00923 __m64 w20 = mw[10]; 00924 __m64 w22 = mw[11]; 00925 __m64 w26 = _m_from_int(w[26]); 00926 __m64 w27 = _m_from_int(w[27]); 00927 00928 __m64 x0 = _m_from_int(x[0]); 00929 __m64 x1 = _m_from_int(x[1]); 00930 __m64 x4 = mx[2]; 00931 __m64 x6 = mx[3]; 00932 __m64 x8 = mx[4]; 00933 __m64 x10 = mx[5]; 00934 __m64 x12 = mx[6]; 00935 __m64 x14 = mx[7]; 00936 __m64 x16 = mx[8]; 00937 __m64 x18 = mx[9]; 00938 __m64 x20 = mx[10]; 00939 __m64 x22 = mx[11]; 00940 __m64 x26 = _m_from_int(x[26]); 00941 __m64 x27 = _m_from_int(x[27]); 00942 00943 __m64 y0 = _m_from_int(y[0]); 00944 __m64 y1 = _m_from_int(y[1]); 00945 __m64 y4 = my[2]; 00946 __m64 y6 = my[3]; 00947 __m64 y8 = my[4]; 00948 __m64 y10 = my[5]; 00949 __m64 y12 = my[6]; 00950 __m64 y14 = my[7]; 00951 __m64 y16 = my[8]; 00952 __m64 y18 = my[9]; 00953 __m64 y20 = my[10]; 00954 __m64 y22 = my[11]; 00955 __m64 y26 = _m_from_int(y[26]); 00956 __m64 y27 = _m_from_int(y[27]); 00957 00958 __m64 z0 = _m_from_int(z[0]); 00959 __m64 z1 = _m_from_int(z[1]); 00960 __m64 z4 = mz[2]; 00961 __m64 z6 = mz[3]; 00962 __m64 z8 = mz[4]; 00963 __m64 z10 = mz[5]; 00964 __m64 z12 = mz[6]; 00965 __m64 z14 = mz[7]; 00966 __m64 z16 = mz[8]; 00967 __m64 z18 = mz[9]; 00968 __m64 z20 = mz[10]; 00969 __m64 z22 = mz[11]; 00970 __m64 z26 = _m_from_int(z[26]); 00971 00972 s1 = _mm_add_si64(w1, w4); 00973 C[1] = _m_to_int(s1); 00974 s1 = _m_psrlqi(s1, 32); 00975 00976 s2 = _mm_add_si64(w6, w8); 00977 s1 = _mm_add_si64(s1, s2); 00978 C[2] = _m_to_int(s1); 00979 s1 = _m_psrlqi(s1, 32); 00980 00981 s2 = _mm_add_si64(w10, w12); 00982 s1 = _mm_add_si64(s1, s2); 00983 C[3] = _m_to_int(s1); 00984 s1 = _m_psrlqi(s1, 32); 00985 00986 s3 = _mm_add_si64(x0, y0); 00987 s2 = _mm_add_si64(w14, w16); 00988 s1 = _mm_add_si64(s1, s3); 00989 s1 = _mm_add_si64(s1, s2); 00990 C[4] = _m_to_int(s1); 00991 s1 = _m_psrlqi(s1, 32); 00992 00993 s3 = _mm_add_si64(x1, y1); 00994 s4 = _mm_add_si64(x4, y4); 00995 s1 = _mm_add_si64(s1, w18); 00996 s3 = _mm_add_si64(s3, s4); 00997 s1 = _mm_add_si64(s1, w20); 00998 s1 = _mm_add_si64(s1, s3); 00999 C[5] = _m_to_int(s1); 01000 s1 = _m_psrlqi(s1, 32); 01001 01002 s3 = _mm_add_si64(x6, y6); 01003 s4 = _mm_add_si64(x8, y8); 01004 s1 = _mm_add_si64(s1, w22); 01005 s3 = _mm_add_si64(s3, s4); 01006 s1 = _mm_add_si64(s1, w26); 01007 s1 = _mm_add_si64(s1, s3); 01008 C[6] = _m_to_int(s1); 01009 s1 = _m_psrlqi(s1, 32); 01010 01011 s3 = _mm_add_si64(x10, y10); 01012 s4 = _mm_add_si64(x12, y12); 01013 s1 = _mm_add_si64(s1, w27); 01014 s3 = _mm_add_si64(s3, s4); 01015 s1 = _mm_add_si64(s1, s3); 01016 C[7] = _m_to_int(s1); 01017 s1 = _m_psrlqi(s1, 32); 01018 01019 s3 = _mm_add_si64(x14, y14); 01020 s4 = _mm_add_si64(x16, y16); 01021 s1 = _mm_add_si64(s1, z0); 01022 s3 = _mm_add_si64(s3, s4); 01023 s1 = _mm_add_si64(s1, s3); 01024 C[8] = _m_to_int(s1); 01025 s1 = _m_psrlqi(s1, 32); 01026 01027 s3 = _mm_add_si64(x18, y18); 01028 s4 = _mm_add_si64(x20, y20); 01029 s1 = _mm_add_si64(s1, z1); 01030 s3 = _mm_add_si64(s3, s4); 01031 s1 = _mm_add_si64(s1, z4); 01032 s1 = _mm_add_si64(s1, s3); 01033 C[9] = _m_to_int(s1); 01034 s1 = _m_psrlqi(s1, 32); 01035 01036 s3 = _mm_add_si64(x22, y22); 01037 s4 = _mm_add_si64(x26, y26); 01038 s1 = _mm_add_si64(s1, z6); 01039 s3 = _mm_add_si64(s3, s4); 01040 s1 = _mm_add_si64(s1, z8); 01041 s1 = _mm_add_si64(s1, s3); 01042 C[10] = _m_to_int(s1); 01043 s1 = _m_psrlqi(s1, 32); 01044 01045 s3 = _mm_add_si64(x27, y27); 01046 s1 = _mm_add_si64(s1, z10); 01047 s1 = _mm_add_si64(s1, z12); 01048 s1 = _mm_add_si64(s1, s3); 01049 C[11] = _m_to_int(s1); 01050 s1 = _m_psrlqi(s1, 32); 01051 01052 s3 = _mm_add_si64(z14, z16); 01053 s1 = _mm_add_si64(s1, s3); 01054 C[12] = _m_to_int(s1); 01055 s1 = _m_psrlqi(s1, 32); 01056 01057 s3 = _mm_add_si64(z18, z20); 01058 s1 = _mm_add_si64(s1, s3); 01059 C[13] = _m_to_int(s1); 01060 s1 = _m_psrlqi(s1, 32); 01061 01062 s3 = _mm_add_si64(z22, z26); 01063 s1 = _mm_add_si64(s1, s3); 01064 C[14] = _m_to_int(s1); 01065 s1 = _m_psrlqi(s1, 32); 01066 01067 C[15] = z[27] + _m_to_int(s1); 01068 _mm_empty(); 01069 } 01070 01071 void P4Optimized::Multiply8Bottom(word *C, const word *A, const word *B) 01072 { 01073 __m128i temp[21]; 01074 const word *w = (word *)temp; 01075 const __m64 *mw = (__m64 *)w; 01076 const word *x = (word *)temp+7*4; 01077 const __m64 *mx = (__m64 *)x; 01078 const word *y = (word *)temp+7*4*2; 01079 const __m64 *my = (__m64 *)y; 01080 01081 P4_Mul(temp, (__m128i *)A, (__m128i *)B); 01082 01083 P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B); 01084 01085 P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1); 01086 01087 C[0] = w[0]; 01088 01089 __m64 s1, s2, s3, s4; 01090 01091 __m64 w1 = _m_from_int(w[1]); 01092 __m64 w4 = mw[2]; 01093 __m64 w6 = mw[3]; 01094 __m64 w8 = mw[4]; 01095 __m64 w10 = mw[5]; 01096 __m64 w12 = mw[6]; 01097 __m64 w14 = mw[7]; 01098 __m64 w16 = mw[8]; 01099 __m64 w18 = mw[9]; 01100 __m64 w20 = mw[10]; 01101 __m64 w22 = mw[11]; 01102 __m64 w26 = _m_from_int(w[26]); 01103 01104 __m64 x0 = _m_from_int(x[0]); 01105 __m64 x1 = _m_from_int(x[1]); 01106 __m64 x4 = mx[2]; 01107 __m64 x6 = mx[3]; 01108 __m64 x8 = mx[4]; 01109 01110 __m64 y0 = _m_from_int(y[0]); 01111 __m64 y1 = _m_from_int(y[1]); 01112 __m64 y4 = my[2]; 01113 __m64 y6 = my[3]; 01114 __m64 y8 = my[4]; 01115 01116 s1 = _mm_add_si64(w1, w4); 01117 C[1] = _m_to_int(s1); 01118 s1 = _m_psrlqi(s1, 32); 01119 01120 s2 = _mm_add_si64(w6, w8); 01121 s1 = _mm_add_si64(s1, s2); 01122 C[2] = _m_to_int(s1); 01123 s1 = _m_psrlqi(s1, 32); 01124 01125 s2 = _mm_add_si64(w10, w12); 01126 s1 = _mm_add_si64(s1, s2); 01127 C[3] = _m_to_int(s1); 01128 s1 = _m_psrlqi(s1, 32); 01129 01130 s3 = _mm_add_si64(x0, y0); 01131 s2 = _mm_add_si64(w14, w16); 01132 s1 = _mm_add_si64(s1, s3); 01133 s1 = _mm_add_si64(s1, s2); 01134 C[4] = _m_to_int(s1); 01135 s1 = _m_psrlqi(s1, 32); 01136 01137 s3 = _mm_add_si64(x1, y1); 01138 s4 = _mm_add_si64(x4, y4); 01139 s1 = _mm_add_si64(s1, w18); 01140 s3 = _mm_add_si64(s3, s4); 01141 s1 = _mm_add_si64(s1, w20); 01142 s1 = _mm_add_si64(s1, s3); 01143 C[5] = _m_to_int(s1); 01144 s1 = _m_psrlqi(s1, 32); 01145 01146 s3 = _mm_add_si64(x6, y6); 01147 s4 = _mm_add_si64(x8, y8); 01148 s1 = _mm_add_si64(s1, w22); 01149 s3 = _mm_add_si64(s3, s4); 01150 s1 = _mm_add_si64(s1, w26); 01151 s1 = _mm_add_si64(s1, s3); 01152 C[6] = _m_to_int(s1); 01153 s1 = _m_psrlqi(s1, 32); 01154 01155 C[7] = _m_to_int(s1) + w[27] + x[10] + y[10] + x[12] + y[12]; 01156 _mm_empty(); 01157 } 01158 01159 __declspec(naked) word __fastcall P4Optimized::Add(word *C, const word *A, const word *B, unsigned int N) 01160 { 01161 __asm 01162 { 01163 sub esp, 16 01164 xor eax, eax 01165 mov [esp], edi 01166 mov [esp+4], esi 01167 mov [esp+8], ebx 01168 mov [esp+12], ebp 01169 01170 mov ebx, [esp+20] // B 01171 mov esi, [esp+24] // N 01172 01173 // now: ebx = B, ecx = C, edx = A, esi = N 01174 01175 neg esi 01176 jz loopend // if no dwords then nothing to do 01177 01178 mov edi, [edx] 01179 mov ebp, [ebx] 01180 01181 loopstart: 01182 add edi, eax 01183 jc carry1 01184 01185 xor eax, eax 01186 01187 carry1continue: 01188 add edi, ebp 01189 mov ebp, 1 01190 mov [ecx], edi 01191 mov edi, [edx+4] 01192 cmovc eax, ebp 01193 mov ebp, [ebx+4] 01194 lea ebx, [ebx+8] 01195 add edi, eax 01196 jc carry2 01197 01198 xor eax, eax 01199 01200 carry2continue: 01201 add edi, ebp 01202 mov ebp, 1 01203 cmovc eax, ebp 01204 mov [ecx+4], edi 01205 add ecx, 8 01206 mov edi, [edx+8] 01207 add edx, 8 01208 add esi, 2 01209 mov ebp, [ebx] 01210 jnz loopstart 01211 01212 loopend: 01213 mov edi, [esp] 01214 mov esi, [esp+4] 01215 mov ebx, [esp+8] 01216 mov ebp, [esp+12] 01217 add esp, 16 01218 ret 8 01219 01220 carry1: 01221 mov eax, 1 01222 jmp carry1continue 01223 01224 carry2: 01225 mov eax, 1 01226 jmp carry2continue 01227 } 01228 } 01229 01230 __declspec(naked) word __fastcall P4Optimized::Subtract(word *C, const word *A, const word *B, unsigned int N) 01231 { 01232 __asm 01233 { 01234 sub esp, 16 01235 xor eax, eax 01236 mov [esp], edi 01237 mov [esp+4], esi 01238 mov [esp+8], ebx 01239 mov [esp+12], ebp 01240 01241 mov ebx, [esp+20] // B 01242 mov esi, [esp+24] // N 01243 01244 // now: ebx = B, ecx = C, edx = A, esi = N 01245 01246 neg esi 01247 jz loopend // if no dwords then nothing to do 01248 01249 mov edi, [edx] 01250 mov ebp, [ebx] 01251 01252 loopstart: 01253 sub edi, eax 01254 jc carry1 01255 01256 xor eax, eax 01257 01258 carry1continue: 01259 sub edi, ebp 01260 mov ebp, 1 01261 mov [ecx], edi 01262 mov edi, [edx+4] 01263 cmovc eax, ebp 01264 mov ebp, [ebx+4] 01265 lea ebx, [ebx+8] 01266 sub edi, eax 01267 jc carry2 01268 01269 xor eax, eax 01270 01271 carry2continue: 01272 sub edi, ebp 01273 mov ebp, 1 01274 cmovc eax, ebp 01275 mov [ecx+4], edi 01276 add ecx, 8 01277 mov edi, [edx+8] 01278 add edx, 8 01279 add esi, 2 01280 mov ebp, [ebx] 01281 jnz loopstart 01282 01283 loopend: 01284 mov edi, [esp] 01285 mov esi, [esp+4] 01286 mov ebx, [esp+8] 01287 mov ebp, [esp+12] 01288 add esp, 16 01289 ret 8 01290 01291 carry1: 01292 mov eax, 1 01293 jmp carry1continue 01294 01295 carry2: 01296 mov eax, 1 01297 jmp carry2continue 01298 } 01299 } 01300 01301 #endif // #ifdef SSE2_INTRINSICS_AVAILABLE 01302 01303 #elif defined(__GNUC__) && defined(__i386__) 01304 01305 class PentiumOptimized : public Portable 01306 { 01307 public: 01308 #ifndef __pic__ // -fpic uses up a register, leaving too few for the asm code 01309 static word Add(word *C, const word *A, const word *B, unsigned int N); 01310 static word Subtract(word *C, const word *A, const word *B, unsigned int N); 01311 #endif 01312 static void Square4(word *R, const word *A); 01313 static void Multiply4(word *C, const word *A, const word *B); 01314 static void Multiply8(word *C, const word *A, const word *B); 01315 }; 01316 01317 typedef PentiumOptimized LowLevel; 01318 01319 // Add and Subtract assembly code originally contributed by Alister Lee 01320 01321 #ifndef __pic__ 01322 __attribute__((regparm(3))) word PentiumOptimized::Add(word *C, const word *A, const word *B, unsigned int N) 01323 { 01324 assert (N%2 == 0); 01325 01326 register word carry, temp; 01327 01328 __asm__ __volatile__( 01329 "push %%ebp;" 01330 "sub %3, %2;" 01331 "xor %0, %0;" 01332 "sub %4, %0;" 01333 "lea (%1,%4,4), %1;" 01334 "sar $1, %0;" 01335 "jz 1f;" 01336 01337 "0:;" 01338 "mov 0(%3), %4;" 01339 "mov 4(%3), %%ebp;" 01340 "mov (%1,%0,8), %5;" 01341 "lea 8(%3), %3;" 01342 "adc %5, %4;" 01343 "mov 4(%1,%0,8), %5;" 01344 "adc %5, %%ebp;" 01345 "inc %0;" 01346 "mov %4, -8(%3, %2);" 01347 "mov %%ebp, -4(%3, %2);" 01348 "jnz 0b;" 01349 01350 "1:;" 01351 "adc $0, %0;" 01352 "pop %%ebp;" 01353 01354 : "=aSD" (carry), "+r" (B), "+r" (C), "+r" (A), "+r" (N), "=r" (temp) 01355 : : "cc", "memory"); 01356 01357 return carry; 01358 } 01359 01360 __attribute__((regparm(3))) word PentiumOptimized::Subtract(word *C, const word *A, const word *B, unsigned int N) 01361 { 01362 assert (N%2 == 0); 01363 01364 register word carry, temp; 01365 01366 __asm__ __volatile__( 01367 "push %%ebp;" 01368 "sub %3, %2;" 01369 "xor %0, %0;" 01370 "sub %4, %0;" 01371 "lea (%1,%4,4), %1;" 01372 "sar $1, %0;" 01373 "jz 1f;" 01374 01375 "0:;" 01376 "mov 0(%3), %4;" 01377 "mov 4(%3), %%ebp;" 01378 "mov (%1,%0,8), %5;" 01379 "lea 8(%3), %3;" 01380 "sbb %5, %4;" 01381 "mov 4(%1,%0,8), %5;" 01382 "sbb %5, %%ebp;" 01383 "inc %0;" 01384 "mov %4, -8(%3, %2);" 01385 "mov %%ebp, -4(%3, %2);" 01386 "jnz 0b;" 01387 01388 "1:;" 01389 "adc $0, %0;" 01390 "pop %%ebp;" 01391 01392 : "=aSD" (carry), "+r" (B), "+r" (C), "+r" (A), "+r" (N), "=r" (temp) 01393 : : "cc", "memory"); 01394 01395 return carry; 01396 } 01397 #endif // __pic__ 01398 01399 // Comba square and multiply assembly code originally contributed by Leonard Janke 01400 01401 #define SqrStartup \ 01402 "push %%ebp\n\t" \ 01403 "push %%esi\n\t" \ 01404 "push %%ebx\n\t" \ 01405 "xor %%ebp, %%ebp\n\t" \ 01406 "xor %%ebx, %%ebx\n\t" \ 01407 "xor %%ecx, %%ecx\n\t" 01408 01409 #define SqrShiftCarry \ 01410 "mov %%ebx, %%ebp\n\t" \ 01411 "mov %%ecx, %%ebx\n\t" \ 01412 "xor %%ecx, %%ecx\n\t" 01413 01414 #define SqrAccumulate(i,j) \ 01415 "mov 4*"#j"(%%esi), %%eax\n\t" \ 01416 "mull 4*"#i"(%%esi)\n\t" \ 01417 "add %%eax, %%ebp\n\t" \ 01418 "adc %%edx, %%ebx\n\t" \ 01419 "adc %%ch, %%cl\n\t" \ 01420 "add %%eax, %%ebp\n\t" \ 01421 "adc %%edx, %%ebx\n\t" \ 01422 "adc %%ch, %%cl\n\t" 01423 01424 #define SqrAccumulateCentre(i) \ 01425 "mov 4*"#i"(%%esi), %%eax\n\t" \ 01426 "mull 4*"#i"(%%esi)\n\t" \ 01427 "add %%eax, %%ebp\n\t" \ 01428 "adc %%edx, %%ebx\n\t" \ 01429 "adc %%ch, %%cl\n\t" 01430 01431 #define SqrStoreDigit(X) \ 01432 "mov %%ebp, 4*"#X"(%%edi)\n\t" \ 01433 01434 #define SqrLastDiagonal(digits) \ 01435 "mov 4*("#digits"-1)(%%esi), %%eax\n\t" \ 01436 "mull 4*("#digits"-1)(%%esi)\n\t" \ 01437 "add %%eax, %%ebp\n\t" \ 01438 "adc %%edx, %%ebx\n\t" \ 01439 "mov %%ebp, 4*(2*"#digits"-2)(%%edi)\n\t" \ 01440 "mov %%ebx, 4*(2*"#digits"-1)(%%edi)\n\t" 01441 01442 #define SqrCleanup \ 01443 "pop %%ebx\n\t" \ 01444 "pop %%esi\n\t" \ 01445 "pop %%ebp\n\t" 01446 01447 void PentiumOptimized::Square4(word* Y, const word* X) 01448 { 01449 __asm__ __volatile__( 01450 SqrStartup 01451 01452 SqrAccumulateCentre(0) 01453 SqrStoreDigit(0) 01454 SqrShiftCarry 01455 01456 SqrAccumulate(1,0) 01457 SqrStoreDigit(1) 01458 SqrShiftCarry 01459 01460 SqrAccumulate(2,0) 01461 SqrAccumulateCentre(1) 01462 SqrStoreDigit(2) 01463 SqrShiftCarry 01464 01465 SqrAccumulate(3,0) 01466 SqrAccumulate(2,1) 01467 SqrStoreDigit(3) 01468 SqrShiftCarry 01469 01470 SqrAccumulate(3,1) 01471 SqrAccumulateCentre(2) 01472 SqrStoreDigit(4) 01473 SqrShiftCarry 01474 01475 SqrAccumulate(3,2) 01476 SqrStoreDigit(5) 01477 SqrShiftCarry 01478 01479 SqrLastDiagonal(4) 01480 01481 SqrCleanup 01482 01483 : 01484 : "D" (Y), "S" (X) 01485 : "eax", "ecx", "edx", "ebp", "memory" 01486 ); 01487 } 01488 01489 #define MulStartup \ 01490 "push %%ebp\n\t" \ 01491 "push %%esi\n\t" \ 01492 "push %%ebx\n\t" \ 01493 "push %%edi\n\t" \ 01494 "mov %%eax, %%ebx \n\t" \ 01495 "xor %%ebp, %%ebp\n\t" \ 01496 "xor %%edi, %%edi\n\t" \ 01497 "xor %%ecx, %%ecx\n\t" 01498 01499 #define MulShiftCarry \ 01500 "mov %%edx, %%ebp\n\t" \ 01501 "mov %%ecx, %%edi\n\t" \ 01502 "xor %%ecx, %%ecx\n\t" 01503 01504 #define MulAccumulate(i,j) \ 01505 "mov 4*"#j"(%%ebx), %%eax\n\t" \ 01506 "mull 4*"#i"(%%esi)\n\t" \ 01507 "add %%eax, %%ebp\n\t" \ 01508 "adc %%edx, %%edi\n\t" \ 01509 "adc %%ch, %%cl\n\t" 01510 01511 #define MulStoreDigit(X) \ 01512 "mov %%edi, %%edx \n\t" \ 01513 "mov (%%esp), %%edi \n\t" \ 01514 "mov %%ebp, 4*"#X"(%%edi)\n\t" \ 01515 "mov %%edi, (%%esp)\n\t" 01516 01517 #define MulLastDiagonal(digits) \ 01518 "mov 4*("#digits"-1)(%%ebx), %%eax\n\t" \ 01519 "mull 4*("#digits"-1)(%%esi)\n\t" \ 01520 "add %%eax, %%ebp\n\t" \ 01521 "adc %%edi, %%edx\n\t" \ 01522 "mov (%%esp), %%edi\n\t" \ 01523 "mov %%ebp, 4*(2*"#digits"-2)(%%edi)\n\t" \ 01524 "mov %%edx, 4*(2*"#digits"-1)(%%edi)\n\t" 01525 01526 #define MulCleanup \ 01527 "pop %%edi\n\t" \ 01528 "pop %%ebx\n\t" \ 01529 "pop %%esi\n\t" \ 01530 "pop %%ebp\n\t" 01531 01532 void PentiumOptimized::Multiply4(word* Z, const word* X, const word* Y) 01533 { 01534 __asm__ __volatile__( 01535 MulStartup 01536 MulAccumulate(0,0) 01537 MulStoreDigit(0) 01538 MulShiftCarry 01539 01540 MulAccumulate(1,0) 01541 MulAccumulate(0,1) 01542 MulStoreDigit(1) 01543 MulShiftCarry 01544 01545 MulAccumulate(2,0) 01546 MulAccumulate(1,1) 01547 MulAccumulate(0,2) 01548 MulStoreDigit(2) 01549 MulShiftCarry 01550 01551 MulAccumulate(3,0) 01552 MulAccumulate(2,1) 01553 MulAccumulate(1,2) 01554 MulAccumulate(0,3) 01555 MulStoreDigit(3) 01556 MulShiftCarry 01557 01558 MulAccumulate(3,1) 01559 MulAccumulate(2,2) 01560 MulAccumulate(1,3) 01561 MulStoreDigit(4) 01562 MulShiftCarry 01563 01564 MulAccumulate(3,2) 01565 MulAccumulate(2,3) 01566 MulStoreDigit(5) 01567 MulShiftCarry 01568 01569 MulLastDiagonal(4) 01570 01571 MulCleanup 01572 01573 : 01574 : "D" (Z), "S" (X), "a" (Y) 01575 : "%ecx", "%edx", "memory" 01576 ); 01577 } 01578 01579 void PentiumOptimized::Multiply8(word* Z, const word* X, const word* Y) 01580 { 01581 __asm__ __volatile__( 01582 MulStartup 01583 MulAccumulate(0,0) 01584 MulStoreDigit(0) 01585 MulShiftCarry 01586 01587 MulAccumulate(1,0) 01588 MulAccumulate(0,1) 01589 MulStoreDigit(1) 01590 MulShiftCarry 01591 01592 MulAccumulate(2,0) 01593 MulAccumulate(1,1) 01594 MulAccumulate(0,2) 01595 MulStoreDigit(2) 01596 MulShiftCarry 01597 01598 MulAccumulate(3,0) 01599 MulAccumulate(2,1) 01600 MulAccumulate(1,2) 01601 MulAccumulate(0,3) 01602 MulStoreDigit(3) 01603 MulShiftCarry 01604 01605 MulAccumulate(4,0) 01606 MulAccumulate(3,1) 01607 MulAccumulate(2,2) 01608 MulAccumulate(1,3) 01609 MulAccumulate(0,4) 01610 MulStoreDigit(4) 01611 MulShiftCarry 01612 01613 MulAccumulate(5,0) 01614 MulAccumulate(4,1) 01615 MulAccumulate(3,2) 01616 MulAccumulate(2,3) 01617 MulAccumulate(1,4) 01618 MulAccumulate(0,5) 01619 MulStoreDigit(5) 01620 MulShiftCarry 01621 01622 MulAccumulate(6,0) 01623 MulAccumulate(5,1) 01624 MulAccumulate(4,2) 01625 MulAccumulate(3,3) 01626 MulAccumulate(2,4) 01627 MulAccumulate(1,5) 01628 MulAccumulate(0,6) 01629 MulStoreDigit(6) 01630 MulShiftCarry 01631 01632 MulAccumulate(7,0) 01633 MulAccumulate(6,1) 01634 MulAccumulate(5,2) 01635 MulAccumulate(4,3) 01636 MulAccumulate(3,4) 01637 MulAccumulate(2,5) 01638 MulAccumulate(1,6) 01639 MulAccumulate(0,7) 01640 MulStoreDigit(7) 01641 MulShiftCarry 01642 01643 MulAccumulate(7,1) 01644 MulAccumulate(6,2) 01645 MulAccumulate(5,3) 01646 MulAccumulate(4,4) 01647 MulAccumulate(3,5) 01648 MulAccumulate(2,6) 01649 MulAccumulate(1,7) 01650 MulStoreDigit(8) 01651 MulShiftCarry 01652 01653 MulAccumulate(7,2) 01654 MulAccumulate(6,3) 01655 MulAccumulate(5,4) 01656 MulAccumulate(4,5) 01657 MulAccumulate(3,6) 01658 MulAccumulate(2,7) 01659 MulStoreDigit(9) 01660 MulShiftCarry 01661 01662 MulAccumulate(7,3) 01663 MulAccumulate(6,4) 01664 MulAccumulate(5,5) 01665 MulAccumulate(4,6) 01666 MulAccumulate(3,7) 01667 MulStoreDigit(10) 01668 MulShiftCarry 01669 01670 MulAccumulate(7,4) 01671 MulAccumulate(6,5) 01672 MulAccumulate(5,6) 01673 MulAccumulate(4,7) 01674 MulStoreDigit(11) 01675 MulShiftCarry 01676 01677 MulAccumulate(7,5) 01678 MulAccumulate(6,6) 01679 MulAccumulate(5,7) 01680 MulStoreDigit(12) 01681 MulShiftCarry 01682 01683 MulAccumulate(7,6) 01684 MulAccumulate(6,7) 01685 MulStoreDigit(13) 01686 MulShiftCarry 01687 01688 MulLastDiagonal(8) 01689 01690 MulCleanup 01691 01692 : 01693 : "D" (Z), "S" (X), "a" (Y) 01694 : "%ecx", "%edx", "memory" 01695 ); 01696 } 01697 01698 #elif defined(__GNUC__) && defined(__alpha__) 01699 01700 class AlphaOptimized : public Portable 01701 { 01702 public: 01703 static inline void Multiply2(word *C, const word *A, const word *B); 01704 static inline word Multiply2Add(word *C, const word *A, const word *B); 01705 static inline void Multiply4(word *C, const word *A, const word *B); 01706 static inline unsigned int MultiplyRecursionLimit() {return 4;} 01707 01708 static inline void Multiply4Bottom(word *C, const word *A, const word *B); 01709 static inline unsigned int MultiplyBottomRecursionLimit() {return 4;} 01710 01711 static inline void Square4(word *R, const word *A) 01712 { 01713 Multiply4(R, A, A); 01714 } 01715 }; 01716 01717 typedef AlphaOptimized LowLevel; 01718 01719 inline void AlphaOptimized::Multiply2(word *C, const word *A, const word *B) 01720 { 01721 register dword c, a = *(const dword *)A, b = *(const dword *)B; 01722 ((dword *)C)[0] = a*b; 01723 __asm__("umulh %1,%2,%0" : "=r" (c) : "r" (a), "r" (b)); 01724 ((dword *)C)[1] = c; 01725 } 01726 01727 inline word AlphaOptimized::Multiply2Add(word *C, const word *A, const word *B) 01728 { 01729 register dword c, d, e, a = *(const dword *)A, b = *(const dword *)B; 01730 c = ((dword *)C)[0]; 01731 d = a*b + c; 01732 __asm__("umulh %1,%2,%0" : "=r" (e) : "r" (a), "r" (b)); 01733 ((dword *)C)[0] = d; 01734 d = (d < c); 01735 c = ((dword *)C)[1] + d; 01736 d = (c < d); 01737 c += e; 01738 ((dword *)C)[1] = c; 01739 d |= (c < e); 01740 return d; 01741 } 01742 01743 inline void AlphaOptimized::Multiply4(word *R, const word *A, const word *B) 01744 { 01745 Multiply2(R, A, B); 01746 Multiply2(R+4, A+2, B+2); 01747 word carry = Multiply2Add(R+2, A+0, B+2); 01748 carry += Multiply2Add(R+2, A+2, B+0); 01749 Increment(R+6, 2, carry); 01750 } 01751 01752 static inline void Multiply2BottomAdd(word *C, const word *A, const word *B) 01753 { 01754 register dword a = *(const dword *)A, b = *(const dword *)B; 01755 ((dword *)C)[0] = a*b + ((dword *)C)[0]; 01756 } 01757 01758 inline void AlphaOptimized::Multiply4Bottom(word *R, const word *A, const word *B) 01759 { 01760 Multiply2(R, A, B); 01761 Multiply2BottomAdd(R+2, A+0, B+2); 01762 Multiply2BottomAdd(R+2, A+2, B+0); 01763 } 01764 01765 #else // no processor specific code available 01766 01767 typedef Portable LowLevel; 01768 01769 #endif 01770 01771 // ******************************************************** 01772 01773 #define A0 A 01774 #define A1 (A+N2) 01775 #define B0 B 01776 #define B1 (B+N2) 01777 01778 #define T0 T 01779 #define T1 (T+N2) 01780 #define T2 (T+N) 01781 #define T3 (T+N+N2) 01782 01783 #define R0 R 01784 #define R1 (R+N2) 01785 #define R2 (R+N) 01786 #define R3 (R+N+N2) 01787 01788 //VC60 workaround: compiler bug triggered without the extra dummy parameters 01789 01790 // R[2*N] - result = A*B 01791 // T[2*N] - temporary work space 01792 // A[N] --- multiplier 01793 // B[N] --- multiplicant 01794 01795 template <class P> 01796 void DoRecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL); 01797 01798 template <class P> 01799 inline void RecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL) 01800 { 01801 assert(N>=2 && N%2==0); 01802 01803 if (P::MultiplyRecursionLimit() >= 8 && N==8) 01804 P::Multiply8(R, A, B); 01805 else if (P::MultiplyRecursionLimit() >= 4 && N==4) 01806 P::Multiply4(R, A, B); 01807 else if (N==2) 01808 P::Multiply2(R, A, B); 01809 else 01810 DoRecursiveMultiply<P>(R, T, A, B, N, NULL); // VC60 workaround: needs this NULL 01811 } 01812 01813 template <class P> 01814 void DoRecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy) 01815 { 01816 const unsigned int N2 = N/2; 01817 int carry; 01818 01819 int aComp = Compare(A0, A1, N2); 01820 int bComp = Compare(B0, B1, N2); 01821 01822 switch (2*aComp + aComp + bComp) 01823 { 01824 case -4: 01825 P::Subtract(R0, A1, A0, N2); 01826 P::Subtract(R1, B0, B1, N2); 01827 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01828 P::Subtract(T1, T1, R0, N2); 01829 carry = -1; 01830 break; 01831 case -2: 01832 P::Subtract(R0, A1, A0, N2); 01833 P::Subtract(R1, B0, B1, N2); 01834 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01835 carry = 0; 01836 break; 01837 case 2: 01838 P::Subtract(R0, A0, A1, N2); 01839 P::Subtract(R1, B1, B0, N2); 01840 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01841 carry = 0; 01842 break; 01843 case 4: 01844 P::Subtract(R0, A1, A0, N2); 01845 P::Subtract(R1, B0, B1, N2); 01846 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01847 P::Subtract(T1, T1, R1, N2); 01848 carry = -1; 01849 break; 01850 default: 01851 SetWords(T0, 0, N); 01852 carry = 0; 01853 } 01854 01855 RecursiveMultiply<P>(R0, T2, A0, B0, N2); 01856 RecursiveMultiply<P>(R2, T2, A1, B1, N2); 01857 01858 // now T[01] holds (A1-A0)*(B0-B1), R[01] holds A0*B0, R[23] holds A1*B1 01859 01860 carry += P::Add(T0, T0, R0, N); 01861 carry += P::Add(T0, T0, R2, N); 01862 carry += P::Add(R1, R1, T0, N); 01863 01864 assert (carry >= 0 && carry <= 2); 01865 Increment(R3, N2, carry); 01866 } 01867 01868 // R[2*N] - result = A*A 01869 // T[2*N] - temporary work space 01870 // A[N] --- number to be squared 01871 01872 template <class P> 01873 void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL); 01874 01875 template <class P> 01876 inline void RecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL) 01877 { 01878 assert(N && N%2==0); 01879 if (P::SquareRecursionLimit() >= 8 && N==8) 01880 P::Square8(R, A); 01881 if (P::SquareRecursionLimit() >= 4 && N==4) 01882 P::Square4(R, A); 01883 else if (N==2) 01884 P::Square2(R, A); 01885 else 01886 DoRecursiveSquare<P>(R, T, A, N, NULL); // VC60 workaround: needs this NULL 01887 } 01888 01889 template <class P> 01890 void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy) 01891 { 01892 const unsigned int N2 = N/2; 01893 01894 RecursiveSquare<P>(R0, T2, A0, N2); 01895 RecursiveSquare<P>(R2, T2, A1, N2); 01896 RecursiveMultiply<P>(T0, T2, A0, A1, N2); 01897 01898 word carry = P::Add(R1, R1, T0, N); 01899 carry += P::Add(R1, R1, T0, N); 01900 Increment(R3, N2, carry); 01901 } 01902 01903 // R[N] - bottom half of A*B 01904 // T[N] - temporary work space 01905 // A[N] - multiplier 01906 // B[N] - multiplicant 01907 01908 template <class P> 01909 void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL); 01910 01911 template <class P> 01912 inline void RecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL) 01913 { 01914 assert(N>=2 && N%2==0); 01915 if (P::MultiplyBottomRecursionLimit() >= 8 && N==8) 01916 P::Multiply8Bottom(R, A, B); 01917 else if (P::MultiplyBottomRecursionLimit() >= 4 && N==4) 01918 P::Multiply4Bottom(R, A, B); 01919 else if (N==2) 01920 P::Multiply2Bottom(R, A, B); 01921 else 01922 DoRecursiveMultiplyBottom<P>(R, T, A, B, N, NULL); 01923 } 01924 01925 template <class P> 01926 void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy) 01927 { 01928 const unsigned int N2 = N/2; 01929 01930 RecursiveMultiply<P>(R, T, A0, B0, N2); 01931 RecursiveMultiplyBottom<P>(T0, T1, A1, B0, N2); 01932 P::Add(R1, R1, T0, N2); 01933 RecursiveMultiplyBottom<P>(T0, T1, A0, B1, N2); 01934 P::Add(R1, R1, T0, N2); 01935 } 01936 01937 // R[N] --- upper half of A*B 01938 // T[2*N] - temporary work space 01939 // L[N] --- lower half of A*B 01940 // A[N] --- multiplier 01941 // B[N] --- multiplicant 01942 01943 template <class P> 01944 void RecursiveMultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N, const P *dummy=NULL) 01945 { 01946 assert(N>=2 && N%2==0); 01947 01948 if (N==4) 01949 { 01950 P::Multiply4(T, A, B); 01951 ((dword *)R)[0] = ((dword *)T)[2]; 01952 ((dword *)R)[1] = ((dword *)T)[3]; 01953 } 01954 else if (N==2) 01955 { 01956 P::Multiply2(T, A, B); 01957 ((dword *)R)[0] = ((dword *)T)[1]; 01958 } 01959 else 01960 { 01961 const unsigned int N2 = N/2; 01962 int carry; 01963 01964 int aComp = Compare(A0, A1, N2); 01965 int bComp = Compare(B0, B1, N2); 01966 01967 switch (2*aComp + aComp + bComp) 01968 { 01969 case -4: 01970 P::Subtract(R0, A1, A0, N2); 01971 P::Subtract(R1, B0, B1, N2); 01972 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01973 P::Subtract(T1, T1, R0, N2); 01974 carry = -1; 01975 break; 01976 case -2: 01977 P::Subtract(R0, A1, A0, N2); 01978 P::Subtract(R1, B0, B1, N2); 01979 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01980 carry = 0; 01981 break; 01982 case 2: 01983 P::Subtract(R0, A0, A1, N2); 01984 P::Subtract(R1, B1, B0, N2); 01985 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01986 carry = 0; 01987 break; 01988 case 4: 01989 P::Subtract(R0, A1, A0, N2); 01990 P::Subtract(R1, B0, B1, N2); 01991 RecursiveMultiply<P>(T0, T2, R0, R1, N2); 01992 P::Subtract(T1, T1, R1, N2); 01993 carry = -1; 01994 break; 01995 default: 01996 SetWords(T0, 0, N); 01997 carry = 0; 01998 } 01999 02000 RecursiveMultiply<P>(T2, R0, A1, B1, N2); 02001 02002 // now T[01] holds (A1-A0)*(B0-B1), T[23] holds A1*B1 02003 02004 word c2 = P::Subtract(R0, L+N2, L, N2); 02005 c2 += P::Subtract(R0, R0, T0, N2); 02006 word t = (Compare(R0, T2, N2) == -1); 02007 02008 carry += t; 02009 carry += Increment(R0, N2, c2+t); 02010 carry += P::Add(R0, R0, T1, N2); 02011 carry += P::Add(R0, R0, T3, N2); 02012 assert (carry >= 0 && carry <= 2); 02013 02014 CopyWords(R1, T3, N2); 02015 Increment(R1, N2, carry); 02016 } 02017 } 02018 02019 inline word Add(word *C, const word *A, const word *B, unsigned int N) 02020 { 02021 return LowLevel::Add(C, A, B, N); 02022 } 02023 02024 inline word Subtract(word *C, const word *A, const word *B, unsigned int N) 02025 { 02026 return LowLevel::Subtract(C, A, B, N); 02027 } 02028 02029 inline void Multiply(word *R, word *T, const word *A, const word *B, unsigned int N) 02030 { 02031 #ifdef SSE2_INTRINSICS_AVAILABLE 02032 if (HasSSE2()) 02033 RecursiveMultiply<P4Optimized>(R, T, A, B, N); 02034 else 02035 #endif 02036 RecursiveMultiply<LowLevel>(R, T, A, B, N); 02037 } 02038 02039 inline void Square(word *R, word *T, const word *A, unsigned int N) 02040 { 02041 #ifdef SSE2_INTRINSICS_AVAILABLE 02042 if (HasSSE2()) 02043 RecursiveSquare<P4Optimized>(R, T, A, N); 02044 else 02045 #endif 02046 RecursiveSquare<LowLevel>(R, T, A, N); 02047 } 02048 02049 inline void MultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N) 02050 { 02051 #ifdef SSE2_INTRINSICS_AVAILABLE 02052 if (HasSSE2()) 02053 RecursiveMultiplyBottom<P4Optimized>(R, T, A, B, N); 02054 else 02055 #endif 02056 RecursiveMultiplyBottom<LowLevel>(R, T, A, B, N); 02057 } 02058 02059 inline void MultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N) 02060 { 02061 #ifdef SSE2_INTRINSICS_AVAILABLE 02062 if (HasSSE2()) 02063 RecursiveMultiplyTop<P4Optimized>(R, T, L, A, B, N); 02064 else 02065 #endif 02066 RecursiveMultiplyTop<LowLevel>(R, T, L, A, B, N); 02067 } 02068 02069 // R[NA+NB] - result = A*B 02070 // T[NA+NB] - temporary work space 02071 // A[NA] ---- multiplier 02072 // B[NB] ---- multiplicant 02073 02074 void AsymmetricMultiply(word *R, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB) 02075 { 02076 if (NA == NB) 02077 { 02078 if (A == B) 02079 Square(R, T, A, NA); 02080 else 02081 Multiply(R, T, A, B, NA); 02082 02083 return; 02084 } 02085 02086 if (NA > NB) 02087 { 02088 std::swap(A, B); 02089 std::swap(NA, NB); 02090 } 02091 02092 assert(NB % NA == 0); 02093 assert((NB/NA)%2 == 0); // NB is an even multiple of NA 02094 02095 if (NA==2 && !A[1]) 02096 { 02097 switch (A[0]) 02098 { 02099 case 0: 02100 SetWords(R, 0, NB+2); 02101 return; 02102 case 1: 02103 CopyWords(R, B, NB); 02104 R[NB] = R[NB+1] = 0; 02105 return; 02106 default: 02107 R[NB] = LinearMultiply(R, B, A[0], NB); 02108 R[NB+1] = 0; 02109 return; 02110 } 02111 } 02112 02113 Multiply(R, T, A, B, NA); 02114 CopyWords(T+2*NA, R+NA, NA); 02115 02116 unsigned i; 02117 02118 for (i=2*NA; i<NB; i+=2*NA) 02119 Multiply(T+NA+i, T, A, B+i, NA); 02120 for (i=NA; i<NB; i+=2*NA) 02121 Multiply(R+i, T, A, B+i, NA); 02122 02123 if (Add(R+NA, R+NA, T+2*NA, NB-NA)) 02124 Increment(R+NB, NA); 02125 } 02126 02127 // R[N] ----- result = A inverse mod 2**(WORD_BITS*N) 02128 // T[3*N/2] - temporary work space 02129 // A[N] ----- an odd number as input 02130 02131 void RecursiveInverseModPower2(word *R, word *T, const word *A, unsigned int N) 02132 { 02133 if (N==2) 02134 AtomicInverseModPower2(R, A[0], A[1]); 02135 else 02136 { 02137 const unsigned int N2 = N/2; 02138 RecursiveInverseModPower2(R0, T0, A0, N2); 02139 T0[0] = 1; 02140 SetWords(T0+1, 0, N2-1); 02141 MultiplyTop(R1, T1, T0, R0, A0, N2); 02142 MultiplyBottom(T0, T1, R0, A1, N2); 02143 Add(T0, R1, T0, N2); 02144 TwosComplement(T0, N2); 02145 MultiplyBottom(R1, T1, R0, T0, N2); 02146 } 02147 } 02148 02149 // R[N] --- result = X/(2**(WORD_BITS*N)) mod M 02150 // T[3*N] - temporary work space 02151 // X[2*N] - number to be reduced 02152 // M[N] --- modulus 02153 // U[N] --- multiplicative inverse of M mod 2**(WORD_BITS*N) 02154 02155 void MontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, unsigned int N) 02156 { 02157 MultiplyBottom(R, T, X, U, N); 02158 MultiplyTop(T, T+N, X, R, M, N); 02159 word borrow = Subtract(T, X+N, T, N); 02160 // defend against timing attack by doing this Add even when not needed 02161 word carry = Add(T+N, T, M, N); 02162 assert(carry || !borrow); 02163 CopyWords(R, T + (borrow ? N : 0), N); 02164 } 02165 02166 // R[N] --- result = X/(2**(WORD_BITS*N/2)) mod M 02167 // T[2*N] - temporary work space 02168 // X[2*N] - number to be reduced 02169 // M[N] --- modulus 02170 // U[N/2] - multiplicative inverse of M mod 2**(WORD_BITS*N/2) 02171 // V[N] --- 2**(WORD_BITS*3*N/2) mod M 02172 02173 void HalfMontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, const word *V, unsigned int N) 02174 { 02175 assert(N%2==0 && N>=4); 02176 02177 #define M0 M 02178 #define M1 (M+N2) 02179 #define V0 V 02180 #define V1 (V+N2) 02181 02182 #define X0 X 02183 #define X1 (X+N2) 02184 #define X2 (X+N) 02185 #define X3 (X+N+N2) 02186 02187 const unsigned int N2 = N/2; 02188 Multiply(T0, T2, V0, X3, N2); 02189 int c2 = Add(T0, T0, X0, N); 02190 MultiplyBottom(T3, T2, T0, U, N2); 02191 MultiplyTop(T2, R, T0, T3, M0, N2); 02192 c2 -= Subtract(T2, T1, T2, N2); 02193 Multiply(T0, R, T3, M1, N2); 02194 c2 -= Subtract(T0, T2, T0, N2); 02195 int c3 = -(int)Subtract(T1, X2, T1, N2); 02196 Multiply(R0, T2, V1, X3, N2); 02197 c3 += Add(R, R, T, N); 02198 02199 if (c2>0) 02200 c3 += Increment(R1, N2); 02201 else if (c2<0) 02202 c3 -= Decrement(R1, N2, -c2); 02203 02204 assert(c3>=-1 && c3<=1); 02205 if (c3>0) 02206 Subtract(R, R, M, N); 02207 else if (c3<0) 02208 Add(R, R, M, N); 02209 02210 #undef M0 02211 #undef M1 02212 #undef V0 02213 #undef V1 02214 02215 #undef X0 02216 #undef X1 02217 #undef X2 02218 #undef X3 02219 } 02220 02221 #undef A0 02222 #undef A1 02223 #undef B0 02224 #undef B1 02225 02226 #undef T0 02227 #undef T1 02228 #undef T2 02229 #undef T3 02230 02231 #undef R0 02232 #undef R1 02233 #undef R2 02234 #undef R3 02235 02236 // do a 3 word by 2 word divide, returns quotient and leaves remainder in A 02237 static word SubatomicDivide(word *A, word B0, word B1) 02238 { 02239 // assert {A[2],A[1]} < {B1,B0}, so quotient can fit in a word 02240 assert(A[2] < B1 || (A[2]==B1 && A[1] < B0)); 02241 02242 dword p, u; 02243 word Q; 02244 02245 // estimate the quotient: do a 2 word by 1 word divide 02246 if (B1+1 == 0) 02247 Q = A[2]; 02248 else 02249 Q = word(MAKE_DWORD(A[1], A[2]) / (B1+1)); 02250 02251 // now subtract Q*B from A 02252 p = (dword) B0*Q; 02253 u = (dword) A[0] - LOW_WORD(p); 02254 A[0] = LOW_WORD(u); 02255 u = (dword) A[1] - HIGH_WORD(p) - (word)(0-HIGH_WORD(u)) - (dword)B1*Q; 02256 A[1] = LOW_WORD(u); 02257 A[2] += HIGH_WORD(u); 02258 02259 // Q <= actual quotient, so fix it 02260 while (A[2] || A[1] > B1 || (A[1]==B1 && A[0]>=B0)) 02261 { 02262 u = (dword) A[0] - B0; 02263 A[0] = LOW_WORD(u); 02264 u = (dword) A[1] - B1 - (word)(0-HIGH_WORD(u)); 02265 A[1] = LOW_WORD(u); 02266 A[2] += HIGH_WORD(u); 02267 Q++; 02268 assert(Q); // shouldn't overflow 02269 } 02270 02271 return Q; 02272 } 02273 02274 // do a 4 word by 2 word divide, returns 2 word quotient in Q0 and Q1 02275 static inline void AtomicDivide(word *Q, const word *A, const word *B) 02276 { 02277 if (!B[0] && !B[1]) // if divisor is 0, we assume divisor==2**(2*WORD_BITS) 02278 { 02279 Q[0] = A[2]; 02280 Q[1] = A[3]; 02281 } 02282 else 02283 { 02284 word T[4]; 02285 T[0] = A[0]; T[1] = A[1]; T[2] = A[2]; T[3] = A[3]; 02286 Q[1] = SubatomicDivide(T+1, B[0], B[1]); 02287 Q[0] = SubatomicDivide(T, B[0], B[1]); 02288 02289 #ifndef NDEBUG 02290 // multiply quotient and divisor and add remainder, make sure it equals dividend 02291 assert(!T[2] && !T[3] && (T[1] < B[1] || (T[1]==B[1] && T[0]<B[0]))); 02292 word P[4]; 02293 LowLevel::Multiply2(P, Q, B); 02294 Add(P, P, T, 4); 02295 assert(memcmp(P, A, 4*WORD_SIZE)==0); 02296 #endif 02297 } 02298 } 02299 02300 // for use by Divide(), corrects the underestimated quotient {Q1,Q0} 02301 static void CorrectQuotientEstimate(word *R, word *T, word *Q, const word *B, unsigned int N) 02302 { 02303 assert(N && N%2==0); 02304 02305 if (Q[1]) 02306 { 02307 T[N] = T[N+1] = 0; 02308 unsigned i; 02309 for (i=0; i<N; i+=4) 02310 LowLevel::Multiply2(T+i, Q, B+i); 02311 for (i=2; i<N; i+=4) 02312 if (LowLevel::Multiply2Add(T+i, Q, B+i)) 02313 T[i+5] += (++T[i+4]==0); 02314 } 02315 else 02316 { 02317 T[N] = LinearMultiply(T, B, Q[0], N); 02318 T[N+1] = 0; 02319 } 02320 02321 word borrow = Subtract(R, R, T, N+2); 02322 assert(!borrow && !R[N+1]); 02323 02324 while (R[N] || Compare(R, B, N) >= 0) 02325 { 02326 R[N] -= Subtract(R, R, B, N); 02327 Q[1] += (++Q[0]==0); 02328 assert(Q[0] || Q[1]); // no overflow 02329 } 02330 } 02331 02332 // R[NB] -------- remainder = A%B 02333 // Q[NA-NB+2] --- quotient = A/B 02334 // T[NA+2*NB+4] - temp work space 02335 // A[NA] -------- dividend 02336 // B[NB] -------- divisor 02337 02338 void Divide(word *R, word *Q, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB) 02339 { 02340 assert(NA && NB && NA%2==0 && NB%2==0); 02341 assert(B[NB-1] || B[NB-2]); 02342 assert(NB <= NA); 02343 02344 // set up temporary work space 02345 word *const TA=T; 02346 word *const TB=T+NA+2; 02347 word *const TP=T+NA+2+NB; 02348 02349 // copy B into TB and normalize it so that TB has highest bit set to 1 02350 unsigned shiftWords = (B[NB-1]==0); 02351 TB[0] = TB[NB-1] = 0; 02352 CopyWords(TB+shiftWords, B, NB-shiftWords); 02353 unsigned shiftBits = WORD_BITS - BitPrecision(TB[NB-1]); 02354 assert(shiftBits < WORD_BITS); 02355 ShiftWordsLeftByBits(TB, NB, shiftBits); 02356 02357 // copy A into TA and normalize it 02358 TA[0] = TA[NA] = TA[NA+1] = 0; 02359 CopyWords(TA+shiftWords, A, NA); 02360 ShiftWordsLeftByBits(TA, NA+2, shiftBits); 02361 02362 if (TA[NA+1]==0 && TA[NA] <= 1) 02363 { 02364 Q[NA-NB+1] = Q[NA-NB] = 0; 02365 while (TA[NA] || Compare(TA+NA-NB, TB, NB) >= 0) 02366 { 02367 TA[NA] -= Subtract(TA+NA-NB, TA+NA-NB, TB, NB); 02368 ++Q[NA-NB]; 02369 } 02370 } 02371 else 02372 { 02373 NA+=2; 02374 assert(Compare(TA+NA-NB, TB, NB) < 0); 02375 } 02376 02377 word BT[2]; 02378 BT[0] = TB[NB-2] + 1; 02379 BT[1] = TB[NB-1] + (BT[0]==0); 02380 02381 // start reducing TA mod TB, 2 words at a time 02382 for (unsigned i=NA-2; i>=NB; i-=2) 02383 { 02384 AtomicDivide(Q+i-NB, TA+i-2, BT); 02385 CorrectQuotientEstimate(TA+i-NB, TP, Q+i-NB, TB, NB); 02386 } 02387 02388 // copy TA into R, and denormalize it 02389 CopyWords(R, TA+shiftWords, NB); 02390 ShiftWordsRightByBits(R, NB, shiftBits); 02391 } 02392 02393 static inline unsigned int EvenWordCount(const word *X, unsigned int N) 02394 { 02395 while (N && X[N-2]==0 && X[N-1]==0) 02396 N-=2; 02397 return N; 02398 } 02399 02400 // return k 02401 // R[N] --- result = A^(-1) * 2^k mod M 02402 // T[4*N] - temporary work space 02403 // A[NA] -- number to take inverse of 02404 // M[N] --- modulus 02405 02406 unsigned int AlmostInverse(word *R, word *T, const word *A, unsigned int NA, const word *M, unsigned int N) 02407 { 02408 assert(NA<=N && N && N%2==0); 02409 02410 word *b = T; 02411 word *c = T+N; 02412 word *f = T+2*N; 02413 word *g = T+3*N; 02414 unsigned int bcLen=2, fgLen=EvenWordCount(M, N); 02415 unsigned int k=0, s=0; 02416 02417 SetWords(T, 0, 3*N); 02418 b[0]=1; 02419 CopyWords(f, A, NA); 02420 CopyWords(g, M, N); 02421 02422 while (1) 02423 { 02424 word t=f[0]; 02425 while (!t) 02426 { 02427 if (EvenWordCount(f, fgLen)==0) 02428 { 02429 SetWords(R, 0, N); 02430 return 0; 02431 } 02432 02433 ShiftWordsRightByWords(f, fgLen, 1); 02434 if (c[bcLen-1]) bcLen+=2; 02435 assert(bcLen <= N); 02436 ShiftWordsLeftByWords(c, bcLen, 1); 02437 k+=WORD_BITS; 02438 t=f[0]; 02439 } 02440 02441 unsigned int i=0; 02442 while (t%2 == 0) 02443 { 02444 t>>=1; 02445 i++; 02446 } 02447 k+=i; 02448 02449 if (t==1 && f[1]==0 && EvenWordCount(f, fgLen)==2) 02450 { 02451 if (s%2==0) 02452 CopyWords(R, b, N); 02453 else 02454 Subtract(R, M, b, N); 02455 return k; 02456 } 02457 02458 ShiftWordsRightByBits(f, fgLen, i); 02459 t=ShiftWordsLeftByBits(c, bcLen, i); 02460 if (t) 02461 { 02462 c[bcLen] = t; 02463 bcLen+=2; 02464 assert(bcLen <= N); 02465 } 02466 02467 if (f[fgLen-2]==0 && g[fgLen-2]==0 && f[fgLen-1]==0 && g[fgLen-1]==0) 02468 fgLen-=2; 02469 02470 if (Compare(f, g, fgLen)==-1) 02471 { 02472 std::swap(f, g); 02473 std::swap(b, c); 02474 s++; 02475 } 02476 02477 Subtract(f, f, g, fgLen); 02478 02479 if (Add(b, b, c, bcLen)) 02480 { 02481 b[bcLen] = 1; 02482 bcLen+=2; 02483 assert(bcLen <= N); 02484 } 02485 } 02486 } 02487 02488 // R[N] - result = A/(2^k) mod M 02489 // A[N] - input 02490 // M[N] - modulus 02491 02492 void DivideByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N) 02493 { 02494 CopyWords(R, A, N); 02495 02496 while (k--) 02497 { 02498 if (R[0]%2==0) 02499 ShiftWordsRightByBits(R, N, 1); 02500 else 02501 { 02502 word carry = Add(R, R, M, N); 02503 ShiftWordsRightByBits(R, N, 1); 02504 R[N-1] += carry<<(WORD_BITS-1); 02505 } 02506 } 02507 } 02508 02509 // R[N] - result = A*(2^k) mod M 02510 // A[N] - input 02511 // M[N] - modulus 02512 02513 void MultiplyByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N) 02514 { 02515 CopyWords(R, A, N); 02516 02517 while (k--) 02518 if (ShiftWordsLeftByBits(R, N, 1) || Compare(R, M, N)>=0) 02519 Subtract(R, R, M, N); 02520 } 02521 02522 // ****************************************************************** 02523 02524 static const unsigned int RoundupSizeTable[] = {2, 2, 2, 4, 4, 8, 8, 8, 8}; 02525 02526 static inline unsigned int RoundupSize(unsigned int n) 02527 { 02528 if (n<=8) 02529 return RoundupSizeTable[n]; 02530 else if (n<=16) 02531 return 16; 02532 else if (n<=32) 02533 return 32; 02534 else if (n<=64) 02535 return 64; 02536 else return 1U << BitPrecision(n-1); 02537 } 02538 02539 Integer::Integer() 02540 : reg(2), sign(POSITIVE) 02541 { 02542 reg[0] = reg[1] = 0; 02543 } 02544 02545 Integer::Integer(const Integer& t) 02546 : reg(RoundupSize(t.WordCount())), sign(t.sign) 02547 { 02548 CopyWords(reg, t.reg, reg.size()); 02549 } 02550 02551 Integer::Integer(signed long value) 02552 : reg(2) 02553 { 02554 if (value >= 0) 02555 sign = POSITIVE; 02556 else 02557 { 02558 sign = NEGATIVE; 02559 value = -value; 02560 } 02561 reg[0] = word(value); 02562 reg[1] = word(SafeRightShift<WORD_BITS, unsigned long>(value)); 02563 } 02564 02565 Integer::Integer(Sign s, word high, word low) 02566 : reg(2), sign(s) 02567 { 02568 reg[0] = low; 02569 reg[1] = high; 02570 } 02571 02572 bool Integer::IsConvertableToLong() const 02573 { 02574 if (ByteCount() > sizeof(long)) 02575 return false; 02576 02577 unsigned long value = reg[0]; 02578 value += SafeLeftShift<WORD_BITS, unsigned long>(reg[1]); 02579 02580 if (sign==POSITIVE) 02581 return (signed long)value >= 0; 02582 else 02583 return -(signed long)value < 0; 02584 } 02585 02586 signed long Integer::ConvertToLong() const 02587 { 02588 assert(IsConvertableToLong()); 02589 02590 unsigned long value = reg[0]; 02591 value += SafeLeftShift<WORD_BITS, unsigned long>(reg[1]); 02592 return sign==POSITIVE ? value : -(signed long)value; 02593 } 02594 02595 Integer::Integer(BufferedTransformation &encodedInteger, unsigned int byteCount, Signedness s) 02596 { 02597 Decode(encodedInteger, byteCount, s); 02598 } 02599 02600 Integer::Integer(const byte *encodedInteger, unsigned int byteCount, Signedness s) 02601 { 02602 Decode(encodedInteger, byteCount, s); 02603 } 02604 02605 Integer::Integer(BufferedTransformation &bt) 02606 { 02607 BERDecode(bt); 02608 } 02609 02610 Integer::Integer(RandomNumberGenerator &rng, unsigned int bitcount) 02611 { 02612 Randomize(rng, bitcount); 02613 } 02614 02615 Integer::Integer(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod) 02616 { 02617 if (!Randomize(rng, min, max, rnType, equiv, mod)) 02618 throw Integer::RandomNumberNotFound(); 02619 } 02620 02621 Integer Integer::Power2(unsigned int e) 02622 { 02623 Integer r((word)0, BitsToWords(e+1)); 02624 r.SetBit(e); 02625 return r; 02626 } 02627 02628 const Integer &Integer::Zero() 02629 { 02630 static const Integer zero; 02631 return zero; 02632 } 02633 02634 const Integer &Integer::One() 02635 { 02636 static const Integer one(1,2); 02637 return one; 02638 } 02639 02640 const Integer &Integer::Two() 02641 { 02642 static const Integer two(2,2); 02643 return two; 02644 } 02645 02646 bool Integer::operator!() const 02647 { 02648 return IsNegative() ? false : (reg[0]==0 && WordCount()==0); 02649 } 02650 02651 Integer& Integer::operator=(const Integer& t) 02652 { 02653 if (this != &t) 02654 { 02655 reg.New(RoundupSize(t.WordCount())); 02656 CopyWords(reg, t.reg, reg.size()); 02657 sign = t.sign; 02658 } 02659 return *this; 02660 } 02661 02662 bool Integer::GetBit(unsigned int n) const 02663 { 02664 if (n/WORD_BITS >= reg.size()) 02665 return 0; 02666 else 02667 return bool((reg[n/WORD_BITS] >> (n % WORD_BITS)) & 1); 02668 } 02669 02670 void Integer::SetBit(unsigned int n, bool value) 02671 { 02672 if (value) 02673 { 02674 reg.CleanGrow(RoundupSize(BitsToWords(n+1))); 02675 reg[n/WORD_BITS] |= (word(1) << (n%WORD_BITS)); 02676 } 02677 else 02678 { 02679 if (n/WORD_BITS < reg.size()) 02680 reg[n/WORD_BITS] &= ~(word(1) << (n%WORD_BITS)); 02681 } 02682 } 02683 02684 byte Integer::GetByte(unsigned int n) const 02685 { 02686 if (n/WORD_SIZE >= reg.size()) 02687 return 0; 02688 else 02689 return byte(reg[n/WORD_SIZE] >> ((n%WORD_SIZE)*8)); 02690 } 02691 02692 void Integer::SetByte(unsigned int n, byte value) 02693 { 02694 reg.CleanGrow(RoundupSize(BytesToWords(n+1))); 02695 reg[n/WORD_SIZE] &= ~(word(0xff) << 8*(n%WORD_SIZE)); 02696 reg[n/WORD_SIZE] |= (word(value) << 8*(n%WORD_SIZE)); 02697 } 02698 02699 unsigned long Integer::GetBits(unsigned int i, unsigned int n) const 02700 { 02701 assert(n <= sizeof(unsigned long)*8); 02702 unsigned long v = 0; 02703 for (unsigned int j=0; j<n; j++) 02704 v |= GetBit(i+j) << j; 02705 return v; 02706 } 02707 02708 Integer Integer::operator-() const 02709 { 02710 Integer result(*this); 02711 result.Negate(); 02712 return result; 02713 } 02714 02715 Integer Integer::AbsoluteValue() const 02716 { 02717 Integer result(*this); 02718 result.sign = POSITIVE; 02719 return result; 02720 } 02721 02722 void Integer::swap(Integer &a) 02723 { 02724 reg.swap(a.reg); 02725 std::swap(sign, a.sign); 02726 } 02727 02728 Integer::Integer(word value, unsigned int length) 02729 : reg(RoundupSize(length)), sign(POSITIVE) 02730 { 02731 reg[0] = value; 02732 SetWords(reg+1, 0, reg.size()-1); 02733 } 02734 02735 template <class T> 02736 static Integer StringToInteger(const T *str) 02737 { 02738 word radix; 02739 #if (defined(__GNUC__) && __GNUC__ <= 3) // GCC workaround 02740 // std::char_traits doesn't exist in GCC 2.x 02741 // std::char_traits<wchar_t>::length() not defined in GCC 3.2 02742 unsigned int length; 02743 for (length = 0; str[length] != 0; length++) {} 02744 #else 02745 unsigned int length = std::char_traits<T>::length(str); 02746 #endif 02747 02748 Integer v; 02749 02750 if (length == 0) 02751 return v; 02752 02753 switch (str[length-1]) 02754 { 02755 case 'h': 02756 case 'H': 02757 radix=16; 02758 break; 02759 case 'o': 02760 case 'O': 02761 radix=8; 02762 break; 02763 case 'b': 02764 case 'B': 02765 radix=2; 02766 break; 02767 default: 02768 radix=10; 02769 } 02770 02771 if (length > 2 && str[0] == '0' && str[1] == 'x') 02772 radix = 16; 02773 02774 for (unsigned i=0; i<length; i++) 02775 { 02776 word digit; 02777 02778 if (str[i] >= '0' && str[i] <= '9') 02779 digit = str[i] - '0'; 02780 else if (str[i] >= 'A' && str[i] <= 'F') 02781 digit = str[i] - 'A' + 10; 02782 else if (str[i] >= 'a' && str[i] <= 'f') 02783 digit = str[i] - 'a' + 10; 02784 else 02785 digit = radix; 02786 02787 if (digit < radix) 02788 { 02789 v *= radix; 02790 v += digit; 02791 } 02792 } 02793 02794 if (str[0] == '-') 02795 v.Negate(); 02796 02797 return v; 02798 } 02799 02800 Integer::Integer(const char *str) 02801 : reg(2), sign(POSITIVE) 02802 { 02803 *this = StringToInteger(str); 02804 } 02805 02806 Integer::Integer(const wchar_t *str) 02807 : reg(2), sign(POSITIVE) 02808 { 02809 *this = StringToInteger(str); 02810 } 02811 02812 unsigned int Integer::WordCount() const 02813 { 02814 return CountWords(reg, reg.size()); 02815 } 02816 02817 unsigned int Integer::ByteCount() const 02818 { 02819 unsigned wordCount = WordCount(); 02820 if (wordCount) 02821 return (wordCount-1)*WORD_SIZE + BytePrecision(reg[wordCount-1]); 02822 else 02823 return 0; 02824 } 02825 02826 unsigned int Integer::BitCount() const 02827 { 02828 unsigned wordCount = WordCount(); 02829 if (wordCount) 02830 return (wordCount-1)*WORD_BITS + BitPrecision(reg[wordCount-1]); 02831 else 02832 return 0; 02833 } 02834 02835 void Integer::Decode(const byte *input, unsigned int inputLen, Signedness s) 02836 { 02837 StringStore store(input, inputLen); 02838 Decode(store, inputLen, s); 02839 } 02840 02841 void Integer::Decode(BufferedTransformation &bt, unsigned int inputLen, Signedness s) 02842 { 02843 assert(bt.MaxRetrievable() >= inputLen); 02844 02845 byte b; 02846 bt.Peek(b); 02847 sign = ((s==SIGNED) && (b & 0x80)) ? NEGATIVE : POSITIVE; 02848 02849 while (inputLen>0 && (sign==POSITIVE ? b==0 : b==0xff)) 02850 { 02851 bt.Skip(1); 02852 inputLen--; 02853 bt.Peek(b); 02854 } 02855 02856 reg.CleanNew(RoundupSize(BytesToWords(inputLen))); 02857 02858 for (unsigned int i=inputLen; i > 0; i--) 02859 { 02860 bt.Get(b); 02861 reg[(i-1)/WORD_SIZE] |= b << ((i-1)%WORD_SIZE)*8; 02862 } 02863 02864 if (sign == NEGATIVE) 02865 { 02866 for (unsigned i=inputLen; i<reg.size()*WORD_SIZE; i++) 02867 reg[i/WORD_SIZE] |= 0xff << (i%WORD_SIZE)*8; 02868 TwosComplement(reg, reg.size()); 02869 } 02870 } 02871 02872 unsigned int Integer::MinEncodedSize(Signedness signedness) const 02873 { 02874 unsigned int outputLen = STDMAX(1U, ByteCount()); 02875 if (signedness == UNSIGNED) 02876 return outputLen; 02877 if (NotNegative() && (GetByte(outputLen-1) & 0x80)) 02878 outputLen++; 02879 if (IsNegative() && *this < -Power2(outputLen*8-1)) 02880 outputLen++; 02881 return outputLen; 02882 } 02883 02884 unsigned int Integer::Encode(byte *output, unsigned int outputLen, Signedness signedness) const 02885 { 02886 ArraySink sink(output, outputLen); 02887 return Encode(sink, outputLen, signedness); 02888 } 02889 02890 unsigned int Integer::Encode(BufferedTransformation &bt, unsigned int outputLen, Signedness signedness) const 02891 { 02892 if (signedness == UNSIGNED || NotNegative()) 02893 { 02894 for (unsigned int i=outputLen; i > 0; i--) 02895 bt.Put(GetByte(i-1)); 02896 } 02897 else 02898 { 02899 // take two's complement of *this 02900 Integer temp = Integer::Power2(8*STDMAX(ByteCount(), outputLen)) + *this; 02901 for (unsigned i=0; i<outputLen; i++) 02902 bt.Put(temp.GetByte(outputLen-i-1)); 02903 } 02904 return outputLen; 02905 } 02906 02907 void Integer::DEREncode(BufferedTransformation &bt) const 02908 { 02909 DERGeneralEncoder enc(bt, INTEGER); 02910 Encode(enc, MinEncodedSize(SIGNED), SIGNED); 02911 enc.MessageEnd(); 02912 } 02913 02914 void Integer::BERDecode(const byte *input, unsigned int len) 02915 { 02916 StringStore store(input, len); 02917 BERDecode(store); 02918 } 02919 02920 void Integer::BERDecode(BufferedTransformation &bt) 02921 { 02922 BERGeneralDecoder dec(bt, INTEGER); 02923 if (!dec.IsDefiniteLength() || dec.MaxRetrievable() < dec.RemainingLength()) 02924 BERDecodeError(); 02925 Decode(dec, dec.RemainingLength(), SIGNED); 02926 dec.MessageEnd(); 02927 } 02928 02929 void Integer::DEREncodeAsOctetString(BufferedTransformation &bt, unsigned int length) const 02930 { 02931 DERGeneralEncoder enc(bt, OCTET_STRING); 02932 Encode(enc, length); 02933 enc.MessageEnd(); 02934 } 02935 02936 void Integer::BERDecodeAsOctetString(BufferedTransformation &bt, unsigned int length) 02937 { 02938 BERGeneralDecoder dec(bt, OCTET_STRING); 02939 if (!dec.IsDefiniteLength() || dec.RemainingLength() != length) 02940 BERDecodeError(); 02941 Decode(dec, length); 02942 dec.MessageEnd(); 02943 } 02944 02945 unsigned int Integer::OpenPGPEncode(byte *output, unsigned int len) const 02946 { 02947 ArraySink sink(output, len); 02948 return OpenPGPEncode(sink); 02949 } 02950 02951 unsigned int Integer::OpenPGPEncode(BufferedTransformation &bt) const 02952 { 02953 word16 bitCount = BitCount(); 02954 bt.PutWord16(bitCount); 02955 return 2 + Encode(bt, BitsToBytes(bitCount)); 02956 } 02957 02958 void Integer::OpenPGPDecode(const byte *input, unsigned int len) 02959 { 02960 StringStore store(input, len); 02961 OpenPGPDecode(store); 02962 } 02963 02964 void Integer::OpenPGPDecode(BufferedTransformation &bt) 02965 { 02966 word16 bitCount; 02967 if (bt.GetWord16(bitCount) != 2 || bt.MaxRetrievable() < BitsToBytes(bitCount)) 02968 throw OpenPGPDecodeErr(); 02969 Decode(bt, BitsToBytes(bitCount)); 02970 } 02971 02972 void Integer::Randomize(RandomNumberGenerator &rng, unsigned int nbits) 02973 { 02974 const unsigned int nbytes = nbits/8 + 1; 02975 SecByteBlock buf(nbytes); 02976 rng.GenerateBlock(buf, nbytes); 02977 if (nbytes) 02978 buf[0] = (byte)Crop(buf[0], nbits % 8); 02979 Decode(buf, nbytes, UNSIGNED); 02980 } 02981 02982 void Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max) 02983 { 02984 if (min > max) 02985 throw InvalidArgument("Integer: Min must be no greater than Max"); 02986 02987 Integer range = max - min; 02988 const unsigned int nbits = range.BitCount(); 02989 02990 do 02991 { 02992 Randomize(rng, nbits); 02993 } 02994 while (*this > range); 02995 02996 *this += min; 02997 } 02998 02999 bool Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod) 03000 { 03001 return GenerateRandomNoThrow(rng, MakeParameters("Min", min)("Max", max)("RandomNumberType", rnType)("EquivalentTo", equiv)("Mod", mod)); 03002 } 03003 03004 class KDF2_RNG : public RandomNumberGenerator 03005 { 03006 public: 03007 KDF2_RNG(const byte *seed, unsigned int seedSize) 03008 : m_counter(0), m_counterAndSeed(seedSize + 4) 03009 { 03010 memcpy(m_counterAndSeed + 4, seed, seedSize); 03011 } 03012 03013 byte GenerateByte() 03014 { 03015 byte b; 03016 GenerateBlock(&b, 1); 03017 return b; 03018 } 03019 03020 void GenerateBlock(byte *output, unsigned int size) 03021 { 03022 UnalignedPutWord(BIG_ENDIAN_ORDER, m_counterAndSeed, m_counter); 03023 ++m_counter; 03024 P1363_KDF2<SHA1>::DeriveKey(output, size, m_counterAndSeed, m_counterAndSeed.size()); 03025 } 03026 03027 private: 03028 word32 m_counter; 03029 SecByteBlock m_counterAndSeed; 03030 }; 03031 03032 bool Integer::GenerateRandomNoThrow(RandomNumberGenerator &i_rng, const NameValuePairs &params) 03033 { 03034 Integer min = params.GetValueWithDefault("Min", Integer::Zero()); 03035 Integer max; 03036 if (!params.GetValue("Max", max)) 03037 { 03038 int bitLength; 03039 if (params.GetIntValue("BitLength", bitLength)) 03040 max = Integer::Power2(bitLength); 03041 else 03042 throw InvalidArgument("Integer: missing Max argument"); 03043 } 03044 if (min > max) 03045 throw InvalidArgument("Integer: Min must be no greater than Max"); 03046 03047 Integer equiv = params.GetValueWithDefault("EquivalentTo", Integer::Zero()); 03048 Integer mod = params.GetValueWithDefault("Mod", Integer::One()); 03049 03050 if (equiv.IsNegative() || equiv >= mod) 03051 throw InvalidArgument("Integer: invalid EquivalentTo and/or Mod argument"); 03052 03053 Integer::RandomNumberType rnType = params.GetValueWithDefault("RandomNumberType", Integer::ANY); 03054 03055 member_ptr<KDF2_RNG> kdf2Rng; 03056 ConstByteArrayParameter seed; 03057 if (params.GetValue("Seed", seed)) 03058 { 03059 ByteQueue bq; 03060 DERSequenceEncoder seq(bq); 03061 min.DEREncode(seq); 03062 max.DEREncode(seq); 03063 equiv.DEREncode(seq); 03064 mod.DEREncode(seq); 03065 DEREncodeUnsigned(seq, rnType); 03066 DEREncodeOctetString(seq, seed.begin(), seed.size()); 03067 seq.MessageEnd(); 03068 03069 SecByteBlock finalSeed(bq.MaxRetrievable()); 03070 bq.Get(finalSeed, finalSeed.size()); 03071 kdf2Rng.reset(new KDF2_RNG(finalSeed.begin(), finalSeed.size())); 03072 } 03073 RandomNumberGenerator &rng = kdf2Rng.get() ? (RandomNumberGenerator &)*kdf2Rng : i_rng; 03074 03075 switch (rnType) 03076 { 03077 case ANY: 03078 if (mod == One()) 03079 Randomize(rng, min, max); 03080 else 03081 { 03082 Integer min1 = min + (equiv-min)%mod; 03083 if (max < min1) 03084 return false; 03085 Randomize(rng, Zero(), (max - min1) / mod); 03086 *this *= mod; 03087 *this += min1; 03088 } 03089 return true; 03090 03091 case PRIME: 03092 { 03093 const PrimeSelector *pSelector = params.GetValueWithDefault("PointerToPrimeSelector", (const PrimeSelector *)NULL); 03094 03095 int i; 03096 i = 0; 03097 while (1) 03098 { 03099 if (++i==16) 03100 { 03101 // check if there are any suitable primes in [min, max] 03102 Integer first = min; 03103 if (FirstPrime(first, max, equiv, mod, pSelector)) 03104 { 03105 // if there is only one suitable prime, we're done 03106 *this = first; 03107 if (!FirstPrime(first, max, equiv, mod, pSelector)) 03108 return true; 03109 } 03110 else 03111 return false; 03112 } 03113 03114 Randomize(rng, min, max); 03115 if (FirstPrime(*this, STDMIN(*this+mod*PrimeSearchInterval(max), max), equiv, mod, pSelector)) 03116 return true; 03117 } 03118 } 03119 03120 default: 03121 throw InvalidArgument("Integer: invalid RandomNumberType argument"); 03122 } 03123 } 03124 03125 std::istream& operator>>(std::istream& in, Integer &a) 03126 { 03127 char c; 03128 unsigned int length = 0; 03129 SecBlock<char> str(length + 16); 03130 03131 std::ws(in); 03132 03133 do 03134 { 03135 in.read(&c, 1); 03136 str[length++] = c; 03137 if (length >= str.size()) 03138 str.Grow(length + 16); 03139 } 03140 while (in && (c=='-' || c=='x' || (c>='0' && c<='9') || (c>='a' && c<='f') || (c>='A' && c<='F') || c=='h' || c=='H' || c=='o' || c=='O' || c==',' || c=='.')); 03141 03142 if (in.gcount()) 03143 in.putback(c); 03144 str[length-1] = '\0'; 03145 a = Integer(str); 03146 03147 return in; 03148 } 03149 03150 std::ostream& operator<<(std::ostream& out, const Integer &a) 03151 { 03152 // Get relevant conversion specifications from ostream. 03153 long f = out.flags() & std::ios::basefield; // Get base digits. 03154 int base, block; 03155 char suffix; 03156 switch(f) 03157 { 03158 case std::ios::oct : 03159 base = 8; 03160 block = 8; 03161 suffix = 'o'; 03162 break; 03163 case std::ios::hex : 03164 base = 16; 03165 block = 4; 03166 suffix = 'h'; 03167 break; 03168 default : 03169 base = 10; 03170 block = 3; 03171 suffix = '.'; 03172 } 03173 03174 SecBlock<char> s(a.BitCount() / (BitPrecision(base)-1) + 1); 03175 Integer temp1=a, temp2; 03176 unsigned i=0; 03177 const char vec[]="0123456789ABCDEF"; 03178 03179 if (a.IsNegative()) 03180 { 03181 out << '-'; 03182 temp1.Negate(); 03183 } 03184 03185 if (!a) 03186 out << '0'; 03187 03188 while (!!temp1) 03189 { 03190 word digit; 03191 Integer::Divide(digit, temp2, temp1, base); 03192 s[i++]=vec[digit]; 03193 temp1=temp2; 03194 } 03195 03196 while (i--) 03197 { 03198 out << s[i]; 03199 // if (i && !(i%block)) 03200 // out << ","; 03201 } 03202 return out << suffix; 03203 } 03204 03205 Integer& Integer::operator++() 03206 { 03207 if (NotNegative()) 03208 { 03209 if (Increment(reg, reg.size())) 03210 { 03211 reg.CleanGrow(2*reg.size()); 03212 reg[reg.size()/2]=1; 03213 } 03214 } 03215 else 03216 { 03217 word borrow = Decrement(reg, reg.size()); 03218 assert(!borrow); 03219 if (WordCount()==0) 03220 *this = Zero(); 03221 } 03222 return *this; 03223 } 03224 03225 Integer& Integer::operator--() 03226 { 03227 if (IsNegative()) 03228 { 03229 if (Increment(reg, reg.size())) 03230 { 03231 reg.CleanGrow(2*reg.size()); 03232 reg[reg.size()/2]=1; 03233 } 03234 } 03235 else 03236 { 03237 if (Decrement(reg, reg.size())) 03238 *this = -One(); 03239 } 03240 return *this; 03241 } 03242 03243 void PositiveAdd(Integer &sum, const Integer &a, const Integer& b) 03244 { 03245 word carry; 03246 if (a.reg.size() == b.reg.size()) 03247 carry = Add(sum.reg, a.reg, b.reg, a.reg.size()); 03248 else if (a.reg.size() > b.reg.size()) 03249 { 03250 carry = Add(sum.reg, a.reg, b.reg, b.reg.size()); 03251 CopyWords(sum.reg+b.reg.size(), a.reg+b.reg.size(), a.reg.size()-b.reg.size()); 03252 carry = Increment(sum.reg+b.reg.size(), a.reg.size()-b.reg.size(), carry); 03253 } 03254 else 03255 { 03256 carry = Add(sum.reg, a.reg, b.reg, a.reg.size()); 03257 CopyWords(sum.reg+a.reg.size(), b.reg+a.reg.size(), b.reg.size()-a.reg.size()); 03258 carry = Increment(sum.reg+a.reg.size(), b.reg.size()-a.reg.size(), carry); 03259 } 03260 03261 if (carry) 03262 { 03263 sum.reg.CleanGrow(2*sum.reg.size()); 03264 sum.reg[sum.reg.size()/2] = 1; 03265 } 03266 sum.sign = Integer::POSITIVE; 03267 } 03268 03269 void PositiveSubtract(Integer &diff, const Integer &a, const Integer& b) 03270 { 03271 unsigned aSize = a.WordCount(); 03272 aSize += aSize%2; 03273 unsigned bSize = b.WordCount(); 03274 bSize += bSize%2; 03275 03276 if (aSize == bSize) 03277 { 03278 if (Compare(a.reg, b.reg, aSize) >= 0) 03279 { 03280 Subtract(diff.reg, a.reg, b.reg, aSize); 03281 diff.sign = Integer::POSITIVE; 03282 } 03283 else 03284 { 03285 Subtract(diff.reg, b.reg, a.reg, aSize); 03286 diff.sign = Integer::NEGATIVE; 03287 } 03288 } 03289 else if (aSize > bSize) 03290 { 03291 word borrow = Subtract(diff.reg, a.reg, b.reg, bSize); 03292 CopyWords(diff.reg+bSize, a.reg+bSize, aSize-bSize); 03293 borrow = Decrement(diff.reg+bSize, aSize-bSize, borrow); 03294 assert(!borrow); 03295 diff.sign = Integer::POSITIVE; 03296 } 03297 else 03298 { 03299 word borrow = Subtract(diff.reg, b.reg, a.reg, aSize); 03300 CopyWords(diff.reg+aSize, b.reg+aSize, bSize-aSize); 03301 borrow = Decrement(diff.reg+aSize, bSize-aSize, borrow); 03302 assert(!borrow); 03303 diff.sign = Integer::NEGATIVE; 03304 } 03305 } 03306 03307 Integer Integer::Plus(const Integer& b) const 03308 { 03309 Integer sum((word)0, STDMAX(reg.size(), b.reg.size())); 03310 if (NotNegative()) 03311 { 03312 if (b.NotNegative()) 03313 PositiveAdd(sum, *this, b); 03314 else 03315 PositiveSubtract(sum, *this, b); 03316 } 03317 else 03318 { 03319 if (b.NotNegative()) 03320 PositiveSubtract(sum, b, *this); 03321 else 03322 { 03323 PositiveAdd(sum, *this, b); 03324 sum.sign = Integer::NEGATIVE; 03325 } 03326 } 03327 return sum; 03328 } 03329 03330 Integer& Integer::operator+=(const Integer& t) 03331 { 03332 reg.CleanGrow(t.reg.size()); 03333 if (NotNegative()) 03334 { 03335 if (t.NotNegative()) 03336 PositiveAdd(*this, *this, t); 03337 else 03338 PositiveSubtract(*this, *this, t); 03339 } 03340 else 03341 { 03342 if (t.NotNegative()) 03343 PositiveSubtract(*this, t, *this); 03344 else 03345 { 03346 PositiveAdd(*this, *this, t); 03347 sign = Integer::NEGATIVE; 03348 } 03349 } 03350 return *this; 03351 } 03352 03353 Integer Integer::Minus(const Integer& b) const 03354 { 03355 Integer diff((word)0, STDMAX(reg.size(), b.reg.size())); 03356 if (NotNegative()) 03357 { 03358 if (b.NotNegative()) 03359 PositiveSubtract(diff, *this, b); 03360 else 03361 PositiveAdd(diff, *this, b); 03362 } 03363 else 03364 { 03365 if (b.NotNegative()) 03366 { 03367 PositiveAdd(diff, *this, b); 03368 diff.sign = Integer::NEGATIVE; 03369 } 03370 else 03371 PositiveSubtract(diff, b, *this); 03372 } 03373 return diff; 03374 } 03375 03376 Integer& Integer::operator-=(const Integer& t) 03377 { 03378 reg.CleanGrow(t.reg.size()); 03379 if (NotNegative()) 03380 { 03381 if (t.NotNegative()) 03382 PositiveSubtract(*this, *this, t); 03383 else 03384 PositiveAdd(*this, *this, t); 03385 } 03386 else 03387 { 03388 if (t.NotNegative()) 03389 { 03390 PositiveAdd(*this, *this, t); 03391 sign = Integer::NEGATIVE; 03392 } 03393 else 03394 PositiveSubtract(*this, t, *this); 03395 } 03396 return *this; 03397 } 03398 03399 Integer& Integer::operator<<=(unsigned int n) 03400 { 03401 const unsigned int wordCount = WordCount(); 03402 const unsigned int shiftWords = n / WORD_BITS; 03403 const unsigned int shiftBits = n % WORD_BITS; 03404 03405 reg.CleanGrow(RoundupSize(wordCount+BitsToWords(n))); 03406 ShiftWordsLeftByWords(reg, wordCount + shiftWords, shiftWords); 03407 ShiftWordsLeftByBits(reg+shiftWords, wordCount+BitsToWords(shiftBits), shiftBits); 03408 return *this; 03409 } 03410 03411 Integer& Integer::operator>>=(unsigned int n) 03412 { 03413 const unsigned int wordCount = WordCount(); 03414 const unsigned int shiftWords = n / WORD_BITS; 03415 const unsigned int shiftBits = n % WORD_BITS; 03416 03417 ShiftWordsRightByWords(reg, wordCount, shiftWords); 03418 if (wordCount > shiftWords) 03419 ShiftWordsRightByBits(reg, wordCount-shiftWords, shiftBits); 03420 if (IsNegative() && WordCount()==0) // avoid -0 03421 *this = Zero(); 03422 return *this; 03423 } 03424 03425 void PositiveMultiply(Integer &product, const Integer &a, const Integer &b) 03426 { 03427 unsigned aSize = RoundupSize(a.WordCount()); 03428 unsigned bSize = RoundupSize(b.WordCount()); 03429 03430 product.reg.CleanNew(RoundupSize(aSize+bSize)); 03431 product.sign = Integer::POSITIVE; 03432 03433 SecAlignedWordBlock workspace(aSize + bSize); 03434 AsymmetricMultiply(product.reg, workspace, a.reg, aSize, b.reg, bSize); 03435 } 03436 03437 void Multiply(Integer &product, const Integer &a, const Integer &b) 03438 { 03439 PositiveMultiply(product, a, b); 03440 03441 if (a.NotNegative() != b.NotNegative()) 03442 product.Negate(); 03443 } 03444 03445 Integer Integer::Times(const Integer &b) const 03446 { 03447 Integer product; 03448 Multiply(product, *this, b); 03449 return product; 03450 } 03451 03452 /* 03453 void PositiveDivide(Integer &remainder, Integer &quotient, 03454 const Integer &dividend, const Integer &divisor) 03455 { 03456 remainder.reg.CleanNew(divisor.reg.size()); 03457 remainder.sign = Integer::POSITIVE; 03458 quotient.reg.New(0); 03459 quotient.sign = Integer::POSITIVE; 03460 unsigned i=dividend.BitCount(); 03461 while (i--) 03462 { 03463 word overflow = ShiftWordsLeftByBits(remainder.reg, remainder.reg.size(), 1); 03464 remainder.reg[0] |= dividend[i]; 03465 if (overflow || remainder >= divisor) 03466 { 03467 Subtract(remainder.reg, remainder.reg, divisor.reg, remainder.reg.size()); 03468 quotient.SetBit(i); 03469 } 03470 } 03471 } 03472 */ 03473 03474 void PositiveDivide(Integer &remainder, Integer &quotient, 03475 const Integer &a, const Integer &b) 03476 { 03477 unsigned aSize = a.WordCount(); 03478 unsigned bSize = b.WordCount(); 03479 03480 if (!bSize) 03481 throw Integer::DivideByZero(); 03482 03483 if (a.PositiveCompare(b) == -1) 03484 { 03485 remainder = a; 03486 remainder.sign = Integer::POSITIVE; 03487 quotient = Integer::Zero(); 03488 return; 03489 } 03490 03491 aSize += aSize%2; // round up to next even number 03492 bSize += bSize%2; 03493 03494 remainder.reg.CleanNew(RoundupSize(bSize)); 03495 remainder.sign = Integer::POSITIVE; 03496 quotient.reg.CleanNew(RoundupSize(aSize-bSize+2)); 03497 quotient.sign = Integer::POSITIVE; 03498 03499 SecAlignedWordBlock T(aSize+2*bSize+4); 03500 Divide(remainder.reg, quotient.reg, T, a.reg, aSize, b.reg, bSize); 03501 } 03502 03503 void Integer::Divide(Integer &remainder, Integer &quotient, const Integer &dividend, const Integer &divisor) 03504 { 03505 PositiveDivide(remainder, quotient, dividend, divisor); 03506 03507 if (dividend.IsNegative()) 03508 { 03509 quotient.Negate(); 03510 if (remainder.NotZero()) 03511 { 03512 --quotient; 03513 remainder = divisor.AbsoluteValue() - remainder; 03514 } 03515 } 03516 03517 if (divisor.IsNegative()) 03518 quotient.Negate(); 03519 } 03520 03521 void Integer::DivideByPowerOf2(Integer &r, Integer &q, const Integer &a, unsigned int n) 03522 { 03523 q = a; 03524 q >>= n; 03525 03526 const unsigned int wordCount = BitsToWords(n); 03527 if (wordCount <= a.WordCount()) 03528 { 03529 r.reg.resize(RoundupSize(wordCount)); 03530 CopyWords(r.reg, a.reg, wordCount); 03531 SetWords(r.reg+wordCount, 0, r.reg.size()-wordCount); 03532 if (n % WORD_BITS != 0) 03533 r.reg[wordCount-1] %= (1 << (n % WORD_BITS)); 03534 } 03535 else 03536 { 03537 r.reg.resize(RoundupSize(a.WordCount())); 03538 CopyWords(r.reg, a.reg, r.reg.size()); 03539 } 03540 r.sign = POSITIVE; 03541 03542 if (a.IsNegative() && r.NotZero()) 03543 { 03544 --q; 03545 r = Power2(n) - r; 03546 } 03547 } 03548 03549 Integer Integer::DividedBy(const Integer &b) const 03550 { 03551 Integer remainder, quotient; 03552 Integer::Divide(remainder, quotient, *this, b); 03553 return quotient; 03554 } 03555 03556 Integer Integer::Modulo(const Integer &b) const 03557 { 03558 Integer remainder, quotient; 03559 Integer::Divide(remainder, quotient, *this, b); 03560 return remainder; 03561 } 03562 03563 void Integer::Divide(word &remainder, Integer &quotient, const Integer &dividend, word divisor) 03564 { 03565 if (!divisor) 03566 throw Integer::DivideByZero(); 03567 03568 assert(divisor); 03569 03570 if ((divisor & (divisor-1)) == 0) // divisor is a power of 2 03571 { 03572 quotient = dividend >> (BitPrecision(divisor)-1); 03573 remainder = dividend.reg[0] & (divisor-1); 03574 return; 03575 } 03576 03577 unsigned int i = dividend.WordCount(); 03578 quotient.reg.CleanNew(RoundupSize(i)); 03579 remainder = 0; 03580 while (i--) 03581 { 03582 quotient.reg[i] = word(MAKE_DWORD(dividend.reg[i], remainder) / divisor); 03583 remainder = word(MAKE_DWORD(dividend.reg[i], remainder) % divisor); 03584 } 03585 03586 if (dividend.NotNegative()) 03587 quotient.sign = POSITIVE; 03588 else 03589 { 03590 quotient.sign = NEGATIVE; 03591 if (remainder) 03592 { 03593 --quotient; 03594 remainder = divisor - remainder; 03595 } 03596 } 03597 } 03598 03599 Integer Integer::DividedBy(word b) const 03600 { 03601 word remainder; 03602 Integer quotient; 03603 Integer::Divide(remainder, quotient, *this, b); 03604 return quotient; 03605 } 03606 03607 word Integer::Modulo(word divisor) const 03608 { 03609 if (!divisor) 03610 throw Integer::DivideByZero(); 03611 03612 assert(divisor); 03613 03614 word remainder; 03615 03616 if ((divisor & (divisor-1)) == 0) // divisor is a power of 2 03617 remainder = reg[0] & (divisor-1); 03618 else 03619 { 03620 unsigned int i = WordCount(); 03621 03622 if (divisor <= 5) 03623 { 03624 dword sum=0; 03625 while (i--) 03626 sum += reg[i]; 03627 remainder = word(sum%divisor); 03628 } 03629 else 03630 { 03631 remainder = 0; 03632 while (i--) 03633 remainder = word(MAKE_DWORD(reg[i], remainder) % divisor); 03634 } 03635 } 03636 03637 if (IsNegative() && remainder) 03638 remainder = divisor - remainder; 03639 03640 return remainder; 03641 } 03642 03643 void Integer::Negate() 03644 { 03645 if (!!(*this)) // don't flip sign if *this==0 03646 sign = Sign(1-sign); 03647 } 03648 03649 int Integer::PositiveCompare(const Integer& t) const 03650 { 03651 unsigned size = WordCount(), tSize = t.WordCount(); 03652 03653 if (size == tSize) 03654 return CryptoPP::Compare(reg, t.reg, size); 03655 else 03656 return size > tSize ? 1 : -1; 03657 } 03658 03659 int Integer::Compare(const Integer& t) const 03660 { 03661 if (NotNegative()) 03662 { 03663 if (t.NotNegative()) 03664 return PositiveCompare(t); 03665 else 03666 return 1; 03667 } 03668 else 03669 { 03670 if (t.NotNegative()) 03671 return -1; 03672 else 03673 return -PositiveCompare(t); 03674 } 03675 } 03676 03677 Integer Integer::SquareRoot() const 03678 { 03679 if (!IsPositive()) 03680 return Zero(); 03681 03682 // overestimate square root 03683 Integer x, y = Power2((BitCount()+1)/2); 03684 assert(y*y >= *this); 03685 03686 do 03687 { 03688 x = y; 03689 y = (x + *this/x) >> 1; 03690 } while (y<x); 03691 03692 return x; 03693 } 03694 03695 bool Integer::IsSquare() const 03696 { 03697 Integer r = SquareRoot(); 03698 return *this == r.Squared(); 03699 } 03700 03701 bool Integer::IsUnit() const 03702 { 03703 return (WordCount() == 1) && (reg[0] == 1); 03704 } 03705 03706 Integer Integer::MultiplicativeInverse() const 03707 { 03708 return IsUnit() ? *this : Zero(); 03709 } 03710 03711 Integer a_times_b_mod_c(const Integer &x, const Integer& y, const Integer& m) 03712 { 03713 return x*y%m; 03714 } 03715 03716 Integer a_exp_b_mod_c(const Integer &x, const Integer& e, const Integer& m) 03717 { 03718 ModularArithmetic mr(m); 03719 return mr.Exponentiate(x, e); 03720 } 03721 03722 Integer Integer::Gcd(const Integer &a, const Integer &b) 03723 { 03724 return EuclideanDomainOf<Integer>().Gcd(a, b); 03725 } 03726 03727 Integer Integer::InverseMod(const Integer &m) const 03728 { 03729 assert(m.NotNegative()); 03730 03731 if (IsNegative() || *this>=m) 03732 return (*this%m).InverseMod(m); 03733 03734 if (m.IsEven()) 03735 { 03736 if (!m || IsEven()) 03737 return Zero(); // no inverse 03738 if (*this == One()) 03739 return One(); 03740 03741 Integer u = m.InverseMod(*this); 03742 return !u ? Zero() : (m*(*this-u)+1)/(*this); 03743 } 03744 03745 SecBlock<word> T(m.reg.size() * 4); 03746 Integer r((word)0, m.reg.size()); 03747 unsigned k = AlmostInverse(r.reg, T, reg, reg.size(), m.reg, m.reg.size()); 03748 DivideByPower2Mod(r.reg, r.reg, k, m.reg, m.reg.size()); 03749 return r; 03750 } 03751 03752 word Integer::InverseMod(const word mod) const 03753 { 03754 word g0 = mod, g1 = *this % mod; 03755 word v0 = 0, v1 = 1; 03756 word y; 03757 03758 while (g1) 03759 { 03760 if (g1 == 1) 03761 return v1; 03762 y = g0 / g1; 03763 g0 = g0 % g1; 03764 v0 += y * v1; 03765 03766 if (!g0) 03767 break; 03768 if (g0 == 1) 03769 return mod-v0; 03770 y = g1 / g0; 03771 g1 = g1 % g0; 03772 v1 += y * v0; 03773 } 03774 return 0; 03775 } 03776 03777 // ******************************************************** 03778 03779 ModularArithmetic::ModularArithmetic(BufferedTransformation &bt) 03780 { 03781 BERSequenceDecoder seq(bt); 03782 OID oid(seq); 03783 if (oid != ASN1::prime_field()) 03784 BERDecodeError(); 03785 modulus.BERDecode(seq); 03786 seq.MessageEnd(); 03787 result.reg.resize(modulus.reg.size()); 03788 } 03789 03790 void ModularArithmetic::DEREncode(BufferedTransformation &bt) const 03791 { 03792 DERSequenceEncoder seq(bt); 03793 ASN1::prime_field().DEREncode(seq); 03794 modulus.DEREncode(seq); 03795 seq.MessageEnd(); 03796 } 03797 03798 void ModularArithmetic::DEREncodeElement(BufferedTransformation &out, const Element &a) const 03799 { 03800 a.DEREncodeAsOctetString(out, MaxElementByteLength()); 03801 } 03802 03803 void ModularArithmetic::BERDecodeElement(BufferedTransformation &in, Element &a) const 03804 { 03805 a.BERDecodeAsOctetString(in, MaxElementByteLength()); 03806 } 03807 03808 const Integer& ModularArithmetic::Half(const Integer &a) const 03809 { 03810 if (a.reg.size()==modulus.reg.size()) 03811 { 03812 CryptoPP::DivideByPower2Mod(result.reg.begin(), a.reg, 1, modulus.reg, a.reg.size()); 03813 return result; 03814 } 03815 else 03816 return result1 = (a.IsEven() ? (a >> 1) : ((a+modulus) >> 1)); 03817 } 03818 03819 const Integer& ModularArithmetic::Add(const Integer &a, const Integer &b) const 03820 { 03821 if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) 03822 { 03823 if (CryptoPP::Add(result.reg.begin(), a.reg, b.reg, a.reg.size()) 03824 || Compare(result.reg, modulus.reg, a.reg.size()) >= 0) 03825 { 03826 CryptoPP::Subtract(result.reg.begin(), result.reg, modulus.reg, a.reg.size()); 03827 } 03828 return result; 03829 } 03830 else 03831 { 03832 result1 = a+b; 03833 if (result1 >= modulus) 03834 result1 -= modulus; 03835 return result1; 03836 } 03837 } 03838 03839 Integer& ModularArithmetic::Accumulate(Integer &a, const Integer &b) const 03840 { 03841 if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) 03842 { 03843 if (CryptoPP::Add(a.reg, a.reg, b.reg, a.reg.size()) 03844 || Compare(a.reg, modulus.reg, a.reg.size()) >= 0) 03845 { 03846 CryptoPP::Subtract(a.reg, a.reg, modulus.reg, a.reg.size()); 03847 } 03848 } 03849 else 03850 { 03851 a+=b; 03852 if (a>=modulus) 03853 a-=modulus; 03854 } 03855 03856 return a; 03857 } 03858 03859 const Integer& ModularArithmetic::Subtract(const Integer &a, const Integer &b) const 03860 { 03861 if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) 03862 { 03863 if (CryptoPP::Subtract(result.reg.begin(), a.reg, b.reg, a.reg.size())) 03864 CryptoPP::Add(result.reg.begin(), result.reg, modulus.reg, a.reg.size()); 03865 return result; 03866 } 03867 else 03868 { 03869 result1 = a-b; 03870 if (result1.IsNegative()) 03871 result1 += modulus; 03872 return result1; 03873 } 03874 } 03875 03876 Integer& ModularArithmetic::Reduce(Integer &a, const Integer &b) const 03877 { 03878 if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) 03879 { 03880 if (CryptoPP::Subtract(a.reg, a.reg, b.reg, a.reg.size())) 03881 CryptoPP::Add(a.reg, a.reg, modulus.reg, a.reg.size()); 03882 } 03883 else 03884 { 03885 a-=b; 03886 if (a.IsNegative()) 03887 a+=modulus; 03888 } 03889 03890 return a; 03891 } 03892 03893 const Integer& ModularArithmetic::Inverse(const Integer &a) const 03894 { 03895 if (!a) 03896 return a; 03897 03898 CopyWords(result.reg.begin(), modulus.reg, modulus.reg.size()); 03899 if (CryptoPP::Subtract(result.reg.begin(), result.reg, a.reg, a.reg.size())) 03900 Decrement(result.reg.begin()+a.reg.size(), 1, modulus.reg.size()-a.reg.size()); 03901 03902 return result; 03903 } 03904 03905 Integer ModularArithmetic::CascadeExponentiate(const Integer &x, const Integer &e1, const Integer &y, const Integer &e2) const 03906 { 03907 if (modulus.IsOdd()) 03908 { 03909 MontgomeryRepresentation dr(modulus); 03910 return dr.ConvertOut(dr.CascadeExponentiate(dr.ConvertIn(x), e1, dr.ConvertIn(y), e2)); 03911 } 03912 else 03913 return AbstractRing<Integer>::CascadeExponentiate(x, e1, y, e2); 03914 } 03915 03916 void ModularArithmetic::SimultaneousExponentiate(Integer *results, const Integer &base, const Integer *exponents, unsigned int exponentsCount) const 03917 { 03918 if (modulus.IsOdd()) 03919 { 03920 MontgomeryRepresentation dr(modulus); 03921 dr.SimultaneousExponentiate(results, dr.ConvertIn(base), exponents, exponentsCount); 03922 for (unsigned int i=0; i<exponentsCount; i++) 03923 results[i] = dr.ConvertOut(results[i]); 03924 } 03925 else 03926 AbstractRing<Integer>::SimultaneousExponentiate(results, base, exponents, exponentsCount); 03927 } 03928 03929 MontgomeryRepresentation::MontgomeryRepresentation(const Integer &m) // modulus must be odd 03930 : ModularArithmetic(m), 03931 u((word)0, modulus.reg.size()), 03932 workspace(5*modulus.reg.size()) 03933 { 03934 if (!modulus.IsOdd()) 03935 throw InvalidArgument("MontgomeryRepresentation: Montgomery representation requires an odd modulus"); 03936 03937 RecursiveInverseModPower2(u.reg, workspace, modulus.reg, modulus.reg.size()); 03938 } 03939 03940 const Integer& MontgomeryRepresentation::Multiply(const Integer &a, const Integer &b) const 03941 { 03942 word *const T = workspace.begin(); 03943 word *const R = result.reg.begin(); 03944 const unsigned int N = modulus.reg.size(); 03945 assert(a.reg.size()<=N && b.reg.size()<=N); 03946 03947 AsymmetricMultiply(T, T+2*N, a.reg, a.reg.size(), b.reg, b.reg.size()); 03948 SetWords(T+a.reg.size()+b.reg.size(), 0, 2*N-a.reg.size()-b.reg.size()); 03949 MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); 03950 return result; 03951 } 03952 03953 const Integer& MontgomeryRepresentation::Square(const Integer &a) const 03954 { 03955 word *const T = workspace.begin(); 03956 word *const R = result.reg.begin(); 03957 const unsigned int N = modulus.reg.size(); 03958 assert(a.reg.size()<=N); 03959 03960 CryptoPP::Square(T, T+2*N, a.reg, a.reg.size()); 03961 SetWords(T+2*a.reg.size(), 0, 2*N-2*a.reg.size()); 03962 MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); 03963 return result; 03964 } 03965 03966 Integer MontgomeryRepresentation::ConvertOut(const Integer &a) const 03967 { 03968 word *const T = workspace.begin(); 03969 word *const R = result.reg.begin(); 03970 const unsigned int N = modulus.reg.size(); 03971 assert(a.reg.size()<=N); 03972 03973 CopyWords(T, a.reg, a.reg.size()); 03974 SetWords(T+a.reg.size(), 0, 2*N-a.reg.size()); 03975 MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); 03976 return result; 03977 } 03978 03979 const Integer& MontgomeryRepresentation::MultiplicativeInverse(const Integer &a) const 03980 { 03981 // return (EuclideanMultiplicativeInverse(a, modulus)<<(2*WORD_BITS*modulus.reg.size()))%modulus; 03982 word *const T = workspace.begin(); 03983 word *const R = result.reg.begin(); 03984 const unsigned int N = modulus.reg.size(); 03985 assert(a.reg.size()<=N); 03986 03987 CopyWords(T, a.reg, a.reg.size()); 03988 SetWords(T+a.reg.size(), 0, 2*N-a.reg.size()); 03989 MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); 03990 unsigned k = AlmostInverse(R, T, R, N, modulus.reg, N); 03991 03992 // cout << "k=" << k << " N*32=" << 32*N << endl; 03993 03994 if (k>N*WORD_BITS) 03995 DivideByPower2Mod(R, R, k-N*WORD_BITS, modulus.reg, N); 03996 else 03997 MultiplyByPower2Mod(R, R, N*WORD_BITS-k, modulus.reg, N); 03998 03999 return result; 04000 } 04001 04002 template class AbstractRing<Integer>; 04003 04004 NAMESPACE_END

Generated on Fri Aug 13 09:56:54 2004 for Crypto++ by doxygen 1.3.7