inet_diag: Rename inet_diag_req into inet_diag_req_v2
[pandora-kernel.git] / net / ipv4 / inet_diag.c
index ccee270..cf23a7c 100644 (file)
@@ -33,6 +33,7 @@
 #include <linux/stddef.h>
 
 #include <linux/inet_diag.h>
+#include <linux/sock_diag.h>
 
 static const struct inet_diag_handler **inet_diag_table;
 
@@ -45,24 +46,22 @@ struct inet_diag_entry {
        u16 userlocks;
 };
 
-static struct sock *idiagnl;
-
 #define INET_DIAG_PUT(skb, attrtype, attrlen) \
        RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
 static DEFINE_MUTEX(inet_diag_table_mutex);
 
-static const struct inet_diag_handler *inet_diag_lock_handler(int type)
+static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
-       if (!inet_diag_table[type])
-               request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                              NETLINK_INET_DIAG, type);
+       if (!inet_diag_table[proto])
+               request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
+                              NETLINK_SOCK_DIAG, AF_INET, proto);
 
        mutex_lock(&inet_diag_table_mutex);
-       if (!inet_diag_table[type])
+       if (!inet_diag_table[proto])
                return ERR_PTR(-ENOENT);
 
-       return inet_diag_table[type];
+       return inet_diag_table[proto];
 }
 
 static inline void inet_diag_unlock_handler(
@@ -71,21 +70,21 @@ static inline void inet_diag_unlock_handler(
        mutex_unlock(&inet_diag_table_mutex);
 }
 
-static int inet_csk_diag_fill(struct sock *sk,
-                             struct sk_buff *skb,
-                             int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
+                             struct sk_buff *skb, struct inet_diag_req_v2 *req,
+                             u32 pid, u32 seq, u16 nlmsg_flags,
                              const struct nlmsghdr *unlh)
 {
        const struct inet_sock *inet = inet_sk(sk);
-       const struct inet_connection_sock *icsk = inet_csk(sk);
        struct inet_diag_msg *r;
        struct nlmsghdr  *nlh;
        void *info = NULL;
        struct inet_diag_meminfo  *minfo = NULL;
        unsigned char    *b = skb_tail_pointer(skb);
        const struct inet_diag_handler *handler;
+       int ext = req->idiag_ext;
 
-       handler = inet_diag_table[unlh->nlmsg_type];
+       handler = inet_diag_table[req->sdiag_protocol];
        BUG_ON(handler == NULL);
 
        nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
@@ -97,25 +96,13 @@ static int inet_csk_diag_fill(struct sock *sk,
        if (ext & (1 << (INET_DIAG_MEMINFO - 1)))
                minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, sizeof(*minfo));
 
-       if (ext & (1 << (INET_DIAG_INFO - 1)))
-               info = INET_DIAG_PUT(skb, INET_DIAG_INFO,
-                                    handler->idiag_info_size);
-
-       if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) {
-               const size_t len = strlen(icsk->icsk_ca_ops->name);
-
-               strcpy(INET_DIAG_PUT(skb, INET_DIAG_CONG, len + 1),
-                      icsk->icsk_ca_ops->name);
-       }
-
        r->idiag_family = sk->sk_family;
        r->idiag_state = sk->sk_state;
        r->idiag_timer = 0;
        r->idiag_retrans = 0;
 
        r->id.idiag_if = sk->sk_bound_dev_if;
-       r->id.idiag_cookie[0] = (u32)(unsigned long)sk;
-       r->id.idiag_cookie[1] = (u32)(((unsigned long)sk >> 31) >> 1);
+       sock_diag_save_cookie(sk, r->id.idiag_cookie);
 
        r->id.idiag_sport = inet->inet_sport;
        r->id.idiag_dport = inet->inet_dport;
@@ -128,20 +115,36 @@ static int inet_csk_diag_fill(struct sock *sk,
        if (ext & (1 << (INET_DIAG_TOS - 1)))
                RTA_PUT_U8(skb, INET_DIAG_TOS, inet->tos);
 
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
        if (r->idiag_family == AF_INET6) {
                const struct ipv6_pinfo *np = inet6_sk(sk);
 
+               *(struct in6_addr *)r->id.idiag_src = np->rcv_saddr;
+               *(struct in6_addr *)r->id.idiag_dst = np->daddr;
                if (ext & (1 << (INET_DIAG_TCLASS - 1)))
                        RTA_PUT_U8(skb, INET_DIAG_TCLASS, np->tclass);
-
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &np->rcv_saddr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &np->daddr);
        }
 #endif
 
+       r->idiag_uid = sock_i_uid(sk);
+       r->idiag_inode = sock_i_ino(sk);
+
+       if (minfo) {
+               minfo->idiag_rmem = sk_rmem_alloc_get(sk);
+               minfo->idiag_wmem = sk->sk_wmem_queued;
+               minfo->idiag_fmem = sk->sk_forward_alloc;
+               minfo->idiag_tmem = sk_wmem_alloc_get(sk);
+       }
+
+       if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
+               if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
+                       goto rtattr_failure;
+
+       if (icsk == NULL) {
+               r->idiag_rqueue = r->idiag_wqueue = 0;
+               goto out;
+       }
+
 #define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
 
        if (icsk->icsk_pending == ICSK_TIME_RETRANS) {
@@ -162,14 +165,14 @@ static int inet_csk_diag_fill(struct sock *sk,
        }
 #undef EXPIRES_IN_MS
 
-       r->idiag_uid = sock_i_uid(sk);
-       r->idiag_inode = sock_i_ino(sk);
+       if (ext & (1 << (INET_DIAG_INFO - 1)))
+               info = INET_DIAG_PUT(skb, INET_DIAG_INFO, sizeof(struct tcp_info));
 
-       if (minfo) {
-               minfo->idiag_rmem = sk_rmem_alloc_get(sk);
-               minfo->idiag_wmem = sk->sk_wmem_queued;
-               minfo->idiag_fmem = sk->sk_forward_alloc;
-               minfo->idiag_tmem = sk_wmem_alloc_get(sk);
+       if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) {
+               const size_t len = strlen(icsk->icsk_ca_ops->name);
+
+               strcpy(INET_DIAG_PUT(skb, INET_DIAG_CONG, len + 1),
+                      icsk->icsk_ca_ops->name);
        }
 
        handler->idiag_get_info(sk, r, info);
@@ -178,6 +181,7 @@ static int inet_csk_diag_fill(struct sock *sk,
            icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
                icsk->icsk_ca_ops->get_info(sk, ext, skb);
 
+out:
        nlh->nlmsg_len = skb_tail_pointer(skb) - b;
        return skb->len;
 
@@ -186,10 +190,20 @@ nlmsg_failure:
        nlmsg_trim(skb, b);
        return -EMSGSIZE;
 }
+EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
+
+static int inet_csk_diag_fill(struct sock *sk,
+                             struct sk_buff *skb, struct inet_diag_req_v2 *req,
+                             u32 pid, u32 seq, u16 nlmsg_flags,
+                             const struct nlmsghdr *unlh)
+{
+       return inet_sk_diag_fill(sk, inet_csk(sk),
+                       skb, req, pid, seq, nlmsg_flags, unlh);
+}
 
 static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
-                              struct sk_buff *skb, int ext, u32 pid,
-                              u32 seq, u16 nlmsg_flags,
+                              struct sk_buff *skb, struct inet_diag_req_v2 *req,
+                              u32 pid, u32 seq, u16 nlmsg_flags,
                               const struct nlmsghdr *unlh)
 {
        long tmo;
@@ -210,8 +224,7 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
        r->idiag_family       = tw->tw_family;
        r->idiag_retrans      = 0;
        r->id.idiag_if        = tw->tw_bound_dev_if;
-       r->id.idiag_cookie[0] = (u32)(unsigned long)tw;
-       r->id.idiag_cookie[1] = (u32)(((unsigned long)tw >> 31) >> 1);
+       sock_diag_save_cookie(tw, r->id.idiag_cookie);
        r->id.idiag_sport     = tw->tw_sport;
        r->id.idiag_dport     = tw->tw_dport;
        r->id.idiag_src[0]    = tw->tw_rcv_saddr;
@@ -223,15 +236,13 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
        r->idiag_wqueue       = 0;
        r->idiag_uid          = 0;
        r->idiag_inode        = 0;
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
        if (tw->tw_family == AF_INET6) {
                const struct inet6_timewait_sock *tw6 =
                                                inet6_twsk((struct sock *)tw);
 
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &tw6->tw_v6_rcv_saddr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &tw6->tw_v6_daddr);
+               *(struct in6_addr *)r->id.idiag_src = tw6->tw_v6_rcv_saddr;
+               *(struct in6_addr *)r->id.idiag_dst = tw6->tw_v6_daddr;
        }
 #endif
        nlh->nlmsg_len = skb_tail_pointer(skb) - previous_tail;
@@ -242,42 +253,31 @@ nlmsg_failure:
 }
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
-                       int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+                       struct inet_diag_req_v2 *r, u32 pid, u32 seq, u16 nlmsg_flags,
                        const struct nlmsghdr *unlh)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
                return inet_twsk_diag_fill((struct inet_timewait_sock *)sk,
-                                          skb, ext, pid, seq, nlmsg_flags,
+                                          skb, r, pid, seq, nlmsg_flags,
                                           unlh);
