af_key: fix buffer overread in verify_address_len()
[pandora-kernel.git] / net / l2tp / l2tp_netlink.c
index 93a41a0..12c0e8a 100644 (file)
@@ -40,7 +40,8 @@ static struct genl_family l2tp_nl_family = {
 /* Accessed under genl lock */
 static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX];
 
-static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
+static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
+                                               bool do_ref)
 {
        u32 tunnel_id;
        u32 session_id;
@@ -51,14 +52,17 @@ static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
 
        if (info->attrs[L2TP_ATTR_IFNAME]) {
                ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]);
-               session = l2tp_session_find_by_ifname(net, ifname);
+               session = l2tp_session_get_by_ifname(net, ifname, do_ref);
        } else if ((info->attrs[L2TP_ATTR_SESSION_ID]) &&
                   (info->attrs[L2TP_ATTR_CONN_ID])) {
                tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
                session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
-               tunnel = l2tp_tunnel_find(net, tunnel_id);
-               if (tunnel)
-                       session = l2tp_session_find(net, tunnel, session_id);
+               tunnel = l2tp_tunnel_get(net, tunnel_id);
+               if (tunnel) {
+                       session = l2tp_session_get(net, tunnel, session_id,
+                                                  do_ref);
+                       l2tp_tunnel_dec_refcount(tunnel);
+               }
        }
 
        return session;
@@ -180,14 +184,16 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info
        }
        tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
 
-       tunnel = l2tp_tunnel_find(net, tunnel_id);
-       if (tunnel == NULL) {
+       tunnel = l2tp_tunnel_get(net, tunnel_id);
+       if (!tunnel) {
                ret = -ENODEV;
                goto out;
        }
 
        (void) l2tp_tunnel_delete(tunnel);
 
+       l2tp_tunnel_dec_refcount(tunnel);
+
 out:
        return ret;
 }
@@ -205,8 +211,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info
        }
        tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
 
-       tunnel = l2tp_tunnel_find(net, tunnel_id);
-       if (tunnel == NULL) {
+       tunnel = l2tp_tunnel_get(net, tunnel_id);
+       if (!tunnel) {
                ret = -ENODEV;
                goto out;
        }
@@ -214,6 +220,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info
        if (info->attrs[L2TP_ATTR_DEBUG])
                tunnel->debug = nla_get_u32(info->attrs[L2TP_ATTR_DEBUG]);
 
+       l2tp_tunnel_dec_refcount(tunnel);
+
 out:
        return ret;
 }
@@ -287,34 +295,37 @@ static int l2tp_nl_cmd_tunnel_get(struct sk_buff *skb, struct genl_info *info)
 
        if (!info->attrs[L2TP_ATTR_CONN_ID]) {
                ret = -EINVAL;
-               goto out;
+               goto err;
        }
 
        tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
 
-       tunnel = l2tp_tunnel_find(net, tunnel_id);
-       if (tunnel == NULL) {
-               ret = -ENODEV;
-               goto out;
-       }
-
        msg = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL);
        if (!msg) {
                ret = -ENOMEM;
-               goto out;
+               goto err;
+       }
+
+       tunnel = l2tp_tunnel_get(net, tunnel_id);
+       if (!tunnel) {
+               ret = -ENODEV;
+               goto err_nlmsg;
        }
 
        ret = l2tp_nl_tunnel_send(msg, info->snd_pid, info->snd_seq,
                                  NLM_F_ACK, tunnel);
        if (ret < 0)
-               goto err_out;
+               goto err_nlmsg_tunnel;
+
+       l2tp_tunnel_dec_refcount(tunnel);
 
        return genlmsg_unicast(net, msg, info->snd_pid);
 
-err_out:
+err_nlmsg_tunnel:
+       l2tp_tunnel_dec_refcount(tunnel);
+err_nlmsg:
        nlmsg_free(msg);
-
-out:
+err:
        return ret;
 }
 
@@ -350,7 +361,6 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
        u32 peer_session_id;
        int ret = 0;
        struct l2tp_tunnel *tunnel;
-       struct l2tp_session *session;
        struct l2tp_session_cfg cfg = { 0, };
        struct net *net = genl_info_net(info);
 
@@ -358,8 +368,9 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
                ret = -EINVAL;
                goto out;
        }
+
        tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
