pps: do not crash when failed to register
[pandora-kernel.git] / drivers / vhost / net.c
index e224a92..7ef84c1 100644 (file)
@@ -12,6 +12,7 @@
 #include <linux/virtio_net.h>
 #include <linux/miscdevice.h>
 #include <linux/module.h>
+#include <linux/moduleparam.h>
 #include <linux/mutex.h>
 #include <linux/workqueue.h>
 #include <linux/rcupdate.h>
 
 #include "vhost.h"
 
+static int experimental_zcopytx;
+module_param(experimental_zcopytx, int, 0444);
+MODULE_PARM_DESC(experimental_zcopytx, "Enable Experimental Zero Copy TX");
+
 /* Max number of bytes transferred before requeueing the job.
  * Using this limit prevents one virtqueue from starving others. */
 #define VHOST_NET_WEIGHT 0x80000
 
+/* MAX number of TX used buffers for outstanding zerocopy */
+#define VHOST_MAX_PEND 128
+#define VHOST_GOODCOPY_LEN 256
+
 enum {
        VHOST_NET_VQ_RX = 0,
        VHOST_NET_VQ_TX = 1,
@@ -54,6 +63,12 @@ struct vhost_net {
        enum vhost_net_poll_state tx_poll_state;
 };
 
+static bool vhost_sock_zcopy(struct socket *sock)
+{
+       return unlikely(experimental_zcopytx) &&
+               sock_flag(sock->sk, SOCK_ZEROCOPY);
+}
+
 /* Pop first len bytes from iovec. Return number of segments used. */
 static int move_iovec_hdr(struct iovec *from, struct iovec *to,
                          size_t len, int iov_count)
@@ -129,6 +144,8 @@ static void handle_tx(struct vhost_net *net)
        int err, wmem;
        size_t hdr_size;
        struct socket *sock;
+       struct vhost_ubuf_ref *uninitialized_var(ubufs);
+       bool zcopy;
 
        /* TODO: check that we are running from vhost_worker? */
        sock = rcu_dereference_check(vq->private_data, 1);
@@ -149,8 +166,13 @@ static void handle_tx(struct vhost_net *net)
        if (wmem < sock->sk->sk_sndbuf / 2)
                tx_poll_stop(net);
        hdr_size = vq->vhost_hlen;
+       zcopy = vhost_sock_zcopy(sock);
 
        for (;;) {
+               /* Release DMAs done buffers first */
+               if (zcopy)
+                       vhost_zerocopy_signal_used(vq);
+
                head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
                                         ARRAY_SIZE(vq->iov),
                                         &out, &in,
@@ -160,12 +182,25 @@ static void handle_tx(struct vhost_net *net)
                        break;
                /* Nothing new?  Wait for eventfd to tell us they refilled. */
                if (head == vq->num) {
+                       int num_pends;
+
                        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
                        if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
                                tx_poll_start(net, sock);
                                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
                                break;
                        }
+                       /* If more outstanding DMAs, queue the work.
+                        * Handle upend_idx wrap around
+                        */
+                       num_pends = likely(vq->upend_idx >= vq->done_idx) ?
+                                   (vq->upend_idx - vq->done_idx) :
+                                   (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
+                       if (unlikely(num_pends > VHOST_MAX_PEND)) {
+                               tx_poll_start(net, sock);
+                               set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
+                               break;
+                       }
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                vhost_disable_notify(&net->dev, vq);
                                continue;
@@ -188,9 +223,40 @@ static void handle_tx(struct vhost_net *net)
                               iov_length(vq->hdr, s), hdr_size);
                        break;
                }
+               /* use msg_control to pass vhost zerocopy ubuf info to skb */
+               if (zcopy) {
+                       vq->heads[vq->upend_idx].id = head;
+                       if (len < VHOST_GOODCOPY_LEN) {
+                               /* copy don't need to wait for DMA done */
+                               vq->heads[vq->upend_idx].len =
+                                                       VHOST_DMA_DONE_LEN;
+                               msg.msg_control = NULL;
+                               msg.msg_controllen = 0;
+                               ubufs = NULL;
+                       } else {
+                               struct ubuf_info *ubuf;
+                               ubuf = vq->ubuf_info + vq->upend_idx;
+
+                               vq->heads[vq->upend_idx].len = len;
+                               ubuf->callback = vhost_zerocopy_callback;
+                               ubuf->arg = vq->ubufs;
+                               ubuf->desc = vq->upend_idx;
+                               msg.msg_control = ubuf;
+                               msg.msg_controllen = sizeof(ubuf);
+                               ubufs = vq->ubufs;
+                               kref_get(&ubufs->kref);
+                       }
+                       vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
+               }
                /* TODO: Check specific error and bomb out unless ENOBUFS? */
                err = sock->ops->sendmsg(NULL, sock, &msg, len);
                if (unlikely(err < 0)) {
+                       if (zcopy) {
+                               if (ubufs)
+                                       vhost_ubuf_put(ubufs);
+                               vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
+                                       UIO_MAXIOV;
+                       }
                        vhost_discard_vq_desc(vq, 1);
                        tx_poll_start(net, sock);
                        break;
@@ -198,7 +264,8 @@ static void handle_tx(struct vhost_net *net)
                if (err != len)
                        pr_debug("Truncated TX packet: "
                                 " len %d != %zd\n", err, len);
-               vhost_add_used_and_signal(&net->dev, vq, head, 0);
+               if (!zcopy)
+                       vhost_add_used_and_signal(&net->dev, vq, head, 0);
                total_len += len;
                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
                        vhost_poll_queue(&vq->poll);
@@ -252,9 +319,13 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
                        r = -ENOBUFS;
                        goto err;
                }
-               d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
+               r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
                                      ARRAY_SIZE(vq->iov) - seg, &out,
                                      &in, log, log_num);
