ipv4: Avoid reading user iov twice after raw_probe_proto_opt
[pandora-kernel.git] / net / sunrpc / xprtsock.c
index 2d78d95..3a82fec 100644 (file)
@@ -254,7 +254,6 @@ struct sock_xprt {
        void                    (*old_data_ready)(struct sock *, int);
        void                    (*old_state_change)(struct sock *);
        void                    (*old_write_space)(struct sock *);
-       void                    (*old_error_report)(struct sock *);
 };
 
 /*
@@ -391,8 +390,10 @@ static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen,
        return kernel_sendmsg(sock, &msg, NULL, 0, 0);
 }
 
-static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more)
+static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more, bool zerocopy)
 {
+       ssize_t (*do_sendpage)(struct socket *sock, struct page *page,
+                       int offset, size_t size, int flags);
        struct page **ppage;
        unsigned int remainder;
        int err, sent = 0;
@@ -401,6 +402,9 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i
        base += xdr->page_base;
        ppage = xdr->pages + (base >> PAGE_SHIFT);
        base &= ~PAGE_MASK;
+       do_sendpage = sock->ops->sendpage;
+       if (!zerocopy)
+               do_sendpage = sock_no_sendpage;
        for(;;) {
                unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder);
                int flags = XS_SENDMSG_FLAGS;
@@ -408,7 +412,7 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i
                remainder -= len;
                if (remainder != 0 || more)
                        flags |= MSG_MORE;
-               err = sock->ops->sendpage(sock, *ppage, base, len, flags);
+               err = do_sendpage(sock, *ppage, base, len, flags);
                if (remainder == 0 || err != len)
                        break;
                sent += err;
@@ -429,9 +433,10 @@ static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned i
  * @addrlen: UDP only -- length of destination address
  * @xdr: buffer containing this request
  * @base: starting position in the buffer
+ * @zerocopy: true if it is safe to use sendpage()
  *
  */
-static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base)
+static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base, bool zerocopy)
 {
        unsigned int remainder = xdr->len - base;
        int err, sent = 0;
@@ -459,7 +464,7 @@ static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen,
        if (base < xdr->page_len) {
                unsigned int len = xdr->page_len - base;
                remainder -= len;
-               err = xs_send_pagedata(sock, xdr, base, remainder != 0);
+               err = xs_send_pagedata(sock, xdr, base, remainder != 0, zerocopy);
                if (remainder == 0 || err != len)
                        goto out;
                sent += err;
@@ -496,7 +501,8 @@ static int xs_nospace(struct rpc_task *task)
        struct rpc_rqst *req = task->tk_rqstp;
        struct rpc_xprt *xprt = req->rq_xprt;
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
-       int ret = 0;
+       struct sock *sk = transport->inet;
+       int ret = -EAGAIN;
 
        dprintk("RPC: %5u xmit incomplete (%u left of %u)\n",
                        task->tk_pid, req->rq_slen - req->rq_bytes_sent,
@@ -508,13 +514,12 @@ static int xs_nospace(struct rpc_task *task)
        /* Don't race with disconnect */
        if (xprt_connected(xprt)) {
                if (test_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags)) {
-                       ret = -EAGAIN;
                        /*
                         * Notify TCP that we're limited by the application
                         * window size
                         */
                        set_bit(SOCK_NOSPACE, &transport->sock->flags);
-                       transport->inet->sk_write_pending++;
+                       sk->sk_write_pending++;
                        /* ...and wait for more buffer space */
                        xprt_wait_for_buffer_space(task, xs_nospace_callback);
                }
@@ -524,6 +529,9 @@ static int xs_nospace(struct rpc_task *task)
        }
 
        spin_unlock_bh(&xprt->transport_lock);
+
+       /* Race breaker in case memory is freed before above code is called */
+       sk->sk_write_space(sk);
        return ret;
 }
 
@@ -563,7 +571,7 @@ static int xs_local_send_request(struct rpc_task *task)
                        req->rq_svec->iov_base, req->rq_svec->iov_len);
 
        status = xs_sendpages(transport->sock, NULL, 0,
-                                               xdr, req->rq_bytes_sent);
+                                               xdr, req->rq_bytes_sent, true);
        dprintk("RPC:       %s(%u) = %d\n",
                        __func__, xdr->len - req->rq_bytes_sent, status);
        if (likely(status >= 0)) {
@@ -619,7 +627,7 @@ static int xs_udp_send_request(struct rpc_task *task)
        status = xs_sendpages(transport->sock,
                              xs_addr(xprt),
                              xprt->addrlen, xdr,
-                             req->rq_bytes_sent);
+                             req->rq_bytes_sent, true);
 
        dprintk("RPC:       xs_udp_send_request(%u) = %d\n",
                        xdr->len - req->rq_bytes_sent, status);
