#pragma warning(disable:4786)
#include "Socket.h"
#include <tstring>
#ifdef LINUX
#include <sys/types.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <netdb.h>
#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#endif
#include <errno.h>
#include "Exception.h"
#include <stdio.h>
#include <string.h>
#include "Trace.h"
#include "Thread.h"
#include "WideChar.h"
#include "debug.h"

#ifndef LINUX
#define close(a) closesocket(a)
typedef int socklen_t;
#define MSG_NOSIGNAL 0
#endif

using std::_tstring;

Socket::Address::Address() :
  _good( false)
{
}

/**
 * Parses the internet address. Example addresses are:
 *
 * - 192.168.1.3:1024
 *
 * - www.google.com:80
 */
Socket::Address::Address( const std::_tstring& str ) :
    _good(false)
{
  _tstring::size_type colon = str.find(_T(":"));
  if ( colon == _tstring::npos ) {
    return;
  }

  _tstring host = str.substr(0, colon);
  int port = atoi(NarrowChar(str.substr(colon + 1, str.length() - colon - 1).c_str()));
  
  // look up the host name.
  hostent* hoste = ::gethostbyname( NarrowChar( host.c_str() ) );

  if ( hoste == NULL ) {
    return;
  }

  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);
  
  memcpy(&addr.sin_addr, hoste->h_addr_list[0], sizeof(addr.sin_addr));

  _good = true;
}

Socket::Address::Address( const Socket::Address& other )
{
  memcpy(&addr, &other.addr, sizeof(addr));
  _good = other._good;
}

Socket::Address&
Socket::Address::operator=( const Socket::Address& other )
{
  memcpy(&addr, &other.addr, sizeof(addr));
  _good = other._good;
  return *this;
}

/**
 * Convert the address to a string.
 */
std::_tstring
Socket::Address::toString() const 
{
  _TCHAR buffer[50];
  _stprintf(buffer, _T("%s:%d"), inet_ntoa(addr.sin_addr), 
	  (int)htons(addr.sin_port));
  return buffer;
}

/**
 * Returns true if the object represents a valid address.
 */
bool
Socket::Address::good() const
{
  return _good;
}


/**
 * Construct a socket.
 *
 * @param p The type of the socket must be Socket::STREAM (for TCP) or
 * Socket::DGRAM (for UDP).
 */
Socket::Socket(Socket::Type p) :
  _socket( 0 ),
  _state( IDLE ),
  trace(_T("Socket"))
{

#ifndef LINUX
    static bool init = false;
  if ( init == false ) {
        WORD wVersionRequested;
        WSADATA wsaData;
        int err;
         
        wVersionRequested = MAKEWORD( 2, 2 );
         
        err = WSAStartup( wVersionRequested, &wsaData );
        if ( err != 0 ) {
            /* Tell the user that we could not find a usable */
            /* WinSock DLL.                                  */
            throw Exception( _T("Could not find useable winsock dll") );
        }
        init = true;
  }
#endif

          
  // get protoent
  protoent* pent = getprotobyname("tcp");

  if ( 0 == pent ) {
    throw Exception(_T("Socket::Socket(): getprotobyname failed."));
  }

  // create a socket.
  _socket = socket( PF_INET, p == STREAM ? SOCK_STREAM : SOCK_DGRAM, 
		    p == STREAM ? IPPROTO_TCP : IPPROTO_UDP );
  
  // if failed, throw exception.
  if ( -1 == _socket ) {
    throw Exception((_tstring)_T("Socket::Socket(): socket() failed: ") + 
		    getLastError());
  }

  int on = 1;

  setsockopt( _socket, SOL_SOCKET, SO_REUSEADDR, (const char*)&on, 
          sizeof(on) );
  //fcntl(_socket, F_SETFL, O_NONBLOCK);

  // success.
}

/**
 * Construct a socket from an existing socket handle.
 */
Socket::Socket( SOCKET desc ) :
  _socket ( desc ),
  _state ( CONNECTED )
{
  trace << "Socket::Socket()" << endl;
}

Socket::~Socket()
{
  shutdown();

  // destroy the socket.
  close(_socket);
}

/**
 * Perform the accept operation. Accept returns a new socket that represets the
 * connection.
 */
