[NETLINK]: Support dynamic number of multicast groups per netlink family
authorPatrick McHardy <kaber@trash.net>
Mon, 15 Aug 2005 19:29:13 +0000 (12:29 -0700)
committerDavid S. Miller <davem@sunset.davemloft.net>
Mon, 29 Aug 2005 23:01:02 +0000 (16:01 -0700)
Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/netlink/af_netlink.c

index 444ed22..58d4ca4 100644 (file)
 #include <net/scm.h>
 
 #define Nprintk(a...)
+#define NLGRPSZ(x)     (ALIGN(x, sizeof(unsigned long) * 8) / 8)
 
 struct netlink_sock {
        /* struct sock has to be the first member of netlink_sock */
        struct sock             sk;
        u32                     pid;
-       unsigned int            groups;
        u32                     dst_pid;
        u32                     dst_group;
+       u32                     flags;
+       u32                     subscriptions;
+       u32                     ngroups;
+       unsigned long           *groups;
        unsigned long           state;
        wait_queue_head_t       wait;
        struct netlink_callback *cb;
        spinlock_t              cb_lock;
        void                    (*data_ready)(struct sock *sk, int bytes);
        struct module           *module;
-       u32                     flags;
 };
 
 #define NETLINK_KERNEL_SOCKET  0x1
@@ -101,6 +104,7 @@ struct netlink_table {
        struct nl_pid_hash hash;
        struct hlist_head mc_list;
        unsigned int nl_nonroot;
+       unsigned int groups;
        struct module *module;
        int registered;
 };
@@ -138,6 +142,7 @@ static void netlink_sock_destruct(struct sock *sk)
        BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc));
        BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc));
        BUG_TRAP(!nlk_sk(sk)->cb);
+       BUG_TRAP(!nlk_sk(sk)->groups);
 }
 
 /* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP.
@@ -333,7 +338,7 @@ static void netlink_remove(struct sock *sk)
        netlink_table_grab();
        if (sk_del_node_init(sk))
                nl_table[sk->sk_protocol].hash.entries--;
-       if (nlk_sk(sk)->groups)
+       if (nlk_sk(sk)->subscriptions)
                __sk_del_bind_node(sk);
        netlink_table_ungrab();
 }
@@ -369,6 +374,8 @@ static int __netlink_create(struct socket *sock, int protocol)
 static int netlink_create(struct socket *sock, int protocol)
 {
        struct module *module = NULL;
+       struct netlink_sock *nlk;
+       unsigned int groups;
        int err = 0;
 
        sock->state = SS_UNCONNECTED;
@@ -392,15 +399,23 @@ static int netlink_create(struct socket *sock, int protocol)
                module = nl_table[protocol].module;
        else
                err = -EPROTONOSUPPORT;
+       groups = nl_table[protocol].groups;
        netlink_unlock_table();
 
-       if (err)
-               goto out;
+       if (err || (err = __netlink_create(sock, protocol) < 0))
+               goto out_module;
+
+       nlk = nlk_sk(sock->sk);
 
-       if ((err = __netlink_create(sock, protocol) < 0))
+       nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL);
+       if (nlk->groups == NULL) {
+               err = -ENOMEM;
                goto out_module;
+       }
+       memset(nlk->groups, 0, NLGRPSZ(groups));
+       nlk->ngroups = groups;
 
-       nlk_sk(sock->sk)->module = module;
+       nlk->module = module;
 out:
        return err;
 
@@ -437,7 +452,7 @@ static int netlink_release(struct socket *sock)
 
        skb_queue_purge(&sk->sk_write_queue);
 
-       if (nlk->pid && !nlk->groups) {
+       if (nlk->pid && !nlk->subscriptions) {
                struct netlink_notify n = {
                                                .protocol = sk->sk_protocol,
                                                .pid = nlk->pid,
@@ -455,6 +470,9 @@ static int netlink_release(struct socket *sock)
                netlink_table_ungrab();
        }
 
+       kfree(nlk->groups);
+       nlk->groups = NULL;
+
        sock_put(sk);
        return 0;
 }
@@ -503,6 +521,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag)
               capable(CAP_NET_ADMIN);
 } 
 
+static void
+netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
+{
+       struct netlink_sock *nlk = nlk_sk(sk);
+
+       if (nlk->subscriptions && !subscriptions)
+               __sk_del_bind_node(sk);
+       else if (!nlk->subscriptions && subscriptions)
+               sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
+       nlk->subscriptions = subscriptions;
+}
+
 static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 {
        struct sock *sk = sock->sk;
@@ -528,15 +558,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
                        return err;
        }
 
-       if (!nladdr->nl_groups && !nlk->groups)
+       if (!nladdr->nl_groups && !(u32)nlk->groups[0])
                return 0;
 
        netlink_table_grab();
-       if (nlk->groups && !nladdr->nl_groups)
-               __sk_del_bind_node(sk);
-       else if (!nlk->groups && nladdr->nl_groups)
-               sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list);
-       nlk->groups = nladdr->nl_groups;
+       netlink_update_subscriptions(sk, nlk->subscriptions +
+                                        hweight32(nladdr->nl_groups) -
+                                        hweight32(nlk->groups[0]));
+       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; 
        netlink_table_ungrab();
 
        return 0;
@@ -590,7 +619,7 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr
                nladdr->nl_groups = netlink_group_mask(nlk->dst_group);
        } else {
                nladdr->nl_pid = nlk->pid;
-               nladdr->nl_groups = nlk->groups
+               nladdr->nl_groups = nlk->groups[0];
        }
        return 0;
 }
@@ -791,7 +820,8 @@ static inline int do_one_broadcast(struct sock *sk,
        if (p->exclude_sk == sk)
                goto out;
 
-       if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
+       if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
+           !test_bit(p->group - 1, nlk->groups))
                goto out;
 
        if (p->failure) {
@@ -887,7 +917,8 @@ static inline int do_one_set_err(struct sock *sk,
        if (sk == p->exclude_sk)
                goto out;
 
-       if (nlk->pid == p->pid || !(nlk->groups & netlink_group_mask(p->group)))
+       if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups ||
+           !test_bit(p->group - 1, nlk->groups))
                goto out;
 
        sk->sk_err = p->code;
@@ -1112,6 +1143,7 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct
        nlk->flags |= NETLINK_KERNEL_SOCKET;
 
        netlink_table_grab();
+       nl_table[unit].groups = 32;
        nl_table[unit].module = module;
        nl_table[unit].registered = 1;
        netlink_table_ungrab();
@@ -1358,7 +1390,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v)
                           s,
                           s->sk_protocol,
                           nlk->pid,
-                          nlk->groups,
+                          nlk->flags & NETLINK_KERNEL_SOCKET ?
+                               0 : (unsigned int)nlk->groups[0],
                           atomic_read(&s->sk_rmem_alloc),
                           atomic_read(&s->sk_wmem_alloc),
                           nlk->cb,