@@ -690,6 +698,7 @@ static int xs_tcp_send_request(struct rpc_task *task)
        struct rpc_xprt *xprt = req->rq_xprt;
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
        struct xdr_buf *xdr = &req->rq_snd_buf;
+       bool zerocopy = true;
        int status;
 
        xs_encode_stream_record_marker(&req->rq_snd_buf);
@@ -697,13 +706,20 @@ static int xs_tcp_send_request(struct rpc_task *task)
        xs_pktdump("packet data:",
                                req->rq_svec->iov_base,
                                req->rq_svec->iov_len);
+       /* Don't use zero copy if this is a resend. If the RPC call
+        * completes while the socket holds a reference to the pages,
+        * then we may end up resending corrupted data.
+        */
+       if (task->tk_flags & RPC_TASK_SENT)
+               zerocopy = false;
 
        /* Continue transmitting the packet/record. We must be careful
         * to cope with writespace callbacks arriving _after_ we have
         * called sendmsg(). */
        while (1) {
                status = xs_sendpages(transport->sock,
-                                       NULL, 0, xdr, req->rq_bytes_sent);
+                                       NULL, 0, xdr, req->rq_bytes_sent,
+                                       zerocopy);
 
                dprintk("RPC:       xs_tcp_send_request(%u) = %d\n",
                                xdr->len - req->rq_bytes_sent, status);
@@ -738,10 +754,10 @@ static int xs_tcp_send_request(struct rpc_task *task)
                dprintk("RPC:       sendmsg returned unrecognized error %d\n",
                        -status);
        case -ECONNRESET:
-       case -EPIPE:
                xs_tcp_shutdown(xprt);
        case -ECONNREFUSED:
        case -ENOTCONN:
+       case -EPIPE:
                clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags);
        }
 
@@ -782,7 +798,6 @@ static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk)
        transport->old_data_ready = sk->sk_data_ready;
        transport->old_state_change = sk->sk_state_change;
        transport->old_write_space = sk->sk_write_space;
-       transport->old_error_report = sk->sk_error_report;
 }
 
 static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk)
@@ -790,13 +805,13 @@ static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *s
        sk->sk_data_ready = transport->old_data_ready;
        sk->sk_state_change = transport->old_state_change;
        sk->sk_write_space = transport->old_write_space;
-       sk->sk_error_report = transport->old_error_report;
 }
 
 static void xs_reset_transport(struct sock_xprt *transport)
 {
        struct socket *sock = transport->sock;
        struct sock *sk = transport->inet;
+       struct rpc_xprt *xprt = &transport->xprt;
 
        if (sk == NULL)
                return;
@@ -810,6 +825,7 @@ static void xs_reset_transport(struct sock_xprt *transport)
        sk->sk_user_data = NULL;
 
        xs_restore_old_callbacks(transport, sk);
+       xprt_clear_connected(xprt);
        write_unlock_bh(&sk->sk_callback_lock);
 
        sk->sk_no_check = 0;
@@ -1029,6 +1045,16 @@ static void xs_udp_data_ready(struct sock *sk, int len)
        read_unlock_bh(&sk->sk_callback_lock);
 }
 