-       return inet_csk_diag_fill(sk, skb, ext, pid, seq, nlmsg_flags, unlh);
+       return inet_csk_diag_fill(sk, skb, r, pid, seq, nlmsg_flags, unlh);
 }
 
-static int inet_diag_get_exact(struct sk_buff *in_skb,
-                              const struct nlmsghdr *nlh)
+int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_skb,
+               const struct nlmsghdr *nlh, struct inet_diag_req_v2 *req)
 {
        int err;
        struct sock *sk;
-       struct inet_diag_req *req = NLMSG_DATA(nlh);
        struct sk_buff *rep;
-       struct inet_hashinfo *hashinfo;
-       const struct inet_diag_handler *handler;
 
-       handler = inet_diag_lock_handler(nlh->nlmsg_type);
-       if (IS_ERR(handler)) {
-               err = PTR_ERR(handler);
-               goto unlock;
-       }
-
-       hashinfo = handler->idiag_hashinfo;
        err = -EINVAL;
-
-       if (req->idiag_family == AF_INET) {
+       if (req->sdiag_family == AF_INET) {
                sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0],
                                 req->id.idiag_dport, req->id.idiag_src[0],
                                 req->id.idiag_sport, req->id.idiag_if);
        }
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
-       else if (req->idiag_family == AF_INET6) {
+#if IS_ENABLED(CONFIG_IPV6)
+       else if (req->sdiag_family == AF_INET6) {
                sk = inet6_lookup(&init_net, hashinfo,
                                  (struct in6_addr *)req->id.idiag_dst,
                                  req->id.idiag_dport,
@@ -287,29 +287,26 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        }
 #endif
        else {
-               goto unlock;
+               goto out_nosk;
        }
 
        err = -ENOENT;
        if (sk == NULL)
-               goto unlock;
+               goto out_nosk;
 
-       err = -ESTALE;
-       if ((req->id.idiag_cookie[0] != INET_DIAG_NOCOOKIE ||
-            req->id.idiag_cookie[1] != INET_DIAG_NOCOOKIE) &&
-           ((u32)(unsigned long)sk != req->id.idiag_cookie[0] ||
-            (u32)((((unsigned long)sk) >> 31) >> 1) != req->id.idiag_cookie[1]))
+       err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
+       if (err)
                goto out;
 
        err = -ENOMEM;
        rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) +
                                     sizeof(struct inet_diag_meminfo) +
-                                    handler->idiag_info_size + 64)),
+                                    sizeof(struct tcp_info) + 64)),
                        GFP_KERNEL);
        if (!rep)
                goto out;
 
