netfilter: ipset: References are protected by rwlock instead of mutex
[pandora-kernel.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/netlink.h>
19 #include <linux/rculist.h>
20 #include <linux/version.h>
21 #include <net/netlink.h>
22
23 #include <linux/netfilter.h>
24 #include <linux/netfilter/nfnetlink.h>
25 #include <linux/netfilter/ipset/ip_set.h>
26
27 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
28 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
29 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
30
31 static struct ip_set **ip_set_list;             /* all individual sets */
32 static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
33
34 #define STREQ(a, b)     (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
35
36 static unsigned int max_sets;
37
38 module_param(max_sets, int, 0600);
39 MODULE_PARM_DESC(max_sets, "maximal number of sets");
40 MODULE_LICENSE("GPL");
41 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
42 MODULE_DESCRIPTION("core IP set support");
43 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
44
45 /*
46  * The set types are implemented in modules and registered set types
47  * can be found in ip_set_type_list. Adding/deleting types is
48  * serialized by ip_set_type_mutex.
49  */
50
51 static inline void
52 ip_set_type_lock(void)
53 {
54         mutex_lock(&ip_set_type_mutex);
55 }
56
57 static inline void
58 ip_set_type_unlock(void)
59 {
60         mutex_unlock(&ip_set_type_mutex);
61 }
62
63 /* Register and deregister settype */
64
65 static struct ip_set_type *
66 find_set_type(const char *name, u8 family, u8 revision)
67 {
68         struct ip_set_type *type;
69
70         list_for_each_entry_rcu(type, &ip_set_type_list, list)
71                 if (STREQ(type->name, name) &&
72                     (type->family == family || type->family == AF_UNSPEC) &&
73                     type->revision == revision)
74                         return type;
75         return NULL;
76 }
77
78 /* Unlock, try to load a set type module and lock again */
79 static int
80 try_to_load_type(const char *name)
81 {
82         nfnl_unlock();
83         pr_debug("try to load ip_set_%s\n", name);
84         if (request_module("ip_set_%s", name) < 0) {
85                 pr_warning("Can't find ip_set type %s\n", name);
86                 nfnl_lock();
87                 return -IPSET_ERR_FIND_TYPE;
88         }
89         nfnl_lock();
90         return -EAGAIN;
91 }
92
93 /* Find a set type and reference it */
94 static int
95 find_set_type_get(const char *name, u8 family, u8 revision,
96                   struct ip_set_type **found)
97 {
98         struct ip_set_type *type;
99         int err;
100
101         rcu_read_lock();
102         *found = find_set_type(name, family, revision);
103         if (*found) {
104                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
105                 goto unlock;
106         }
107         /* Make sure the type is loaded but we don't support the revision */
108         list_for_each_entry_rcu(type, &ip_set_type_list, list)
109                 if (STREQ(type->name, name)) {
110                         err = -IPSET_ERR_FIND_TYPE;
111                         goto unlock;
112                 }
113         rcu_read_unlock();
114
115         return try_to_load_type(name);
116
117 unlock:
118         rcu_read_unlock();
119         return err;
120 }
121
122 /* Find a given set type by name and family.
123  * If we succeeded, the supported minimal and maximum revisions are
124  * filled out.
125  */
126 static int
127 find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max)
128 {
129         struct ip_set_type *type;
130         bool found = false;
131
132         *min = 255; *max = 0;
133         rcu_read_lock();
134         list_for_each_entry_rcu(type, &ip_set_type_list, list)
135                 if (STREQ(type->name, name) &&
136                     (type->family == family || type->family == AF_UNSPEC)) {
137                         found = true;
138                         if (type->revision < *min)
139                                 *min = type->revision;
140                         if (type->revision > *max)
141                                 *max = type->revision;
142                 }
143         rcu_read_unlock();
144         if (found)
145                 return 0;
146
147         return try_to_load_type(name);
148 }
149
150 #define family_name(f)  ((f) == AF_INET ? "inet" : \
151                          (f) == AF_INET6 ? "inet6" : "any")
152
153 /* Register a set type structure. The type is identified by
154  * the unique triple of name, family and revision.
155  */
156 int
157 ip_set_type_register(struct ip_set_type *type)
158 {
159         int ret = 0;
160
161         if (type->protocol != IPSET_PROTOCOL) {
162                 pr_warning("ip_set type %s, family %s, revision %u uses "
163                            "wrong protocol version %u (want %u)\n",
164                            type->name, family_name(type->family),
165                            type->revision, type->protocol, IPSET_PROTOCOL);
166                 return -EINVAL;
167         }
168
169         ip_set_type_lock();
170         if (find_set_type(type->name, type->family, type->revision)) {
171                 /* Duplicate! */
172                 pr_warning("ip_set type %s, family %s, revision %u "
173                            "already registered!\n", type->name,
174                            family_name(type->family), type->revision);
175                 ret = -EINVAL;
176                 goto unlock;
177         }
178         list_add_rcu(&type->list, &ip_set_type_list);
179         pr_debug("type %s, family %s, revision %u registered.\n",
180                  type->name, family_name(type->family), type->revision);
181 unlock:
182         ip_set_type_unlock();
183         return ret;
184 }
185 EXPORT_SYMBOL_GPL(ip_set_type_register);
186
187 /* Unregister a set type. There's a small race with ip_set_create */
188 void
189 ip_set_type_unregister(struct ip_set_type *type)
190 {
191         ip_set_type_lock();
192         if (!find_set_type(type->name, type->family, type->revision)) {
193                 pr_warning("ip_set type %s, family %s, revision %u "
194                            "not registered\n", type->name,
195                            family_name(type->family), type->revision);
196                 goto unlock;
197         }
198         list_del_rcu(&type->list);
199         pr_debug("type %s, family %s, revision %u unregistered.\n",
200                  type->name, family_name(type->family), type->revision);
201 unlock:
202         ip_set_type_unlock();
203
204         synchronize_rcu();
205 }
206 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
207
208 /* Utility functions */
209 void *
210 ip_set_alloc(size_t size)
211 {
212         void *members = NULL;
213
214         if (size < KMALLOC_MAX_SIZE)
215                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
216
217         if (members) {
218                 pr_debug("%p: allocated with kmalloc\n", members);
219                 return members;
220         }
221
222         members = vzalloc(size);
223         if (!members)
224                 return NULL;
225         pr_debug("%p: allocated with vmalloc\n", members);
226
227         return members;
228 }
229 EXPORT_SYMBOL_GPL(ip_set_alloc);
230
231 void
232 ip_set_free(void *members)
233 {
234         pr_debug("%p: free with %s\n", members,
235                  is_vmalloc_addr(members) ? "vfree" : "kfree");
236         if (is_vmalloc_addr(members))
237                 vfree(members);
238         else
239                 kfree(members);
240 }
241 EXPORT_SYMBOL_GPL(ip_set_free);
242
243 static inline bool
244 flag_nested(const struct nlattr *nla)
245 {
246         return nla->nla_type & NLA_F_NESTED;
247 }
248
249 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
250         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
251         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
252                                             .len = sizeof(struct in6_addr) },
253 };
254
255 int
256 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
257 {
258         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
259
260         if (unlikely(!flag_nested(nla)))
261                 return -IPSET_ERR_PROTOCOL;
262         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
263                 return -IPSET_ERR_PROTOCOL;
264         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
265                 return -IPSET_ERR_PROTOCOL;
266
267         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
268         return 0;
269 }
270 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
271
272 int
273 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
274 {
275         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
276
277         if (unlikely(!flag_nested(nla)))
278                 return -IPSET_ERR_PROTOCOL;
279
280         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
281                 return -IPSET_ERR_PROTOCOL;
282         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
283                 return -IPSET_ERR_PROTOCOL;
284
285         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
286                 sizeof(struct in6_addr));
287         return 0;
288 }
289 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
290
291 /*
292  * Creating/destroying/renaming/swapping affect the existence and
293  * the properties of a set. All of these can be executed from userspace
294  * only and serialized by the nfnl mutex indirectly from nfnetlink.
295  *
296  * Sets are identified by their index in ip_set_list and the index
297  * is used by the external references (set/SET netfilter modules).
298  *
299  * The set behind an index may change by swapping only, from userspace.
300  */
301
302 static inline void
303 __ip_set_get(ip_set_id_t index)
304 {
305         write_lock_bh(&ip_set_ref_lock);
306         ip_set_list[index]->ref++;
307         write_unlock_bh(&ip_set_ref_lock);
308 }
309
310 static inline void
311 __ip_set_put(ip_set_id_t index)
312 {
313         write_lock_bh(&ip_set_ref_lock);
314         BUG_ON(ip_set_list[index]->ref == 0);
315         ip_set_list[index]->ref--;
316         write_unlock_bh(&ip_set_ref_lock);
317 }
318
319 /*
320  * Add, del and test set entries from kernel.
321  *
322  * The set behind the index must exist and must be referenced
323  * so it can't be destroyed (or changed) under our foot.
324  */
325
326 int
327 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
328             u8 family, u8 dim, u8 flags)
329 {
330         struct ip_set *set = ip_set_list[index];
331         int ret = 0;
332
333         BUG_ON(set == NULL);
334         pr_debug("set %s, index %u\n", set->name, index);
335
336         if (dim < set->type->dimension ||
337             !(family == set->family || set->family == AF_UNSPEC))
338                 return 0;
339
340         read_lock_bh(&set->lock);
341         ret = set->variant->kadt(set, skb, IPSET_TEST, family, dim, flags);
342         read_unlock_bh(&set->lock);
343
344         if (ret == -EAGAIN) {
345                 /* Type requests element to be completed */
346                 pr_debug("element must be competed, ADD is triggered\n");
347                 write_lock_bh(&set->lock);
348                 set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
349                 write_unlock_bh(&set->lock);
350                 ret = 1;
351         }
352
353         /* Convert error codes to nomatch */
354         return (ret < 0 ? 0 : ret);
355 }
356 EXPORT_SYMBOL_GPL(ip_set_test);
357
358 int
359 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
360            u8 family, u8 dim, u8 flags)
361 {
362         struct ip_set *set = ip_set_list[index];
363         int ret;
364
365         BUG_ON(set == NULL);
366         pr_debug("set %s, index %u\n", set->name, index);
367
368         if (dim < set->type->dimension ||
369             !(family == set->family || set->family == AF_UNSPEC))
370                 return 0;
371
372         write_lock_bh(&set->lock);
373         ret = set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
374         write_unlock_bh(&set->lock);
375
376         return ret;
377 }
378 EXPORT_SYMBOL_GPL(ip_set_add);
379
380 int
381 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
382            u8 family, u8 dim, u8 flags)
383 {
384         struct ip_set *set = ip_set_list[index];
385         int ret = 0;
386
387         BUG_ON(set == NULL);
388         pr_debug("set %s, index %u\n", set->name, index);
389
390         if (dim < set->type->dimension ||
391             !(family == set->family || set->family == AF_UNSPEC))
392                 return 0;
393
394         write_lock_bh(&set->lock);
395         ret = set->variant->kadt(set, skb, IPSET_DEL, family, dim, flags);
396         write_unlock_bh(&set->lock);
397
398         return ret;
399 }
400 EXPORT_SYMBOL_GPL(ip_set_del);
401
402 /*
403  * Find set by name, reference it once. The reference makes sure the
404  * thing pointed to, does not go away under our feet.
405  *
406  */
407 ip_set_id_t
408 ip_set_get_byname(const char *name, struct ip_set **set)
409 {
410         ip_set_id_t i, index = IPSET_INVALID_ID;
411         struct ip_set *s;
412
413         for (i = 0; i < ip_set_max; i++) {
414                 s = ip_set_list[i];
415                 if (s != NULL && STREQ(s->name, name)) {
416                         __ip_set_get(i);
417                         index = i;
418                         *set = s;
419                 }
420         }
421
422         return index;
423 }
424 EXPORT_SYMBOL_GPL(ip_set_get_byname);
425
426 /*
427  * If the given set pointer points to a valid set, decrement
428  * reference count by 1. The caller shall not assume the index
429  * to be valid, after calling this function.
430  *
431  */
432 void
433 ip_set_put_byindex(ip_set_id_t index)
434 {
435         if (ip_set_list[index] != NULL)
436                 __ip_set_put(index);
437 }
438 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
439
440 /*
441  * Get the name of a set behind a set index.
442  * We assume the set is referenced, so it does exist and
443  * can't be destroyed. The set cannot be renamed due to
444  * the referencing either.
445  *
446  */
447 const char *
448 ip_set_name_byindex(ip_set_id_t index)
449 {
450         const struct ip_set *set = ip_set_list[index];
451
452         BUG_ON(set == NULL);
453         BUG_ON(set->ref == 0);
454
455         /* Referenced, so it's safe */
456         return set->name;
457 }
458 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
459
460 /*
461  * Routines to call by external subsystems, which do not
462  * call nfnl_lock for us.
463  */
464
465 /*
466  * Find set by name, reference it once. The reference makes sure the
467  * thing pointed to, does not go away under our feet.
468  *
469  * The nfnl mutex is used in the function.
470  */
471 ip_set_id_t
472 ip_set_nfnl_get(const char *name)
473 {
474         struct ip_set *s;
475         ip_set_id_t index;
476
477         nfnl_lock();
478         index = ip_set_get_byname(name, &s);
479         nfnl_unlock();
480
481         return index;
482 }
483 EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
484
485 /*
486  * Find set by index, reference it once. The reference makes sure the
487  * thing pointed to, does not go away under our feet.
488  *
489  * The nfnl mutex is used in the function.
490  */
491 ip_set_id_t
492 ip_set_nfnl_get_byindex(ip_set_id_t index)
493 {
494         if (index > ip_set_max)
495                 return IPSET_INVALID_ID;
496
497         nfnl_lock();
498         if (ip_set_list[index])
499                 __ip_set_get(index);
500         else
501                 index = IPSET_INVALID_ID;
502         nfnl_unlock();
503
504         return index;
505 }
506 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
507
508 /*
509  * If the given set pointer points to a valid set, decrement
510  * reference count by 1. The caller shall not assume the index
511  * to be valid, after calling this function.
512  *
513  * The nfnl mutex is used in the function.
514  */
515 void
516 ip_set_nfnl_put(ip_set_id_t index)
517 {
518         nfnl_lock();
519         ip_set_put_byindex(index);
520         nfnl_unlock();
521 }
522 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
523
524 /*
525  * Communication protocol with userspace over netlink.
526  *
527  * The commands are serialized by the nfnl mutex.
528  */
529
530 static inline bool
531 protocol_failed(const struct nlattr * const tb[])
532 {
533         return !tb[IPSET_ATTR_PROTOCOL] ||
534                nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
535 }
536
537 static inline u32
538 flag_exist(const struct nlmsghdr *nlh)
539 {
540         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
541 }
542
543 static struct nlmsghdr *
544 start_msg(struct sk_buff *skb, u32 pid, u32 seq, unsigned int flags,
545           enum ipset_cmd cmd)
546 {
547         struct nlmsghdr *nlh;
548         struct nfgenmsg *nfmsg;
549
550         nlh = nlmsg_put(skb, pid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
551                         sizeof(*nfmsg), flags);
552         if (nlh == NULL)
553                 return NULL;
554
555         nfmsg = nlmsg_data(nlh);
556         nfmsg->nfgen_family = AF_INET;
557         nfmsg->version = NFNETLINK_V0;
558         nfmsg->res_id = 0;
559
560         return nlh;
561 }
562
563 /* Create a set */
564
565 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
566         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
567         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
568                                     .len = IPSET_MAXNAMELEN - 1 },
569         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
570                                     .len = IPSET_MAXNAMELEN - 1},
571         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
572         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
573         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
574 };
575
576 static ip_set_id_t
577 find_set_id(const char *name)
578 {
579         ip_set_id_t i, index = IPSET_INVALID_ID;
580         const struct ip_set *set;
581
582         for (i = 0; index == IPSET_INVALID_ID && i < ip_set_max; i++) {
583                 set = ip_set_list[i];
584                 if (set != NULL && STREQ(set->name, name))
585                         index = i;
586         }
587         return index;
588 }
589
590 static inline struct ip_set *
591 find_set(const char *name)
592 {
593         ip_set_id_t index = find_set_id(name);
594
595         return index == IPSET_INVALID_ID ? NULL : ip_set_list[index];
596 }
597
598 static int
599 find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
600 {
601         ip_set_id_t i;
602
603         *index = IPSET_INVALID_ID;
604         for (i = 0;  i < ip_set_max; i++) {
605                 if (ip_set_list[i] == NULL) {
606                         if (*index == IPSET_INVALID_ID)
607                                 *index = i;
608                 } else if (STREQ(name, ip_set_list[i]->name)) {
609                         /* Name clash */
610                         *set = ip_set_list[i];
611                         return -EEXIST;
612                 }
613         }
614         if (*index == IPSET_INVALID_ID)
615                 /* No free slot remained */
616                 return -IPSET_ERR_MAX_SETS;
617         return 0;
618 }
619
620 static int
621 ip_set_create(struct sock *ctnl, struct sk_buff *skb,
622               const struct nlmsghdr *nlh,
623               const struct nlattr * const attr[])
624 {
625         struct ip_set *set, *clash = NULL;
626         ip_set_id_t index = IPSET_INVALID_ID;
627         struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
628         const char *name, *typename;
629         u8 family, revision;
630         u32 flags = flag_exist(nlh);
631         int ret = 0;
632
633         if (unlikely(protocol_failed(attr) ||
634                      attr[IPSET_ATTR_SETNAME] == NULL ||
635                      attr[IPSET_ATTR_TYPENAME] == NULL ||
636                      attr[IPSET_ATTR_REVISION] == NULL ||
637                      attr[IPSET_ATTR_FAMILY] == NULL ||
638                      (attr[IPSET_ATTR_DATA] != NULL &&
639                       !flag_nested(attr[IPSET_ATTR_DATA]))))
640                 return -IPSET_ERR_PROTOCOL;
641
642         name = nla_data(attr[IPSET_ATTR_SETNAME]);
643         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
644         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
645         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
646         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
647                  name, typename, family_name(family), revision);
648
649         /*
650          * First, and without any locks, allocate and initialize
651          * a normal base set structure.
652          */
653         set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
654         if (!set)
655                 return -ENOMEM;
656         rwlock_init(&set->lock);
657         strlcpy(set->name, name, IPSET_MAXNAMELEN);
658         set->family = family;
659
660         /*
661          * Next, check that we know the type, and take
662          * a reference on the type, to make sure it stays available
663          * while constructing our new set.
664          *
665          * After referencing the type, we try to create the type
666          * specific part of the set without holding any locks.
667          */
668         ret = find_set_type_get(typename, family, revision, &(set->type));
669         if (ret)
670                 goto out;
671
672         /*
673          * Without holding any locks, create private part.
674          */
675         if (attr[IPSET_ATTR_DATA] &&
676             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
677                              set->type->create_policy)) {
678                 ret = -IPSET_ERR_PROTOCOL;
679                 goto put_out;
680         }
681
682         ret = set->type->create(set, tb, flags);
683         if (ret != 0)
684                 goto put_out;
685
686         /* BTW, ret==0 here. */
687
688         /*
689          * Here, we have a valid, constructed set and we are protected
690          * by the nfnl mutex. Find the first free index in ip_set_list
691          * and check clashing.
692          */
693         if ((ret = find_free_id(set->name, &index, &clash)) != 0) {
694                 /* If this is the same set and requested, ignore error */
695                 if (ret == -EEXIST &&
696                     (flags & IPSET_FLAG_EXIST) &&
697                     STREQ(set->type->name, clash->type->name) &&
698                     set->type->family == clash->type->family &&
699                     set->type->revision == clash->type->revision &&
700                     set->variant->same_set(set, clash))
701                         ret = 0;
702                 goto cleanup;
703         }
704
705         /*
706          * Finally! Add our shiny new set to the list, and be done.
707          */
708         pr_debug("create: '%s' created with index %u!\n", set->name, index);
709         ip_set_list[index] = set;
710
711         return ret;
712
713 cleanup:
714         set->variant->destroy(set);
715 put_out:
716         module_put(set->type->me);
717 out:
718         kfree(set);
719         return ret;
720 }
721
722 /* Destroy sets */
723
724 static const struct nla_policy
725 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
726         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
727         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
728                                     .len = IPSET_MAXNAMELEN - 1 },
729 };
730
731 static void
732 ip_set_destroy_set(ip_set_id_t index)
733 {
734         struct ip_set *set = ip_set_list[index];
735
736         pr_debug("set: %s\n",  set->name);
737         ip_set_list[index] = NULL;
738
739         /* Must call it without holding any lock */
740         set->variant->destroy(set);
741         module_put(set->type->me);
742         kfree(set);
743 }
744
745 static int
746 ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
747                const struct nlmsghdr *nlh,
748                const struct nlattr * const attr[])
749 {
750         ip_set_id_t i;
751         int ret = 0;
752
753         if (unlikely(protocol_failed(attr)))
754                 return -IPSET_ERR_PROTOCOL;
755
756         /* Commands are serialized and references are
757          * protected by the ip_set_ref_lock.
758          * External systems (i.e. xt_set) must call
759          * ip_set_put|get_nfnl_* functions, that way we
760          * can safely check references here.
761          *
762          * list:set timer can only decrement the reference
763          * counter, so if it's already zero, we can proceed
764          * without holding the lock.
765          */
766         read_lock_bh(&ip_set_ref_lock);
767         if (!attr[IPSET_ATTR_SETNAME]) {
768                 for (i = 0; i < ip_set_max; i++) {
769                         if (ip_set_list[i] != NULL && ip_set_list[i]->ref) {
770                                 ret = IPSET_ERR_BUSY;
771                                 goto out;
772                         }
773                 }
774                 read_unlock_bh(&ip_set_ref_lock);
775                 for (i = 0; i < ip_set_max; i++) {
776                         if (ip_set_list[i] != NULL)
777                                 ip_set_destroy_set(i);
778                 }
779         } else {
780                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
781                 if (i == IPSET_INVALID_ID) {
782                         ret = -ENOENT;
783                         goto out;
784                 } else if (ip_set_list[i]->ref) {
785                         ret = -IPSET_ERR_BUSY;
786                         goto out;
787                 }
788                 read_unlock_bh(&ip_set_ref_lock);
789
790                 ip_set_destroy_set(i);
791         }
792         return 0;
793 out:
794         read_unlock_bh(&ip_set_ref_lock);
795         return ret;
796 }
797
798 /* Flush sets */
799
800 static void
801 ip_set_flush_set(struct ip_set *set)
802 {
803         pr_debug("set: %s\n",  set->name);
804
805         write_lock_bh(&set->lock);
806         set->variant->flush(set);
807         write_unlock_bh(&set->lock);
808 }
809
810 static int
811 ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
812              const struct nlmsghdr *nlh,
813              const struct nlattr * const attr[])
814 {
815         ip_set_id_t i;
816
817         if (unlikely(protocol_failed(attr)))
818                 return -EPROTO;
819
820         if (!attr[IPSET_ATTR_SETNAME]) {
821                 for (i = 0; i < ip_set_max; i++)
822                         if (ip_set_list[i] != NULL)
823                                 ip_set_flush_set(ip_set_list[i]);
824         } else {
825                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
826                 if (i == IPSET_INVALID_ID)
827                         return -ENOENT;
828
829                 ip_set_flush_set(ip_set_list[i]);
830         }
831
832         return 0;
833 }
834
835 /* Rename a set */
836
837 static const struct nla_policy
838 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
839         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
840         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
841                                     .len = IPSET_MAXNAMELEN - 1 },
842         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
843                                     .len = IPSET_MAXNAMELEN - 1 },
844 };
845
846 static int
847 ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
848               const struct nlmsghdr *nlh,
849               const struct nlattr * const attr[])
850 {
851         struct ip_set *set;
852         const char *name2;
853         ip_set_id_t i;
854         int ret = 0;
855
856         if (unlikely(protocol_failed(attr) ||
857                      attr[IPSET_ATTR_SETNAME] == NULL ||
858                      attr[IPSET_ATTR_SETNAME2] == NULL))
859                 return -IPSET_ERR_PROTOCOL;
860
861         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
862         if (set == NULL)
863                 return -ENOENT;
864
865         read_lock_bh(&ip_set_ref_lock);
866         if (set->ref != 0) {
867                 ret = -IPSET_ERR_REFERENCED;
868                 goto out;
869         }
870
871         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
872         for (i = 0; i < ip_set_max; i++) {
873                 if (ip_set_list[i] != NULL &&
874                     STREQ(ip_set_list[i]->name, name2)) {
875                         ret = -IPSET_ERR_EXIST_SETNAME2;
876                         goto out;
877                 }
878         }
879         strncpy(set->name, name2, IPSET_MAXNAMELEN);
880
881 out:
882         read_unlock_bh(&ip_set_ref_lock);
883         return ret;
884 }
885
886 /* Swap two sets so that name/index points to the other.
887  * References and set names are also swapped.
888  *
889  * The commands are serialized by the nfnl mutex and references are
890  * protected by the ip_set_ref_lock. The kernel interfaces
891  * do not hold the mutex but the pointer settings are atomic
892  * so the ip_set_list always contains valid pointers to the sets.
893  */
894
895 static int
896 ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
897             const struct nlmsghdr *nlh,
898             const struct nlattr * const attr[])
899 {
900         struct ip_set *from, *to;
901         ip_set_id_t from_id, to_id;
902         char from_name[IPSET_MAXNAMELEN];
903
904         if (unlikely(protocol_failed(attr) ||
905                      attr[IPSET_ATTR_SETNAME] == NULL ||
906                      attr[IPSET_ATTR_SETNAME2] == NULL))
907                 return -IPSET_ERR_PROTOCOL;
908
909         from_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
910         if (from_id == IPSET_INVALID_ID)
911                 return -ENOENT;
912
913         to_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME2]));
914         if (to_id == IPSET_INVALID_ID)
915                 return -IPSET_ERR_EXIST_SETNAME2;
916
917         from = ip_set_list[from_id];
918         to = ip_set_list[to_id];
919
920         /* Features must not change.
921          * Not an artifical restriction anymore, as we must prevent
922          * possible loops created by swapping in setlist type of sets. */
923         if (!(from->type->features == to->type->features &&
924               from->type->family == to->type->family))
925                 return -IPSET_ERR_TYPE_MISMATCH;
926
927         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
928         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
929         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
930
931         write_lock_bh(&ip_set_ref_lock);
932         swap(from->ref, to->ref);
933         ip_set_list[from_id] = to;
934         ip_set_list[to_id] = from;
935         write_unlock_bh(&ip_set_ref_lock);
936
937         return 0;
938 }
939
940 /* List/save set data */
941
942 #define DUMP_INIT       0L
943 #define DUMP_ALL        1L
944 #define DUMP_ONE        2L
945 #define DUMP_LAST       3L
946
947 static int
948 ip_set_dump_done(struct netlink_callback *cb)
949 {
950         if (cb->args[2]) {
951                 pr_debug("release set %s\n", ip_set_list[cb->args[1]]->name);
952                 ip_set_put_byindex((ip_set_id_t) cb->args[1]);
953         }
954         return 0;
955 }
956
957 static inline void
958 dump_attrs(struct nlmsghdr *nlh)
959 {
960         const struct nlattr *attr;
961         int rem;
962
963         pr_debug("dump nlmsg\n");
964         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
965                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
966         }
967 }
968
969 static int
970 dump_init(struct netlink_callback *cb)
971 {
972         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
973         int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
974         struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
975         struct nlattr *attr = (void *)nlh + min_len;
976         ip_set_id_t index;
977
978         /* Second pass, so parser can't fail */
979         nla_parse(cda, IPSET_ATTR_CMD_MAX,
980                   attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
981
982         /* cb->args[0] : dump single set/all sets
983          *         [1] : set index
984          *         [..]: type specific
985          */
986
987         if (!cda[IPSET_ATTR_SETNAME]) {
988                 cb->args[0] = DUMP_ALL;
989                 return 0;
990         }
991
992         index = find_set_id(nla_data(cda[IPSET_ATTR_SETNAME]));
993         if (index == IPSET_INVALID_ID)
994                 return -ENOENT;
995
996         cb->args[0] = DUMP_ONE;
997         cb->args[1] = index;
998         return 0;
999 }
1000
1001 static int
1002 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1003 {
1004         ip_set_id_t index = IPSET_INVALID_ID, max;
1005         struct ip_set *set = NULL;
1006         struct nlmsghdr *nlh = NULL;
1007         unsigned int flags = NETLINK_CB(cb->skb).pid ? NLM_F_MULTI : 0;
1008         int ret = 0;
1009
1010         if (cb->args[0] == DUMP_INIT) {
1011                 ret = dump_init(cb);
1012                 if (ret < 0) {
1013                         nlh = nlmsg_hdr(cb->skb);
1014                         /* We have to create and send the error message
1015                          * manually :-( */
1016                         if (nlh->nlmsg_flags & NLM_F_ACK)
1017                                 netlink_ack(cb->skb, nlh, ret);
1018                         return ret;
1019                 }
1020         }
1021
1022         if (cb->args[1] >= ip_set_max)
1023                 goto out;
1024
1025         pr_debug("args[0]: %ld args[1]: %ld\n", cb->args[0], cb->args[1]);
1026         max = cb->args[0] == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1027         for (; cb->args[1] < max; cb->args[1]++) {
1028                 index = (ip_set_id_t) cb->args[1];
1029                 set = ip_set_list[index];
1030                 if (set == NULL) {
1031                         if (cb->args[0] == DUMP_ONE) {
1032                                 ret = -ENOENT;
1033                                 goto out;
1034                         }
1035                         continue;
1036                 }
1037                 /* When dumping all sets, we must dump "sorted"
1038                  * so that lists (unions of sets) are dumped last.
1039                  */
1040                 if (cb->args[0] != DUMP_ONE &&
1041                     !((cb->args[0] == DUMP_ALL) ^
1042                       (set->type->features & IPSET_DUMP_LAST)))
1043                         continue;
1044                 pr_debug("List set: %s\n", set->name);
1045                 if (!cb->args[2]) {
1046                         /* Start listing: make sure set won't be destroyed */
1047                         pr_debug("reference set\n");
1048                         __ip_set_get(index);
1049                 }
1050                 nlh = start_msg(skb, NETLINK_CB(cb->skb).pid,
1051                                 cb->nlh->nlmsg_seq, flags,
1052                                 IPSET_CMD_LIST);
1053                 if (!nlh) {
1054                         ret = -EMSGSIZE;
1055                         goto release_refcount;
1056                 }
1057                 NLA_PUT_U8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1058                 NLA_PUT_STRING(skb, IPSET_ATTR_SETNAME, set->name);
1059                 switch (cb->args[2]) {
1060                 case 0:
1061                         /* Core header data */
1062                         NLA_PUT_STRING(skb, IPSET_ATTR_TYPENAME,
1063                                        set->type->name);
1064                         NLA_PUT_U8(skb, IPSET_ATTR_FAMILY,
1065                                    set->family);
1066                         NLA_PUT_U8(skb, IPSET_ATTR_REVISION,
1067                                    set->type->revision);
1068                         ret = set->variant->head(set, skb);
1069                         if (ret < 0)
1070                                 goto release_refcount;
1071                         /* Fall through and add elements */
1072                 default:
1073                         read_lock_bh(&set->lock);
1074                         ret = set->variant->list(set, skb, cb);
1075                         read_unlock_bh(&set->lock);
1076                         if (!cb->args[2]) {
1077                                 /* Set is done, proceed with next one */
1078                                 if (cb->args[0] == DUMP_ONE)
1079                                         cb->args[1] = IPSET_INVALID_ID;
1080                                 else
1081                                         cb->args[1]++;
1082                         }
1083                         goto release_refcount;
1084                 }
1085         }
1086         goto out;
1087
1088 nla_put_failure:
1089         ret = -EFAULT;
1090 release_refcount:
1091         /* If there was an error or set is done, release set */
1092         if (ret || !cb->args[2]) {
1093                 pr_debug("release set %s\n", ip_set_list[index]->name);
1094                 ip_set_put_byindex(index);
1095         }
1096
1097         /* If we dump all sets, continue with dumping last ones */
1098         if (cb->args[0] == DUMP_ALL && cb->args[1] >= max && !cb->args[2])
1099                 cb->args[0] = DUMP_LAST;
1100
1101 out:
1102         if (nlh) {
1103                 nlmsg_end(skb, nlh);
1104                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1105                 dump_attrs(nlh);
1106         }
1107
1108         return ret < 0 ? ret : skb->len;
1109 }
1110
1111 static int
1112 ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1113             const struct nlmsghdr *nlh,
1114             const struct nlattr * const attr[])
1115 {
1116         if (unlikely(protocol_failed(attr)))
1117                 return -IPSET_ERR_PROTOCOL;
1118
1119         return netlink_dump_start(ctnl, skb, nlh,
1120                                   ip_set_dump_start,
1121                                   ip_set_dump_done);
1122 }
1123
1124 /* Add, del and test */
1125
1126 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1127         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1128         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1129                                     .len = IPSET_MAXNAMELEN - 1 },
1130         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1131         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1132         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1133 };
1134
1135 static int
1136 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1137         struct nlattr *tb[], enum ipset_adt adt,
1138         u32 flags, bool use_lineno)
1139 {
1140         int ret, retried = 0;
1141         u32 lineno = 0;
1142         bool eexist = flags & IPSET_FLAG_EXIST;
1143
1144         do {
1145                 write_lock_bh(&set->lock);
1146                 ret = set->variant->uadt(set, tb, adt, &lineno, flags);
1147                 write_unlock_bh(&set->lock);
1148         } while (ret == -EAGAIN &&
1149                  set->variant->resize &&
1150                  (ret = set->variant->resize(set, retried++)) == 0);
1151
1152         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1153                 return 0;
1154         if (lineno && use_lineno) {
1155                 /* Error in restore/batch mode: send back lineno */
1156                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1157                 struct sk_buff *skb2;
1158                 struct nlmsgerr *errmsg;
1159                 size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1160                 int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1161                 struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1162                 struct nlattr *cmdattr;
1163                 u32 *errline;
1164
1165                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1166                 if (skb2 == NULL)
1167                         return -ENOMEM;
1168                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).pid,
1169                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1170                 errmsg = nlmsg_data(rep);
1171                 errmsg->error = ret;
1172                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1173                 cmdattr = (void *)&errmsg->msg + min_len;
1174
1175                 nla_parse(cda, IPSET_ATTR_CMD_MAX,
1176                           cmdattr, nlh->nlmsg_len - min_len,
1177                           ip_set_adt_policy);
1178
1179                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1180
1181                 *errline = lineno;
1182
1183                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1184                 /* Signal netlink not to send its ACK/errmsg.  */
1185                 return -EINTR;
1186         }
1187
1188         return ret;
1189 }
1190
1191 static int
1192 ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1193             const struct nlmsghdr *nlh,
1194             const struct nlattr * const attr[])
1195 {
1196         struct ip_set *set;
1197         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1198         const struct nlattr *nla;
1199         u32 flags = flag_exist(nlh);
1200         bool use_lineno;
1201         int ret = 0;
1202
1203         if (unlikely(protocol_failed(attr) ||
1204                      attr[IPSET_ATTR_SETNAME] == NULL ||
1205                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1206                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1207                      (attr[IPSET_ATTR_DATA] != NULL &&
1208                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1209                      (attr[IPSET_ATTR_ADT] != NULL &&
1210                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1211                        attr[IPSET_ATTR_LINENO] == NULL))))
1212                 return -IPSET_ERR_PROTOCOL;
1213
1214         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1215         if (set == NULL)
1216                 return -ENOENT;
1217
1218         use_lineno = !!attr[IPSET_ATTR_LINENO];
1219         if (attr[IPSET_ATTR_DATA]) {
1220                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1221                                      attr[IPSET_ATTR_DATA],
1222                                      set->type->adt_policy))
1223                         return -IPSET_ERR_PROTOCOL;
1224                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1225                               use_lineno);
1226         } else {
1227                 int nla_rem;
1228
1229                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1230                         memset(tb, 0, sizeof(tb));
1231                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1232                             !flag_nested(nla) ||
1233                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1234                                              set->type->adt_policy))
1235                                 return -IPSET_ERR_PROTOCOL;
1236                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1237                                       flags, use_lineno);
1238                         if (ret < 0)
1239                                 return ret;
1240                 }
1241         }
1242         return ret;
1243 }
1244
1245 static int
1246 ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1247             const struct nlmsghdr *nlh,
1248             const struct nlattr * const attr[])
1249 {
1250         struct ip_set *set;
1251         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1252         const struct nlattr *nla;
1253         u32 flags = flag_exist(nlh);
1254         bool use_lineno;
1255         int ret = 0;
1256
1257         if (unlikely(protocol_failed(attr) ||
1258                      attr[IPSET_ATTR_SETNAME] == NULL ||
1259                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1260                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1261                      (attr[IPSET_ATTR_DATA] != NULL &&
1262                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1263                      (attr[IPSET_ATTR_ADT] != NULL &&
1264                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1265                        attr[IPSET_ATTR_LINENO] == NULL))))
1266                 return -IPSET_ERR_PROTOCOL;
1267
1268         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1269         if (set == NULL)
1270                 return -ENOENT;
1271
1272         use_lineno = !!attr[IPSET_ATTR_LINENO];
1273         if (attr[IPSET_ATTR_DATA]) {
1274                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1275                                      attr[IPSET_ATTR_DATA],
1276                                      set->type->adt_policy))
1277                         return -IPSET_ERR_PROTOCOL;
1278                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1279                               use_lineno);
1280         } else {
1281                 int nla_rem;
1282
1283                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1284                         memset(tb, 0, sizeof(*tb));
1285                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1286                             !flag_nested(nla) ||
1287                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1288                                              set->type->adt_policy))
1289                                 return -IPSET_ERR_PROTOCOL;
1290                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1291                                       flags, use_lineno);
1292                         if (ret < 0)
1293                                 return ret;
1294                 }
1295         }
1296         return ret;
1297 }
1298
1299 static int
1300 ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1301              const struct nlmsghdr *nlh,
1302              const struct nlattr * const attr[])
1303 {
1304         struct ip_set *set;
1305         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1306         int ret = 0;
1307
1308         if (unlikely(protocol_failed(attr) ||
1309                      attr[IPSET_ATTR_SETNAME] == NULL ||
1310                      attr[IPSET_ATTR_DATA] == NULL ||
1311                      !flag_nested(attr[IPSET_ATTR_DATA])))
1312                 return -IPSET_ERR_PROTOCOL;
1313
1314         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1315         if (set == NULL)
1316                 return -ENOENT;
1317
1318         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1319                              set->type->adt_policy))
1320                 return -IPSET_ERR_PROTOCOL;
1321
1322         read_lock_bh(&set->lock);
1323         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0);
1324         read_unlock_bh(&set->lock);
1325         /* Userspace can't trigger element to be re-added */
1326         if (ret == -EAGAIN)
1327                 ret = 1;
1328
1329         return ret < 0 ? ret : ret > 0 ? 0 : -IPSET_ERR_EXIST;
1330 }
1331
1332 /* Get headed data of a set */
1333
1334 static int
1335 ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1336               const struct nlmsghdr *nlh,
1337               const struct nlattr * const attr[])
1338 {
1339         const struct ip_set *set;
1340         struct sk_buff *skb2;
1341         struct nlmsghdr *nlh2;
1342         ip_set_id_t index;
1343         int ret = 0;
1344
1345         if (unlikely(protocol_failed(attr) ||
1346                      attr[IPSET_ATTR_SETNAME] == NULL))
1347                 return -IPSET_ERR_PROTOCOL;
1348
1349         index = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
1350         if (index == IPSET_INVALID_ID)
1351                 return -ENOENT;
1352         set = ip_set_list[index];
1353
1354         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1355         if (skb2 == NULL)
1356                 return -ENOMEM;
1357
1358         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1359                          IPSET_CMD_HEADER);
1360         if (!nlh2)
1361                 goto nlmsg_failure;
1362         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1363         NLA_PUT_STRING(skb2, IPSET_ATTR_SETNAME, set->name);
1364         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, set->type->name);
1365         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, set->family);
1366         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, set->type->revision);
1367         nlmsg_end(skb2, nlh2);
1368
1369         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1370         if (ret < 0)
1371                 return ret;
1372
1373         return 0;
1374
1375 nla_put_failure:
1376         nlmsg_cancel(skb2, nlh2);
1377 nlmsg_failure:
1378         kfree_skb(skb2);
1379         return -EMSGSIZE;
1380 }
1381
1382 /* Get type data */
1383
1384 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1385         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1386         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1387                                     .len = IPSET_MAXNAMELEN - 1 },
1388         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1389 };
1390
1391 static int
1392 ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1393             const struct nlmsghdr *nlh,
1394             const struct nlattr * const attr[])
1395 {
1396         struct sk_buff *skb2;
1397         struct nlmsghdr *nlh2;
1398         u8 family, min, max;
1399         const char *typename;
1400         int ret = 0;
1401
1402         if (unlikely(protocol_failed(attr) ||
1403                      attr[IPSET_ATTR_TYPENAME] == NULL ||
1404                      attr[IPSET_ATTR_FAMILY] == NULL))
1405                 return -IPSET_ERR_PROTOCOL;
1406
1407         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1408         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1409         ret = find_set_type_minmax(typename, family, &min, &max);
1410         if (ret)
1411                 return ret;
1412
1413         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1414         if (skb2 == NULL)
1415                 return -ENOMEM;
1416
1417         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1418                          IPSET_CMD_TYPE);
1419         if (!nlh2)
1420                 goto nlmsg_failure;
1421         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1422         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, typename);
1423         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, family);
1424         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, max);
1425         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION_MIN, min);
1426         nlmsg_end(skb2, nlh2);
1427
1428         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1429         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1430         if (ret < 0)
1431                 return ret;
1432
1433         return 0;
1434
1435 nla_put_failure:
1436         nlmsg_cancel(skb2, nlh2);
1437 nlmsg_failure:
1438         kfree_skb(skb2);
1439         return -EMSGSIZE;
1440 }
1441
1442 /* Get protocol version */
1443
1444 static const struct nla_policy
1445 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1446         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1447 };
1448
1449 static int
1450 ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1451                 const struct nlmsghdr *nlh,
1452                 const struct nlattr * const attr[])
1453 {
1454         struct sk_buff *skb2;
1455         struct nlmsghdr *nlh2;
1456         int ret = 0;
1457
1458         if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1459                 return -IPSET_ERR_PROTOCOL;
1460
1461         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1462         if (skb2 == NULL)
1463                 return -ENOMEM;
1464
1465         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1466                          IPSET_CMD_PROTOCOL);
1467         if (!nlh2)
1468                 goto nlmsg_failure;
1469         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1470         nlmsg_end(skb2, nlh2);
1471
1472         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1473         if (ret < 0)
1474                 return ret;
1475
1476         return 0;
1477
1478 nla_put_failure:
1479         nlmsg_cancel(skb2, nlh2);
1480 nlmsg_failure:
1481         kfree_skb(skb2);
1482         return -EMSGSIZE;
1483 }
1484
1485 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1486         [IPSET_CMD_CREATE]      = {
1487                 .call           = ip_set_create,
1488                 .attr_count     = IPSET_ATTR_CMD_MAX,
1489                 .policy         = ip_set_create_policy,
1490         },
1491         [IPSET_CMD_DESTROY]     = {
1492                 .call           = ip_set_destroy,
1493                 .attr_count     = IPSET_ATTR_CMD_MAX,
1494                 .policy         = ip_set_setname_policy,
1495         },
1496         [IPSET_CMD_FLUSH]       = {
1497                 .call           = ip_set_flush,
1498                 .attr_count     = IPSET_ATTR_CMD_MAX,
1499                 .policy         = ip_set_setname_policy,
1500         },
1501         [IPSET_CMD_RENAME]      = {
1502                 .call           = ip_set_rename,
1503                 .attr_count     = IPSET_ATTR_CMD_MAX,
1504                 .policy         = ip_set_setname2_policy,
1505         },
1506         [IPSET_CMD_SWAP]        = {
1507                 .call           = ip_set_swap,
1508                 .attr_count     = IPSET_ATTR_CMD_MAX,
1509                 .policy         = ip_set_setname2_policy,
1510         },
1511         [IPSET_CMD_LIST]        = {
1512                 .call           = ip_set_dump,
1513                 .attr_count     = IPSET_ATTR_CMD_MAX,
1514                 .policy         = ip_set_setname_policy,
1515         },
1516         [IPSET_CMD_SAVE]        = {
1517                 .call           = ip_set_dump,
1518                 .attr_count     = IPSET_ATTR_CMD_MAX,
1519                 .policy         = ip_set_setname_policy,
1520         },
1521         [IPSET_CMD_ADD] = {
1522                 .call           = ip_set_uadd,
1523                 .attr_count     = IPSET_ATTR_CMD_MAX,
1524                 .policy         = ip_set_adt_policy,
1525         },
1526         [IPSET_CMD_DEL] = {
1527                 .call           = ip_set_udel,
1528                 .attr_count     = IPSET_ATTR_CMD_MAX,
1529                 .policy         = ip_set_adt_policy,
1530         },
1531         [IPSET_CMD_TEST]        = {
1532                 .call           = ip_set_utest,
1533                 .attr_count     = IPSET_ATTR_CMD_MAX,
1534                 .policy         = ip_set_adt_policy,
1535         },
1536         [IPSET_CMD_HEADER]      = {
1537                 .call           = ip_set_header,
1538                 .attr_count     = IPSET_ATTR_CMD_MAX,
1539                 .policy         = ip_set_setname_policy,
1540         },
1541         [IPSET_CMD_TYPE]        = {
1542                 .call           = ip_set_type,
1543                 .attr_count     = IPSET_ATTR_CMD_MAX,
1544                 .policy         = ip_set_type_policy,
1545         },
1546         [IPSET_CMD_PROTOCOL]    = {
1547                 .call           = ip_set_protocol,
1548                 .attr_count     = IPSET_ATTR_CMD_MAX,
1549                 .policy         = ip_set_protocol_policy,
1550         },
1551 };
1552
1553 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1554         .name           = "ip_set",
1555         .subsys_id      = NFNL_SUBSYS_IPSET,
1556         .cb_count       = IPSET_MSG_MAX,
1557         .cb             = ip_set_netlink_subsys_cb,
1558 };
1559
1560 /* Interface to iptables/ip6tables */
1561
1562 static int
1563 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1564 {
1565         unsigned *op;
1566         void *data;
1567         int copylen = *len, ret = 0;
1568
1569         if (!capable(CAP_NET_ADMIN))
1570                 return -EPERM;
1571         if (optval != SO_IP_SET)
1572                 return -EBADF;
1573         if (*len < sizeof(unsigned))
1574                 return -EINVAL;
1575
1576         data = vmalloc(*len);
1577         if (!data)
1578                 return -ENOMEM;
1579         if (copy_from_user(data, user, *len) != 0) {
1580                 ret = -EFAULT;
1581                 goto done;
1582         }
1583         op = (unsigned *) data;
1584
1585         if (*op < IP_SET_OP_VERSION) {
1586                 /* Check the version at the beginning of operations */
1587                 struct ip_set_req_version *req_version = data;
1588                 if (req_version->version != IPSET_PROTOCOL) {
1589                         ret = -EPROTO;
1590                         goto done;
1591                 }
1592         }
1593
1594         switch (*op) {
1595         case IP_SET_OP_VERSION: {
1596                 struct ip_set_req_version *req_version = data;
1597
1598                 if (*len != sizeof(struct ip_set_req_version)) {
1599                         ret = -EINVAL;
1600                         goto done;
1601                 }
1602
1603                 req_version->version = IPSET_PROTOCOL;
1604                 ret = copy_to_user(user, req_version,
1605                                    sizeof(struct ip_set_req_version));
1606                 goto done;
1607         }
1608         case IP_SET_OP_GET_BYNAME: {
1609                 struct ip_set_req_get_set *req_get = data;
1610
1611                 if (*len != sizeof(struct ip_set_req_get_set)) {
1612                         ret = -EINVAL;
1613                         goto done;
1614                 }
1615                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1616                 nfnl_lock();
1617                 req_get->set.index = find_set_id(req_get->set.name);
1618                 nfnl_unlock();
1619                 goto copy;
1620         }
1621         case IP_SET_OP_GET_BYINDEX: {
1622                 struct ip_set_req_get_set *req_get = data;
1623
1624                 if (*len != sizeof(struct ip_set_req_get_set) ||
1625                     req_get->set.index >= ip_set_max) {
1626                         ret = -EINVAL;
1627                         goto done;
1628                 }
1629                 nfnl_lock();
1630                 strncpy(req_get->set.name,
1631                         ip_set_list[req_get->set.index]
1632                                 ? ip_set_list[req_get->set.index]->name : "",
1633                         IPSET_MAXNAMELEN);
1634                 nfnl_unlock();
1635                 goto copy;
1636         }
1637         default:
1638                 ret = -EBADMSG;
1639                 goto done;
1640         }       /* end of switch(op) */
1641
1642 copy:
1643         ret = copy_to_user(user, data, copylen);
1644
1645 done:
1646         vfree(data);
1647         if (ret > 0)
1648                 ret = 0;
1649         return ret;
1650 }
1651
1652 static struct nf_sockopt_ops so_set __read_mostly = {
1653         .pf             = PF_INET,
1654         .get_optmin     = SO_IP_SET,
1655         .get_optmax     = SO_IP_SET + 1,
1656         .get            = &ip_set_sockfn_get,
1657         .owner          = THIS_MODULE,
1658 };
1659
1660 static int __init
1661 ip_set_init(void)
1662 {
1663         int ret;
1664
1665         if (max_sets)
1666                 ip_set_max = max_sets;
1667         if (ip_set_max >= IPSET_INVALID_ID)
1668                 ip_set_max = IPSET_INVALID_ID - 1;
1669
1670         ip_set_list = kzalloc(sizeof(struct ip_set *) * ip_set_max,
1671                               GFP_KERNEL);
1672         if (!ip_set_list) {
1673                 pr_err("ip_set: Unable to create ip_set_list\n");
1674                 return -ENOMEM;
1675         }
1676
1677         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1678         if (ret != 0) {
1679                 pr_err("ip_set: cannot register with nfnetlink.\n");
1680                 kfree(ip_set_list);
1681                 return ret;
1682         }
1683         ret = nf_register_sockopt(&so_set);
1684         if (ret != 0) {
1685                 pr_err("SO_SET registry failed: %d\n", ret);
1686                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1687                 kfree(ip_set_list);
1688                 return ret;
1689         }
1690
1691         pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1692         return 0;
1693 }
1694
1695 static void __exit
1696 ip_set_fini(void)
1697 {
1698         /* There can't be any existing set */
1699         nf_unregister_sockopt(&so_set);
1700         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1701         kfree(ip_set_list);
1702         pr_debug("these are the famous last words\n");
1703 }
1704
1705 module_init(ip_set_init);
1706 module_exit(ip_set_fini);