00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include <iostream>
00024 #include <fstream>
00025 #include <sys/types.h>
00026 #include "socket.hh"
00027
00028 namespace Network
00029 {
00030
00031 Socket::Socket(SOCKET_KIND kind, SOCKET_VERSION version) :
00032 _kind(kind), _version(version), _state_timeout(0),
00033 _socket(0), _recv_flags(kind), _proto_kind(text), _empty_lines(false),
00034 _buffer(""), _tls(false)
00035 {
00036 _delim.push_back("\0");
00037 #ifdef LIBSOCKET_WIN
00038 WSADATA wsadata;
00039 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00040 throw WSAStartupError("WSAStartup failed", HERE);
00041 #endif
00042 #ifndef IPV6_ENABLED
00043 if (version == V6)
00044 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00045 #endif
00046 }
00047
00048 Socket::Socket(SOCKET_KIND kind, PROTO_KIND pkind, SOCKET_VERSION version) :
00049 _kind(kind), _version(version), _state_timeout(0),
00050 _socket(0), _recv_flags(kind), _proto_kind(pkind), _empty_lines(false),
00051 _buffer(""), _tls(false)
00052 {
00053 _delim.push_back("\0");
00054 #ifdef LIBSOCKET_WIN
00055 WSADATA wsadata;
00056 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0)
00057 throw WSAStartupError("WSAStartup failed", HERE);
00058 #endif
00059 #ifndef IPV6_ENABLED
00060 if (version == V6)
00061 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE);
00062 #endif
00063 }
00064
00065 Socket::~Socket()
00066 {
00067 }
00068
00069 void Socket::enable_tls()
00070 {
00071 #ifdef TLS
00072 int ret;
00073
00074 if (_kind != TCP)
00075 throw TLSError("You need to have a TCP connection", HERE);
00076 if (!connected())
00077 throw NoConnection("You need to have a connection", HERE);
00078
00079 gnutls_transport_set_ptr(_session, (gnutls_transport_ptr)_socket);
00080 ret = gnutls_handshake(_session);
00081 if (ret < 0)
00082 {
00083 close(_socket);
00084 gnutls_deinit(_session);
00085 throw TLSError(gnutls_strerror(ret), HERE);
00086 }
00087 #else
00088 throw TLSSupportError("lib was not compiled with TLS support", HERE);
00089 #endif
00090 }
00091
00092 void Socket::init_tls(GnuTLSKind kind,
00093 unsigned size, const std::string &certfile,
00094 const std::string &keyfile,
00095 const std::string &trustfile,
00096 const std::string &crlfile)
00097 {
00098 #ifdef TLS
00099 static bool init = false;
00100 static gnutls_dh_params dh_params;
00101 const int protocol_tls[] = { GNUTLS_TLS1, 0 };
00102 const int protocol_ssl[] = { GNUTLS_SSL3, 0 };
00103 const int cert_type_priority[] = { GNUTLS_CRT_X509,
00104 GNUTLS_CRT_OPENPGP, 0 };
00105
00106 if (!init)
00107 {
00108 gnutls_global_init();
00109 init = true;
00110 }
00111 _tls = true;
00112 _tls_main = true;
00113 gnutls_certificate_allocate_credentials(&_x509_cred);
00114 if (keyfile.size() > 0 && certfile.size() > 0)
00115 {
00116 std::ifstream key(keyfile.c_str()), cert(certfile.c_str());
00117 if (!key.is_open() || !cert.is_open())
00118 throw InvalidFile("key or cert invalid", HERE);
00119 key.close();
00120 cert.close();
00121
00122 _nbbits = size;
00123 if (trustfile.size() > 0)
00124 gnutls_certificate_set_x509_trust_file(_x509_cred, trustfile.c_str(),
00125 GNUTLS_X509_FMT_PEM);
00126 if (crlfile.size() > 0)
00127 gnutls_certificate_set_x509_crl_file(_x509_cred, crlfile.c_str(),
00128 GNUTLS_X509_FMT_PEM);
00129 gnutls_certificate_set_x509_key_file(_x509_cred, certfile.c_str(),
00130 keyfile.c_str(),
00131 GNUTLS_X509_FMT_PEM);
00132 gnutls_dh_params_init(&dh_params);
00133 gnutls_dh_params_generate2(dh_params, _nbbits);
00134 gnutls_certificate_set_dh_params(_x509_cred, dh_params);
00135
00136 if (gnutls_init(&_session, GNUTLS_SERVER))
00137 throw TLSError("gnutls_init failed", HERE);
00138 }
00139 else
00140 {
00141 if (gnutls_init(&_session, GNUTLS_CLIENT))
00142 throw TLSError("gnutls_init failed", HERE);
00143 }
00144
00145 gnutls_set_default_priority(_session);
00146 if (kind == TLS)
00147 gnutls_protocol_set_priority(_session, protocol_tls);
00148 else
00149 gnutls_protocol_set_priority(_session, protocol_ssl);
00150
00151 if (keyfile.size() > 0 && certfile.size() > 0)
00152 {
00153 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00154 gnutls_certificate_server_set_request(_session, GNUTLS_CERT_REQUEST);
00155 gnutls_dh_set_prime_bits(_session, _nbbits);
00156 }
00157 else
00158 {
00159 gnutls_certificate_type_set_priority(_session, cert_type_priority);
00160 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred);
00161 }
00162 #else
00163 throw TLSSupportError("lib was not compiled with TLS support", HERE);
00164 #endif
00165 }
00166
00167 void Socket::_close(int socket) const
00168 {
00169 #ifndef LIBSOCKET_WIN
00170 if (socket < 0 || close(socket) < 0)
00171 throw CloseError("Close Error", HERE);
00172 socket = 0;
00173 #endif
00174 #ifdef TLS
00175 if (_tls)
00176 {
00177 std::cout << "Deletion..." << std::endl;
00178 gnutls_deinit(_session);
00179 if (_tls_main)
00180 {
00181 gnutls_certificate_free_credentials(_x509_cred);
00182 gnutls_global_deinit();
00183 }
00184 }
00185 #endif
00186 }
00187
00188 void Socket::_listen(int socket) const
00189 {
00190 if (socket < 0 || listen(socket, 5) < 0)
00191 throw ListenError("Listen Error", HERE);
00192 }
00193
00194 void Socket::_write_str(int socket, const std::string& str) const
00195 {
00196 int res = 1;
00197 unsigned int count = 0;
00198 const char *buf;
00199
00200 buf = str.c_str();
00201 if (socket < 0)
00202 throw NoConnection("No Socket", HERE);
00203 while (res && count < str.size())
00204 {
00205 #ifdef IPV6_ENABLED
00206 if (V4 == _version)
00207 #endif
00208 #ifdef TLS
00209 if (_tls)
00210 res = gnutls_record_send(_session, buf + count, str.size() - count);
00211 else
00212 #endif
00213 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00214 (const struct sockaddr*)&_addr, sizeof(_addr));
00215 #ifdef IPV6_ENABLED
00216 else
00217 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS,
00218 (const struct sockaddr*)&_addr6, sizeof(_addr6));
00219 #endif
00220 if (res <= 0)
00221 throw ConnectionClosed("Connection Closed", HERE);
00222 count += res;
00223 }
00224 }
00225
00226 void Socket::_write_str_bin(int socket, const std::string& str) const
00227 {
00228 int res = 1;
00229 unsigned int count = 0;
00230 #ifdef LIBSOCKET_WIN
00231 char* buf = new char[str.size() + 2];
00232 #else
00233 char buf[str.size() + 2];
00234 #endif
00235 buf[0] = str.size() / 256;
00236 buf[1] = str.size() % 256;
00237 memcpy(buf + 2, str.c_str(), str.size());
00238 if (socket < 0)
00239 throw NoConnection("No Socket", HERE);
00240 while (res && count < str.size() + 2)
00241 {
00242 #ifdef IPV6_ENABLED
00243 if (V4 == _version)
00244 #endif
00245 #ifdef TLS
00246 if (_tls)
00247 res = gnutls_record_send(_session, buf + count, str.size() + 2 - count);
00248 else
00249 #endif
00250 res = sendto(socket, buf + count, str.size() + 2 - count,
00251 SENDTO_FLAGS,
00252 (const struct sockaddr*)&_addr, sizeof(_addr));
00253 #ifdef IPV6_ENABLED
00254 else
00255 res = sendto(socket, buf + count, str.size() + 2 - count,
00256 \ SENDTO_FLAGS,
00257 (const struct sockaddr*)&_addr6, sizeof(_addr6));
00258 #endif
00259 if (res <= 0)
00260 throw ConnectionClosed("Connection Closed", HERE);
00261 count += res;
00262 }
00263 #ifdef LIBSOCKET_WIN
00264 delete[] buf;
00265 #endif
00266 }
00267
00268 void Socket::_set_timeout(bool enable, int socket, int timeout)
00269 {
00270 fd_set fdset;
00271 struct timeval timetowait;
00272 int res;
00273
00274 if (enable)
00275 timetowait.tv_sec = timeout;
00276 else
00277 timetowait.tv_sec = 65535;
00278 timetowait.tv_usec = 0;
00279 FD_ZERO(&fdset);
00280 FD_SET(socket, &fdset);
00281 if (enable)
00282 res = select(socket + 1, &fdset, NULL, NULL, &timetowait);
00283 else
00284 res = select(socket + 1, &fdset, NULL, NULL, NULL);
00285 if (res < 0)
00286 throw SelectError("Select error", HERE);
00287 if (res == 0)
00288 throw Timeout("Timeout on socket", HERE);
00289 }
00290
00291 void Socket::write(const std::string& str)
00292 {
00293 if (_proto_kind == binary)
00294 _write_str_bin(_socket, str);
00295 else
00296 _write_str(_socket, str);
00297 }
00298
00299 bool Socket::connected() const
00300 {
00301 return _socket != 0;
00302 }
00303
00304 void Socket::allow_empty_lines()
00305 {
00306 _empty_lines = true;
00307 }
00308
00309 int Socket::get_socket()
00310 {
00311 return _socket;
00312 }
00313
00314 void Socket::add_delim(const std::string& delim)
00315 {
00316 _delim.push_back(delim);
00317 }
00318
00319 void Socket::del_delim(const std::string& delim)
00320 {
00321 std::list<std::string>::iterator it, it2;
00322
00323 for (it = _delim.begin(); it != _delim.end(); )
00324 {
00325 if (*it == delim)
00326 {
00327 it2 = it++;
00328 _delim.erase(it2);
00329 }
00330 else
00331 it++;
00332 }
00333 }
00334
00335 std::pair<int, int> Socket::_find_delim(const std::string& str, int start) const
00336 {
00337 int i = -1;
00338 int pos = -1, size = 0;
00339 std::list<std::string>::const_iterator it;
00340
00341
00342 if (_delim.size() > 0)
00343 {
00344 it = _delim.begin();
00345 while (it != _delim.end())
00346 {
00347 if (*it == "")
00348 i = str.find('\0', start);
00349 else
00350 i = str.find(*it, start);
00351 if ((i >= 0) && ((unsigned int)i < str.size()) &&
00352 (pos < 0 || i < pos))
00353 {
00354 pos = i;
00355 size = it->size() ? it->size() : 1;
00356 }
00357 it++;
00358 }
00359 }
00360 return std::pair<int, int>(pos, size);
00361 }
00362
00363 Socket& operator<<(Socket& s, const std::string& str)
00364 {
00365 s.write(str);
00366 return s;
00367 }
00368
00369 Socket& operator>>(Socket& s, std::string& str)
00370 {
00371 str = s.read();
00372 return s;
00373 }
00374 }