-       err = sk_diag_fill(sk, rep, req->idiag_ext,
+       err = sk_diag_fill(sk, rep, req,
                           NETLINK_CB(in_skb).pid,
                           nlh->nlmsg_seq, 0, nlh);
        if (err < 0) {
@@ -317,7 +314,7 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
                kfree_skb(rep);
                goto out;
        }
-       err = netlink_unicast(idiagnl, rep, NETLINK_CB(in_skb).pid,
+       err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid,
                              MSG_DONTWAIT);
        if (err > 0)
                err = 0;
@@ -329,8 +326,25 @@ out:
                else
                        sock_put(sk);
        }
-unlock:
+out_nosk:
+       return err;
+}
+EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
+
+static int inet_diag_get_exact(struct sk_buff *in_skb,
+                              const struct nlmsghdr *nlh,
+                              struct inet_diag_req_v2 *req)
+{
+       const struct inet_diag_handler *handler;
+       int err;
+
+       handler = inet_diag_lock_handler(req->sdiag_protocol);
+       if (IS_ERR(handler))
+               err = PTR_ERR(handler);
+       else
+               err = handler->dump_one(in_skb, nlh, req);
        inet_diag_unlock_handler(handler);
+
        return err;
 }
 
@@ -361,9 +375,12 @@ static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
 }
 
 
