socketft.cpp

00001 // socketft.cpp - written and placed in the public domain by Wei Dai
00002 
00003 #include "pch.h"
00004 #include "socketft.h"
00005 
00006 #ifdef SOCKETS_AVAILABLE
00007 
00008 #include "wait.h"
00009 
00010 #ifdef USE_BERKELEY_STYLE_SOCKETS
00011 #include <errno.h>
00012 #include <netdb.h>
00013 #include <unistd.h>
00014 #include <arpa/inet.h>
00015 #include <netinet/in.h>
00016 #include <sys/ioctl.h>
00017 #endif
00018 
00019 NAMESPACE_BEGIN(CryptoPP)
00020 
00021 #ifdef USE_WINDOWS_STYLE_SOCKETS
00022 const int SOCKET_EINVAL = WSAEINVAL;
00023 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
00024 typedef int socklen_t;
00025 #else
00026 const int SOCKET_EINVAL = EINVAL;
00027 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
00028 #endif
00029 
00030 Socket::Err::Err(socket_t s, const std::string& operation, int error)
00031         : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
00032         , m_s(s)
00033 {
00034 }
00035 
00036 Socket::~Socket()
00037 {
00038         if (m_own)
00039         {
00040                 try
00041                 {
00042                         CloseSocket();
00043                 }
00044                 catch (...)
00045                 {
00046                 }
00047         }
00048 }
00049 
00050 void Socket::AttachSocket(socket_t s, bool own)
00051 {
00052         if (m_own)
00053                 CloseSocket();
00054 
00055         m_s = s;
00056         m_own = own;
00057         SocketChanged();
00058 }
00059 
00060 socket_t Socket::DetachSocket()
00061 {
00062         socket_t s = m_s;
00063         m_s = INVALID_SOCKET;
00064         SocketChanged();
00065         return s;
00066 }
00067 
00068 void Socket::Create(int nType)
00069 {
00070         assert(m_s == INVALID_SOCKET);
00071         m_s = socket(AF_INET, nType, 0);
00072         CheckAndHandleError("socket", m_s);
00073         m_own = true;
00074         SocketChanged();
00075 }
00076 
00077 void Socket::CloseSocket()
00078 {
00079         if (m_s != INVALID_SOCKET)
00080         {
00081 #ifdef USE_WINDOWS_STYLE_SOCKETS
00082                 CheckAndHandleError_int("closesocket", closesocket(m_s));
00083 #else
00084                 CheckAndHandleError_int("close", close(m_s));
00085 #endif
00086                 m_s = INVALID_SOCKET;
00087                 SocketChanged();
00088         }
00089 }
00090 
00091 void Socket::Bind(unsigned int port, const char *addr)
00092 {
00093         sockaddr_in sa;
00094         memset(&sa, 0, sizeof(sa));
00095         sa.sin_family = AF_INET;
00096 
00097         if (addr == NULL)
00098                 sa.sin_addr.s_addr = htonl(INADDR_ANY);
00099         else
00100         {
00101                 unsigned long result = inet_addr(addr);
00102                 if (result == static_cast<unsigned long>(-1))   // Solaris doesn't have INADDR_NONE
00103                 {
00104                         SetLastError(SOCKET_EINVAL);
00105                         CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
00106                 }
00107                 sa.sin_addr.s_addr = result;
00108         }
00109 
00110         sa.sin_port = htons((u_short)port);
00111 
00112         Bind((sockaddr *)&sa, sizeof(sa));
00113 }
00114 
00115 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
00116 {
00117         assert(m_s != INVALID_SOCKET);
00118         // cygwin workaround: needs const_cast
00119         CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
00120 }
00121 
00122 void Socket::Listen(int backlog)
00123 {
00124         assert(m_s != INVALID_SOCKET);
00125         CheckAndHandleError_int("listen", listen(m_s, backlog));
00126 }
00127 
00128 bool Socket::Connect(const char *addr, unsigned int port)
00129 {
00130         assert(addr != NULL);
00131 
00132         sockaddr_in sa;
00133         memset(&sa, 0, sizeof(sa));
00134         sa.sin_family = AF_INET;
00135         sa.sin_addr.s_addr = inet_addr(addr);
00136 
00137         if (sa.sin_addr.s_addr == -1)   // Solaris doesn't have INADDR_NONE
00138         {
00139                 hostent *lphost = gethostbyname(addr);
00140                 if (lphost == NULL)
00141                 {
00142                         SetLastError(SOCKET_EINVAL);
00143                         CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
00144                 }
00145 
00146                 sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr;
00147         }
00148 
00149         sa.sin_port = htons((u_short)port);
00150 
00151         return Connect((const sockaddr *)&sa, sizeof(sa));
00152 }
00153 
00154 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
00155 {
00156         assert(m_s != INVALID_SOCKET);
00157         int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
00158         if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
00159                 return false;
00160         CheckAndHandleError_int("connect", result);
00161         return true;
00162 }
00163 
00164 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
00165 {
00166         assert(m_s != INVALID_SOCKET);
00167         socket_t s = accept(m_s, psa, psaLen);
00168         if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
00169                 return false;
00170         CheckAndHandleError_int("accept", s);
00171         target.AttachSocket(s, true);
00172         return true;
00173 }
00174 
00175 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
00176 {
00177         assert(m_s != INVALID_SOCKET);
00178         CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
00179 }
00180 
00181 unsigned int Socket::Send(const byte* buf, unsigned int bufLen, int flags)
00182 {
00183         assert(m_s != INVALID_SOCKET);
00184         int result = send(m_s, (const char *)buf, bufLen, flags);
00185         CheckAndHandleError_int("send", result);
00186         return result;
00187 }
00188 
00189 unsigned int Socket::Receive(byte* buf, unsigned int bufLen, int flags)
00190 {
00191         assert(m_s != INVALID_SOCKET);
00192         int result = recv(m_s, (char *)buf, bufLen, flags);
00193         CheckAndHandleError_int("recv", result);
00194         return result;
00195 }
00196 
00197 void Socket::ShutDown(int how)
00198 {
00199         assert(m_s != INVALID_SOCKET);
00200         int result = shutdown(m_s, how);
00201         CheckAndHandleError_int("shutdown", result);
00202 }
00203 
00204 void Socket::IOCtl(long cmd, unsigned long *argp)
00205 {
00206         assert(m_s != INVALID_SOCKET);
00207 #ifdef USE_WINDOWS_STYLE_SOCKETS
00208         CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
00209 #else
00210         CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
00211 #endif
00212 }
00213 
00214 bool Socket::SendReady(const timeval *timeout)
00215 {
00216         fd_set fds;
00217         FD_ZERO(&fds);
00218         FD_SET(m_s, &fds);
00219         int ready;
00220         if (timeout == NULL)
00221                 ready = select(m_s+1, NULL, &fds, NULL, NULL);
00222         else
00223         {
00224                 timeval timeoutCopy = *timeout; // select() modified timeout on Linux
00225                 ready = select(m_s+1, NULL, &fds, NULL, &timeoutCopy);
00226         }
00227         CheckAndHandleError_int("select", ready);
00228         return ready > 0;
00229 }
00230 
00231 bool Socket::ReceiveReady(const timeval *timeout)
00232 {
00233         fd_set fds;
00234         FD_ZERO(&fds);
00235         FD_SET(m_s, &fds);
00236         int ready;
00237         if (timeout == NULL)
00238                 ready = select(m_s+1, &fds, NULL, NULL, NULL);
00239         else
00240         {
00241                 timeval timeoutCopy = *timeout; // select() modified timeout on Linux
00242                 ready = select(m_s+1, &fds, NULL, NULL, &timeoutCopy);
00243         }
00244         CheckAndHandleError_int("select", ready);
00245         return ready > 0;
00246 }
00247 
00248 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
00249 {
00250         int port = atoi(name);
00251         if (IntToString(port) == name)
00252                 return port;
00253 
00254         servent *se = getservbyname(name, protocol);
00255         if (!se)
00256                 throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
00257         return ntohs(se->s_port);
00258 }
00259 
00260 void Socket::StartSockets()
00261 {
00262 #ifdef USE_WINDOWS_STYLE_SOCKETS
00263         WSADATA wsd;
00264         int result = WSAStartup(0x0002, &wsd);
00265         if (result != 0)
00266                 throw Err(INVALID_SOCKET, "WSAStartup", result);
00267 #endif
00268 }
00269 
00270 void Socket::ShutdownSockets()
00271 {
00272 #ifdef USE_WINDOWS_STYLE_SOCKETS
00273         int result = WSACleanup();
00274         if (result != 0)
00275                 throw Err(INVALID_SOCKET, "WSACleanup", result);
00276 #endif
00277 }
00278 
00279 int Socket::GetLastError()
00280 {
00281 #ifdef USE_WINDOWS_STYLE_SOCKETS
00282         return WSAGetLastError();
00283 #else
00284         return errno;
00285 #endif
00286 }
00287 
00288 void Socket::SetLastError(int errorCode)
00289 {
00290 #ifdef USE_WINDOWS_STYLE_SOCKETS
00291         WSASetLastError(errorCode);
00292 #else
00293         errno = errorCode;
00294 #endif
00295 }
00296 
00297 void Socket::HandleError(const char *operation) const
00298 {
00299         int err = GetLastError();
00300         throw Err(m_s, operation, err);
00301 }
00302 
00303 #ifdef USE_WINDOWS_STYLE_SOCKETS
00304 
00305 SocketReceiver::SocketReceiver(Socket &s)
00306         : m_s(s), m_resultPending(false), m_eofReceived(false)
00307 {
00308         m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
00309         m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
00310         memset(&m_overlapped, 0, sizeof(m_overlapped));
00311         m_overlapped.hEvent = m_event;
00312 }
00313 
00314 bool SocketReceiver::Receive(byte* buf, unsigned int bufLen)
00315 {
00316         assert(!m_resultPending && !m_eofReceived);
00317 
00318         DWORD flags = 0;
00319         // don't queue too much at once, or we might use up non-paged memory
00320         WSABUF wsabuf = {STDMIN(bufLen, 128U*1024U), (char *)buf};
00321         if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
00322         {
00323                 if (m_lastResult == 0)
00324                         m_eofReceived = true;
00325         }
00326         else
00327         {
00328                 switch (WSAGetLastError())
00329                 {
00330                 default:
00331                         m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
00332                 case WSAEDISCON:
00333                         m_lastResult = 0;
00334                         m_eofReceived = true;
00335                         break;
00336                 case WSA_IO_PENDING:
00337                         m_resultPending = true;
00338                 }
00339         }
00340         return !m_resultPending;
00341 }
00342 
00343 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container)
00344 {
00345         if (m_resultPending)
00346                 container.AddHandle(m_event);
00347         else if (!m_eofReceived)
00348                 container.SetNoWait();
00349 }
00350 
00351 unsigned int SocketReceiver::GetReceiveResult()
00352 {
00353         if (m_resultPending)
00354         {
00355                 DWORD flags = 0;
00356                 if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
00357                 {
00358                         if (m_lastResult == 0)
00359                                 m_eofReceived = true;
00360                 }
00361                 else
00362                 {
00363                         switch (WSAGetLastError())
00364                         {
00365                         default:
00366                                 m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
00367                         case WSAEDISCON:
00368                                 m_lastResult = 0;
00369                                 m_eofReceived = true;
00370                         }
00371                 }
00372                 m_resultPending = false;
00373         }
00374         return m_lastResult;
00375 }
00376 
00377 // *************************************************************
00378 
00379 SocketSender::SocketSender(Socket &s)
00380         : m_s(s), m_resultPending(false), m_lastResult(0)
00381 {
00382         m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
00383         m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
00384         memset(&m_overlapped, 0, sizeof(m_overlapped));
00385         m_overlapped.hEvent = m_event;
00386 }
00387 
00388 void SocketSender::Send(const byte* buf, unsigned int bufLen)
00389 {
00390         DWORD written = 0;
00391         // don't queue too much at once, or we might use up non-paged memory
00392         WSABUF wsabuf = {STDMIN(bufLen, 128U*1024U), (char *)buf};
00393         if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
00394         {
00395                 m_resultPending = false;
00396                 m_lastResult = written;
00397         }
00398         else
00399         {
00400                 if (WSAGetLastError() != WSA_IO_PENDING)
00401                         m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
00402 
00403                 m_resultPending = true;
00404         }
00405 }
00406 
00407 void SocketSender::GetWaitObjects(WaitObjectContainer &container)
00408 {
00409         if (m_resultPending)
00410                 container.AddHandle(m_event);
00411         else
00412                 container.SetNoWait();
00413 }
00414 
00415 unsigned int SocketSender::GetSendResult()
00416 {
00417         if (m_resultPending)
00418         {
00419                 DWORD flags = 0;
00420                 BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
00421                 m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
00422                 m_resultPending = false;
00423         }
00424         return m_lastResult;
00425 }
00426 
00427 #endif
00428 
00429 #ifdef USE_BERKELEY_STYLE_SOCKETS
00430 
00431 SocketReceiver::SocketReceiver(Socket &s)
00432         : m_s(s), m_eofReceived(false), m_lastResult(0)
00433 {
00434 }
00435 
00436 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container)
00437 {
00438         if (!m_eofReceived)
00439                 container.AddReadFd(m_s);
00440 }
00441 
00442 bool SocketReceiver::Receive(byte* buf, unsigned int bufLen)
00443 {
00444         m_lastResult = m_s.Receive(buf, bufLen);
00445         if (bufLen > 0 && m_lastResult == 0)
00446                 m_eofReceived = true;
00447         return true;
00448 }
00449 
00450 unsigned int SocketReceiver::GetReceiveResult()
00451 {
00452         return m_lastResult;
00453 }
00454 
00455 SocketSender::SocketSender(Socket &s)
00456         : m_s(s), m_lastResult(0)
00457 {
00458 }
00459 
00460 void SocketSender::Send(const byte* buf, unsigned int bufLen)
00461 {
00462         m_lastResult = m_s.Send(buf, bufLen);
00463 }
00464 
00465 unsigned int SocketSender::GetSendResult()
00466 {
00467         return m_lastResult;
00468 }
00469 
00470 void SocketSender::GetWaitObjects(WaitObjectContainer &container)
00471 {
00472         container.AddWriteFd(m_s);
00473 }
00474 
00475 #endif
00476 
00477 NAMESPACE_END
00478 
00479 #endif  // #ifdef SOCKETS_AVAILABLE

Generated on Fri Dec 16 03:04:18 2005 for Crypto++ by  doxygen 1.4.5