Merge branch 'staging-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/gregkh...
[pandora-kernel.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/netlink.h>
19 #include <linux/rculist.h>
20 #include <linux/version.h>
21 #include <net/netlink.h>
22
23 #include <linux/netfilter.h>
24 #include <linux/netfilter/nfnetlink.h>
25 #include <linux/netfilter/ipset/ip_set.h>
26
27 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
28 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
29 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
30
31 static struct ip_set **ip_set_list;             /* all individual sets */
32 static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
33
34 #define STREQ(a, b)     (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
35
36 static unsigned int max_sets;
37
38 module_param(max_sets, int, 0600);
39 MODULE_PARM_DESC(max_sets, "maximal number of sets");
40 MODULE_LICENSE("GPL");
41 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
42 MODULE_DESCRIPTION("core IP set support");
43 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
44
45 /*
46  * The set types are implemented in modules and registered set types
47  * can be found in ip_set_type_list. Adding/deleting types is
48  * serialized by ip_set_type_mutex.
49  */
50
51 static inline void
52 ip_set_type_lock(void)
53 {
54         mutex_lock(&ip_set_type_mutex);
55 }
56
57 static inline void
58 ip_set_type_unlock(void)
59 {
60         mutex_unlock(&ip_set_type_mutex);
61 }
62
63 /* Register and deregister settype */
64
65 static struct ip_set_type *
66 find_set_type(const char *name, u8 family, u8 revision)
67 {
68         struct ip_set_type *type;
69
70         list_for_each_entry_rcu(type, &ip_set_type_list, list)
71                 if (STREQ(type->name, name) &&
72                     (type->family == family || type->family == AF_UNSPEC) &&
73                     type->revision == revision)
74                         return type;
75         return NULL;
76 }
77
78 /* Unlock, try to load a set type module and lock again */
79 static int
80 try_to_load_type(const char *name)
81 {
82         nfnl_unlock();
83         pr_debug("try to load ip_set_%s\n", name);
84         if (request_module("ip_set_%s", name) < 0) {
85                 pr_warning("Can't find ip_set type %s\n", name);
86                 nfnl_lock();
87                 return -IPSET_ERR_FIND_TYPE;
88         }
89         nfnl_lock();
90         return -EAGAIN;
91 }
92
93 /* Find a set type and reference it */
94 static int
95 find_set_type_get(const char *name, u8 family, u8 revision,
96                   struct ip_set_type **found)
97 {
98         struct ip_set_type *type;
99         int err;
100
101         rcu_read_lock();
102         *found = find_set_type(name, family, revision);
103         if (*found) {
104                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
105                 goto unlock;
106         }
107         /* Make sure the type is loaded but we don't support the revision */
108         list_for_each_entry_rcu(type, &ip_set_type_list, list)
109                 if (STREQ(type->name, name)) {
110                         err = -IPSET_ERR_FIND_TYPE;
111                         goto unlock;
112                 }
113         rcu_read_unlock();
114
115         return try_to_load_type(name);
116
117 unlock:
118         rcu_read_unlock();
119         return err;
120 }
121
122 /* Find a given set type by name and family.
123  * If we succeeded, the supported minimal and maximum revisions are
124  * filled out.
125  */
126 static int
127 find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max)
128 {
129         struct ip_set_type *type;
130         bool found = false;
131
132         *min = 255; *max = 0;
133         rcu_read_lock();
134         list_for_each_entry_rcu(type, &ip_set_type_list, list)
135                 if (STREQ(type->name, name) &&
136                     (type->family == family || type->family == AF_UNSPEC)) {
137                         found = true;
138                         if (type->revision < *min)
139                                 *min = type->revision;
140                         if (type->revision > *max)
141                                 *max = type->revision;
142                 }
143         rcu_read_unlock();
144         if (found)
145                 return 0;
146
147         return try_to_load_type(name);
148 }
149
150 #define family_name(f)  ((f) == AF_INET ? "inet" : \
151                          (f) == AF_INET6 ? "inet6" : "any")
152
153 /* Register a set type structure. The type is identified by
154  * the unique triple of name, family and revision.
155  */
156 int
157 ip_set_type_register(struct ip_set_type *type)
158 {
159         int ret = 0;
160
161         if (type->protocol != IPSET_PROTOCOL) {
162                 pr_warning("ip_set type %s, family %s, revision %u uses "
163                            "wrong protocol version %u (want %u)\n",
164                            type->name, family_name(type->family),
165                            type->revision, type->protocol, IPSET_PROTOCOL);
166                 return -EINVAL;
167         }
168
169         ip_set_type_lock();
170         if (find_set_type(type->name, type->family, type->revision)) {
171                 /* Duplicate! */
172                 pr_warning("ip_set type %s, family %s, revision %u "
173                            "already registered!\n", type->name,
174                            family_name(type->family), type->revision);
175                 ret = -EINVAL;
176                 goto unlock;
177         }
178         list_add_rcu(&type->list, &ip_set_type_list);
179         pr_debug("type %s, family %s, revision %u registered.\n",
180                  type->name, family_name(type->family), type->revision);
181 unlock:
182         ip_set_type_unlock();
183         return ret;
184 }
185 EXPORT_SYMBOL_GPL(ip_set_type_register);
186
187 /* Unregister a set type. There's a small race with ip_set_create */
188 void
189 ip_set_type_unregister(struct ip_set_type *type)
190 {
191         ip_set_type_lock();
192         if (!find_set_type(type->name, type->family, type->revision)) {
193                 pr_warning("ip_set type %s, family %s, revision %u "
194                            "not registered\n", type->name,
195                            family_name(type->family), type->revision);
196                 goto unlock;
197         }
198         list_del_rcu(&type->list);
199         pr_debug("type %s, family %s, revision %u unregistered.\n",
200                  type->name, family_name(type->family), type->revision);
201 unlock:
202         ip_set_type_unlock();
203
204         synchronize_rcu();
205 }
206 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
207
208 /* Utility functions */
209 void *
210 ip_set_alloc(size_t size)
211 {
212         void *members = NULL;
213
214         if (size < KMALLOC_MAX_SIZE)
215                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
216
217         if (members) {
218                 pr_debug("%p: allocated with kmalloc\n", members);
219                 return members;
220         }
221
222         members = vzalloc(size);
223         if (!members)
224                 return NULL;
225         pr_debug("%p: allocated with vmalloc\n", members);
226
227         return members;
228 }
229 EXPORT_SYMBOL_GPL(ip_set_alloc);
230
231 void
232 ip_set_free(void *members)
233 {
234         pr_debug("%p: free with %s\n", members,
235                  is_vmalloc_addr(members) ? "vfree" : "kfree");
236         if (is_vmalloc_addr(members))
237                 vfree(members);
238         else
239                 kfree(members);
240 }
241 EXPORT_SYMBOL_GPL(ip_set_free);
242
243 static inline bool
244 flag_nested(const struct nlattr *nla)
245 {
246         return nla->nla_type & NLA_F_NESTED;
247 }
248
249 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
250         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
251         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
252                                             .len = sizeof(struct in6_addr) },
253 };
254
255 int
256 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
257 {
258         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
259
260         if (unlikely(!flag_nested(nla)))
261                 return -IPSET_ERR_PROTOCOL;
262         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
263                 return -IPSET_ERR_PROTOCOL;
264         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
265                 return -IPSET_ERR_PROTOCOL;
266
267         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
268         return 0;
269 }
270 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
271
272 int
273 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
274 {
275         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
276
277         if (unlikely(!flag_nested(nla)))
278                 return -IPSET_ERR_PROTOCOL;
279
280         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
281                 return -IPSET_ERR_PROTOCOL;
282         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
283                 return -IPSET_ERR_PROTOCOL;
284
285         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
286                 sizeof(struct in6_addr));
287         return 0;
288 }
289 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
290
291 /*
292  * Creating/destroying/renaming/swapping affect the existence and
293  * the properties of a set. All of these can be executed from userspace
294  * only and serialized by the nfnl mutex indirectly from nfnetlink.
295  *
296  * Sets are identified by their index in ip_set_list and the index
297  * is used by the external references (set/SET netfilter modules).
298  *
299  * The set behind an index may change by swapping only, from userspace.
300  */
301
302 static inline void
303 __ip_set_get(ip_set_id_t index)
304 {
305         write_lock_bh(&ip_set_ref_lock);
306         ip_set_list[index]->ref++;
307         write_unlock_bh(&ip_set_ref_lock);
308 }
309
310 static inline void
311 __ip_set_put(ip_set_id_t index)
312 {
313         write_lock_bh(&ip_set_ref_lock);
314         BUG_ON(ip_set_list[index]->ref == 0);
315         ip_set_list[index]->ref--;
316         write_unlock_bh(&ip_set_ref_lock);
317 }
318
319 /*
320  * Add, del and test set entries from kernel.
321  *
322  * The set behind the index must exist and must be referenced
323  * so it can't be destroyed (or changed) under our foot.
324  */
325
326 int
327 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
328             u8 family, u8 dim, u8 flags)
329 {
330         struct ip_set *set = ip_set_list[index];
331         int ret = 0;
332
333         BUG_ON(set == NULL);
334         pr_debug("set %s, index %u\n", set->name, index);
335
336         if (dim < set->type->dimension ||
337             !(family == set->family || set->family == AF_UNSPEC))
338                 return 0;
339
340         read_lock_bh(&set->lock);
341         ret = set->variant->kadt(set, skb, IPSET_TEST, family, dim, flags);
342         read_unlock_bh(&set->lock);
343
344         if (ret == -EAGAIN) {
345                 /* Type requests element to be completed */
346                 pr_debug("element must be competed, ADD is triggered\n");
347                 write_lock_bh(&set->lock);
348                 set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
349                 write_unlock_bh(&set->lock);
350                 ret = 1;
351         }
352
353         /* Convert error codes to nomatch */
354         return (ret < 0 ? 0 : ret);
355 }
356 EXPORT_SYMBOL_GPL(ip_set_test);
357
358 int
359 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
360            u8 family, u8 dim, u8 flags)
361 {
362         struct ip_set *set = ip_set_list[index];
363         int ret;
364
365         BUG_ON(set == NULL);
366         pr_debug("set %s, index %u\n", set->name, index);
367
368         if (dim < set->type->dimension ||
369             !(family == set->family || set->family == AF_UNSPEC))
370                 return 0;
371
372         write_lock_bh(&set->lock);
373         ret = set->variant->kadt(set, skb, IPSET_ADD, family, dim, flags);
374         write_unlock_bh(&set->lock);
375
376         return ret;
377 }
378 EXPORT_SYMBOL_GPL(ip_set_add);
379
380 int
381 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
382            u8 family, u8 dim, u8 flags)
383 {
384         struct ip_set *set = ip_set_list[index];
385         int ret = 0;
386
387         BUG_ON(set == NULL);
388         pr_debug("set %s, index %u\n", set->name, index);
389
390         if (dim < set->type->dimension ||
391             !(family == set->family || set->family == AF_UNSPEC))
392                 return 0;
393
394         write_lock_bh(&set->lock);
395         ret = set->variant->kadt(set, skb, IPSET_DEL, family, dim, flags);
396         write_unlock_bh(&set->lock);
397
398         return ret;
399 }
400 EXPORT_SYMBOL_GPL(ip_set_del);
401
402 /*
403  * Find set by name, reference it once. The reference makes sure the
404  * thing pointed to, does not go away under our feet.
405  *
406  */
407 ip_set_id_t
408 ip_set_get_byname(const char *name, struct ip_set **set)
409 {
410         ip_set_id_t i, index = IPSET_INVALID_ID;
411         struct ip_set *s;
412
413         for (i = 0; i < ip_set_max; i++) {
414                 s = ip_set_list[i];
415                 if (s != NULL && STREQ(s->name, name)) {
416                         __ip_set_get(i);
417                         index = i;
418                         *set = s;
419                 }
420         }
421
422         return index;
423 }
424 EXPORT_SYMBOL_GPL(ip_set_get_byname);
425
426 /*
427  * If the given set pointer points to a valid set, decrement
428  * reference count by 1. The caller shall not assume the index
429  * to be valid, after calling this function.
430  *
431  */
432 void
433 ip_set_put_byindex(ip_set_id_t index)
434 {
435         if (ip_set_list[index] != NULL)
436                 __ip_set_put(index);
437 }
438 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
439
440 /*
441  * Get the name of a set behind a set index.
442  * We assume the set is referenced, so it does exist and
443  * can't be destroyed. The set cannot be renamed due to
444  * the referencing either.
445  *
446  */
447 const char *
448 ip_set_name_byindex(ip_set_id_t index)
449 {
450         const struct ip_set *set = ip_set_list[index];
451
452         BUG_ON(set == NULL);
453         BUG_ON(set->ref == 0);
454
455         /* Referenced, so it's safe */
456         return set->name;
457 }
458 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
459
460 /*
461  * Routines to call by external subsystems, which do not
462  * call nfnl_lock for us.
463  */
464
465 /*
466  * Find set by name, reference it once. The reference makes sure the
467  * thing pointed to, does not go away under our feet.
468  *
469  * The nfnl mutex is used in the function.
470  */
471 ip_set_id_t
472 ip_set_nfnl_get(const char *name)
473 {
474         struct ip_set *s;
475         ip_set_id_t index;
476
477         nfnl_lock();
478         index = ip_set_get_byname(name, &s);
479         nfnl_unlock();
480
481         return index;
482 }
483 EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
484
485 /*
486  * Find set by index, reference it once. The reference makes sure the
487  * thing pointed to, does not go away under our feet.
488  *
489  * The nfnl mutex is used in the function.
490  */
491 ip_set_id_t
492 ip_set_nfnl_get_byindex(ip_set_id_t index)
493 {
494         if (index > ip_set_max)
495                 return IPSET_INVALID_ID;
496
497         nfnl_lock();
498         if (ip_set_list[index])
499                 __ip_set_get(index);
500         else
501                 index = IPSET_INVALID_ID;
502         nfnl_unlock();
503
504         return index;
505 }
506 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
507
508 /*
509  * If the given set pointer points to a valid set, decrement
510  * reference count by 1. The caller shall not assume the index
511  * to be valid, after calling this function.
512  *
513  * The nfnl mutex is used in the function.
514  */
515 void
516 ip_set_nfnl_put(ip_set_id_t index)
517 {
518         nfnl_lock();
519         ip_set_put_byindex(index);
520         nfnl_unlock();
521 }
522 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
523
524 /*
525  * Communication protocol with userspace over netlink.
526  *
527  * The commands are serialized by the nfnl mutex.
528  */
529
530 static inline bool
531 protocol_failed(const struct nlattr * const tb[])
532 {
533         return !tb[IPSET_ATTR_PROTOCOL] ||
534                nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
535 }
536
537 static inline u32
538 flag_exist(const struct nlmsghdr *nlh)
539 {
540         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
541 }
542
543 static struct nlmsghdr *
544 start_msg(struct sk_buff *skb, u32 pid, u32 seq, unsigned int flags,
545           enum ipset_cmd cmd)
546 {
547         struct nlmsghdr *nlh;
548         struct nfgenmsg *nfmsg;
549
550         nlh = nlmsg_put(skb, pid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
551                         sizeof(*nfmsg), flags);
552         if (nlh == NULL)
553                 return NULL;
554
555         nfmsg = nlmsg_data(nlh);
556         nfmsg->nfgen_family = AF_INET;
557         nfmsg->version = NFNETLINK_V0;
558         nfmsg->res_id = 0;
559
560         return nlh;
561 }
562
563 /* Create a set */
564
565 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
566         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
567         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
568                                     .len = IPSET_MAXNAMELEN - 1 },
569         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
570                                     .len = IPSET_MAXNAMELEN - 1},
571         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
572         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
573         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
574 };
575
576 static ip_set_id_t
577 find_set_id(const char *name)
578 {
579         ip_set_id_t i, index = IPSET_INVALID_ID;
580         const struct ip_set *set;
581
582         for (i = 0; index == IPSET_INVALID_ID && i < ip_set_max; i++) {
583                 set = ip_set_list[i];
584                 if (set != NULL && STREQ(set->name, name))
585                         index = i;
586         }
587         return index;
588 }
589
590 static inline struct ip_set *
591 find_set(const char *name)
592 {
593         ip_set_id_t index = find_set_id(name);
594
595         return index == IPSET_INVALID_ID ? NULL : ip_set_list[index];
596 }
597
598 static int
599 find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
600 {
601         ip_set_id_t i;
602
603         *index = IPSET_INVALID_ID;
604         for (i = 0;  i < ip_set_max; i++) {
605                 if (ip_set_list[i] == NULL) {
606                         if (*index == IPSET_INVALID_ID)
607                                 *index = i;
608                 } else if (STREQ(name, ip_set_list[i]->name)) {
609                         /* Name clash */
610                         *set = ip_set_list[i];
611                         return -EEXIST;
612                 }
613         }
614         if (*index == IPSET_INVALID_ID)
615                 /* No free slot remained */
616                 return -IPSET_ERR_MAX_SETS;
617         return 0;
618 }
619
620 static int
621 ip_set_create(struct sock *ctnl, struct sk_buff *skb,
622               const struct nlmsghdr *nlh,
623               const struct nlattr * const attr[])
624 {
625         struct ip_set *set, *clash = NULL;
626         ip_set_id_t index = IPSET_INVALID_ID;
627         struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
628         const char *name, *typename;
629         u8 family, revision;
630         u32 flags = flag_exist(nlh);
631         int ret = 0;
632
633         if (unlikely(protocol_failed(attr) ||
634                      attr[IPSET_ATTR_SETNAME] == NULL ||
635                      attr[IPSET_ATTR_TYPENAME] == NULL ||
636                      attr[IPSET_ATTR_REVISION] == NULL ||
637                      attr[IPSET_ATTR_FAMILY] == NULL ||
638                      (attr[IPSET_ATTR_DATA] != NULL &&
639                       !flag_nested(attr[IPSET_ATTR_DATA]))))
640                 return -IPSET_ERR_PROTOCOL;
641
642         name = nla_data(attr[IPSET_ATTR_SETNAME]);
643         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
644         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
645         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
646         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
647                  name, typename, family_name(family), revision);
648
649         /*
650          * First, and without any locks, allocate and initialize
651          * a normal base set structure.
652          */
653         set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
654         if (!set)
655                 return -ENOMEM;
656         rwlock_init(&set->lock);
657         strlcpy(set->name, name, IPSET_MAXNAMELEN);
658         set->family = family;
659
660         /*
661          * Next, check that we know the type, and take
662          * a reference on the type, to make sure it stays available
663          * while constructing our new set.
664          *
665          * After referencing the type, we try to create the type
666          * specific part of the set without holding any locks.
667          */
668         ret = find_set_type_get(typename, family, revision, &(set->type));
669         if (ret)
670                 goto out;
671
672         /*
673          * Without holding any locks, create private part.
674          */
675         if (attr[IPSET_ATTR_DATA] &&
676             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
677                              set->type->create_policy)) {
678                 ret = -IPSET_ERR_PROTOCOL;
679                 goto put_out;
680         }
681
682         ret = set->type->create(set, tb, flags);
683         if (ret != 0)
684                 goto put_out;
685
686         /* BTW, ret==0 here. */
687
688         /*
689          * Here, we have a valid, constructed set and we are protected
690          * by the nfnl mutex. Find the first free index in ip_set_list
691          * and check clashing.
692          */
693         if ((ret = find_free_id(set->name, &index, &clash)) != 0) {
694                 /* If this is the same set and requested, ignore error */
695                 if (ret == -EEXIST &&
696                     (flags & IPSET_FLAG_EXIST) &&
697                     STREQ(set->type->name, clash->type->name) &&
698                     set->type->family == clash->type->family &&
699                     set->type->revision == clash->type->revision &&
700                     set->variant->same_set(set, clash))
701                         ret = 0;
702                 goto cleanup;
703         }
704
705         /*
706          * Finally! Add our shiny new set to the list, and be done.
707          */
708         pr_debug("create: '%s' created with index %u!\n", set->name, index);
709         ip_set_list[index] = set;
710
711         return ret;
712
713 cleanup:
714         set->variant->destroy(set);
715 put_out:
716         module_put(set->type->me);
717 out:
718         kfree(set);
719         return ret;
720 }
721
722 /* Destroy sets */
723
724 static const struct nla_policy
725 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
726         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
727         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
728                                     .len = IPSET_MAXNAMELEN - 1 },
729 };
730
731 static void
732 ip_set_destroy_set(ip_set_id_t index)
733 {
734         struct ip_set *set = ip_set_list[index];
735
736         pr_debug("set: %s\n",  set->name);
737         ip_set_list[index] = NULL;
738
739         /* Must call it without holding any lock */
740         set->variant->destroy(set);
741         module_put(set->type->me);
742         kfree(set);
743 }
744
745 static int
746 ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
747                const struct nlmsghdr *nlh,
748                const struct nlattr * const attr[])
749 {
750         ip_set_id_t i;
751         int ret = 0;
752
753         if (unlikely(protocol_failed(attr)))
754                 return -IPSET_ERR_PROTOCOL;
755
756         /* Commands are serialized and references are
757          * protected by the ip_set_ref_lock.
758          * External systems (i.e. xt_set) must call
759          * ip_set_put|get_nfnl_* functions, that way we
760          * can safely check references here.
761          *
762          * list:set timer can only decrement the reference
763          * counter, so if it's already zero, we can proceed
764          * without holding the lock.
765          */
766         read_lock_bh(&ip_set_ref_lock);
767         if (!attr[IPSET_ATTR_SETNAME]) {
768                 for (i = 0; i < ip_set_max; i++) {
769                         if (ip_set_list[i] != NULL && ip_set_list[i]->ref) {
770                                 ret = IPSET_ERR_BUSY;
771                                 goto out;
772                         }
773                 }
774                 read_unlock_bh(&ip_set_ref_lock);
775                 for (i = 0; i < ip_set_max; i++) {
776                         if (ip_set_list[i] != NULL)
777                                 ip_set_destroy_set(i);
778                 }
779         } else {
780                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
781                 if (i == IPSET_INVALID_ID) {
782                         ret = -ENOENT;
783                         goto out;
784                 } else if (ip_set_list[i]->ref) {
785                         ret = -IPSET_ERR_BUSY;
786                         goto out;
787                 }
788                 read_unlock_bh(&ip_set_ref_lock);
789
790                 ip_set_destroy_set(i);
791         }
792         return 0;
793 out:
794         read_unlock_bh(&ip_set_ref_lock);
795         return ret;
796 }
797
798 /* Flush sets */
799
800 static void
801 ip_set_flush_set(struct ip_set *set)
802 {
803         pr_debug("set: %s\n",  set->name);
804
805         write_lock_bh(&set->lock);
806         set->variant->flush(set);
807         write_unlock_bh(&set->lock);
808 }
809
810 static int
811 ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
812              const struct nlmsghdr *nlh,
813              const struct nlattr * const attr[])
814 {
815         ip_set_id_t i;
816
817         if (unlikely(protocol_failed(attr)))
818                 return -EPROTO;
819
820         if (!attr[IPSET_ATTR_SETNAME]) {
821                 for (i = 0; i < ip_set_max; i++)
822                         if (ip_set_list[i] != NULL)
823                                 ip_set_flush_set(ip_set_list[i]);
824         } else {
825                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
826                 if (i == IPSET_INVALID_ID)
827                         return -ENOENT;
828
829                 ip_set_flush_set(ip_set_list[i]);
830         }
831
832         return 0;
833 }
834
835 /* Rename a set */
836
837 static const struct nla_policy
838 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
839         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
840         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
841                                     .len = IPSET_MAXNAMELEN - 1 },
842         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
843                                     .len = IPSET_MAXNAMELEN - 1 },
844 };
845
846 static int
847 ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
848               const struct nlmsghdr *nlh,
849               const struct nlattr * const attr[])
850 {
851         struct ip_set *set;
852         const char *name2;
853         ip_set_id_t i;
854         int ret = 0;
855
856         if (unlikely(protocol_failed(attr) ||
857                      attr[IPSET_ATTR_SETNAME] == NULL ||
858                      attr[IPSET_ATTR_SETNAME2] == NULL))
859                 return -IPSET_ERR_PROTOCOL;
860
861         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
862         if (set == NULL)
863                 return -ENOENT;
864
865         read_lock_bh(&ip_set_ref_lock);
866         if (set->ref != 0) {
867                 ret = -IPSET_ERR_REFERENCED;
868                 goto out;
869         }
870
871         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
872         for (i = 0; i < ip_set_max; i++) {
873                 if (ip_set_list[i] != NULL &&
874                     STREQ(ip_set_list[i]->name, name2)) {
875                         ret = -IPSET_ERR_EXIST_SETNAME2;
876                         goto out;
877                 }
878         }
879         strncpy(set->name, name2, IPSET_MAXNAMELEN);
880
881 out:
882         read_unlock_bh(&ip_set_ref_lock);
883         return ret;
884 }
885
886 /* Swap two sets so that name/index points to the other.
887  * References and set names are also swapped.
888  *
889  * The commands are serialized by the nfnl mutex and references are
890  * protected by the ip_set_ref_lock. The kernel interfaces
891  * do not hold the mutex but the pointer settings are atomic
892  * so the ip_set_list always contains valid pointers to the sets.
893  */
894
895 static int
896 ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
897             const struct nlmsghdr *nlh,
898             const struct nlattr * const attr[])
899 {
900         struct ip_set *from, *to;
901         ip_set_id_t from_id, to_id;
902         char from_name[IPSET_MAXNAMELEN];
903
904         if (unlikely(protocol_failed(attr) ||
905                      attr[IPSET_ATTR_SETNAME] == NULL ||
906                      attr[IPSET_ATTR_SETNAME2] == NULL))
907                 return -IPSET_ERR_PROTOCOL;
908
909         from_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
910         if (from_id == IPSET_INVALID_ID)
911                 return -ENOENT;
912
913         to_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME2]));
914         if (to_id == IPSET_INVALID_ID)
915                 return -IPSET_ERR_EXIST_SETNAME2;
916
917         from = ip_set_list[from_id];
918         to = ip_set_list[to_id];
919
920         /* Features must not change.
921          * Not an artificial restriction anymore, as we must prevent
922          * possible loops created by swapping in setlist type of sets. */
923         if (!(from->type->features == to->type->features &&
924               from->type->family == to->type->family))
925                 return -IPSET_ERR_TYPE_MISMATCH;
926
927         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
928         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
929         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
930
931         write_lock_bh(&ip_set_ref_lock);
932         swap(from->ref, to->ref);
933         ip_set_list[from_id] = to;
934         ip_set_list[to_id] = from;
935         write_unlock_bh(&ip_set_ref_lock);
936
937         return 0;
938 }
939
940 /* List/save set data */
941
942 #define DUMP_INIT       0L
943 #define DUMP_ALL        1L
944 #define DUMP_ONE        2L
945 #define DUMP_LAST       3L
946
947 static int
948 ip_set_dump_done(struct netlink_callback *cb)
949 {
950         if (cb->args[2]) {
951                 pr_debug("release set %s\n", ip_set_list[cb->args[1]]->name);
952                 ip_set_put_byindex((ip_set_id_t) cb->args[1]);
953         }
954         return 0;
955 }
956
957 static inline void
958 dump_attrs(struct nlmsghdr *nlh)
959 {
960         const struct nlattr *attr;
961         int rem;
962
963         pr_debug("dump nlmsg\n");
964         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
965                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
966         }
967 }
968
969 static int
970 dump_init(struct netlink_callback *cb)
971 {
972         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
973         int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
974         struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
975         struct nlattr *attr = (void *)nlh + min_len;
976         ip_set_id_t index;
977
978         /* Second pass, so parser can't fail */
979         nla_parse(cda, IPSET_ATTR_CMD_MAX,
980                   attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
981
982         /* cb->args[0] : dump single set/all sets
983          *         [1] : set index
984          *         [..]: type specific
985          */
986
987         if (!cda[IPSET_ATTR_SETNAME]) {
988                 cb->args[0] = DUMP_ALL;
989                 return 0;
990         }
991
992         index = find_set_id(nla_data(cda[IPSET_ATTR_SETNAME]));
993         if (index == IPSET_INVALID_ID)
994                 return -ENOENT;
995
996         cb->args[0] = DUMP_ONE;
997         cb->args[1] = index;
998         return 0;
999 }
1000
1001 static int
1002 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1003 {
1004         ip_set_id_t index = IPSET_INVALID_ID, max;
1005         struct ip_set *set = NULL;
1006         struct nlmsghdr *nlh = NULL;
1007         unsigned int flags = NETLINK_CB(cb->skb).pid ? NLM_F_MULTI : 0;
1008         int ret = 0;
1009
1010         if (cb->args[0] == DUMP_INIT) {
1011                 ret = dump_init(cb);
1012                 if (ret < 0) {
1013                         nlh = nlmsg_hdr(cb->skb);
1014                         /* We have to create and send the error message
1015                          * manually :-( */
1016                         if (nlh->nlmsg_flags & NLM_F_ACK)
1017                                 netlink_ack(cb->skb, nlh, ret);
1018                         return ret;
1019                 }
1020         }
1021
1022         if (cb->args[1] >= ip_set_max)
1023                 goto out;
1024
1025         max = cb->args[0] == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1026 dump_last:
1027         pr_debug("args[0]: %ld args[1]: %ld\n", cb->args[0], cb->args[1]);
1028         for (; cb->args[1] < max; cb->args[1]++) {
1029                 index = (ip_set_id_t) cb->args[1];
1030                 set = ip_set_list[index];
1031                 if (set == NULL) {
1032                         if (cb->args[0] == DUMP_ONE) {
1033                                 ret = -ENOENT;
1034                                 goto out;
1035                         }
1036                         continue;
1037                 }
1038                 /* When dumping all sets, we must dump "sorted"
1039                  * so that lists (unions of sets) are dumped last.
1040                  */
1041                 if (cb->args[0] != DUMP_ONE &&
1042                     ((cb->args[0] == DUMP_ALL) ==
1043                      !!(set->type->features & IPSET_DUMP_LAST)))
1044                         continue;
1045                 pr_debug("List set: %s\n", set->name);
1046                 if (!cb->args[2]) {
1047                         /* Start listing: make sure set won't be destroyed */
1048                         pr_debug("reference set\n");
1049                         __ip_set_get(index);
1050                 }
1051                 nlh = start_msg(skb, NETLINK_CB(cb->skb).pid,
1052                                 cb->nlh->nlmsg_seq, flags,
1053                                 IPSET_CMD_LIST);
1054                 if (!nlh) {
1055                         ret = -EMSGSIZE;
1056                         goto release_refcount;
1057                 }
1058                 NLA_PUT_U8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1059                 NLA_PUT_STRING(skb, IPSET_ATTR_SETNAME, set->name);
1060                 switch (cb->args[2]) {
1061                 case 0:
1062                         /* Core header data */
1063                         NLA_PUT_STRING(skb, IPSET_ATTR_TYPENAME,
1064                                        set->type->name);
1065                         NLA_PUT_U8(skb, IPSET_ATTR_FAMILY,
1066                                    set->family);
1067                         NLA_PUT_U8(skb, IPSET_ATTR_REVISION,
1068                                    set->type->revision);
1069                         ret = set->variant->head(set, skb);
1070                         if (ret < 0)
1071                                 goto release_refcount;
1072                         /* Fall through and add elements */
1073                 default:
1074                         read_lock_bh(&set->lock);
1075                         ret = set->variant->list(set, skb, cb);
1076                         read_unlock_bh(&set->lock);
1077                         if (!cb->args[2]) {
1078                                 /* Set is done, proceed with next one */
1079                                 if (cb->args[0] == DUMP_ONE)
1080                                         cb->args[1] = IPSET_INVALID_ID;
1081                                 else
1082                                         cb->args[1]++;
1083                         }
1084                         goto release_refcount;
1085                 }
1086         }
1087         /* If we dump all sets, continue with dumping last ones */
1088         if (cb->args[0] == DUMP_ALL) {
1089                 cb->args[0] = DUMP_LAST;
1090                 cb->args[1] = 0;
1091                 goto dump_last;
1092         }
1093         goto out;
1094
1095 nla_put_failure:
1096         ret = -EFAULT;
1097 release_refcount:
1098         /* If there was an error or set is done, release set */
1099         if (ret || !cb->args[2]) {
1100                 pr_debug("release set %s\n", ip_set_list[index]->name);
1101                 ip_set_put_byindex(index);
1102         }
1103 out:
1104         if (nlh) {
1105                 nlmsg_end(skb, nlh);
1106                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1107                 dump_attrs(nlh);
1108         }
1109
1110         return ret < 0 ? ret : skb->len;
1111 }
1112
1113 static int
1114 ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1115             const struct nlmsghdr *nlh,
1116             const struct nlattr * const attr[])
1117 {
1118         if (unlikely(protocol_failed(attr)))
1119                 return -IPSET_ERR_PROTOCOL;
1120
1121         return netlink_dump_start(ctnl, skb, nlh,
1122                                   ip_set_dump_start,
1123                                   ip_set_dump_done);
1124 }
1125
1126 /* Add, del and test */
1127
1128 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1129         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1130         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1131                                     .len = IPSET_MAXNAMELEN - 1 },
1132         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1133         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1134         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1135 };
1136
1137 static int
1138 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1139         struct nlattr *tb[], enum ipset_adt adt,
1140         u32 flags, bool use_lineno)
1141 {
1142         int ret, retried = 0;
1143         u32 lineno = 0;
1144         bool eexist = flags & IPSET_FLAG_EXIST;
1145
1146         do {
1147                 write_lock_bh(&set->lock);
1148                 ret = set->variant->uadt(set, tb, adt, &lineno, flags);
1149                 write_unlock_bh(&set->lock);
1150         } while (ret == -EAGAIN &&
1151                  set->variant->resize &&
1152                  (ret = set->variant->resize(set, retried++)) == 0);
1153
1154         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1155                 return 0;
1156         if (lineno && use_lineno) {
1157                 /* Error in restore/batch mode: send back lineno */
1158                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1159                 struct sk_buff *skb2;
1160                 struct nlmsgerr *errmsg;
1161                 size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1162                 int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1163                 struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1164                 struct nlattr *cmdattr;
1165                 u32 *errline;
1166
1167                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1168                 if (skb2 == NULL)
1169                         return -ENOMEM;
1170                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).pid,
1171                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1172                 errmsg = nlmsg_data(rep);
1173                 errmsg->error = ret;
1174                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1175                 cmdattr = (void *)&errmsg->msg + min_len;
1176
1177                 nla_parse(cda, IPSET_ATTR_CMD_MAX,
1178                           cmdattr, nlh->nlmsg_len - min_len,
1179                           ip_set_adt_policy);
1180
1181                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1182
1183                 *errline = lineno;
1184
1185                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1186                 /* Signal netlink not to send its ACK/errmsg.  */
1187                 return -EINTR;
1188         }
1189
1190         return ret;
1191 }
1192
1193 static int
1194 ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1195             const struct nlmsghdr *nlh,
1196             const struct nlattr * const attr[])
1197 {
1198         struct ip_set *set;
1199         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1200         const struct nlattr *nla;
1201         u32 flags = flag_exist(nlh);
1202         bool use_lineno;
1203         int ret = 0;
1204
1205         if (unlikely(protocol_failed(attr) ||
1206                      attr[IPSET_ATTR_SETNAME] == NULL ||
1207                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1208                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1209                      (attr[IPSET_ATTR_DATA] != NULL &&
1210                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1211                      (attr[IPSET_ATTR_ADT] != NULL &&
1212                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1213                        attr[IPSET_ATTR_LINENO] == NULL))))
1214                 return -IPSET_ERR_PROTOCOL;
1215
1216         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1217         if (set == NULL)
1218                 return -ENOENT;
1219
1220         use_lineno = !!attr[IPSET_ATTR_LINENO];
1221         if (attr[IPSET_ATTR_DATA]) {
1222                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1223                                      attr[IPSET_ATTR_DATA],
1224                                      set->type->adt_policy))
1225                         return -IPSET_ERR_PROTOCOL;
1226                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1227                               use_lineno);
1228         } else {
1229                 int nla_rem;
1230
1231                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1232                         memset(tb, 0, sizeof(tb));
1233                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1234                             !flag_nested(nla) ||
1235                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1236                                              set->type->adt_policy))
1237                                 return -IPSET_ERR_PROTOCOL;
1238                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1239                                       flags, use_lineno);
1240                         if (ret < 0)
1241                                 return ret;
1242                 }
1243         }
1244         return ret;
1245 }
1246
1247 static int
1248 ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1249             const struct nlmsghdr *nlh,
1250             const struct nlattr * const attr[])
1251 {
1252         struct ip_set *set;
1253         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1254         const struct nlattr *nla;
1255         u32 flags = flag_exist(nlh);
1256         bool use_lineno;
1257         int ret = 0;
1258
1259         if (unlikely(protocol_failed(attr) ||
1260                      attr[IPSET_ATTR_SETNAME] == NULL ||
1261                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1262                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1263                      (attr[IPSET_ATTR_DATA] != NULL &&
1264                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1265                      (attr[IPSET_ATTR_ADT] != NULL &&
1266                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1267                        attr[IPSET_ATTR_LINENO] == NULL))))
1268                 return -IPSET_ERR_PROTOCOL;
1269
1270         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1271         if (set == NULL)
1272                 return -ENOENT;
1273
1274         use_lineno = !!attr[IPSET_ATTR_LINENO];
1275         if (attr[IPSET_ATTR_DATA]) {
1276                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1277                                      attr[IPSET_ATTR_DATA],
1278                                      set->type->adt_policy))
1279                         return -IPSET_ERR_PROTOCOL;
1280                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1281                               use_lineno);
1282         } else {
1283                 int nla_rem;
1284
1285                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1286                         memset(tb, 0, sizeof(*tb));
1287                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1288                             !flag_nested(nla) ||
1289                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1290                                              set->type->adt_policy))
1291                                 return -IPSET_ERR_PROTOCOL;
1292                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1293                                       flags, use_lineno);
1294                         if (ret < 0)
1295                                 return ret;
1296                 }
1297         }
1298         return ret;
1299 }
1300
1301 static int
1302 ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1303              const struct nlmsghdr *nlh,
1304              const struct nlattr * const attr[])
1305 {
1306         struct ip_set *set;
1307         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1308         int ret = 0;
1309
1310         if (unlikely(protocol_failed(attr) ||
1311                      attr[IPSET_ATTR_SETNAME] == NULL ||
1312                      attr[IPSET_ATTR_DATA] == NULL ||
1313                      !flag_nested(attr[IPSET_ATTR_DATA])))
1314                 return -IPSET_ERR_PROTOCOL;
1315
1316         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1317         if (set == NULL)
1318                 return -ENOENT;
1319
1320         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1321                              set->type->adt_policy))
1322                 return -IPSET_ERR_PROTOCOL;
1323
1324         read_lock_bh(&set->lock);
1325         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0);
1326         read_unlock_bh(&set->lock);
1327         /* Userspace can't trigger element to be re-added */
1328         if (ret == -EAGAIN)
1329                 ret = 1;
1330
1331         return ret < 0 ? ret : ret > 0 ? 0 : -IPSET_ERR_EXIST;
1332 }
1333
1334 /* Get headed data of a set */
1335
1336 static int
1337 ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1338               const struct nlmsghdr *nlh,
1339               const struct nlattr * const attr[])
1340 {
1341         const struct ip_set *set;
1342         struct sk_buff *skb2;
1343         struct nlmsghdr *nlh2;
1344         ip_set_id_t index;
1345         int ret = 0;
1346
1347         if (unlikely(protocol_failed(attr) ||
1348                      attr[IPSET_ATTR_SETNAME] == NULL))
1349                 return -IPSET_ERR_PROTOCOL;
1350
1351         index = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
1352         if (index == IPSET_INVALID_ID)
1353                 return -ENOENT;
1354         set = ip_set_list[index];
1355
1356         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1357         if (skb2 == NULL)
1358                 return -ENOMEM;
1359
1360         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1361                          IPSET_CMD_HEADER);
1362         if (!nlh2)
1363                 goto nlmsg_failure;
1364         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1365         NLA_PUT_STRING(skb2, IPSET_ATTR_SETNAME, set->name);
1366         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, set->type->name);
1367         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, set->family);
1368         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, set->type->revision);
1369         nlmsg_end(skb2, nlh2);
1370
1371         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1372         if (ret < 0)
1373                 return ret;
1374
1375         return 0;
1376
1377 nla_put_failure:
1378         nlmsg_cancel(skb2, nlh2);
1379 nlmsg_failure:
1380         kfree_skb(skb2);
1381         return -EMSGSIZE;
1382 }
1383
1384 /* Get type data */
1385
1386 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1387         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1388         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1389                                     .len = IPSET_MAXNAMELEN - 1 },
1390         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1391 };
1392
1393 static int
1394 ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1395             const struct nlmsghdr *nlh,
1396             const struct nlattr * const attr[])
1397 {
1398         struct sk_buff *skb2;
1399         struct nlmsghdr *nlh2;
1400         u8 family, min, max;
1401         const char *typename;
1402         int ret = 0;
1403
1404         if (unlikely(protocol_failed(attr) ||
1405                      attr[IPSET_ATTR_TYPENAME] == NULL ||
1406                      attr[IPSET_ATTR_FAMILY] == NULL))
1407                 return -IPSET_ERR_PROTOCOL;
1408
1409         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1410         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1411         ret = find_set_type_minmax(typename, family, &min, &max);
1412         if (ret)
1413                 return ret;
1414
1415         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1416         if (skb2 == NULL)
1417                 return -ENOMEM;
1418
1419         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1420                          IPSET_CMD_TYPE);
1421         if (!nlh2)
1422                 goto nlmsg_failure;
1423         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1424         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, typename);
1425         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, family);
1426         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, max);
1427         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION_MIN, min);
1428         nlmsg_end(skb2, nlh2);
1429
1430         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1431         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1432         if (ret < 0)
1433                 return ret;
1434
1435         return 0;
1436
1437 nla_put_failure:
1438         nlmsg_cancel(skb2, nlh2);
1439 nlmsg_failure:
1440         kfree_skb(skb2);
1441         return -EMSGSIZE;
1442 }
1443
1444 /* Get protocol version */
1445
1446 static const struct nla_policy
1447 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1448         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1449 };
1450
1451 static int
1452 ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1453                 const struct nlmsghdr *nlh,
1454                 const struct nlattr * const attr[])
1455 {
1456         struct sk_buff *skb2;
1457         struct nlmsghdr *nlh2;
1458         int ret = 0;
1459
1460         if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1461                 return -IPSET_ERR_PROTOCOL;
1462
1463         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1464         if (skb2 == NULL)
1465                 return -ENOMEM;
1466
1467         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1468                          IPSET_CMD_PROTOCOL);
1469         if (!nlh2)
1470                 goto nlmsg_failure;
1471         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1472         nlmsg_end(skb2, nlh2);
1473
1474         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1475         if (ret < 0)
1476                 return ret;
1477
1478         return 0;
1479
1480 nla_put_failure:
1481         nlmsg_cancel(skb2, nlh2);
1482 nlmsg_failure:
1483         kfree_skb(skb2);
1484         return -EMSGSIZE;
1485 }
1486
1487 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1488         [IPSET_CMD_CREATE]      = {
1489                 .call           = ip_set_create,
1490                 .attr_count     = IPSET_ATTR_CMD_MAX,
1491                 .policy         = ip_set_create_policy,
1492         },
1493         [IPSET_CMD_DESTROY]     = {
1494                 .call           = ip_set_destroy,
1495                 .attr_count     = IPSET_ATTR_CMD_MAX,
1496                 .policy         = ip_set_setname_policy,
1497         },
1498         [IPSET_CMD_FLUSH]       = {
1499                 .call           = ip_set_flush,
1500                 .attr_count     = IPSET_ATTR_CMD_MAX,
1501                 .policy         = ip_set_setname_policy,
1502         },
1503         [IPSET_CMD_RENAME]      = {
1504                 .call           = ip_set_rename,
1505                 .attr_count     = IPSET_ATTR_CMD_MAX,
1506                 .policy         = ip_set_setname2_policy,
1507         },
1508         [IPSET_CMD_SWAP]        = {
1509                 .call           = ip_set_swap,
1510                 .attr_count     = IPSET_ATTR_CMD_MAX,
1511                 .policy         = ip_set_setname2_policy,
1512         },
1513         [IPSET_CMD_LIST]        = {
1514                 .call           = ip_set_dump,
1515                 .attr_count     = IPSET_ATTR_CMD_MAX,
1516                 .policy         = ip_set_setname_policy,
1517         },
1518         [IPSET_CMD_SAVE]        = {
1519                 .call           = ip_set_dump,
1520                 .attr_count     = IPSET_ATTR_CMD_MAX,
1521                 .policy         = ip_set_setname_policy,
1522         },
1523         [IPSET_CMD_ADD] = {
1524                 .call           = ip_set_uadd,
1525                 .attr_count     = IPSET_ATTR_CMD_MAX,
1526                 .policy         = ip_set_adt_policy,
1527         },
1528         [IPSET_CMD_DEL] = {
1529                 .call           = ip_set_udel,
1530                 .attr_count     = IPSET_ATTR_CMD_MAX,
1531                 .policy         = ip_set_adt_policy,
1532         },
1533         [IPSET_CMD_TEST]        = {
1534                 .call           = ip_set_utest,
1535                 .attr_count     = IPSET_ATTR_CMD_MAX,
1536                 .policy         = ip_set_adt_policy,
1537         },
1538         [IPSET_CMD_HEADER]      = {
1539                 .call           = ip_set_header,
1540                 .attr_count     = IPSET_ATTR_CMD_MAX,
1541                 .policy         = ip_set_setname_policy,
1542         },
1543         [IPSET_CMD_TYPE]        = {
1544                 .call           = ip_set_type,
1545                 .attr_count     = IPSET_ATTR_CMD_MAX,
1546                 .policy         = ip_set_type_policy,
1547         },
1548         [IPSET_CMD_PROTOCOL]    = {
1549                 .call           = ip_set_protocol,
1550                 .attr_count     = IPSET_ATTR_CMD_MAX,
1551                 .policy         = ip_set_protocol_policy,
1552         },
1553 };
1554
1555 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1556         .name           = "ip_set",
1557         .subsys_id      = NFNL_SUBSYS_IPSET,
1558         .cb_count       = IPSET_MSG_MAX,
1559         .cb             = ip_set_netlink_subsys_cb,
1560 };
1561
1562 /* Interface to iptables/ip6tables */
1563
1564 static int
1565 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1566 {
1567         unsigned *op;
1568         void *data;
1569         int copylen = *len, ret = 0;
1570
1571         if (!capable(CAP_NET_ADMIN))
1572                 return -EPERM;
1573         if (optval != SO_IP_SET)
1574                 return -EBADF;
1575         if (*len < sizeof(unsigned))
1576                 return -EINVAL;
1577
1578         data = vmalloc(*len);
1579         if (!data)
1580                 return -ENOMEM;
1581         if (copy_from_user(data, user, *len) != 0) {
1582                 ret = -EFAULT;
1583                 goto done;
1584         }
1585         op = (unsigned *) data;
1586
1587         if (*op < IP_SET_OP_VERSION) {
1588                 /* Check the version at the beginning of operations */
1589                 struct ip_set_req_version *req_version = data;
1590                 if (req_version->version != IPSET_PROTOCOL) {
1591                         ret = -EPROTO;
1592                         goto done;
1593                 }
1594         }
1595
1596         switch (*op) {
1597         case IP_SET_OP_VERSION: {
1598                 struct ip_set_req_version *req_version = data;
1599
1600                 if (*len != sizeof(struct ip_set_req_version)) {
1601                         ret = -EINVAL;
1602                         goto done;
1603                 }
1604
1605                 req_version->version = IPSET_PROTOCOL;
1606                 ret = copy_to_user(user, req_version,
1607                                    sizeof(struct ip_set_req_version));
1608                 goto done;
1609         }
1610         case IP_SET_OP_GET_BYNAME: {
1611                 struct ip_set_req_get_set *req_get = data;
1612
1613                 if (*len != sizeof(struct ip_set_req_get_set)) {
1614                         ret = -EINVAL;
1615                         goto done;
1616                 }
1617                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1618                 nfnl_lock();
1619                 req_get->set.index = find_set_id(req_get->set.name);
1620                 nfnl_unlock();
1621                 goto copy;
1622         }
1623         case IP_SET_OP_GET_BYINDEX: {
1624                 struct ip_set_req_get_set *req_get = data;
1625
1626                 if (*len != sizeof(struct ip_set_req_get_set) ||
1627                     req_get->set.index >= ip_set_max) {
1628                         ret = -EINVAL;
1629                         goto done;
1630                 }
1631                 nfnl_lock();
1632                 strncpy(req_get->set.name,
1633                         ip_set_list[req_get->set.index]
1634                                 ? ip_set_list[req_get->set.index]->name : "",
1635                         IPSET_MAXNAMELEN);
1636                 nfnl_unlock();
1637                 goto copy;
1638         }
1639         default:
1640                 ret = -EBADMSG;
1641                 goto done;
1642         }       /* end of switch(op) */
1643
1644 copy:
1645         ret = copy_to_user(user, data, copylen);
1646
1647 done:
1648         vfree(data);
1649         if (ret > 0)
1650                 ret = 0;
1651         return ret;
1652 }
1653
1654 static struct nf_sockopt_ops so_set __read_mostly = {
1655         .pf             = PF_INET,
1656         .get_optmin     = SO_IP_SET,
1657         .get_optmax     = SO_IP_SET + 1,
1658         .get            = &ip_set_sockfn_get,
1659         .owner          = THIS_MODULE,
1660 };
1661
1662 static int __init
1663 ip_set_init(void)
1664 {
1665         int ret;
1666
1667         if (max_sets)
1668                 ip_set_max = max_sets;
1669         if (ip_set_max >= IPSET_INVALID_ID)
1670                 ip_set_max = IPSET_INVALID_ID - 1;
1671
1672         ip_set_list = kzalloc(sizeof(struct ip_set *) * ip_set_max,
1673                               GFP_KERNEL);
1674         if (!ip_set_list) {
1675                 pr_err("ip_set: Unable to create ip_set_list\n");
1676                 return -ENOMEM;
1677         }
1678
1679         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1680         if (ret != 0) {
1681                 pr_err("ip_set: cannot register with nfnetlink.\n");
1682                 kfree(ip_set_list);
1683                 return ret;
1684         }
1685         ret = nf_register_sockopt(&so_set);
1686         if (ret != 0) {
1687                 pr_err("SO_SET registry failed: %d\n", ret);
1688                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1689                 kfree(ip_set_list);
1690                 return ret;
1691         }
1692
1693         pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1694         return 0;
1695 }
1696
1697 static void __exit
1698 ip_set_fini(void)
1699 {
1700         /* There can't be any existing set */
1701         nf_unregister_sockopt(&so_set);
1702         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1703         kfree(ip_set_list);
1704         pr_debug("these are the famous last words\n");
1705 }
1706
1707 module_init(ip_set_init);
1708 module_exit(ip_set_fini);