l2tp: protect sock pointer of struct pppol2tp_session with RCU
[pandora-kernel.git] / net / l2tp / l2tp_ppp.c
index e400acf..37bea01 100644 (file)
 struct pppol2tp_session {
        int                     owner;          /* pid that opened the socket */
 
-       struct sock             *sock;          /* Pointer to the session
+       struct mutex            sk_lock;        /* Protects .sk */
+       struct sock __rcu       *sk;            /* Pointer to the session
                                                 * PPPoX socket */
+       struct sock             *__sk;          /* Copy of .sk, for cleanup */
+       struct rcu_head         rcu;            /* For asynchronous release */
        struct sock             *tunnel_sock;   /* Pointer to the tunnel UDP
                                                 * socket */
        int                     flags;          /* accessed by PPPIOCGFLAGS.
@@ -142,6 +145,24 @@ static const struct ppp_channel_ops pppol2tp_chan_ops = {
 
 static const struct proto_ops pppol2tp_ops;
 
+/* Retrieves the pppol2tp socket associated to a session.
+ * A reference is held on the returned socket, so this function must be paired
+ * with sock_put().
+ */
+static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session)
+{
+       struct pppol2tp_session *ps = l2tp_session_priv(session);
+       struct sock *sk;
+
+       rcu_read_lock();
+       sk = rcu_dereference(ps->sk);
+       if (sk)
+               sock_hold(sk);
+       rcu_read_unlock();
+
+       return sk;
+}
+
 /* Helpers to obtain tunnel/session contexts from sockets.
  */
 static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
@@ -229,7 +250,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
        /* If the socket is bound, send it in to PPP's input queue. Otherwise
         * queue it on the session socket.
         */
-       sk = ps->sock;
+       rcu_read_lock();
+       sk = rcu_dereference(ps->sk);
        if (sk == NULL)
                goto no_sock;
 
@@ -265,31 +287,17 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
                session->stats.rx_errors++;
                kfree_skb(skb);
        }
+       rcu_read_unlock();
 
        return;
 
 no_sock:
+       rcu_read_unlock();
        PRINTK(session->debug, PPPOL2TP_MSG_DATA, KERN_INFO,
               "%s: no socket\n", session->name);
        kfree_skb(skb);
 }
 
-static void pppol2tp_session_sock_hold(struct l2tp_session *session)
-{
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
-
-       if (ps->sock)
-               sock_hold(ps->sock);
-}
-
-static void pppol2tp_session_sock_put(struct l2tp_session *session)
-{
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
-
-       if (ps->sock)
-               sock_put(ps->sock);
-}
-
 /************************************************************************
  * Transmit handling
  ***********************************************************************/
@@ -459,15 +467,17 @@ abort:
  */
 static void pppol2tp_session_close(struct l2tp_session *session)
 {
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
-       struct sock *sk = ps->sock;
-       struct socket *sock = sk->sk_socket;
+       struct sock *sk;
 
        BUG_ON(session->magic != L2TP_SESSION_MAGIC);
 
 
-       if (sock)
-               inet_shutdown(sock, 2);
+       sk = pppol2tp_session_get_sock(session);
+       if (sk) {
+               if (sk->sk_socket)
+                       inet_shutdown(sk->sk_socket, SEND_SHUTDOWN);
+               sock_put(sk);
+       }
 
        /* Don't let the session go away before our socket does */
        l2tp_session_inc_refcount(session);
@@ -495,6 +505,14 @@ out:
        return;
 }
 
+static void pppol2tp_put_sk(struct rcu_head *head)
+{
+       struct pppol2tp_session *ps;
+
+       ps = container_of(head, typeof(*ps), rcu);
+       sock_put(ps->__sk);
+}
+
 /* Called when the PPPoX socket (session) is closed.
  */
 static int pppol2tp_release(struct socket *sock)
@@ -520,10 +538,23 @@ static int pppol2tp_release(struct socket *sock)
 
        session = pppol2tp_sock_to_session(sk);
 
-       /* Purge any queued data */
        if (session != NULL) {
+               struct pppol2tp_session *ps;
+
                l2tp_session_queue_purge(session);
-               sock_put(sk);
+
+               ps = l2tp_session_priv(session);
+               mutex_lock(&ps->sk_lock);
+               ps->__sk = rcu_dereference_protected(ps->sk,
+                                                    lockdep_is_held(&ps->sk_lock));
+               RCU_INIT_POINTER(ps->sk, NULL);
+               mutex_unlock(&ps->sk_lock);
+               call_rcu(&ps->rcu, pppol2tp_put_sk);
+
+               /* Rely on the sock_put() call at the end of the function for
+                * dropping the reference held by pppol2tp_sock_to_session().
+                * The last reference will be dropped by pppol2tp_put_sk().
+                */
        }
        skb_queue_purge(&sk->sk_receive_queue);
        skb_queue_purge(&sk->sk_write_queue);