+               if (unlikely(r < 0))
+                       goto err;
+
+               d = r;
                if (d == vq->num) {
                        r = 0;
                        goto err;
@@ -279,6 +350,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
        *iovcount = seg;
        if (unlikely(log))
                *log_num = nlogs;
+
+       /* Detect overrun */
+       if (unlikely(datalen > 0)) {
+               r = UIO_MAXIOV + 1;
+               goto err;
+       }
        return headcount;
 err:
        vhost_discard_vq_desc(vq, headcount);
@@ -305,7 +382,8 @@ static void handle_rx(struct vhost_net *net)
                .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
        };
        size_t total_len = 0;
-       int err, headcount, mergeable;
+       int err, mergeable;
+       s16 headcount;
        size_t vhost_hlen, sock_hlen;
        size_t vhost_len, sock_len;
        /* TODO: check that we are running from vhost_worker? */
@@ -332,6 +410,14 @@ static void handle_rx(struct vhost_net *net)
                /* On error, stop handling until the next kick. */
                if (unlikely(headcount < 0))
                        break;
+               /* On overrun, truncate and discard */
+               if (unlikely(headcount > UIO_MAXIOV)) {
+                       msg.msg_iovlen = 1;
+                       err = sock->ops->recvmsg(NULL, sock, &msg,
+                                                1, MSG_DONTWAIT | MSG_TRUNC);
+                       pr_debug("Discarded rx packet: len %zd\n", sock_len);
+                       continue;
+               }
                /* OK, now we need to know about added descriptors. */
                if (!headcount) {
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -603,6 +689,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 {
        struct socket *sock, *oldsock;
        struct vhost_virtqueue *vq;
+       struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
        int r;
 
        mutex_lock(&n->dev.mutex);
@@ -632,13 +719,31 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
        oldsock = rcu_dereference_protected(vq->private_data,
                                            lockdep_is_held(&vq->mutex));
        if (sock != oldsock) {
+               ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
+               if (IS_ERR(ubufs)) {
+                       r = PTR_ERR(ubufs);
+                       goto err_ubufs;
+               }
+               oldubufs = vq->ubufs;
+               vq->ubufs = ubufs;
                vhost_net_disable_vq(n, vq);
                rcu_assign_pointer(vq->private_data, sock);
                vhost_net_enable_vq(n, vq);
+
+               r = vhost_init_used(vq);
+               if (r)
+                       goto err_vq;
        }
 
        mutex_unlock(&vq->mutex);
 
+       if (oldubufs) {
+               vhost_ubuf_put_and_wait(oldubufs);
+               mutex_lock(&vq->mutex);
+               vhost_zerocopy_signal_used(vq);
+               mutex_unlock(&vq->mutex);
+       }
+
        if (oldsock) {
                vhost_net_flush_vq(n, index);
                fput(oldsock->file);
@@ -647,6 +752,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
        mutex_unlock(&n->dev.mutex);
        return 0;
 
+err_ubufs:
+       fput(sock->file);
 err_vq:
        mutex_unlock(&vq->mutex);
 err:
@@ -776,6 +883,8 @@ static struct miscdevice vhost_net_misc = {
 
 static int vhost_net_init(void)
 {
+       if (experimental_zcopytx)
+               vhost_enable_zcopy(VHOST_NET_VQ_TX);
        return misc_register(&vhost_net_misc);
 }
 module_init(vhost_net_init);