-static int inet_diag_bc_run(const void *bc, int len,
-                           const struct inet_diag_entry *entry)
+static int inet_diag_bc_run(const struct nlattr *_bc,
+               const struct inet_diag_entry *entry)
 {
+       const void *bc = nla_data(_bc);
+       int len = nla_len(_bc);
+
        while (len > 0) {
                int yes = 1;
                const struct inet_diag_bc_op *op = bc;
@@ -437,6 +454,35 @@ static int inet_diag_bc_run(const void *bc, int len,
        return len == 0;
 }
 
+int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
+{
+       struct inet_diag_entry entry;
+       struct inet_sock *inet = inet_sk(sk);
+
+       if (bc == NULL)
+               return 1;
+
+       entry.family = sk->sk_family;
+#if IS_ENABLED(CONFIG_IPV6)
+       if (entry.family == AF_INET6) {
+               struct ipv6_pinfo *np = inet6_sk(sk);
+
+               entry.saddr = np->rcv_saddr.s6_addr32;
+               entry.daddr = np->daddr.s6_addr32;
+       } else
+#endif
+       {
+               entry.saddr = &inet->inet_rcv_saddr;
+               entry.daddr = &inet->inet_daddr;
+       }
+       entry.sport = inet->inet_num;
+       entry.dport = ntohs(inet->inet_dport);
+       entry.userlocks = sk->sk_userlocks;
+
+       return inet_diag_bc_run(bc, &entry);
+}
+EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
+
 static int valid_cc(const void *bc, int len, int cc)
 {
        while (len >= 0) {
@@ -493,57 +539,29 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
 
 static int inet_csk_diag_dump(struct sock *sk,
                              struct sk_buff *skb,
-                             struct netlink_callback *cb)
+                             struct netlink_callback *cb,
+                             struct inet_diag_req_v2 *r,
+                             const struct nlattr *bc)
 {
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
-
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
-               struct inet_diag_entry entry;
-               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
-                                                         sizeof(*r),
-                                                         INET_DIAG_REQ_BYTECODE);
-               struct inet_sock *inet = inet_sk(sk);
-
-               entry.family = sk->sk_family;
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
-               if (entry.family == AF_INET6) {
-                       struct ipv6_pinfo *np = inet6_sk(sk);
-
-                       entry.saddr = np->rcv_saddr.s6_addr32;
-                       entry.daddr = np->daddr.s6_addr32;
-               } else
-#endif
-               {
-                       entry.saddr = &inet->inet_rcv_saddr;
-                       entry.daddr = &inet->inet_daddr;
-               }
-               entry.sport = inet->inet_num;
-               entry.dport = ntohs(inet->inet_dport);
-               entry.userlocks = sk->sk_userlocks;
+       if (!inet_diag_bc_sk(bc, sk))
+               return 0;
 
-               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
-                       return 0;
-       }
-
-       return inet_csk_diag_fill(sk, skb, r->idiag_ext,
+       return inet_csk_diag_fill(sk, skb, r,
                                  NETLINK_CB(cb->skb).pid,
                                  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
 
 static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                               struct sk_buff *skb,
-                              struct netlink_callback *cb)
+                              struct netlink_callback *cb,
+                              struct inet_diag_req_v2 *r,
+                              const struct nlattr *bc)
 {
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
-
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
+       if (bc != NULL) {
                struct inet_diag_entry entry;
-               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
-                                                         sizeof(*r),
-                                                         INET_DIAG_REQ_BYTECODE);
 
                entry.family = tw->tw_family;
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
                if (tw->tw_family == AF_INET6) {
                        struct inet6_timewait_sock *tw6 =
                                                inet6_twsk((struct sock *)tw);
@@ -559,11 +577,11 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                entry.dport = ntohs(tw->tw_dport);
                entry.userlocks = 0;
 
-               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
+               if (!inet_diag_bc_run(bc, &entry))
                        return 0;
        }
 
-       return inet_twsk_diag_fill(tw, skb, r->idiag_ext,
+       return inet_twsk_diag_fill(tw, skb, r,
                                   NETLINK_CB(cb->skb).pid,
                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -589,8 +607,7 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
        r->idiag_retrans = req->retrans;
 
        r->id.idiag_if = sk->sk_bound_dev_if;
-       r->id.idiag_cookie[0] = (u32)(unsigned long)req;
-       r->id.idiag_cookie[1] = (u32)(((unsigned long)req >> 31) >> 1);
+       sock_diag_save_cookie(req, r->id.idiag_cookie);
 
        tmo = req->expires - jiffies;
        if (tmo < 0)
@@ -605,12 +622,10 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
        r->idiag_wqueue = 0;
        r->idiag_uid = sock_i_uid(sk);
        r->idiag_inode = 0;
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
        if (r->idiag_family == AF_INET6) {
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_src,
-                              &inet6_rsk(req)->loc_addr);
-               ipv6_addr_copy((struct in6_addr *)r->id.idiag_dst,
-                              &inet6_rsk(req)->rmt_addr);
+               *(struct in6_addr *)r->id.idiag_src = inet6_rsk(req)->loc_addr;
+               *(struct in6_addr *)r->id.idiag_dst = inet6_rsk(req)->rmt_addr;
        }
 #endif
        nlh->nlmsg_len = skb_tail_pointer(skb) - b;
@@ -623,13 +638,13 @@ nlmsg_failure:
 }
 
 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
-                              struct netlink_callback *cb)
+                              struct netlink_callback *cb,
+                              struct inet_diag_req_v2 *r,
+                              const struct nlattr *bc)
 {
        struct inet_diag_entry entry;
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct listen_sock *lopt;
-       const struct nlattr *bc = NULL;
        struct inet_sock *inet = inet_sk(sk);
        int j, s_j;
        int reqnum, s_reqnum;
@@ -649,9 +664,7 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        if (!lopt || !lopt->qlen)
                goto out;
 
-       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
-               bc = nlmsg_find_attr(cb->nlh, sizeof(*r),
-                                    INET_DIAG_REQ_BYTECODE);
+       if (bc != NULL) {
                entry.sport = inet->inet_num;
                entry.userlocks = sk->sk_userlocks;
        }
@@ -671,21 +684,20 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
 
                        if (bc) {
                                entry.saddr =
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
                                        (entry.family == AF_INET6) ?
                                        inet6_rsk(req)->loc_addr.s6_addr32 :
 #endif
                                        &ireq->loc_addr;
                                entry.daddr =
-#if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
+#if IS_ENABLED(CONFIG_IPV6)
                                        (entry.family == AF_INET6) ?
                                        inet6_rsk(req)->rmt_addr.s6_addr32 :
 #endif
                                        &ireq->rmt_addr;
                                entry.dport = ntohs(ireq->rmt_port);
 
-                               if (!inet_diag_bc_run(nla_data(bc),
-                                                     nla_len(bc), &entry))
+                               if (!inet_diag_bc_run(bc, &entry))
                                        continue;
                        }
 
@@ -708,19 +720,11 @@ out:
        return err;
 }
 
-static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
+void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
+               struct netlink_callback *cb, struct inet_diag_req_v2 *r, struct nlattr *bc)
 {
        int i, num;
        int s_i, s_num;
-       struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
-       const struct inet_diag_handler *handler;
-       struct inet_hashinfo *hashinfo;
-
-       handler = inet_diag_lock_handler(cb->nlh->nlmsg_type);
-       if (IS_ERR(handler))
-               goto unlock;
-
-       hashinfo = handler->idiag_hashinfo;
 
        s_i = cb->args[1];
        s_num = num = cb->args[2];
@@ -745,6 +749,10 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        continue;
                                }
 
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               sk->sk_family != r->sdiag_family)
+                                       goto next_listen;
+
                                if (r->id.idiag_sport != inet->inet_sport &&
                                    r->id.idiag_sport)
                                        goto next_listen;