Socket*
Socket::accept()
{
  sockaddr addr;
  socklen_t addrlen = sizeof(addr);

  // call impl.
  int result = EAGAIN;

  result = ::accept( _socket, &addr, &addrlen ); 

  // on error,
  if ( result == -1 )  
  {
    // have to try again.
    return 0;
  }

  // return new socket representing connection.
  return new Socket( result );
}

/**
 * Perform the select operation on the socket. It waits for data for the given
 * number of milliseconds.
 *
 * If the time passes without any data, the function returns false.
 *
 * @param ms maximum time to wait in milliseconds.
 */
bool 
Socket::waitForData(int ms )
{
  fd_set rfds;
  struct timeval tv;
  int retval;

  FD_ZERO(&rfds);
  FD_SET(_socket, &rfds);

  tv.tv_sec = ms / 1000;
  tv.tv_usec = ms % 1000;

  retval = ::select(_socket + 1, &rfds, NULL, NULL, &tv);

  return retval != 0;
}

/**
 * Binds a socket to listen to the specified port number.
 *
 * @param port The port number, between 1 and 65535.
 */
void
Socket::bind( int port )
{
  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);
  //addr.sin_addr = INADDR_ANY;
  
#ifdef LINUX
  inet_aton("0.0.0.0", &addr.sin_addr);
#else
  memset( &addr.sin_addr, 0, sizeof(addr.sin_addr) );
#endif

  trace << "socket::bind()" << endl;

  int result = ::bind( _socket, (sockaddr*)&addr, sizeof( sockaddr_in ) );

  if ( result == -1 ) {
    trace << "Couldn't bind to " << port << ": " << getLastError() << endl;
    throw Exception((_tstring)_T("Could not bind to port: ") + getLastError());
  }

  trace << "bind::bind() finish" << endl;

}

/**
 * Connect to the given address. Valid addresses are of the form hostname:port.
 */
bool
Socket::connect( const _TCHAR* hostport )
{
  _tstring hostname(hostport);
  _tstring::size_type i = hostname.find(_T(":"));

  if ( i == _tstring::npos ) {
    trace << "Socket: hostname contained no port: " << hostport << endl; 
    return false;
  }

  _tstring host = hostname.substr(0, i );
  int port = atoi(NarrowChar(hostname.substr(i + 1, hostname.length() - i - 1).c_str()));

  return connect(host.c_str(), port);
}

/**
 * Connect to the given hostname and port.
 */
bool
Socket::connect( const _TCHAR* host, int port )
{
  // if we are connected, 
  if ( _state == CONNECTED ) 
  {
    // disconnect.
    shutdown();
  }

  // look up the host name.
  hostent* hoste = ::gethostbyname( NarrowChar(host) );

  if ( hoste == NULL ) {
    return false;
  }

  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);
  
  memcpy(&addr.sin_addr, hoste->h_addr_list[0], sizeof(addr.sin_addr));

  trace << "Connecting to " << inet_ntoa(addr.sin_addr) << ":" 
       << port << endl;

  // attempt a connection.
  int result = ::connect( _socket, (sockaddr*)&addr, sizeof(addr));

  if ( result == -1 ) {
    return false;
  }

  _state = CONNECTED;

  return true;
}

/**
 * Perform the listen operation on the socket. Throws Exception on failure.
 */
void
Socket::listen()
{
  const unsigned int backlog = SOMAXCONN;

  if ( -1 == ::listen ( _socket, backlog ) ) {
    throw Exception(_T("Socket::listen() failed."));
  }
}

/**
 * Perform recv operation.
 */
int
Socket::recv( void* data, int length )
{
  int bytes = ::recv( _socket, (char*)data, length, MSG_NOSIGNAL );

  return bytes;
}

/**
 * Perform recvfrom operation.
 */
int
Socket::recvfrom( Address* pAddr, void* data, int length )
{
  socklen_t size = sizeof(pAddr->addr);
  int result = ::recvfrom(_socket, (char*)data, length, MSG_NOSIGNAL, 
			  (sockaddr*)&pAddr->addr, &size);
  
  if ( result != -1 ) {
    pAddr->_good = true;
  }

  return result; 
}

/**
 * Perform send operation.
 */
int
Socket::send( void* data, int length )
{
  int bytes = ::send( _socket, (const char*)data, length, MSG_NOSIGNAL );

  return bytes;
}

/**
 * Perform sendto operation.
 */