+/*
+ * Helper function to force a TCP close if the server is sending
+ * junk and/or it has put us in CLOSE_WAIT
+ */
+static void xs_tcp_force_close(struct rpc_xprt *xprt)
+{
+       set_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
+       xprt_force_disconnect(xprt);
+}
+
 static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc)
 {
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
@@ -1055,7 +1081,7 @@ static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_rea
        /* Sanity check of the record length */
        if (unlikely(transport->tcp_reclen < 8)) {
                dprintk("RPC:       invalid TCP record fragment length\n");
-               xprt_force_disconnect(xprt);
+               xs_tcp_force_close(xprt);
                return;
        }
        dprintk("RPC:       reading TCP record fragment of length %d\n",
@@ -1136,7 +1162,7 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport,
                break;
        default:
                dprintk("RPC:       invalid request message type\n");
-               xprt_force_disconnect(&transport->xprt);
+               xs_tcp_force_close(&transport->xprt);
        }
        xs_tcp_check_fraghdr(transport);
 }
@@ -1456,12 +1482,19 @@ static void xs_tcp_cancel_linger_timeout(struct rpc_xprt *xprt)
        xprt_clear_connecting(xprt);
 }
 
-static void xs_sock_mark_closed(struct rpc_xprt *xprt)
+static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt)
 {
        smp_mb__before_clear_bit();
+       clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
+       clear_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
        clear_bit(XPRT_CLOSE_WAIT, &xprt->state);
        clear_bit(XPRT_CLOSING, &xprt->state);
        smp_mb__after_clear_bit();
+}
+
+static void xs_sock_mark_closed(struct rpc_xprt *xprt)
+{
+       xs_sock_reset_connection_flags(xprt);
        /* Mark transport as closed and wake up all pending tasks */
        xprt_disconnect_done(xprt);
 }
@@ -1516,8 +1549,9 @@ static void xs_tcp_state_change(struct sock *sk)
                break;
        case TCP_CLOSE_WAIT:
                /* The server initiated a shutdown of the socket */
-               xprt_force_disconnect(xprt);
                xprt->connect_cookie++;
