(module zmq

        (zmq-default-context zmq-io-threads zmq-version
         make-context terminate-context context?
         make-socket socket? close-socket bind-socket connect-socket
         socket-option-set! socket-option socket-fd socket-pointer
         send-message receive-message receive-message*
         make-poll-item poll poll-item-socket
         poll-item-fd poll-item-in? poll-item-out? poll-item-error?
         curve-keypair)

        (import scheme (chicken base) (chicken foreign)
                (chicken bitwise) (chicken memory) (chicken blob)
                (chicken memory representation) (chicken gc)
                (chicken format)
                srfi-1 srfi-4 srfi-18 srfi-13 foreigners)

(import-for-syntax srfi-1)

(foreign-declare "#include <zmq.h>")
(foreign-declare "#include <errno.h>")

(define-record context pointer sockets)
(define-foreign-type context c-pointer)

(define-foreign-type message (c-pointer "zmq_msg_t"))

(define-record socket pointer mutex message)
(define-foreign-type socket c-pointer)

(define-foreign-enum-type (socket-type int)
  (socket-type->int int->socket-type)
  ((pair) ZMQ_PAIR)
  ((pub) ZMQ_PUB)
  ((sub) ZMQ_SUB)
  ((req) ZMQ_REQ)
  ((rep) ZMQ_REP)
  ((xreq) ZMQ_XREQ)
  ((xrep) ZMQ_XREP)
  ((pull) ZMQ_PULL)
  ((push) ZMQ_PUSH))

(define-foreign-enum-type (socket-option int)
  (socket-option->int int->socket-option)
  ((affinity) ZMQ_AFFINITY)
  ((sndhwm) ZMQ_SNDHWM)
  ((rcvhwm) ZMQ_RCVHWM)
  ((identity) ZMQ_IDENTITY)
  ((subscribe) ZMQ_SUBSCRIBE)
  ((unsubscribe) ZMQ_UNSUBSCRIBE)
  ((rate) ZMQ_RATE)
  ((recovery-ivl) ZMQ_RECOVERY_IVL)
  ((sndbuf) ZMQ_SNDBUF)
  ((rcvbuf) ZMQ_RCVBUF)
  ((rcvmore) ZMQ_RCVMORE)
  ((fd) ZMQ_FD)
  ((curve/pubkey) ZMQ_CURVE_PUBLICKEY)
  ((curve/pubkey-blob) ZMQ_CURVE_PUBLICKEY)
  ((curve/secretkey) ZMQ_CURVE_SECRETKEY)
  ((curve/secretkey-blob) ZMQ_CURVE_SECRETKEY)
  ((curve/server) ZMQ_CURVE_SERVER)
  ((curve/serverkey) ZMQ_CURVE_SERVERKEY)
  ((curve/serverkey-blob) ZMQ_CURVE_SERVERKEY))


(define socket-options
  '((integer sndhwm rcvhwm affinity rate recovery-ivl sndbuf rcvbuf)
    (boolean rcvmore)
    (string subscribe unsubscribe identity)))

(define-foreign-enum-type (socket-flag int)
  (socket-flag->int int->socket-flag)
  ((noblock zmq/noblock) ZMQ_NOBLOCK)
  ((sndmore zmq/sndmore) ZMQ_SNDMORE))

(define-foreign-enum-type (poll-flag short)
  (poll-flat->int short->poll-int)
  ((in zmq/pollin) ZMQ_POLLIN)
  ((out zmq/pollout) ZMQ_POLLOUT)
  ((err zmq/pollerr) ZMQ_POLLERR))

(define-record poll-item pointer socket in out)
(define-foreign-record-type (poll-item zmq_pollitem_t)
  (constructor: make-foreign-poll-item)
  (destructor: free-foreign-poll-item)
  (socket socket %poll-item-socket %poll-item-socket-set!)
  (int fd %poll-item-fd %poll-item-fd-set!)
  (short events %poll-item-events %poll-item-events-set!)
  (short revents %poll-item-revents %poll-item-revents-set!))

(define-foreign-enum-type (errno int)
  (errno->int int->errno)
  ((again) EAGAIN)
  ((term) ETERM))

;; helpers

(define (zmq-error location)
  (let ((errno (foreign-value errno int)))
    (error location
           ((foreign-lambda c-string zmq_strerror int) errno)
           errno)))

