Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/jmorris...
[pandora-kernel.git] / net / netlink / af_netlink.c
index 1f15821..5681ce3 100644 (file)
@@ -62,6 +62,7 @@
 #include <net/netlink.h>
 
 #define NLGRPSZ(x)     (ALIGN(x, sizeof(unsigned long) * 8) / 8)
+#define NLGRPLONGS(x)  (NLGRPSZ(x)/sizeof(unsigned long))
 
 struct netlink_sock {
        /* struct sock has to be the first member of netlink_sock */
@@ -314,10 +315,12 @@ netlink_update_listeners(struct sock *sk)
        unsigned long mask;
        unsigned int i;
 
-       for (i = 0; i < NLGRPSZ(tbl->groups)/sizeof(unsigned long); i++) {
+       for (i = 0; i < NLGRPLONGS(tbl->groups); i++) {
                mask = 0;
-               sk_for_each_bound(sk, node, &tbl->mc_list)
-                       mask |= nlk_sk(sk)->groups[i];
+               sk_for_each_bound(sk, node, &tbl->mc_list) {
+                       if (i < NLGRPLONGS(nlk_sk(sk)->ngroups))
+                               mask |= nlk_sk(sk)->groups[i];
+               }
                tbl->listeners[i] = mask;
        }
        /* this function is only called with the netlink table "grabbed", which
@@ -555,26 +558,37 @@ netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions)
        nlk->subscriptions = subscriptions;
 }
 
-static int netlink_alloc_groups(struct sock *sk)
+static int netlink_realloc_groups(struct sock *sk)
 {
        struct netlink_sock *nlk = nlk_sk(sk);
        unsigned int groups;
+       unsigned long *new_groups;
        int err = 0;
 
-       netlink_lock_table();
+       netlink_table_grab();
+
        groups = nl_table[sk->sk_protocol].groups;
-       if (!nl_table[sk->sk_protocol].registered)
+       if (!nl_table[sk->sk_protocol].registered) {
                err = -ENOENT;
-       netlink_unlock_table();
+               goto out_unlock;
+       }
 
-       if (err)
-               return err;
+       if (nlk->ngroups >= groups)
+               goto out_unlock;
 
-       nlk->groups = kzalloc(NLGRPSZ(groups), GFP_KERNEL);
-       if (nlk->groups == NULL)
-               return -ENOMEM;
+       new_groups = krealloc(nlk->groups, NLGRPSZ(groups), GFP_ATOMIC);
+       if (new_groups == NULL) {
+               err = -ENOMEM;
+               goto out_unlock;
+       }
+       memset((char*)new_groups + NLGRPSZ(nlk->ngroups), 0,
+              NLGRPSZ(groups) - NLGRPSZ(nlk->ngroups));
+
+       nlk->groups = new_groups;
        nlk->ngroups = groups;
-       return 0;
+ out_unlock:
+       netlink_table_ungrab();
+       return err;
 }
 
 static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
@@ -591,11 +605,9 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len
        if (nladdr->nl_groups) {
                if (!netlink_capable(sock, NL_NONROOT_RECV))
                        return -EPERM;
-               if (nlk->groups == NULL) {
-                       err = netlink_alloc_groups(sk);
-                       if (err)
-                               return err;
-               }
+               err = netlink_realloc_groups(sk);
+               if (err)
+                       return err;
        }
 
        if (nlk->pid) {
@@ -839,10 +851,18 @@ retry:
 int netlink_has_listeners(struct sock *sk, unsigned int group)
 {
        int res = 0;
+       unsigned long *listeners;
 
        BUG_ON(!(nlk_sk(sk)->flags & NETLINK_KERNEL_SOCKET));
+
+       rcu_read_lock();
+       listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners);
+
        if (group - 1 < nl_table[sk->sk_protocol].groups)
-               res = test_bit(group - 1, nl_table[sk->sk_protocol].listeners);
+               res = test_bit(group - 1, listeners);
+
+       rcu_read_unlock();
+
        return res;
 }
 EXPORT_SYMBOL_GPL(netlink_has_listeners);
@@ -1007,18 +1027,36 @@ void netlink_set_err(struct sock *ssk, u32 pid, u32 group, int code)
        read_unlock(&nl_table_lock);
 }
 
+/* must be called with netlink table grabbed */
+static void netlink_update_socket_mc(struct netlink_sock *nlk,
+                                    unsigned int group,
+                                    int is_new)
+{
+       int old, new = !!is_new, subscriptions;
+
+       old = test_bit(group - 1, nlk->groups);
+       subscriptions = nlk->subscriptions - old + new;
+       if (new)
+               __set_bit(group - 1, nlk->groups);
+       else
+               __clear_bit(group - 1, nlk->groups);
+       netlink_update_subscriptions(&nlk->sk, subscriptions);
+       netlink_update_listeners(&nlk->sk);
+}
+
 static int netlink_setsockopt(struct socket *sock, int level, int optname,
                              char __user *optval, int optlen)
 {
        struct sock *sk = sock->sk;
        struct netlink_sock *nlk = nlk_sk(sk);
-       int val = 0, err;
+       unsigned int val = 0;
+       int err;
 
        if (level != SOL_NETLINK)
                return -ENOPROTOOPT;
 
        if (optlen >= sizeof(int) &&
-           get_user(val, (int __user *)optval))
+           get_user(val, (unsigned int __user *)optval))
                return -EFAULT;
 
        switch (optname) {
@@ -1031,27 +1069,16 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
                break;
        case NETLINK_ADD_MEMBERSHIP:
        case NETLINK_DROP_MEMBERSHIP: {
-               unsigned int subscriptions;
-               int old, new = optname == NETLINK_ADD_MEMBERSHIP ? 1 : 0;
-
                if (!netlink_capable(sock, NL_NONROOT_RECV))
                        return -EPERM;
-               if (nlk->groups == NULL) {
-                       err = netlink_alloc_groups(sk);
-                       if (err)
-                               return err;
-               }
+               err = netlink_realloc_groups(sk);
+               if (err)
+                       return err;
                if (!val || val - 1 >= nlk->ngroups)
                        return -EINVAL;
                netlink_table_grab();
-               old = test_bit(val - 1, nlk->groups);
-               subscriptions = nlk->subscriptions - old + new;
-               if (new)
-                       __set_bit(val - 1, nlk->groups);
-               else
-                       __clear_bit(val - 1, nlk->groups);
-               netlink_update_subscriptions(sk, subscriptions);
-               netlink_update_listeners(sk);
+               netlink_update_socket_mc(nlk, val,
+                                        optname == NETLINK_ADD_MEMBERSHIP);
                netlink_table_ungrab();
                err = 0;
                break;
@@ -1327,6 +1354,71 @@ out_sock_release:
        return NULL;
 }
 
+/**
+ * netlink_change_ngroups - change number of multicast groups
+ *
+ * This changes the number of multicast groups that are available
+ * on a certain netlink family. Note that it is not possible to
+ * change the number of groups to below 32. Also note that it does
+ * not implicitly call netlink_clear_multicast_users() when the
+ * number of groups is reduced.
+ *
+ * @sk: The kernel netlink socket, as returned by netlink_kernel_create().
+ * @groups: The new number of groups.
+ */
+int netlink_change_ngroups(struct sock *sk, unsigned int groups)
+{
+       unsigned long *listeners, *old = NULL;
+       struct netlink_table *tbl = &nl_table[sk->sk_protocol];
+       int err = 0;
+
+       if (groups < 32)
+               groups = 32;
+
+       netlink_table_grab();
+       if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) {
+               listeners = kzalloc(NLGRPSZ(groups), GFP_ATOMIC);
+               if (!listeners) {
+                       err = -ENOMEM;
+                       goto out_ungrab;
+               }
+               old = tbl->listeners;
+               memcpy(listeners, old, NLGRPSZ(tbl->groups));
+               rcu_assign_pointer(tbl->listeners, listeners);
+       }
+       tbl->groups = groups;
+
+ out_ungrab:
+       netlink_table_ungrab();
+       synchronize_rcu();
+       kfree(old);
+       return err;
+}
+EXPORT_SYMBOL(netlink_change_ngroups);
+
+/**
+ * netlink_clear_multicast_users - kick off multicast listeners
+ *
+ * This function removes all listeners from the given group.
+ * @ksk: The kernel netlink socket, as returned by
+ *     netlink_kernel_create().
+ * @group: The multicast group to clear.
+ */
+void netlink_clear_multicast_users(struct sock *ksk, unsigned int group)
+{
+       struct sock *sk;
+       struct hlist_node *node;
+       struct netlink_table *tbl = &nl_table[ksk->sk_protocol];
+
+       netlink_table_grab();
+
+       sk_for_each_bound(sk, node, &tbl->mc_list)
+               netlink_update_socket_mc(nlk_sk(sk), group, 0);
+
+       netlink_table_ungrab();
+}
+EXPORT_SYMBOL(netlink_clear_multicast_users);
+
 void netlink_set_nonroot(int protocol, unsigned int flags)
 {
        if ((unsigned int)protocol < MAX_LINKS)
@@ -1713,7 +1805,7 @@ static int netlink_seq_show(struct seq_file *seq, void *v)
        return 0;
 }
 
-static struct seq_operations netlink_seq_ops = {
+static const struct seq_operations netlink_seq_ops = {
        .start  = netlink_seq_start,
        .next   = netlink_seq_next,
        .stop   = netlink_seq_stop,