@@ -593,12 +624,14 @@ out:
 static void pppol2tp_show(struct seq_file *m, void *arg)
 {
        struct l2tp_session *session = arg;
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
+       struct sock *sk;
+
+       sk = pppol2tp_session_get_sock(session);
+       if (sk) {
+               struct pppox_sock *po = pppox_sk(sk);
 
-       if (ps) {
-               struct pppox_sock *po = pppox_sk(ps->sock);
-               if (po)
-                       seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+               seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+               sock_put(sk);
        }
 }
 #endif
@@ -712,13 +745,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                /* Using a pre-existing session is fine as long as it hasn't
                 * been connected yet.
                 */
-               if (ps->sock) {
+               mutex_lock(&ps->sk_lock);
+               if (rcu_dereference_protected(ps->sk,
+                                             lockdep_is_held(&ps->sk_lock))) {
+                       mutex_unlock(&ps->sk_lock);
                        error = -EEXIST;
                        goto end;
                }
 
                /* consistency checks */
                if (ps->tunnel_sock != tunnel->sock) {
+                       mutex_unlock(&ps->sk_lock);
                        error = -EEXIST;
                        goto end;
                }
@@ -735,19 +772,21 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
                        goto end;
                }
 
+               ps = l2tp_session_priv(session);
+               mutex_init(&ps->sk_lock);
                l2tp_session_inc_refcount(session);
+
+               mutex_lock(&ps->sk_lock);
                error = l2tp_session_register(session, tunnel);
                if (error < 0) {
+                       mutex_unlock(&ps->sk_lock);
                        kfree(session);
                        goto end;
                }
                drop_refcnt = true;
        }
 
-       /* Associate session with its PPPoL2TP socket */
-       ps = l2tp_session_priv(session);
        ps->owner            = current->pid;
-       ps->sock             = sk;
        ps->tunnel_sock = tunnel->sock;
 
        session->recv_skb       = pppol2tp_recv;
@@ -756,12 +795,6 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
        session->show           = pppol2tp_show;
 #endif
 
-       /* We need to know each time a skb is dropped from the reorder
-        * queue.
-        */
-       session->ref = pppol2tp_session_sock_hold;
-       session->deref = pppol2tp_session_sock_put;
-
        /* If PMTU discovery was enabled, use the MTU that was discovered */
        dst = sk_dst_get(tunnel->sock);
        if (dst != NULL) {
@@ -795,12 +828,17 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
        po->chan.mtu     = session->mtu;
 
        error = ppp_register_net_channel(sock_net(sk), &po->chan);
-       if (error)
+       if (error) {
+               mutex_unlock(&ps->sk_lock);
                goto end;
+       }
 
 out_no_ppp:
        /* This is how we get the session context from the socket. */
        sk->sk_user_data = session;
+       rcu_assign_pointer(ps->sk, sk);
+       mutex_unlock(&ps->sk_lock);
+
        sk->sk_state = PPPOX_CONNECTED;
        PRINTK(session->debug, PPPOL2TP_MSG_CONTROL, KERN_INFO,
               "%s: created\n", session->name);
@@ -848,6 +886,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
        }
 
        ps = l2tp_session_priv(session);
+       mutex_init(&ps->sk_lock);
        ps->tunnel_sock = tunnel->sock;
 
        error = l2tp_session_register(session, tunnel);
@@ -979,12 +1018,10 @@ static int pppol2tp_session_ioctl(struct l2tp_session *session,
               "%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n",
               session->name, cmd, arg);
 
-       sk = ps->sock;
+       sk = pppol2tp_session_get_sock(session);
        if (!sk)
                return -EBADR;
 
-       sock_hold(sk);
-
        switch (cmd) {
        case SIOCGIFMTU:
                err = -ENXIO;
@@ -1260,7 +1297,6 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
                                       int optname, int val)
 {
        int err = 0;
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
 
        switch (optname) {
        case PPPOL2TP_SO_RECVSEQ:
@@ -1280,8 +1316,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
                }
                session->send_seq = val ? -1 : 0;
                {
-                       struct sock *ssk      = ps->sock;
-                       struct pppox_sock *po = pppox_sk(ssk);
+                       struct pppox_sock *po = pppox_sk(sk);
+
                        po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
                                PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
                }
@@ -1616,8 +1652,9 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
 {
        struct l2tp_session *session = v;
        struct l2tp_tunnel *tunnel = session->tunnel;
-       struct pppol2tp_session *ps = l2tp_session_priv(session);
-       struct pppox_sock *po = pppox_sk(ps->sock);
+       unsigned char state;
+       char user_data_ok;
+       struct sock *sk;
        u32 ip = 0;
        u16 port = 0;
 
@@ -1627,6 +1664,15 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
                port = ntohs(inet->inet_sport);
        }
 
+       sk = pppol2tp_session_get_sock(session);
+       if (sk) {
+               state = sk->sk_state;
+               user_data_ok = (session == sk->sk_user_data) ? 'Y' : 'N';
+       } else {
+               state = 0;
+               user_data_ok = 'N';
+       }
+
        seq_printf(m, "  SESSION '%s' %08X/%d %04X/%04X -> "
                   "%04X/%04X %d %c\n",
                   session->name, ip, port,
@@ -1634,9 +1680,7 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
                   session->session_id,
                   tunnel->peer_tunnel_id,
                   session->peer_session_id,
-                  ps->sock->sk_state,
-                  (session == ps->sock->sk_user_data) ?
-                  'Y' : 'N');
+                  state, user_data_ok);
        seq_printf(m, "   %d/%d/%c/%c/%s %08x %u\n",
                   session->mtu, session->mru,
                   session->recv_seq ? 'R' : '-',
@@ -1653,8 +1697,12 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
                   (unsigned long long)session->stats.rx_bytes,
                   (unsigned long long)session->stats.rx_errors);
 
-       if (po)
+       if (sk) {
+               struct pppox_sock *po = pppox_sk(sk);
+
                seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+               sock_put(sk);
+       }
 }
 
 static int pppol2tp_seq_show(struct seq_file *m, void *v)