SUNRPC: Prevent races in xs_abort_connection()
[pandora-kernel.git] / net / sunrpc / xprtsock.c
index 55472c4..aada5df 100644 (file)
@@ -737,10 +737,10 @@ static int xs_tcp_send_request(struct rpc_task *task)
                dprintk("RPC:       sendmsg returned unrecognized error %d\n",
                        -status);
        case -ECONNRESET:
-       case -EPIPE:
                xs_tcp_shutdown(xprt);
        case -ECONNREFUSED:
        case -ENOTCONN:
+       case -EPIPE:
                clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags);
        }
 
@@ -1028,6 +1028,16 @@ static void xs_udp_data_ready(struct sock *sk, int len)
        read_unlock_bh(&sk->sk_callback_lock);
 }
 
+/*
+ * Helper function to force a TCP close if the server is sending
+ * junk and/or it has put us in CLOSE_WAIT
+ */
+static void xs_tcp_force_close(struct rpc_xprt *xprt)
+{
+       set_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
+       xprt_force_disconnect(xprt);
+}
+
 static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc)
 {
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
@@ -1054,7 +1064,7 @@ static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_rea
        /* Sanity check of the record length */
        if (unlikely(transport->tcp_reclen < 8)) {
                dprintk("RPC:       invalid TCP record fragment length\n");
-               xprt_force_disconnect(xprt);
+               xs_tcp_force_close(xprt);
                return;
        }
        dprintk("RPC:       reading TCP record fragment of length %d\n",
@@ -1135,7 +1145,7 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport,
                break;
        default:
                dprintk("RPC:       invalid request message type\n");
-               xprt_force_disconnect(&transport->xprt);
+               xs_tcp_force_close(&transport->xprt);
        }
        xs_tcp_check_fraghdr(transport);
 }
@@ -1455,12 +1465,19 @@ static void xs_tcp_cancel_linger_timeout(struct rpc_xprt *xprt)
        xprt_clear_connecting(xprt);
 }
 
-static void xs_sock_mark_closed(struct rpc_xprt *xprt)
+static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt)
 {
        smp_mb__before_clear_bit();
+       clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
+       clear_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
        clear_bit(XPRT_CLOSE_WAIT, &xprt->state);
        clear_bit(XPRT_CLOSING, &xprt->state);
        smp_mb__after_clear_bit();
+}
+
+static void xs_sock_mark_closed(struct rpc_xprt *xprt)
+{
+       xs_sock_reset_connection_flags(xprt);
        /* Mark transport as closed and wake up all pending tasks */
        xprt_disconnect_done(xprt);
 }
@@ -1515,8 +1532,9 @@ static void xs_tcp_state_change(struct sock *sk)
                break;
        case TCP_CLOSE_WAIT:
                /* The server initiated a shutdown of the socket */
-               xprt_force_disconnect(xprt);
                xprt->connect_cookie++;
+               clear_bit(XPRT_CONNECTED, &xprt->state);
+               xs_tcp_force_close(xprt);
        case TCP_CLOSING:
                /*
                 * If the server closed down the connection, make sure that
@@ -1895,6 +1913,8 @@ static void xs_local_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
        status = __sock_create(xprt->xprt_net, AF_LOCAL,
                                        SOCK_STREAM, 0, &sock, 1);
@@ -1928,6 +1948,7 @@ static void xs_local_setup_socket(struct work_struct *work)
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
@@ -1970,6 +1991,8 @@ static void xs_udp_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        /* Start by resetting any existing state */
        xs_reset_transport(transport);
        sock = xs_create_sock(xprt, transport,
@@ -1988,6 +2011,7 @@ static void xs_udp_setup_socket(struct work_struct *work)
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 /*
@@ -2009,10 +2033,8 @@ static void xs_abort_connection(struct sock_xprt *transport)
        any.sa_family = AF_UNSPEC;
        result = kernel_connect(transport->sock, &any, sizeof(any), 0);
        if (!result)
-               xs_sock_mark_closed(&transport->xprt);
-       else
-               dprintk("RPC:       AF_UNSPEC connect return code %d\n",
-                               result);
+               xs_sock_reset_connection_flags(&transport->xprt);
+       dprintk("RPC:       AF_UNSPEC connect return code %d\n", result);
 }
 
 static void xs_tcp_reuse_connection(struct sock_xprt *transport)
@@ -2113,6 +2135,8 @@ static void xs_tcp_setup_socket(struct work_struct *work)
        if (xprt->shutdown)
                goto out;
 
+       current->flags |= PF_FSTRANS;
+
        if (!sock) {
                clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
                sock = xs_create_sock(xprt, transport,
@@ -2151,8 +2175,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
                /* We're probably in TIME_WAIT. Get rid of existing socket,
                 * and retry
                 */
-               set_bit(XPRT_CONNECTION_CLOSE, &xprt->state);
-               xprt_force_disconnect(xprt);
+               xs_tcp_force_close(xprt);
                break;
        case -ECONNREFUSED:
        case -ECONNRESET:
@@ -2162,6 +2185,7 @@ static void xs_tcp_setup_socket(struct work_struct *work)
        case -EINPROGRESS:
        case -EALREADY:
                xprt_clear_connecting(xprt);
+               current->flags &= ~PF_FSTRANS;
                return;
        case -EINVAL:
                /* Happens, for instance, if the user specified a link
@@ -2174,6 +2198,7 @@ out_eagain:
 out:
        xprt_clear_connecting(xprt);
        xprt_wake_pending_tasks(xprt, status);
+       current->flags &= ~PF_FSTRANS;
 }
 
 /**
@@ -2412,6 +2437,7 @@ static void bc_destroy(struct rpc_xprt *xprt)
 static struct rpc_xprt_ops xs_local_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xs_tcp_release_xprt,
+       .alloc_slot             = xprt_alloc_slot,
        .rpcbind                = xs_local_rpcbind,
        .set_port               = xs_local_set_port,
        .connect                = xs_connect,
@@ -2428,6 +2454,7 @@ static struct rpc_xprt_ops xs_udp_ops = {
        .set_buffer_size        = xs_udp_set_buffer_size,
        .reserve_xprt           = xprt_reserve_xprt_cong,
        .release_xprt           = xprt_release_xprt_cong,
+       .alloc_slot             = xprt_alloc_slot,
        .rpcbind                = rpcb_getport_async,
        .set_port               = xs_set_port,
        .connect                = xs_connect,
@@ -2445,6 +2472,7 @@ static struct rpc_xprt_ops xs_udp_ops = {
 static struct rpc_xprt_ops xs_tcp_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xs_tcp_release_xprt,
+       .alloc_slot             = xprt_lock_and_alloc_slot,
        .rpcbind                = rpcb_getport_async,
        .set_port               = xs_set_port,
        .connect                = xs_connect,
@@ -2464,6 +2492,7 @@ static struct rpc_xprt_ops xs_tcp_ops = {
 static struct rpc_xprt_ops bc_tcp_ops = {
        .reserve_xprt           = xprt_reserve_xprt,
        .release_xprt           = xprt_release_xprt,
+       .alloc_slot             = xprt_alloc_slot,
        .buf_alloc              = bc_malloc,
        .buf_free               = bc_free,
        .send_request           = bc_send_request,