int 
Socket::sendto( const Address* paddr, void* data, int length ) 
{
  if ( !paddr->good()) {
    trace << "Socket::sendto: Address is bad." << endl;
    return -1;
  }

  int result = ::sendto(_socket, (const char*)data, length, 0, 
			(sockaddr*)&paddr->addr, 
			sizeof(paddr->addr));

  trace << " sent " << result << " bytes." << endl;

  if ( result == -1 ) {
    trace << getLastError() << endl;
  }

  return result;
}

void
Socket::shutdown()
{
  if ( _state == IDLE ) {
    return ;
  }

  ::shutdown( _socket, 2 );

  return;

}

/**
 * Return a string representing the last error.
 */
std::_tstring Socket::getLastError()
{
#ifdef _UNICODE
  return std::_tstring(WideChar(strerror(errno)));
#else
  return std::_tstring(strerror(errno));
#endif
}

bool
Socket::connected()
{
  return _state == CONNECTED;
}

class SocketServerTestThread : public Thread
{
public:
  SocketServerTestThread(Condition& cond, Mutex& mutex) : 
    cond(cond),
    mutex(mutex)
  {}

  Condition& cond;
  Mutex& mutex;

  virtual void run()
  {
    trace << "Server is running." << endl;
    try {
      Socket socket;
      {
	Mutex::Scope lock(mutex);
	trace << "About to bind to socket." << endl;
	socket.bind(7302);
	trace << "Server about to listen." << endl;
	socket.listen();

	trace << "Server about to accept." << endl;
	cond.signal();
      }
      socket.waitForData(5000);
      Socket* newSocket = socket.accept();

      trace << "Server about to receive data." << endl;
      newSocket->waitForData(5000);
      unsigned char buff[80];
      buff[79] = 0;
      int bytes = newSocket->recv(buff, 80);

      if ( bytes == -1 ) {
	trace << newSocket->getLastError() << endl;
      }

      trace << "Server received " << bytes << " bytes." << endl;
      trace << "Server got: " << (char*)buff << endl;
      delete newSocket;
    } catch( Exception e ) {
      trace << "Server got an exception: " << e.msg << endl;
      cond.signal();
    }
  }
};

bool
Socket::UnitTest()
{  
  Condition cond;
  Mutex mutex;
  SocketServerTestThread server(cond, mutex);

  mutex.lock(DBG);
  server.start();
  cond.wait(mutex, DBG);
  mutex.unlock(DBG);

  try {

    Socket socket;
    ::trace << "Client is connecting." << endl;
    if ( !socket.connect(_T("127.0.0.1"), 7302) ) {
      throw Exception((_tstring)_T("Client could not connect to host: ") + 
		      socket.getLastError());
    }

    ::trace << "Client about to send data." << endl;
    fflush(0);
    Thread::sleep(1000);
    const char* msg = "The quick brown fox jumped over the lazy dog.";
    int sent = socket.send((void*)msg, strlen(msg));
    if ( sent == -1 ) {
      ::trace << socket.getLastError() << endl;
    }

    fflush(0);
  } catch ( Exception e ) {
    ::trace << "Client got exception: " << e.msg << endl;
  }

  server.join();
  
  // now lets test connectionless.
  try {
    Socket sender(Socket::DGRAM);
    Socket receiver(Socket::DGRAM);
    
    receiver.bind(7301);

    Socket::Address addr(_T("localhost:7301"));
    const char* message = "Kate thinks you're cute.";
    if ( 0 >= sender.sendto(&addr, (void*)message, strlen(message)+1) ) {
      throw Exception((_tstring)_T("Socket test failed: ") +
		      _T("Could not send UDP. Error: ")
		      + sender.getLastError());
    }

    if ( !receiver.waitForData(5000)) {
      throw Exception((_tstring)_T("Socket test failed: ") +
		      _T("Did not get UDP sent to self. Error: ")
		      + receiver.getLastError());
    }

    char buff[80];
    buff[0] = 0;
    Socket::Address addrfrom;
    if ( 0 >= receiver.recvfrom(&addrfrom, buff, sizeof(buff) ) ) {
      throw Exception((_tstring)_T("Socket test failure: ") + 
		      _T("Could not read UDP packet: ") +
		      receiver.getLastError());
    }

    ::trace << "Got UDP message: " << buff << endl;
  } catch ( Exception e ) {
    ::trace << "Got exception: " << e.msg << endl;
  }

  return true;
}