@@ -754,7 +762,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                    cb->args[3] > 0)
                                        goto syn_recv;
 
-                               if (inet_csk_diag_dump(sk, skb, cb) < 0) {
+                               if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
                                        spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
@@ -763,7 +771,7 @@ syn_recv:
                                if (!(r->idiag_states & TCPF_SYN_RECV))
                                        goto next_listen;
 
-                               if (inet_diag_dump_reqs(skb, sk, cb) < 0) {
+                               if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
                                        spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
@@ -785,7 +793,7 @@ skip_listen_ht:
        }
 
        if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
-               goto unlock;
+               goto out;
 
        for (i = s_i; i <= hashinfo->ehash_mask; i++) {
                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
@@ -810,13 +818,16 @@ skip_listen_ht:
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
                                goto next_normal;
+                       if (r->sdiag_family != AF_UNSPEC &&
+                                       sk->sk_family != r->sdiag_family)
+                               goto next_normal;
                        if (r->id.idiag_sport != inet->inet_sport &&
                            r->id.idiag_sport)
                                goto next_normal;
                        if (r->id.idiag_dport != inet->inet_dport &&
                            r->id.idiag_dport)
                                goto next_normal;
-                       if (inet_csk_diag_dump(sk, skb, cb) < 0) {
+                       if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
                                spin_unlock_bh(lock);
                                goto done;
                        }
@@ -832,13 +843,16 @@ next_normal:
 
                                if (num < s_num)
                                        goto next_dying;
+                               if (r->sdiag_family != AF_UNSPEC &&
+                                               tw->tw_family != r->sdiag_family)
+                                       goto next_dying;
                                if (r->id.idiag_sport != tw->tw_sport &&
                                    r->id.idiag_sport)
                                        goto next_dying;
                                if (r->id.idiag_dport != tw->tw_dport &&
                                    r->id.idiag_dport)
                                        goto next_dying;
-                               if (inet_twsk_diag_dump(tw, skb, cb) < 0) {
+                               if (inet_twsk_diag_dump(tw, skb, cb, r, bc) < 0) {
                                        spin_unlock_bh(lock);
                                        goto done;
                                }
@@ -852,14 +866,84 @@ next_dying:
 done:
        cb->args[1] = i;
        cb->args[2] = num;
-unlock:
+out:
+       ;
+}
+EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
+
+static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
+               struct inet_diag_req_v2 *r, struct nlattr *bc)
+{
+       const struct inet_diag_handler *handler;
+
+       handler = inet_diag_lock_handler(r->sdiag_protocol);
+       if (!IS_ERR(handler))
+               handler->dump(skb, cb, r, bc);
        inet_diag_unlock_handler(handler);
+
        return skb->len;
 }
 
-static int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
+static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
 {
-       int hdrlen = sizeof(struct inet_diag_req);
+       struct nlattr *bc = NULL;
+       int hdrlen = sizeof(struct inet_diag_req_v2);
+
+       if (nlmsg_attrlen(cb->nlh, hdrlen))
+               bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
+
+       return __inet_diag_dump(skb, cb, (struct inet_diag_req_v2 *)NLMSG_DATA(cb->nlh), bc);
+}
+
+static inline int inet_diag_type2proto(int type)
+{
+       switch (type) {
+       case TCPDIAG_GETSOCK:
+               return IPPROTO_TCP;
+       case DCCPDIAG_GETSOCK:
+               return IPPROTO_DCCP;
+       default:
+               return 0;
+       }
+}
+
+static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)
+{
+       struct inet_diag_req_compat *rc = NLMSG_DATA(cb->nlh);
+       struct inet_diag_req_v2 req;
+       struct nlattr *bc = NULL;
+       int hdrlen = sizeof(struct inet_diag_req_compat);
+
+       req.sdiag_family = AF_UNSPEC; /* compatibility */
+       req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
+       req.idiag_ext = rc->idiag_ext;
+       req.idiag_states = rc->idiag_states;
+       req.id = rc->id;
+
+       if (nlmsg_attrlen(cb->nlh, hdrlen))
+               bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
+
+       return __inet_diag_dump(skb, cb, &req, bc);
+}
+
+static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
+                              const struct nlmsghdr *nlh)
+{
+       struct inet_diag_req_compat *rc = NLMSG_DATA(nlh);
+       struct inet_diag_req_v2 req;
+
+       req.sdiag_family = rc->idiag_family;
+       req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
+       req.idiag_ext = rc->idiag_ext;
+       req.idiag_states = rc->idiag_states;
+       req.id = rc->id;
+
+       return inet_diag_get_exact(in_skb, nlh, &req);
+}
+
+static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
+{
+       int hdrlen = sizeof(struct inet_diag_req_compat);
 
        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
            nlmsg_len(nlh) < hdrlen)