+               clear_bit(XPRT_CONNECTED, &xprt->state);
+               xs_tcp_force_close(xprt);
        case TCP_CLOSING:
                /*
                 * If the server closed down the connection, make sure that
@@ -1541,25 +1575,6 @@ static void xs_tcp_state_change(struct sock *sk)
        read_unlock_bh(&sk->sk_callback_lock);
 }
 
-/**
- * xs_error_report - callback mainly for catching socket errors
- * @sk: socket
- */
-static void xs_error_report(struct sock *sk)
-{
-       struct rpc_xprt *xprt;
-
-       read_lock_bh(&sk->sk_callback_lock);
-       if (!(xprt = xprt_from_sock(sk)))
-               goto out;
-       dprintk("RPC:       %s client %p...\n"
-                       "RPC:       error %d\n",
-                       __func__, xprt, sk->sk_err);
-       xprt_wake_pending_tasks(xprt, -EAGAIN);
-out:
-       read_unlock_bh(&sk->sk_callback_lock);
-}
-
 static void xs_write_space(struct sock *sk)
 {
        struct socket *sock;
@@ -1859,7 +1874,6 @@ static int xs_local_finish_connecting(struct rpc_xprt *xprt,
                sk->sk_user_data = xprt;
                sk->sk_data_ready = xs_local_data_ready;
                sk->sk_write_space = xs_udp_write_space;
-               sk->sk_error_report = xs_error_report;
                sk->sk_allocation = GFP_ATOMIC;
 
                xprt_clear_connected(xprt);
@@ -1896,6 +1910,8 @@ static void xs_local_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
        status = __sock_create(xprt->xprt_net, AF_LOCAL,
                                        SOCK_STREAM, 0, &sock, 1);
@@ -1929,6 +1945,7 @@ static void xs_local_setup_socket(struct work_struct *work)
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
@@ -1945,7 +1962,6 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
                sk->sk_user_data = xprt;
                sk->sk_data_ready = xs_udp_data_ready;
                sk->sk_write_space = xs_udp_write_space;
-               sk->sk_error_report = xs_error_report;
                sk->sk_no_check = UDP_CSUM_NORCV;
                sk->sk_allocation = GFP_ATOMIC;
 
@@ -1971,6 +1987,8 @@ static void xs_udp_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        /* Start by resetting any existing state */
        xs_reset_transport(transport);
        sock = xs_create_sock(xprt, transport,
@@ -1989,6 +2007,7 @@ static void xs_udp_setup_socket(struct work_struct *work)
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 /*
@@ -2010,10 +2029,8 @@ static void xs_abort_connection(struct sock_xprt *transport)
        any.sa_family = AF_UNSPEC;
        result = kernel_connect(transport->sock, &any, sizeof(any), 0);
        if (!result)
-               xs_sock_mark_closed(&transport->xprt);
-       else
-               dprintk("RPC:       AF_UNSPEC connect return code %d\n",
-                               result);
+               xs_sock_reset_connection_flags(&transport->xprt);
+       dprintk("RPC:       AF_UNSPEC connect return code %d\n", result);
 }
 
 static void xs_tcp_reuse_connection(struct sock_xprt *transport)
@@ -2058,7 +2075,6 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
                sk->sk_data_ready = xs_tcp_data_ready;
                sk->sk_state_change = xs_tcp_state_change;
                sk->sk_write_space = xs_tcp_write_space;
-               sk->sk_error_report = xs_error_report;
                sk->sk_allocation = GFP_ATOMIC;
 
                /* socket options */
@@ -2114,6 +2130,8 @@ static void xs_tcp_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        if (!sock) {
                clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
                sock = xs_create_sock(xprt, transport,
@@ -2152,8 +2170,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
                /* We're probably in TIME_WAIT. Get rid of existing socket,
                 * and retry
                 */
-               set_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
-               xprt_force_disconnect(xprt);
+               xs_tcp_force_close(xprt);
                break;
        case -ECONNREFUSED:
        case -ECONNRESET:
@@ -2163,6 +2180,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
        case -EINPROGRESS:
        case -EALREADY:
                xprt_clear_connecting(xprt);
+               current->flags &= ~PF_FSTRANS;
                return;
        case -EINVAL:
                /* Happens, for instance, if the user specified a link
@@ -2175,6 +2193,7 @@ out_eagain:
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 /**
@@ -2413,6 +2432,7 @@ static void bc_destroy(struct rpc_xprt *xprt)
 static struct rpc_xprt_ops xs_local_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xs_tcp_release_xprt,
+       .alloc_slot             = xprt_alloc_slot,
        .rpcbind                = xs_local_rpcbind,
        .set_port               = xs_local_set_port,
        .connect                = xs_connect,
@@ -2429,6 +2449,7 @@ static struct rpc_xprt_ops xs_udp_ops = {
        .set_buffer_size        = xs_udp_set_buffer_size,
        .reserve_xprt           = xprt_reserve_xprt_cong,
        .release_xprt           = xprt_release_xprt_cong,
+       .alloc_slot             = xprt_alloc_slot,
        .rpcbind                = rpcb_getport_async,
        .set_port               = xs_set_port,
        .connect                = xs_connect,
@@ -2446,6 +2467,7 @@ static struct rpc_xprt_ops xs_udp_ops = {
 static struct rpc_xprt_ops xs_tcp_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xs_tcp_release_xprt,
+       .alloc_slot             = xprt_lock_and_alloc_slot,
        .rpcbind                = rpcb_getport_async,
        .set_port               = xs_set_port,
        .connect                = xs_connect,
@@ -2465,6 +2487,7 @@ static struct rpc_xprt_ops xs_tcp_ops = {
 static struct rpc_xprt_ops bc_tcp_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xprt_release_xprt,
+       .alloc_slot             = xprt_alloc_slot,
        .buf_alloc              = bc_malloc,
        .buf_free               = bc_free,
        .send_request           = bc_send_request,