]> git.donarmstrong.com Git - bamtools.git/blobdiff - src/api/internal/io/TcpSocketEngine_win_p.cpp
Fixed: premature EOF issues & updated Windows implementation
[bamtools.git] / src / api / internal / io / TcpSocketEngine_win_p.cpp
index 6438b124c768b50ece40e0a07bcb6d0604afa718..c4d9b47d700361b8974ac20fa36133012caaa692 100644 (file)
@@ -1,53 +1,28 @@
+// ***************************************************************************
+// TcpSocketEngine_win_p.cpp (c) 2011 Derek Barnett
+// Marth Lab, Department of Biology, Boston College
+// ---------------------------------------------------------------------------
+// Last modified: 8 December 2011 (DB)
+// ---------------------------------------------------------------------------
+// Provides low-level implementation of TCP I/O for all Windows systems
+// ***************************************************************************
+
 #include "api/internal/io/TcpSocketEngine_p.h"
 #include "api/internal/io/NetWin_p.h"
 using namespace BamTools;
 using namespace BamTools::Internal;
 
+#include <cstring>
 #include <iostream>
+#include <sstream>
 using namespace std;
 
-// ------------------------
-// static utility methods
-// ------------------------
-
-namespace BamTools {
-namespace Internal {
-
-static inline
-void getPortAndAddress(const sockaddr* s, uint16_t& port, HostAddress& address) {
-
-    // IPv6
-    if (s->sa_family == AF_INET6) {
-        sockaddr_in6* ip6 = (sockaddr_in6*)s;
-        port = ntohs(ip6->sin6_port);
-        IPv6Address tmp;
-        memcpy(&tmp, &ip6->sin6_addr.in6_addr, sizeof(tmp));
-        address.SetAddress(tmp);
-        return;
-    }
-
-    // IPv4
-    if ( s->sa_family == AF_INET ) {
-        sockaddr_in* ip4 = (sockaddr_in*)s;
-        port = ntohl(ip4->sin_port);
-        address.SetAddress( ntohl(ip4->sin_addr) );
-        return;
-    }
-
-    // should be unreachable
-    BT_ASSERT_X(false, "TcpSocketEngine::getPortAndAddress() : unknown network protocol ");
-    return false;
-}
-
-} // namespace Internal
-} // namespace BamTools
-
 // --------------------------------
 // TcpSocketEngine implementation
 // --------------------------------
 
 void TcpSocketEngine::nativeClose(void) {
-    close(m_socketDescriptor);
+    closesocket(m_socketDescriptor);
 }
 
 bool TcpSocketEngine::nativeConnect(const HostAddress& address, const uint16_t port) {
@@ -88,55 +63,54 @@ bool TcpSocketEngine::nativeConnect(const HostAddress& address, const uint16_t p
     else BT_ASSERT_X(false, "TcpSocketEngine::nativeConnect() : unknown network protocol");
 
     // attempt conenction
-    int connectResult = connect(socketDescriptor, sockAddrPtr, sockAddrSize);
+    const int connectResult = WSAConnect(m_socketDescriptor, sockAddrPtr, sockAddrSize, 0, 0, 0, 0);
 
-    // if hit error
-    if ( connectResult == -1 ) {
+    // if failed, handle error
+    if ( connectResult == SOCKET_ERROR ) {
 
-        // see what error was encountered
-        switch ( errno ) {
+        // ensure state is set before checking error code
+        m_socketState = TcpSocket::UnconnectedState;
 
-            case EISCONN:
-                m_socketState = TcpSocket::ConnectedState;
+        // set error type/message depending on errorCode
+        const int errorCode = WSAGetLastError();
+        switch ( errorCode ) {
+            case WSANOTINITIALISED:
+                m_socketError = TcpSocket::UnknownSocketError;
+                m_errorString = "Windows socket functionality not properly initialized";
                 break;
-            case ECONNREFUSED:
-            case EINVAL:
+            case WSAEISCONN:
+                m_socketState = TcpSocket::ConnectedState; // socket already connected
+                break;
+            case WSAECONNREFUSED:
+            case WSAEINVAL:
                 m_socketError = TcpSocket::ConnectionRefusedError;
-                m_socketState = TcpSocket::UnconnectedState;
                 m_errorString = "connection refused";
                 break;
-            case ETIMEDOUT:
+            case WSAETIMEDOUT:
                 m_socketError = TcpSocket::NetworkError;
                 m_errorString = "connection timed out";
                 break;
-            case EHOSTUNREACH:
+            case WSAEHOSTUNREACH:
                 m_socketError = TcpSocket::NetworkError;
-                m_socketState = TcpSocket::UnconnectedState;
                 m_errorString = "host unreachable";
                 break;
-            case ENETUNREACH:
+            case WSAENETUNREACH:
                 m_socketError = TcpSocket::NetworkError;
-                m_socketState = TcpSocket::UnconnectedState;
                 m_errorString = "network unreachable";
                 break;
-            case EADDRINUSE:
-                m_socketError = TcpSocket::NetworkError;
+            case WSAEADDRINUSE:
+                m_socketError = TcpSocket::SocketResourceError;
                 m_errorString = "address already in use";
                 break;
-            case EACCES:
-            case EPERM:
+            case WSAEACCES:
                 m_socketError = TcpSocket::SocketAccessError;
-                m_socketState = TcpSocket::UnconnectedState;
                 m_errorString = "permission denied";
-            case EAFNOSUPPORT:
-            case EBADF:
-            case EFAULT:
-            case ENOTSOCK:
-                m_socketState = TcpSocket::UnconnectedState;
+                break;
             default:
                 break;
         }
 
+        // double check that we're not in 'connected' state; if so, return failure
         if ( m_socketState != TcpSocket::ConnectedState )
             return false;
     }
@@ -153,31 +127,35 @@ bool TcpSocketEngine::nativeCreateSocket(HostAddress::NetworkProtocol protocol)
     const int protocolNum = ( (protocol == HostAddress::IPv6Protocol) ? AF_INET6 : AF_INET );
 
     // attempt to create socket
-    int socketFd = socket(protocolNum, SOCK_STREAM, IPPROTO_TCP);
+    SOCKET socketFd = WSASocket(protocolNum, SOCK_STREAM, IPPROTO_TCP, 0, 0, WSA_FLAG_OVERLAPPED);
 
     // if we fetched an invalid socket descriptor
-    if ( socketFd <= 0 ) {
-
-        // see what error we got
-        switch ( errno ) {
-            case EPROTONOSUPPORT:
-            case EAFNOSUPPORT:
-            case EINVAL:
+    if ( socketFd == INVALID_SOCKET ) {
+
+        // set error type/message depending on error code
+        const int errorCode = WSAGetLastError();
+        switch ( errorCode ) {
+            case WSANOTINITIALISED:
+                m_socketError = TcpSocket::UnknownSocketError;
+                m_errorString = "Windows socket functionality not properly initialized";
+                break;
+            case WSAEAFNOSUPPORT:
+            case WSAESOCKTNOSUPPORT:
+            case WSAEPROTOTYPE:
+            case WSAEINVAL:
                 m_socketError = TcpSocket::UnsupportedSocketOperationError;
                 m_errorString = "protocol not supported";
                 break;
-            case ENFILE:
-            case EMFILE:
-            case ENOBUFS:
-            case ENOMEM:
+            case WSAEMFILE:
+            case WSAENOBUFS:
                 m_socketError = TcpSocket::SocketResourceError;
                 m_errorString = "out of resources";
                 break;
-            case EACCES:
-                m_socketError = TcpSocket::SocketAccessError;
-                m_errorString = "permission denied";
-                break;
             default:
+                m_socketError = TcpSocket::UnknownSocketError;
+                stringstream errStream("");
+                errStream << "WSA ErrorCode: " << errorCode;
+                m_errorString = errStream.str();
                 break;
         }
 
@@ -186,118 +164,78 @@ bool TcpSocketEngine::nativeCreateSocket(HostAddress::NetworkProtocol protocol)
     }
 
     // otherwise, store our socket FD & return success
-    m_socketDescriptor = socketFd;
+    m_socketDescriptor = static_cast<int>(socketFd);
     return true;
 }
 
-bool TcpSocketEngine::nativeFetchConnectionParameters(void) {
-
-    // reset addresses/ports
-    m_localAddress.Clear();
-    m_remoteAddress.Clear();
-    m_localPort  = 0;
-    m_remotePort = 0;
+int64_t TcpSocketEngine::nativeNumBytesAvailable(void) const {
 
-    // skip (return failure) if invalid socket FD
-    if ( m_socketDescriptor == -1 )
-        return false;
-
-    sockaddr sa;
-    BT_SOCKLEN_T sockAddrSize = sizeof(sa);
+    int64_t numBytes(0);
+    int64_t dummy(0);
+    DWORD bytesWritten(0);
 
-    // fetch local address info
-    memset(&sa, 0, sizeof(sa));
-    if ( getsockname(m_socketDescriptor, &sa, &sockAddrSize) == 0 ) {
-        getPortAndAddress(&sa, m_localPort, m_localAddress);
-    }
-    else if ( errno == EBADF ) {
-        m_socketError = TcpSocket::UnsupportedSocketOperationError;
-        m_errorString = "invalid socket descriptor";
-        return false;
-    }
-
-    // fetch remote address
-    if ( getpeername(m_socketDescriptor, &sa, &sockAddrSize) == 0 )
-        getPortAndAddress(&sa, m_remotePort, m_remoteAddress);
-
-    // return success
-    return true;
-}
-
-size_t TcpSocketEngine::nativeNumBytesAvailable(void) const {
-
-    // fetch number of bytes, return 0 on error
-    int numBytes(0);
-    if ( ioctl(m_socketDescriptor, FIONREAD, (char*)&numBytes) < 0 )
-        return 0;
-    return static_cast<size_t>(numBytes);
+    const int ioctlResult = WSAIoctl( m_socketDescriptor, FIONREAD
+                                    , &dummy, sizeof(dummy)
+                                    , &numBytes, sizeof(numBytes)
+                                    , &bytesWritten, 0, 0
+                                    );
+    return ( ioctlResult == SOCKET_ERROR ? -1 : numBytes );
 }
 
 int64_t TcpSocketEngine::nativeRead(char* dest, size_t max) {
 
+    // skip if invalid socket
     if ( !IsValid() )
         return -1;
 
-    ssize_t ret = read(m_socketDescriptor, dest, max);
-    if ( ret < 0 ) {
-        ret = -1;
-        switch ( errno ) {
-            case EAGAIN :
-                // No data was available for reading
-                ret = -2;
-                break;
-            case ECONNRESET :
-                ret = 0;
-                break;
-            default:
-                break;
-        }
-    }
+    // set up our WSA output buffer
+    WSABUF buf;
+    buf.buf = dest;
+    buf.len = max;
+
+    // attempt to read bytes
+    DWORD flags = 0;
+    DWORD bytesRead = 0;
+    const int readResult = WSARecv(m_socketDescriptor, &buf, 1, &bytesRead, &flags, 0, 0);
+    if ( readResult == SOCKET_ERROR )
+        return -1;
 
-    return static_cast<int64_t>(ret);
+    // return number of bytes read
+    return static_cast<int64_t>(bytesRead);
 }
 
 // negative value for msecs will block (forever) until
 int TcpSocketEngine::nativeSelect(int msecs, bool isRead) const {
 
-    // set up FD set
     fd_set fds;
     FD_ZERO(&fds);
     FD_SET(m_socketDescriptor, &fds);
 
-    // setup our timeout
     timeval tv;
     tv.tv_sec  = msecs / 1000;
     tv.tv_usec = (msecs % 1000) * 1000;
 
     // do 'select'
-    int ret;
     if ( isRead )
-        ret = select(m_socketDescriptor + 1, &fds, 0, 0, (msecs < 0 ? 0 : &tv));
+        return select(0, &fds, 0, 0, (msecs < 0 ? 0 : &tv));
     else
-        ret = select(m_socketDescriptor + 1, 0, &fds, 0, (msecs < 0 ? 0 : &tv));
-    return ret;
+        return select(0, 0, &fds, 0, (msecs < 0 ? 0 : &tv));
 }
 
 int64_t TcpSocketEngine::nativeWrite(const char* data, size_t length) {
 
-    ssize_t writtenBytes = write(m_socketDescriptor, data, length);
-    if ( writtenBytes < 0 ) {
-        switch (errno) {
-            case EPIPE:
-            case ECONNRESET:
-                writtenBytes = -1;
-                m_socketError = TcpSocket::RemoteHostClosedError;
-                m_errorString = "remote host closed connection";
-                Close();
-                break;
-            case EAGAIN:
-                writtenBytes = 0;
-                break;
-            default:
-                break;
-        }
-    }
+    // setup our WSA write buffer
+    WSABUF buf;
+    buf.buf = (char*)data;
+    buf.len = length;
+
+    // attempt to write bytes
+    DWORD flags = 0;
+    DWORD bytesWritten = 0;
+    const int writeResult = WSASend(m_socketDescriptor, &buf, 1, &bytesWritten, flags, 0, 0);
+    if ( writeResult == SOCKET_ERROR )
+        return -1;
 
-    return static_cast<int64_t>(writtenBytes);
+    // return number of bytes written
+    return static_cast<int64_t>(bytesWritten);
 }