l2tp: prevent creation of sessions on terminated tunnels
[pandora-kernel.git] / net / l2tp / l2tp_core.c
index 3954b18..5fa9068 100644 (file)
@@ -112,10 +112,9 @@ struct l2tp_net {
 };
 
 static void l2tp_session_set_header_len(struct l2tp_session *session, int version);
-static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);
 static void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel);
 
-static inline struct l2tp_net *l2tp_pernet(struct net *net)
+static inline struct l2tp_net *l2tp_pernet(const struct net *net)
 {
        BUG_ON(!net);
 
@@ -123,33 +122,6 @@ static inline struct l2tp_net *l2tp_pernet(struct net *net)
 }
 
 
-/* Tunnel reference counts. Incremented per session that is added to
- * the tunnel.
- */
-static inline void l2tp_tunnel_inc_refcount_1(struct l2tp_tunnel *tunnel)
-{
-       atomic_inc(&tunnel->ref_count);
-}
-
-static inline void l2tp_tunnel_dec_refcount_1(struct l2tp_tunnel *tunnel)
-{
-       if (atomic_dec_and_test(&tunnel->ref_count))
-               l2tp_tunnel_free(tunnel);
-}
-#ifdef L2TP_REFCNT_DEBUG
-#define l2tp_tunnel_inc_refcount(_t) do { \
-               printk(KERN_DEBUG "l2tp_tunnel_inc_refcount: %s:%d %s: cnt=%d\n", __func__, __LINE__, (_t)->name, atomic_read(&_t->ref_count)); \
-               l2tp_tunnel_inc_refcount_1(_t);                         \
-       } while (0)
-#define l2tp_tunnel_dec_refcount(_t) do { \
-               printk(KERN_DEBUG "l2tp_tunnel_dec_refcount: %s:%d %s: cnt=%d\n", __func__, __LINE__, (_t)->name, atomic_read(&_t->ref_count)); \
-               l2tp_tunnel_dec_refcount_1(_t);                         \
-       } while (0)
-#else
-#define l2tp_tunnel_inc_refcount(t) l2tp_tunnel_inc_refcount_1(t)
-#define l2tp_tunnel_dec_refcount(t) l2tp_tunnel_dec_refcount_1(t)
-#endif
-
 /* Session hash global list for L2TPv3.
  * The session_id SHOULD be random according to RFC3931, but several
  * L2TP implementations use incrementing session_ids.  So we do a real
@@ -225,10 +197,31 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
 }
 EXPORT_SYMBOL_GPL(l2tp_session_find);
 
+/* Lookup a tunnel. A new reference is held on the returned tunnel. */
+struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
+{
+       const struct l2tp_net *pn = l2tp_pernet(net);
+       struct l2tp_tunnel *tunnel;
+
+       rcu_read_lock_bh();
+       list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
+               if (tunnel->tunnel_id == tunnel_id) {
+                       l2tp_tunnel_inc_refcount(tunnel);
+                       rcu_read_unlock_bh();
+
+                       return tunnel;
+               }
+       }
+       rcu_read_unlock_bh();
+
+       return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
+
 /* Like l2tp_session_find() but takes a reference on the returned session.
  * Optionally calls session->ref() too if do_ref is true.
  */
-struct l2tp_session *l2tp_session_get(struct net *net,
+struct l2tp_session *l2tp_session_get(const struct net *net,
                                      struct l2tp_tunnel *tunnel,
                                      u32 session_id, bool do_ref)
 {
@@ -305,7 +298,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_get_nth);
 /* Lookup a session by interface name.
  * This is very inefficient but is only used by management interfaces.
  */
-struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
+struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
+                                               const char *ifname,
                                                bool do_ref)
 {
        struct l2tp_net *pn = l2tp_pernet(net);
@@ -341,13 +335,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
        struct hlist_head *head;
        struct l2tp_net *pn;
        struct hlist_node *walk;
+       int err;
 
        head = l2tp_session_id_hash(tunnel, session->session_id);
 
        write_lock_bh(&tunnel->hlist_lock);
+       if (!tunnel->acpt_newsess) {
+               err = -ENODEV;
+               goto err_tlock;
+       }
+
        hlist_for_each_entry(session_walk, walk, head, hlist)
-               if (session_walk->session_id == session->session_id)
-                       goto exist;
+               if (session_walk->session_id == session->session_id) {
+                       err = -EEXIST;
+                       goto err_tlock;
+               }
 
        if (tunnel->version == L2TP_HDR_VER_3) {
                pn = l2tp_pernet(tunnel->l2tp_net);
@@ -355,12 +357,21 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
                                                session->session_id);
 
                spin_lock_bh(&pn->l2tp_session_hlist_lock);
+
                hlist_for_each_entry(session_walk, walk, g_head, global_hlist)
-                       if (session_walk->session_id == session->session_id)
-                               goto exist_glob;
+                       if (session_walk->session_id == session->session_id) {
+                               err = -EEXIST;
+                               goto err_tlock_pnlock;
+                       }
 
+               l2tp_tunnel_inc_refcount(tunnel);
+               sock_hold(tunnel->sock);
                hlist_add_head_rcu(&session->global_hlist, g_head);
+
                spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+       } else {
+               l2tp_tunnel_inc_refcount(tunnel);
+               sock_hold(tunnel->sock);
        }
 
        hlist_add_head(&session->hlist, head);
@@ -368,12 +379,12 @@ static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
 
        return 0;
 
-exist_glob:
+err_tlock_pnlock:
        spin_unlock_bh(&pn->l2tp_session_hlist_lock);
-exist:
+err_tlock:
        write_unlock_bh(&tunnel->hlist_lock);
 
-       return -EEXIST;
+       return err;
 }
 
 /* Lookup a tunnel by id
@@ -1288,6 +1299,7 @@ static void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel)
               "%s: closing all sessions...\n", tunnel->name);
 
        write_lock_bh(&tunnel->hlist_lock);
+       tunnel->acpt_newsess = false;
        for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
 again:
                hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) {
@@ -1340,7 +1352,7 @@ again:
 /* Really kill the tunnel.
  * Come here only when all sessions have been cleared from the tunnel.
  */
-static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
+void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
 {
        struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
 
@@ -1358,6 +1370,7 @@ static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
 
        atomic_dec(&l2tp_tunnel_count);
 }
+EXPORT_SYMBOL_GPL(l2tp_tunnel_free);
 
 /* Create a socket for the tunnel, if one isn't set up by
  * userspace. This is used for static tunnels where there is no
@@ -1509,6 +1522,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
        tunnel->magic = L2TP_TUNNEL_MAGIC;
        sprintf(&tunnel->name[0], "tunl %u", tunnel_id);
        rwlock_init(&tunnel->hlist_lock);
+       tunnel->acpt_newsess = true;
 
        /* The net we belong to */
        tunnel->l2tp_net = net;
@@ -1715,6 +1729,8 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
 
                l2tp_session_set_header_len(session, tunnel->version);
 
+               l2tp_session_inc_refcount(session);
+
                err = l2tp_session_add_to_tunnel(tunnel, session);
                if (err) {
                        kfree(session);
@@ -1722,15 +1738,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
                        return ERR_PTR(err);
                }
 
-               /* Bump the reference count. The session context is deleted
-                * only when this drops to zero.
-                */
-               l2tp_session_inc_refcount(session);
-               l2tp_tunnel_inc_refcount(tunnel);
-
-               /* Ensure tunnel socket isn't deleted */
-               sock_hold(tunnel->sock);
-
                /* Ignore management session in session count value */
                if (session->session_id != 0)
                        atomic_inc(&l2tp_session_count);