-       tunnel = l2tp_tunnel_find(net, tunnel_id);
+       tunnel = l2tp_tunnel_get(net, tunnel_id);
        if (!tunnel) {
                ret = -ENODEV;
                goto out;
@@ -367,29 +378,24 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
 
        if (!info->attrs[L2TP_ATTR_SESSION_ID]) {
                ret = -EINVAL;
-               goto out;
+               goto out_tunnel;
        }
        session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
-       session = l2tp_session_find(net, tunnel, session_id);
-       if (session) {
-               ret = -EEXIST;
-               goto out;
-       }
 
        if (!info->attrs[L2TP_ATTR_PEER_SESSION_ID]) {
                ret = -EINVAL;
-               goto out;
+               goto out_tunnel;
        }
        peer_session_id = nla_get_u32(info->attrs[L2TP_ATTR_PEER_SESSION_ID]);
 
        if (!info->attrs[L2TP_ATTR_PW_TYPE]) {
                ret = -EINVAL;
-               goto out;
+               goto out_tunnel;
        }
        cfg.pw_type = nla_get_u16(info->attrs[L2TP_ATTR_PW_TYPE]);
        if (cfg.pw_type >= __L2TP_PWTYPE_MAX) {
                ret = -EINVAL;
-               goto out;
+               goto out_tunnel;
        }
 
        if (tunnel->version > 2) {
@@ -411,7 +417,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
                        u16 len = nla_len(info->attrs[L2TP_ATTR_COOKIE]);
                        if (len > 8) {
                                ret = -EINVAL;
-                               goto out;
+                               goto out_tunnel;
                        }
                        cfg.cookie_len = len;
                        memcpy(&cfg.cookie[0], nla_data(info->attrs[L2TP_ATTR_COOKIE]), len);
@@ -420,7 +426,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
                        u16 len = nla_len(info->attrs[L2TP_ATTR_PEER_COOKIE]);
                        if (len > 8) {
                                ret = -EINVAL;
-                               goto out;
+                               goto out_tunnel;
                        }
                        cfg.peer_cookie_len = len;
                        memcpy(&cfg.peer_cookie[0], nla_data(info->attrs[L2TP_ATTR_PEER_COOKIE]), len);
@@ -456,7 +462,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
        if ((l2tp_nl_cmd_ops[cfg.pw_type] == NULL) ||
            (l2tp_nl_cmd_ops[cfg.pw_type]->session_create == NULL)) {
                ret = -EPROTONOSUPPORT;
-               goto out;
+               goto out_tunnel;
        }
 
        /* Check that pseudowire-specific params are present */
@@ -466,7 +472,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
        case L2TP_PWTYPE_ETH_VLAN:
                if (!info->attrs[L2TP_ATTR_VLAN_ID]) {
                        ret = -EINVAL;
-                       goto out;
+                       goto out_tunnel;
                }
                break;
        case L2TP_PWTYPE_ETH:
@@ -480,11 +486,13 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
                break;
        }
 
-       ret = -EPROTONOSUPPORT;
-       if (l2tp_nl_cmd_ops[cfg.pw_type]->session_create)
-               ret = (*l2tp_nl_cmd_ops[cfg.pw_type]->session_create)(net, tunnel_id,
-                       session_id, peer_session_id, &cfg);
+       ret = l2tp_nl_cmd_ops[cfg.pw_type]->session_create(net, tunnel,
+                                                          session_id,
+                                                          peer_session_id,
+                                                          &cfg);
 
+out_tunnel:
+       l2tp_tunnel_dec_refcount(tunnel);
 out:
        return ret;
 }
@@ -495,7 +503,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
        struct l2tp_session *session;
        u16 pw_type;
 
-       session = l2tp_nl_session_find(info);
+       session = l2tp_nl_session_get(info, true);
        if (session == NULL) {
                ret = -ENODEV;
                goto out;
@@ -506,6 +514,10 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
                if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete)
                        ret = (*l2tp_nl_cmd_ops[pw_type]->session_delete)(session);
 
+       if (session->deref)
+               session->deref(session);
+       l2tp_session_dec_refcount(session);
+
 out:
        return ret;
 }
@@ -515,7 +527,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
        int ret = 0;
        struct l2tp_session *session;
 
-       session = l2tp_nl_session_find(info);
+       session = l2tp_nl_session_get(info, false);
        if (session == NULL) {
                ret = -ENODEV;
                goto out;
@@ -545,6 +557,8 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
        if (info->attrs[L2TP_ATTR_MRU])
                session->mru = nla_get_u16(info->attrs[L2TP_ATTR_MRU]);
 
+       l2tp_session_dec_refcount(session);
+
 out:
        return ret;
 }
@@ -615,29 +629,34 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info)
        struct sk_buff *msg;
        int ret;
 
-       session = l2tp_nl_session_find(info);
+       session = l2tp_nl_session_get(info, false);
        if (session == NULL) {
                ret = -ENODEV;
-               goto out;
+               goto err;
        }
 
        msg = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL);
        if (!msg) {
                ret = -ENOMEM;
-               goto out;
+               goto err_ref;
        }
 
        ret = l2tp_nl_session_send(msg, info->snd_pid, info->snd_seq,
                                   0, session);
        if (ret < 0)
-               goto err_out;
+               goto err_ref_msg;
 
-       return genlmsg_unicast(genl_info_net(info), msg, info->snd_pid);
+       ret = genlmsg_unicast(genl_info_net(info), msg, info->snd_pid);
 
-err_out:
-       nlmsg_free(msg);
+       l2tp_session_dec_refcount(session);
 
-out:
+       return ret;
+
+err_ref_msg:
+       nlmsg_free(msg);
+err_ref:
+       l2tp_session_dec_refcount(session);
+err:
        return ret;
 }
 
@@ -656,7 +675,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
                                goto out;
                }
 
-               session = l2tp_session_find_nth(tunnel, si);
+               session = l2tp_session_get_nth(tunnel, si, false);
                if (session == NULL) {
                        ti++;
                        tunnel = NULL;
@@ -666,8 +685,11 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
 
                if (l2tp_nl_session_send(skb, NETLINK_CB(cb->skb).pid,
                                         cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                                        session) <= 0)
+                                        session) <= 0) {
+                       l2tp_session_dec_refcount(session);
                        break;
+               }
+               l2tp_session_dec_refcount(session);
 
                si++;
        }