(define (errno)
  (foreign-value errno errno))

(define (type-error value expected-type)
  (error (format "invalid value: ~S (expected ~A)" value expected-type)))

(define (zmq-version)
  (let-location ((major int) (minor int) (patch int))
    ((foreign-lambda void zmq_version (c-pointer int) (c-pointer int) (c-pointer int))
     (location major) (location minor) (location patch))
    (list major minor patch)))

;; contexts

(define zmq-io-threads (make-parameter 1))

(define zmq-default-context (make-parameter #f))

(define (zmq-default-context/initialize)
  (or (zmq-default-context)
      (begin (zmq-default-context (make-context (zmq-io-threads)))
             (zmq-default-context))))

(define %make-context make-context)

(define (make-context io-threads)
  (let ((c (%make-context ((foreign-lambda context zmq_init int) io-threads)
                          (make-mutex))))
    (if (not (context-pointer c))
        (zmq-error 'make-context)
        (begin
          (mutex-specific-set! (context-sockets c) '())
          (set-finalizer! c (lambda (c)
                              (for-each close-socket (mutex-specific (context-sockets c)))
                              (terminate-context c)))))))

(define (terminate-context ctx)
  (or (zero? ((foreign-lambda int zmq_term context)
              (context-pointer ctx)))
      (zmq-error 'terminate-context)))

;; messages

(define (initialize-message message #!optional data)
  (if (zero? (if data
                 (begin
                  (unless (or (string? data) (blob? data))
                    (type-error data "string or blob"))
                  (let* ((len (number-of-bytes data))
                         (cdata (allocate len)))
                    ((foreign-lambda void "C_memcpy" c-pointer scheme-pointer int)
                     cdata data len)
                    ((foreign-lambda int
                                     zmq_msg_init_data
                                     message
                                     c-pointer
                                     unsigned-int
                                     c-pointer
                                     c-pointer)
                     message
                     cdata
                     len
                     (foreign-value "C_free" c-pointer)
                     #f)))
                 ((foreign-lambda int zmq_msg_init message) message)))
      message
      (zmq-error 'initialize-message)))

(define (close-message message)
  (or (zero? ((foreign-lambda int zmq_msg_close message) message))
      (zmq-error 'close-message)))

(define (message-size message)
  ((foreign-lambda unsigned-integer zmq_msg_size message) message))

(define (message-data message type)
  (let* ((size (message-size message))
         (ptr ((foreign-lambda c-pointer zmq_msg_data message) message)))

    (cond ((symbol? type)
           (case type
             ((string)
              (let ((str (make-string size)))
                (move-memory! ptr str size)
                str))
             ((blob)
              (let ((blob (make-blob size)))
                (move-memory! ptr blob size)
                blob))
             (else (error 'message-data "invalid message data type" type))))
          ((procedure? type)
           (type ptr size))
          (else (error 'message-data "invalid message data type" type)))))

;; sockets

(define %make-socket make-socket)

(define (make-socket type #!optional (context (zmq-default-context/initialize)))
  (let ((sp ((foreign-lambda socket zmq_socket context socket-type)
             (context-pointer context) type)))
    (if (not sp)
        (zmq-error 'make-socket)
        (let ((m (context-sockets context))
              (s (%make-socket sp
                               (make-mutex)
                               (allocate (foreign-value "sizeof(zmq_msg_t)" int)))))

          (mutex-lock! m)
          (mutex-specific-set! m (cons sp (mutex-specific m)))
          (mutex-unlock! m)
          (set-finalizer! s (lambda (s)
                              (free (socket-message s))
                              (close-socket s)))))))

(define (close-socket socket)
  (let ((sp (cond ((socket? socket) (socket-pointer socket))
                  ((pointer? socket) socket)
                  (else (type-error socket 'socket)))))

    (when sp
      (if (zero? ((foreign-lambda int zmq_close socket) sp))
          (when (socket? socket) (socket-pointer-set! socket #f))
          (zmq-error 'close-socket)))))

(define (bind-socket socket endpoint)
  (or (zero? ((foreign-lambda int zmq_bind socket c-string)
              (socket-pointer socket)
              endpoint))
      (zmq-error 'bind-socket)))

(define (connect-socket socket endpoint)
  (or (zero? ((foreign-lambda int zmq_connect socket c-string)
              (socket-pointer socket)
              endpoint))
      (zmq-error 'connect-socket)))

;; integer64 is used instead of unsigned-integer64 for uint64_t
;; options since the latter has only been added to the experimental
;; branch recently. Also, we must use foreign-lambda* to be able to
;; pass in integer64 values because let-location doesn't accept
;; integer64 (also fixed in experimental)

(define (socket-option-set! socket option value)
  (or (zero? (case option
               ((rcvhwm sndhwm affinity sndbuf rcvbuf rate recovery-ivl curve/server )
                (if (integer? value)
                    ((foreign-safe-lambda* int
                                           ((scheme-object error)
                                            (scheme-object error_location)
                                            (socket socket)
                                            (socket-option option)
                                            (int value))
                                           "size_t size = sizeof(value);
                                            int status = zmq_setsockopt(socket, option, &value, size);
                                           if (status == 0) {
                                             C_return(0);
                                           } else {
                                             C_save(error_location);
                                             C_callback(error, 1);
                                           }")
                     zmq-error 'socket-option-set! (socket-pointer socket) option value)
                    (type-error value 'integer)))

               ((identity subscribe unsubscribe curve/pubkey curve/secretkey curve/serverkey)
                (if (string? value)
                    (let ((status ((foreign-lambda int zmq_setsockopt socket socket-option c-string unsigned-int)
                                   (socket-pointer socket) option value (number-of-bytes value))))
                      (if (not (zero? status)) (zmq-error 'socket-option-set!) status))
                    (type-error value 'string)))

               ((curve/pubkey-blob curve/secretkey-blob curve/serverkey-blob)
                (if (blob? value)
                    (let ((status ((foreign-lambda int zmq_setsockopt socket socket-option blob unsigned-int)
                                   (socket-pointer socket) option value (blob-size value))))
                      (if (not (zero? status)) (zmq-error 'socket-option-set!) status))
                    (type-error value 'blob)))

               (else (error (format "unknown socket option: ~A" option)))))
      (zmq-error 'socket-option-set!)))

(define-syntax %socket-option
  (er-macro-transformer
  (lambda (e r c)
    (let ((location (second e))
          (f-type (third e))
          (c-type (fourth e))
          (socket (fifth e))
          (option (sixth e)))
      `((,(r 'foreign-safe-lambda*) ,f-type ((scheme-object error)
                                             (scheme-object error_location)
                                             (socket socket)
                                             (socket-option option))
         ,(string-append c-type " value;
                                  size_t size = sizeof(value);
                                  int status = zmq_getsockopt(socket, option, &value, &size);
                                  if (status == 0) {
                                    C_return(value);
                                  } else {
                                    C_save(error_location);
                                    C_callback(error, 1);
                                  }"))
        ,(r 'zmq-error) ,location (,(r 'socket-pointer) ,socket) ,option)))))

(define (socket-fd socket)
  (%socket-option 'socket-fd int "int" socket 'fd))

(define socket-option
  (let ((routing-id (make-string 255)))
    (lambda (socket option)
      (case option
        ((identity)
         (let-location
          ((size unsigned-integer64 255))
          (if (zero? ((foreign-lambda int zmq_getsockopt socket socket-option scheme-pointer
                                      (c-pointer unsigned-integer64))
                      (socket-pointer socket) option routing-id (location size)))
              (substring routing-id 0 size)
              (zmq-error 'socket-option))))
        (else
         (cond

          ((memq option (alist-ref 'integer socket-options))
           (%socket-option 'socket-option int "int" socket option))

          ((memq option (alist-ref 'boolean socket-options))
           (%socket-option 'socket-option bool "int" socket option))

          (else
           (error (format "socket option ~A is not retrievable" option)))))
        ))
    ))


;; communication

(define (send-message socket data #!key non-blocking send-more)
  (mutex-lock! (socket-mutex socket))
  (let* ((message (initialize-message (socket-message socket) data))
         (result ((foreign-lambda int zmq_msg_send message socket int)
                  message
                  (socket-pointer socket)
                  (bitwise-ior (if non-blocking zmq/noblock  0)
                               (if send-more zmq/sndmore 0)))))

    (close-message message)
    (mutex-unlock! (socket-mutex socket))
    (if (< result 0) (zmq-error 'send-message))))


(define (receive-message socket #!key non-blocking (as 'string))
  (mutex-lock! (socket-mutex socket))
  (let* ((message (initialize-message (socket-message socket)))
         (result ((foreign-lambda int zmq_msg_recv message socket int)
                  message
                  (socket-pointer socket)
                  (if non-blocking zmq/noblock 0))))

    (if (>= result 0)
        (let ((data (message-data message as)))
          (mutex-unlock! (socket-mutex socket))
          (close-message message)
          data)
        (begin
          (mutex-unlock! (socket-mutex socket))
          (close-message message)
          (if (memq (errno) '(again term))
              #f
              (zmq-error 'receive-message))))))

(define (receive-message* socket #!key (as 'string))
  (or (receive-message socket non-blocking: #t as: as)
      (begin
        (thread-wait-for-i/o! (socket-fd socket) #:input)
        (receive-message* socket as: as))))

;; polling

(define %make-poll-item make-poll-item)

(define (make-poll-item socket/fd #!key in out)
  (let ((item (%make-poll-item (make-foreign-poll-item)
                               (and (socket? socket/fd) socket/fd)
                               in out)))
    (if (socket? socket/fd)
        (%poll-item-socket-set! (poll-item-pointer item) (socket-pointer socket/fd))
        (%poll-item-fd-set! (poll-item-pointer item) socket/fd))

    (%poll-item-events-set! (poll-item-pointer item)
                            (bitwise-ior (if in zmq/pollin 0)
                                         (if out zmq/pollout 0)))

    (%poll-item-revents-set! (poll-item-pointer item) 0)

    (set-finalizer! item (lambda (i)
                           (free-foreign-poll-item (poll-item-pointer i))))))

(define (poll-item-fd item)
  (%poll-item-fd (poll-item-pointer item)))

(define (poll-item-revents item)
  (%poll-item-revents (poll-item-pointer item)))

(define (poll-item-in? item)
  (not (zero? (bitwise-and zmq/pollin (poll-item-revents item)))))

(define (poll-item-out? item)
  (not (zero? (bitwise-and zmq/pollout (poll-item-revents item)))))

(define (poll-item-error? item)
  (not (zero? (bitwise-and zmq/pollerr (poll-item-revents item)))))

(define %poll-sockets
  (foreign-safe-lambda* int
                        ((scheme-object poll_item_ref)
                         (unsigned-int length)
                         (long timeout))
                        "zmq_pollitem_t items[length];
                         zmq_pollitem_t *item_ptrs[length];
                         int i;

                         for (i = 0; i < length; i++) {
                           C_save(C_fix(i));
                           item_ptrs[i] = (zmq_pollitem_t *)C_pointer_address(C_callback(poll_item_ref, 1));
                         }

                         for (i = 0; i < length; i++) {
                           items[i] = *item_ptrs[i];
                         }

                         int rc = zmq_poll(items, length, timeout);

                         if (rc != -1) {
                           for (i = 0; i < length; i++) {
                             (*item_ptrs[i]).revents = items[i].revents;
                           }
                         }

                         C_return(rc);"))

(define (poll poll-items timeout/block)
  (if (null? poll-items)
      (error 'poll "null list passed for poll-items")
      (let ((result (%poll-sockets (lambda (i)
                                     (poll-item-pointer (list-ref poll-items i)))
                                   (length poll-items)
                                   (case timeout/block
                                     ((#f) 0)
                                     ((#t) -1)
                                     (else timeout/block)))))
        (if (= result -1)
            (zmq-error 'poll)
            result))))

(define (curve-keypair)
  (let-values (((pk sk)
                ((foreign-primitive ()
                     "char public_key [41];
                      char secret_key [41];
                      int rc = zmq_curve_keypair (public_key, secret_key);

                      C_word* pkbuf = C_alloc(41);
                      C_word* skbuf = C_alloc(41);
                      C_word pkstr;
                      C_word skstr;

                      pkstr = C_string2(&pkbuf, public_key);
                      skstr = C_string2(&skbuf, secret_key);

                      C_word vals[4] = { C_SCHEME_UNDEFINED, C_k, pkstr, skstr };
                      C_values(4, vals);\n"))))
    (values pk sk)))
)