@@ -877,28 +961,54 @@ static int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                                return -EINVAL;
                }
 
-               return netlink_dump_start(idiagnl, skb, nlh,
-                                         inet_diag_dump, NULL, 0);
+               return netlink_dump_start(sock_diag_nlsk, skb, nlh,
+                                         inet_diag_dump_compat, NULL, 0);
        }
 
-       return inet_diag_get_exact(skb, nlh);
+       return inet_diag_get_exact_compat(skb, nlh);
 }
 
-static DEFINE_MUTEX(inet_diag_mutex);
-
-static void inet_diag_rcv(struct sk_buff *skb)
+static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 {
-       mutex_lock(&inet_diag_mutex);
-       netlink_rcv_skb(skb, &inet_diag_rcv_msg);
-       mutex_unlock(&inet_diag_mutex);
+       int hdrlen = sizeof(struct inet_diag_req_v2);
+
+       if (nlmsg_len(h) < hdrlen)
+               return -EINVAL;
+
+       if (h->nlmsg_flags & NLM_F_DUMP) {
+               if (nlmsg_attrlen(h, hdrlen)) {
+                       struct nlattr *attr;
+                       attr = nlmsg_find_attr(h, hdrlen,
+                                              INET_DIAG_REQ_BYTECODE);
+                       if (attr == NULL ||
+                           nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
+                           inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
+                               return -EINVAL;
+               }
+
+               return netlink_dump_start(sock_diag_nlsk, skb, h,
+                                         inet_diag_dump, NULL, 0);
+       }
+
+       return inet_diag_get_exact(skb, h, (struct inet_diag_req_v2 *)NLMSG_DATA(h));
 }
 
+static struct sock_diag_handler inet_diag_handler = {
+       .family = AF_INET,
+       .dump = inet_diag_handler_dump,
+};
+
+static struct sock_diag_handler inet6_diag_handler = {
+       .family = AF_INET6,
+       .dump = inet_diag_handler_dump,
+};
+
 int inet_diag_register(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
        int err = -EINVAL;
 
-       if (type >= INET_DIAG_GETSOCK_MAX)
+       if (type >= IPPROTO_MAX)
                goto out;
 
        mutex_lock(&inet_diag_table_mutex);
@@ -917,7 +1027,7 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
 
-       if (type >= INET_DIAG_GETSOCK_MAX)
+       if (type >= IPPROTO_MAX)
                return;
 
        mutex_lock(&inet_diag_table_mutex);
@@ -928,7 +1038,7 @@ EXPORT_SYMBOL_GPL(inet_diag_unregister);
 
 static int __init inet_diag_init(void)
 {
-       const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX *
+       const int inet_diag_table_size = (IPPROTO_MAX *
                                          sizeof(struct inet_diag_handler *));
        int err = -ENOMEM;
 
@@ -936,25 +1046,35 @@ static int __init inet_diag_init(void)
        if (!inet_diag_table)
                goto out;
 
-       idiagnl = netlink_kernel_create(&init_net, NETLINK_INET_DIAG, 0,
-                                       inet_diag_rcv, NULL, THIS_MODULE);
-       if (idiagnl == NULL)
-               goto out_free_table;
-       err = 0;
+       err = sock_diag_register(&inet_diag_handler);
+       if (err)
+               goto out_free_nl;
+
+       err = sock_diag_register(&inet6_diag_handler);
+       if (err)
+               goto out_free_inet;
+
+       sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
 out:
        return err;
-out_free_table:
+
+out_free_inet:
+       sock_diag_unregister(&inet_diag_handler);
+out_free_nl:
        kfree(inet_diag_table);
        goto out;
 }
 
 static void __exit inet_diag_exit(void)
 {
-       netlink_kernel_release(idiagnl);
+       sock_diag_unregister(&inet6_diag_handler);
+       sock_diag_unregister(&inet_diag_handler);
+       sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
        kfree(inet_diag_table);
 }
 
 module_init(inet_diag_init);
 module_exit(inet_diag_exit);
 MODULE_LICENSE("GPL");
-MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_INET_DIAG);
+MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
+MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);