Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

datatest.cpp

00001 #include "factory.h" 00002 #include "integer.h" 00003 #include "filters.h" 00004 #include "hex.h" 00005 #include "randpool.h" 00006 #include "files.h" 00007 #include "trunhash.h" 00008 #include <iostream> 00009 #include <memory> 00010 00011 USING_NAMESPACE(CryptoPP) 00012 USING_NAMESPACE(std) 00013 00014 RandomPool & GlobalRNG(); 00015 void RegisterFactories(); 00016 00017 typedef std::map<std::string, std::string> TestData; 00018 00019 class TestFailure : public Exception 00020 { 00021 public: 00022 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {} 00023 }; 00024 00025 static const TestData *s_currentTestData = NULL; 00026 00027 void OutputTestData(const TestData &v) 00028 { 00029 for (TestData::const_iterator i = v.begin(); i != v.end(); ++i) 00030 { 00031 cerr << i->first << ": " << i->second << endl; 00032 } 00033 } 00034 00035 void SignalTestFailure() 00036 { 00037 OutputTestData(*s_currentTestData); 00038 throw TestFailure(); 00039 } 00040 00041 void SignalTestError() 00042 { 00043 OutputTestData(*s_currentTestData); 00044 throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test"); 00045 } 00046 00047 class TestDataNameValuePairs : public NameValuePairs 00048 { 00049 public: 00050 TestDataNameValuePairs(const TestData &data) : m_data(data) {} 00051 00052 virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const 00053 { 00054 TestData::const_iterator i = m_data.find(name); 00055 if (i == m_data.end()) 00056 return false; 00057 00058 const std::string &value = i->second; 00059 00060 if (valueType == typeid(int)) 00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str()); 00062 else if (valueType == typeid(Integer)) 00063 *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str()); 00064 else 00065 throw ValueTypeMismatch(name, typeid(std::string), valueType); 00066 00067 return true; 00068 } 00069 00070 private: 00071 const TestData &m_data; 00072 }; 00073 00074 const std::string & GetRequiredDatum(const TestData &data, const char *name) 00075 { 00076 TestData::const_iterator i = data.find(name); 00077 if (i == data.end()) 00078 SignalTestError(); 00079 return i->second; 00080 } 00081 00082 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target) 00083 { 00084 std::string s1 = GetRequiredDatum(data, name), s2; 00085 00086 int repeat = 1; 00087 if (s1[0] == 'r') 00088 { 00089 repeat = atoi(s1.c_str()+1); 00090 s1 = s1.substr(s1.find(' ')+1); 00091 } 00092 00093 if (s1[0] == '\"') 00094 s2 = s1.substr(1, s1.find('\"', 1)-1); 00095 else if (s1.substr(0, 2) == "0x") 00096 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2))); 00097 else 00098 StringSource(s1, true, new HexDecoder(new StringSink(s2))); 00099 00100 while (repeat--) 00101 target.Put((const byte *)s2.data(), s2.size()); 00102 } 00103 00104 std::string GetDecodedDatum(const TestData &data, const char *name) 00105 { 00106 std::string s; 00107 PutDecodedDatumInto(data, name, StringSink(s).Ref()); 00108 return s; 00109 } 00110 00111 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv) 00112 { 00113 if (!pub.Validate(GlobalRNG(), 3)) 00114 SignalTestFailure(); 00115 if (!priv.Validate(GlobalRNG(), 3)) 00116 SignalTestFailure(); 00117 00118 /* EqualityComparisonFilter comparison; 00119 pub.Save(ChannelSwitch(comparison, "0")); 00120 pub.AssignFrom(priv); 00121 pub.Save(ChannelSwitch(comparison, "1")); 00122 comparison.ChannelMessageSeriesEnd("0"); 00123 comparison.ChannelMessageSeriesEnd("1"); 00124 */ 00125 } 00126 00127 void TestSignatureScheme(TestData &v) 00128 { 00129 std::string name = GetRequiredDatum(v, "Name"); 00130 std::string test = GetRequiredDatum(v, "Test"); 00131 00132 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str())); 00133 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str())); 00134 00135 TestDataNameValuePairs pairs(v); 00136 std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); 00137 00138 if (keyFormat == "DER") 00139 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); 00140 else if (keyFormat == "Component") 00141 verifier->AccessMaterial().AssignFrom(pairs); 00142 00143 if (test == "Verify" || test == "NotVerify") 00144 { 00145 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN); 00146 PutDecodedDatumInto(v, "Signature", verifierFilter); 00147 PutDecodedDatumInto(v, "Message", verifierFilter); 00148 verifierFilter.MessageEnd(); 00149 if (verifierFilter.GetLastResult() == (test == "NotVerify")) 00150 SignalTestFailure(); 00151 } 00152 else if (test == "PublicKeyValid") 00153 { 00154 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3)) 00155 SignalTestFailure(); 00156 } 00157 else 00158 goto privateKeyTests; 00159 00160 return; 00161 00162 privateKeyTests: 00163 if (keyFormat == "DER") 00164 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); 00165 else if (keyFormat == "Component") 00166 signer->AccessMaterial().AssignFrom(pairs); 00167 00168 if (test == "KeyPairValidAndConsistent") 00169 { 00170 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial()); 00171 } 00172 else if (test == "Sign") 00173 { 00174 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout))); 00175 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f)); 00176 SignalTestFailure(); 00177 } 00178 else if (test == "DeterministicSign") 00179 { 00180 SignalTestError(); 00181 assert(false); // TODO: implement 00182 } 00183 else if (test == "RandomSign") 00184 { 00185 SignalTestError(); 00186 assert(false); // TODO: implement 00187 } 00188 else if (test == "GenerateKey") 00189 { 00190 SignalTestError(); 00191 assert(false); 00192 } 00193 else 00194 { 00195 SignalTestError(); 00196 assert(false); 00197 } 00198 } 00199 00200 void TestEncryptionScheme(TestData &v) 00201 { 00202 std::string name = GetRequiredDatum(v, "Name"); 00203 std::string test = GetRequiredDatum(v, "Test"); 00204 00205 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str())); 00206 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str())); 00207 00208 std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); 00209 00210 if (keyFormat == "DER") 00211 { 00212 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); 00213 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); 00214 } 00215 else if (keyFormat == "Component") 00216 { 00217 TestDataNameValuePairs pairs(v); 00218 decryptor->AccessMaterial().AssignFrom(pairs); 00219 encryptor->AccessMaterial().AssignFrom(pairs); 00220 } 00221 00222 if (test == "DecryptMatch") 00223 { 00224 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext"); 00225 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted))); 00226 if (decrypted != expected) 00227 SignalTestFailure(); 00228 } 00229 else if (test == "KeyPairValidAndConsistent") 00230 { 00231 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial()); 00232 } 00233 else 00234 { 00235 SignalTestError(); 00236 assert(false); 00237 } 00238 } 00239 00240 void TestDigestOrMAC(TestData &v, bool testDigest) 00241 { 00242 std::string name = GetRequiredDatum(v, "Name"); 00243 std::string test = GetRequiredDatum(v, "Test"); 00244 00245 member_ptr<MessageAuthenticationCode> mac; 00246 member_ptr<HashTransformation> hash; 00247 HashTransformation *pHash = NULL; 00248 00249 if (testDigest) 00250 { 00251 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str())); 00252 pHash = hash.get(); 00253 } 00254 else 00255 { 00256 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str())); 00257 pHash = mac.get(); 00258 std::string key = GetDecodedDatum(v, "Key"); 00259 mac->SetKey((const byte *)key.c_str(), key.size()); 00260 } 00261 00262 if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify") 00263 { 00264 int digestSize = pHash->DigestSize(); 00265 if (test == "VerifyTruncated") 00266 digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str()); 00267 TruncatedHashModule thash(*pHash, digestSize); 00268 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN); 00269 PutDecodedDatumInto(v, "Digest", verifierFilter); 00270 PutDecodedDatumInto(v, "Message", verifierFilter); 00271 verifierFilter.MessageEnd(); 00272 if (verifierFilter.GetLastResult() == (test == "NotVerify")) 00273 SignalTestFailure(); 00274 } 00275 else 00276 { 00277 SignalTestError(); 00278 assert(false); 00279 } 00280 } 00281 00282 bool GetField(std::istream &is, std::string &name, std::string &value) 00283 { 00284 name.resize(0); // GCC workaround: 2.95.3 doesn't have clear() 00285 is >> name; 00286 if (name.empty()) 00287 return false; 00288 00289 if (name[name.size()-1] != ':') 00290 SignalTestError(); 00291 name.erase(name.size()-1); 00292 00293 while (is.peek() == ' ') 00294 is.ignore(1); 00295 00296 // VC60 workaround: getline bug 00297 char buffer[128]; 00298 value.resize(0); // GCC workaround: 2.95.3 doesn't have clear() 00299 bool continueLine; 00300 00301 do 00302 { 00303 do 00304 { 00305 is.get(buffer, sizeof(buffer)); 00306 value += buffer; 00307 } 00308 while (buffer[0] != 0); 00309 is.clear(); 00310 is.ignore(); 00311 00312 if (value[value.size()-1] == '\\') 00313 { 00314 value.resize(value.size()-1); 00315 continueLine = true; 00316 } 00317 else 00318 continueLine = false; 00319 00320 std::string::size_type i = value.find('#'); 00321 if (i != std::string::npos) 00322 value.erase(i); 00323 } 00324 while (continueLine); 00325 00326 return true; 00327 } 00328 00329 void OutputPair(const NameValuePairs &v, const char *name) 00330 { 00331 Integer x; 00332 bool b = v.GetValue(name, x); 00333 assert(b); 00334 cout << name << ": \\\n "; 00335 x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize()); 00336 cout << endl; 00337 } 00338 00339 void OutputNameValuePairs(const NameValuePairs &v) 00340 { 00341 std::string names = v.GetValueNames(); 00342 string::size_type i = 0; 00343 while (i < names.size()) 00344 { 00345 string::size_type j = names.find_first_of (';', i); 00346 00347 if (j == string::npos) 00348 return; 00349 else 00350 { 00351 std::string name = names.substr(i, j-i); 00352 if (name.find(':') == string::npos) 00353 OutputPair(v, name.c_str()); 00354 } 00355 00356 i = j + 1; 00357 } 00358 } 00359 00360 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests) 00361 { 00362 std::ifstream file(filename.c_str()); 00363 TestData v; 00364 s_currentTestData = &v; 00365 std::string name, value, lastAlgName; 00366 00367 while (file) 00368 { 00369 while (file.peek() == '#') 00370 file.ignore(INT_MAX, '\n'); 00371 00372 if (file.peek() == '\n') 00373 v.clear(); 00374 00375 if (!GetField(file, name, value)) 00376 break; 00377 v[name] = value; 00378 00379 if (name == "Test") 00380 { 00381 bool failed = true; 00382 std::string algType = GetRequiredDatum(v, "AlgorithmType"); 00383 00384 if (lastAlgName != GetRequiredDatum(v, "Name")) 00385 { 00386 lastAlgName = GetRequiredDatum(v, "Name"); 00387 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n"; 00388 } 00389 00390 try 00391 { 00392 if (algType == "Signature") 00393 TestSignatureScheme(v); 00394 else if (algType == "AsymmetricCipher") 00395 TestEncryptionScheme(v); 00396 else if (algType == "MessageDigest") 00397 TestDigestOrMAC(v, true); 00398 else if (algType == "MAC") 00399 TestDigestOrMAC(v, false); 00400 else if (algType == "FileList") 00401 TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests); 00402 else 00403 SignalTestError(); 00404 failed = false; 00405 } 00406 catch (TestFailure &) 00407 { 00408 cout << "\nTest failed.\n"; 00409 } 00410 catch (CryptoPP::Exception &e) 00411 { 00412 cout << "\nCryptoPP::Exception caught: " << e.what() << endl; 00413 } 00414 catch (std::exception &e) 00415 { 00416 cout << "\nstd::exception caught: " << e.what() << endl; 00417 } 00418 00419 if (failed) 00420 { 00421 cout << "Skipping to next test.\n"; 00422 failedTests++; 00423 } 00424 else 00425 cout << "." << flush; 00426 00427 totalTests++; 00428 } 00429 } 00430 } 00431 00432 bool RunTestDataFile(const char *filename) 00433 { 00434 RegisterFactories(); 00435 unsigned int totalTests = 0, failedTests = 0; 00436 TestDataFile(filename, totalTests, failedTests); 00437 cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n"; 00438 if (failedTests != 0) 00439 cout << "SOME TESTS FAILED!\n"; 00440 return failedTests == 0; 00441 }

Generated on Fri Aug 13 09:56:53 2004 for Crypto++ by doxygen 1.3.7