netfilter: ipset: small potential read beyond the end of buffer
[pandora-kernel.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/netlink.h>
19 #include <linux/rculist.h>
20 #include <net/netlink.h>
21
22 #include <linux/netfilter.h>
23 #include <linux/netfilter/x_tables.h>
24 #include <linux/netfilter/nfnetlink.h>
25 #include <linux/netfilter/ipset/ip_set.h>
26
27 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
28 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
29 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
30
31 static struct ip_set **ip_set_list;             /* all individual sets */
32 static ip_set_id_t ip_set_max = CONFIG_IP_SET_MAX; /* max number of sets */
33
34 #define STREQ(a, b)     (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
35
36 static unsigned int max_sets;
37
38 module_param(max_sets, int, 0600);
39 MODULE_PARM_DESC(max_sets, "maximal number of sets");
40 MODULE_LICENSE("GPL");
41 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
42 MODULE_DESCRIPTION("core IP set support");
43 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
44
45 /*
46  * The set types are implemented in modules and registered set types
47  * can be found in ip_set_type_list. Adding/deleting types is
48  * serialized by ip_set_type_mutex.
49  */
50
51 static inline void
52 ip_set_type_lock(void)
53 {
54         mutex_lock(&ip_set_type_mutex);
55 }
56
57 static inline void
58 ip_set_type_unlock(void)
59 {
60         mutex_unlock(&ip_set_type_mutex);
61 }
62
63 /* Register and deregister settype */
64
65 static struct ip_set_type *
66 find_set_type(const char *name, u8 family, u8 revision)
67 {
68         struct ip_set_type *type;
69
70         list_for_each_entry_rcu(type, &ip_set_type_list, list)
71                 if (STREQ(type->name, name) &&
72                     (type->family == family || type->family == AF_UNSPEC) &&
73                     revision >= type->revision_min &&
74                     revision <= type->revision_max)
75                         return type;
76         return NULL;
77 }
78
79 /* Unlock, try to load a set type module and lock again */
80 static int
81 try_to_load_type(const char *name)
82 {
83         nfnl_unlock();
84         pr_debug("try to load ip_set_%s\n", name);
85         if (request_module("ip_set_%s", name) < 0) {
86                 pr_warning("Can't find ip_set type %s\n", name);
87                 nfnl_lock();
88                 return -IPSET_ERR_FIND_TYPE;
89         }
90         nfnl_lock();
91         return -EAGAIN;
92 }
93
94 /* Find a set type and reference it */
95 static int
96 find_set_type_get(const char *name, u8 family, u8 revision,
97                   struct ip_set_type **found)
98 {
99         struct ip_set_type *type;
100         int err;
101
102         rcu_read_lock();
103         *found = find_set_type(name, family, revision);
104         if (*found) {
105                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
106                 goto unlock;
107         }
108         /* Make sure the type is loaded but we don't support the revision */
109         list_for_each_entry_rcu(type, &ip_set_type_list, list)
110                 if (STREQ(type->name, name)) {
111                         err = -IPSET_ERR_FIND_TYPE;
112                         goto unlock;
113                 }
114         rcu_read_unlock();
115
116         return try_to_load_type(name);
117
118 unlock:
119         rcu_read_unlock();
120         return err;
121 }
122
123 /* Find a given set type by name and family.
124  * If we succeeded, the supported minimal and maximum revisions are
125  * filled out.
126  */
127 static int
128 find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max)
129 {
130         struct ip_set_type *type;
131         bool found = false;
132
133         *min = 255; *max = 0;
134         rcu_read_lock();
135         list_for_each_entry_rcu(type, &ip_set_type_list, list)
136                 if (STREQ(type->name, name) &&
137                     (type->family == family || type->family == AF_UNSPEC)) {
138                         found = true;
139                         if (type->revision_min < *min)
140                                 *min = type->revision_min;
141                         if (type->revision_max > *max)
142                                 *max = type->revision_max;
143                 }
144         rcu_read_unlock();
145         if (found)
146                 return 0;
147
148         return try_to_load_type(name);
149 }
150
151 #define family_name(f)  ((f) == AF_INET ? "inet" : \
152                          (f) == AF_INET6 ? "inet6" : "any")
153
154 /* Register a set type structure. The type is identified by
155  * the unique triple of name, family and revision.
156  */
157 int
158 ip_set_type_register(struct ip_set_type *type)
159 {
160         int ret = 0;
161
162         if (type->protocol != IPSET_PROTOCOL) {
163                 pr_warning("ip_set type %s, family %s, revision %u:%u uses "
164                            "wrong protocol version %u (want %u)\n",
165                            type->name, family_name(type->family),
166                            type->revision_min, type->revision_max,
167                            type->protocol, IPSET_PROTOCOL);
168                 return -EINVAL;
169         }
170
171         ip_set_type_lock();
172         if (find_set_type(type->name, type->family, type->revision_min)) {
173                 /* Duplicate! */
174                 pr_warning("ip_set type %s, family %s with revision min %u "
175                            "already registered!\n", type->name,
176                            family_name(type->family), type->revision_min);
177                 ret = -EINVAL;
178                 goto unlock;
179         }
180         list_add_rcu(&type->list, &ip_set_type_list);
181         pr_debug("type %s, family %s, revision %u:%u registered.\n",
182                  type->name, family_name(type->family),
183                  type->revision_min, type->revision_max);
184 unlock:
185         ip_set_type_unlock();
186         return ret;
187 }
188 EXPORT_SYMBOL_GPL(ip_set_type_register);
189
190 /* Unregister a set type. There's a small race with ip_set_create */
191 void
192 ip_set_type_unregister(struct ip_set_type *type)
193 {
194         ip_set_type_lock();
195         if (!find_set_type(type->name, type->family, type->revision_min)) {
196                 pr_warning("ip_set type %s, family %s with revision min %u "
197                            "not registered\n", type->name,
198                            family_name(type->family), type->revision_min);
199                 goto unlock;
200         }
201         list_del_rcu(&type->list);
202         pr_debug("type %s, family %s with revision min %u unregistered.\n",
203                  type->name, family_name(type->family), type->revision_min);
204 unlock:
205         ip_set_type_unlock();
206
207         synchronize_rcu();
208 }
209 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
210
211 /* Utility functions */
212 void *
213 ip_set_alloc(size_t size)
214 {
215         void *members = NULL;
216
217         if (size < KMALLOC_MAX_SIZE)
218                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
219
220         if (members) {
221                 pr_debug("%p: allocated with kmalloc\n", members);
222                 return members;
223         }
224
225         members = vzalloc(size);
226         if (!members)
227                 return NULL;
228         pr_debug("%p: allocated with vmalloc\n", members);
229
230         return members;
231 }
232 EXPORT_SYMBOL_GPL(ip_set_alloc);
233
234 void
235 ip_set_free(void *members)
236 {
237         pr_debug("%p: free with %s\n", members,
238                  is_vmalloc_addr(members) ? "vfree" : "kfree");
239         if (is_vmalloc_addr(members))
240                 vfree(members);
241         else
242                 kfree(members);
243 }
244 EXPORT_SYMBOL_GPL(ip_set_free);
245
246 static inline bool
247 flag_nested(const struct nlattr *nla)
248 {
249         return nla->nla_type & NLA_F_NESTED;
250 }
251
252 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
253         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
254         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
255                                             .len = sizeof(struct in6_addr) },
256 };
257
258 int
259 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
260 {
261         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
262
263         if (unlikely(!flag_nested(nla)))
264                 return -IPSET_ERR_PROTOCOL;
265         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
266                 return -IPSET_ERR_PROTOCOL;
267         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
268                 return -IPSET_ERR_PROTOCOL;
269
270         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
271         return 0;
272 }
273 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
274
275 int
276 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
277 {
278         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX+1];
279
280         if (unlikely(!flag_nested(nla)))
281                 return -IPSET_ERR_PROTOCOL;
282
283         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy))
284                 return -IPSET_ERR_PROTOCOL;
285         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
286                 return -IPSET_ERR_PROTOCOL;
287
288         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
289                 sizeof(struct in6_addr));
290         return 0;
291 }
292 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
293
294 /*
295  * Creating/destroying/renaming/swapping affect the existence and
296  * the properties of a set. All of these can be executed from userspace
297  * only and serialized by the nfnl mutex indirectly from nfnetlink.
298  *
299  * Sets are identified by their index in ip_set_list and the index
300  * is used by the external references (set/SET netfilter modules).
301  *
302  * The set behind an index may change by swapping only, from userspace.
303  */
304
305 static inline void
306 __ip_set_get(ip_set_id_t index)
307 {
308         write_lock_bh(&ip_set_ref_lock);
309         ip_set_list[index]->ref++;
310         write_unlock_bh(&ip_set_ref_lock);
311 }
312
313 static inline void
314 __ip_set_put(ip_set_id_t index)
315 {
316         write_lock_bh(&ip_set_ref_lock);
317         BUG_ON(ip_set_list[index]->ref == 0);
318         ip_set_list[index]->ref--;
319         write_unlock_bh(&ip_set_ref_lock);
320 }
321
322 /*
323  * Add, del and test set entries from kernel.
324  *
325  * The set behind the index must exist and must be referenced
326  * so it can't be destroyed (or changed) under our foot.
327  */
328
329 int
330 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
331             const struct xt_action_param *par,
332             const struct ip_set_adt_opt *opt)
333 {
334         struct ip_set *set = ip_set_list[index];
335         int ret = 0;
336
337         BUG_ON(set == NULL);
338         pr_debug("set %s, index %u\n", set->name, index);
339
340         if (opt->dim < set->type->dimension ||
341             !(opt->family == set->family || set->family == AF_UNSPEC))
342                 return 0;
343
344         read_lock_bh(&set->lock);
345         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
346         read_unlock_bh(&set->lock);
347
348         if (ret == -EAGAIN) {
349                 /* Type requests element to be completed */
350                 pr_debug("element must be competed, ADD is triggered\n");
351                 write_lock_bh(&set->lock);
352                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
353                 write_unlock_bh(&set->lock);
354                 ret = 1;
355         }
356
357         /* Convert error codes to nomatch */
358         return (ret < 0 ? 0 : ret);
359 }
360 EXPORT_SYMBOL_GPL(ip_set_test);
361
362 int
363 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
364            const struct xt_action_param *par,
365            const struct ip_set_adt_opt *opt)
366 {
367         struct ip_set *set = ip_set_list[index];
368         int ret;
369
370         BUG_ON(set == NULL);
371         pr_debug("set %s, index %u\n", set->name, index);
372
373         if (opt->dim < set->type->dimension ||
374             !(opt->family == set->family || set->family == AF_UNSPEC))
375                 return 0;
376
377         write_lock_bh(&set->lock);
378         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
379         write_unlock_bh(&set->lock);
380
381         return ret;
382 }
383 EXPORT_SYMBOL_GPL(ip_set_add);
384
385 int
386 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
387            const struct xt_action_param *par,
388            const struct ip_set_adt_opt *opt)
389 {
390         struct ip_set *set = ip_set_list[index];
391         int ret = 0;
392
393         BUG_ON(set == NULL);
394         pr_debug("set %s, index %u\n", set->name, index);
395
396         if (opt->dim < set->type->dimension ||
397             !(opt->family == set->family || set->family == AF_UNSPEC))
398                 return 0;
399
400         write_lock_bh(&set->lock);
401         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
402         write_unlock_bh(&set->lock);
403
404         return ret;
405 }
406 EXPORT_SYMBOL_GPL(ip_set_del);
407
408 /*
409  * Find set by name, reference it once. The reference makes sure the
410  * thing pointed to, does not go away under our feet.
411  *
412  */
413 ip_set_id_t
414 ip_set_get_byname(const char *name, struct ip_set **set)
415 {
416         ip_set_id_t i, index = IPSET_INVALID_ID;
417         struct ip_set *s;
418
419         for (i = 0; i < ip_set_max; i++) {
420                 s = ip_set_list[i];
421                 if (s != NULL && STREQ(s->name, name)) {
422                         __ip_set_get(i);
423                         index = i;
424                         *set = s;
425                 }
426         }
427
428         return index;
429 }
430 EXPORT_SYMBOL_GPL(ip_set_get_byname);
431
432 /*
433  * If the given set pointer points to a valid set, decrement
434  * reference count by 1. The caller shall not assume the index
435  * to be valid, after calling this function.
436  *
437  */
438 void
439 ip_set_put_byindex(ip_set_id_t index)
440 {
441         if (ip_set_list[index] != NULL)
442                 __ip_set_put(index);
443 }
444 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
445
446 /*
447  * Get the name of a set behind a set index.
448  * We assume the set is referenced, so it does exist and
449  * can't be destroyed. The set cannot be renamed due to
450  * the referencing either.
451  *
452  */
453 const char *
454 ip_set_name_byindex(ip_set_id_t index)
455 {
456         const struct ip_set *set = ip_set_list[index];
457
458         BUG_ON(set == NULL);
459         BUG_ON(set->ref == 0);
460
461         /* Referenced, so it's safe */
462         return set->name;
463 }
464 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
465
466 /*
467  * Routines to call by external subsystems, which do not
468  * call nfnl_lock for us.
469  */
470
471 /*
472  * Find set by name, reference it once. The reference makes sure the
473  * thing pointed to, does not go away under our feet.
474  *
475  * The nfnl mutex is used in the function.
476  */
477 ip_set_id_t
478 ip_set_nfnl_get(const char *name)
479 {
480         struct ip_set *s;
481         ip_set_id_t index;
482
483         nfnl_lock();
484         index = ip_set_get_byname(name, &s);
485         nfnl_unlock();
486
487         return index;
488 }
489 EXPORT_SYMBOL_GPL(ip_set_nfnl_get);
490
491 /*
492  * Find set by index, reference it once. The reference makes sure the
493  * thing pointed to, does not go away under our feet.
494  *
495  * The nfnl mutex is used in the function.
496  */
497 ip_set_id_t
498 ip_set_nfnl_get_byindex(ip_set_id_t index)
499 {
500         if (index > ip_set_max)
501                 return IPSET_INVALID_ID;
502
503         nfnl_lock();
504         if (ip_set_list[index])
505                 __ip_set_get(index);
506         else
507                 index = IPSET_INVALID_ID;
508         nfnl_unlock();
509
510         return index;
511 }
512 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
513
514 /*
515  * If the given set pointer points to a valid set, decrement
516  * reference count by 1. The caller shall not assume the index
517  * to be valid, after calling this function.
518  *
519  * The nfnl mutex is used in the function.
520  */
521 void
522 ip_set_nfnl_put(ip_set_id_t index)
523 {
524         nfnl_lock();
525         ip_set_put_byindex(index);
526         nfnl_unlock();
527 }
528 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
529
530 /*
531  * Communication protocol with userspace over netlink.
532  *
533  * The commands are serialized by the nfnl mutex.
534  */
535
536 static inline bool
537 protocol_failed(const struct nlattr * const tb[])
538 {
539         return !tb[IPSET_ATTR_PROTOCOL] ||
540                nla_get_u8(tb[IPSET_ATTR_PROTOCOL]) != IPSET_PROTOCOL;
541 }
542
543 static inline u32
544 flag_exist(const struct nlmsghdr *nlh)
545 {
546         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
547 }
548
549 static struct nlmsghdr *
550 start_msg(struct sk_buff *skb, u32 pid, u32 seq, unsigned int flags,
551           enum ipset_cmd cmd)
552 {
553         struct nlmsghdr *nlh;
554         struct nfgenmsg *nfmsg;
555
556         nlh = nlmsg_put(skb, pid, seq, cmd | (NFNL_SUBSYS_IPSET << 8),
557                         sizeof(*nfmsg), flags);
558         if (nlh == NULL)
559                 return NULL;
560
561         nfmsg = nlmsg_data(nlh);
562         nfmsg->nfgen_family = AF_INET;
563         nfmsg->version = NFNETLINK_V0;
564         nfmsg->res_id = 0;
565
566         return nlh;
567 }
568
569 /* Create a set */
570
571 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
572         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
573         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
574                                     .len = IPSET_MAXNAMELEN - 1 },
575         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
576                                     .len = IPSET_MAXNAMELEN - 1},
577         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
578         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
579         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
580 };
581
582 static ip_set_id_t
583 find_set_id(const char *name)
584 {
585         ip_set_id_t i, index = IPSET_INVALID_ID;
586         const struct ip_set *set;
587
588         for (i = 0; index == IPSET_INVALID_ID && i < ip_set_max; i++) {
589                 set = ip_set_list[i];
590                 if (set != NULL && STREQ(set->name, name))
591                         index = i;
592         }
593         return index;
594 }
595
596 static inline struct ip_set *
597 find_set(const char *name)
598 {
599         ip_set_id_t index = find_set_id(name);
600
601         return index == IPSET_INVALID_ID ? NULL : ip_set_list[index];
602 }
603
604 static int
605 find_free_id(const char *name, ip_set_id_t *index, struct ip_set **set)
606 {
607         ip_set_id_t i;
608
609         *index = IPSET_INVALID_ID;
610         for (i = 0;  i < ip_set_max; i++) {
611                 if (ip_set_list[i] == NULL) {
612                         if (*index == IPSET_INVALID_ID)
613                                 *index = i;
614                 } else if (STREQ(name, ip_set_list[i]->name)) {
615                         /* Name clash */
616                         *set = ip_set_list[i];
617                         return -EEXIST;
618                 }
619         }
620         if (*index == IPSET_INVALID_ID)
621                 /* No free slot remained */
622                 return -IPSET_ERR_MAX_SETS;
623         return 0;
624 }
625
626 static int
627 ip_set_create(struct sock *ctnl, struct sk_buff *skb,
628               const struct nlmsghdr *nlh,
629               const struct nlattr * const attr[])
630 {
631         struct ip_set *set, *clash = NULL;
632         ip_set_id_t index = IPSET_INVALID_ID;
633         struct nlattr *tb[IPSET_ATTR_CREATE_MAX+1] = {};
634         const char *name, *typename;
635         u8 family, revision;
636         u32 flags = flag_exist(nlh);
637         int ret = 0;
638
639         if (unlikely(protocol_failed(attr) ||
640                      attr[IPSET_ATTR_SETNAME] == NULL ||
641                      attr[IPSET_ATTR_TYPENAME] == NULL ||
642                      attr[IPSET_ATTR_REVISION] == NULL ||
643                      attr[IPSET_ATTR_FAMILY] == NULL ||
644                      (attr[IPSET_ATTR_DATA] != NULL &&
645                       !flag_nested(attr[IPSET_ATTR_DATA]))))
646                 return -IPSET_ERR_PROTOCOL;
647
648         name = nla_data(attr[IPSET_ATTR_SETNAME]);
649         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
650         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
651         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
652         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
653                  name, typename, family_name(family), revision);
654
655         /*
656          * First, and without any locks, allocate and initialize
657          * a normal base set structure.
658          */
659         set = kzalloc(sizeof(struct ip_set), GFP_KERNEL);
660         if (!set)
661                 return -ENOMEM;
662         rwlock_init(&set->lock);
663         strlcpy(set->name, name, IPSET_MAXNAMELEN);
664         set->family = family;
665         set->revision = revision;
666
667         /*
668          * Next, check that we know the type, and take
669          * a reference on the type, to make sure it stays available
670          * while constructing our new set.
671          *
672          * After referencing the type, we try to create the type
673          * specific part of the set without holding any locks.
674          */
675         ret = find_set_type_get(typename, family, revision, &(set->type));
676         if (ret)
677                 goto out;
678
679         /*
680          * Without holding any locks, create private part.
681          */
682         if (attr[IPSET_ATTR_DATA] &&
683             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
684                              set->type->create_policy)) {
685                 ret = -IPSET_ERR_PROTOCOL;
686                 goto put_out;
687         }
688
689         ret = set->type->create(set, tb, flags);
690         if (ret != 0)
691                 goto put_out;
692
693         /* BTW, ret==0 here. */
694
695         /*
696          * Here, we have a valid, constructed set and we are protected
697          * by the nfnl mutex. Find the first free index in ip_set_list
698          * and check clashing.
699          */
700         if ((ret = find_free_id(set->name, &index, &clash)) != 0) {
701                 /* If this is the same set and requested, ignore error */
702                 if (ret == -EEXIST &&
703                     (flags & IPSET_FLAG_EXIST) &&
704                     STREQ(set->type->name, clash->type->name) &&
705                     set->type->family == clash->type->family &&
706                     set->type->revision_min == clash->type->revision_min &&
707                     set->type->revision_max == clash->type->revision_max &&
708                     set->variant->same_set(set, clash))
709                         ret = 0;
710                 goto cleanup;
711         }
712
713         /*
714          * Finally! Add our shiny new set to the list, and be done.
715          */
716         pr_debug("create: '%s' created with index %u!\n", set->name, index);
717         ip_set_list[index] = set;
718
719         return ret;
720
721 cleanup:
722         set->variant->destroy(set);
723 put_out:
724         module_put(set->type->me);
725 out:
726         kfree(set);
727         return ret;
728 }
729
730 /* Destroy sets */
731
732 static const struct nla_policy
733 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
734         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
735         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
736                                     .len = IPSET_MAXNAMELEN - 1 },
737 };
738
739 static void
740 ip_set_destroy_set(ip_set_id_t index)
741 {
742         struct ip_set *set = ip_set_list[index];
743
744         pr_debug("set: %s\n",  set->name);
745         ip_set_list[index] = NULL;
746
747         /* Must call it without holding any lock */
748         set->variant->destroy(set);
749         module_put(set->type->me);
750         kfree(set);
751 }
752
753 static int
754 ip_set_destroy(struct sock *ctnl, struct sk_buff *skb,
755                const struct nlmsghdr *nlh,
756                const struct nlattr * const attr[])
757 {
758         ip_set_id_t i;
759         int ret = 0;
760
761         if (unlikely(protocol_failed(attr)))
762                 return -IPSET_ERR_PROTOCOL;
763
764         /* Commands are serialized and references are
765          * protected by the ip_set_ref_lock.
766          * External systems (i.e. xt_set) must call
767          * ip_set_put|get_nfnl_* functions, that way we
768          * can safely check references here.
769          *
770          * list:set timer can only decrement the reference
771          * counter, so if it's already zero, we can proceed
772          * without holding the lock.
773          */
774         read_lock_bh(&ip_set_ref_lock);
775         if (!attr[IPSET_ATTR_SETNAME]) {
776                 for (i = 0; i < ip_set_max; i++) {
777                         if (ip_set_list[i] != NULL && ip_set_list[i]->ref) {
778                                 ret = -IPSET_ERR_BUSY;
779                                 goto out;
780                         }
781                 }
782                 read_unlock_bh(&ip_set_ref_lock);
783                 for (i = 0; i < ip_set_max; i++) {
784                         if (ip_set_list[i] != NULL)
785                                 ip_set_destroy_set(i);
786                 }
787         } else {
788                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
789                 if (i == IPSET_INVALID_ID) {
790                         ret = -ENOENT;
791                         goto out;
792                 } else if (ip_set_list[i]->ref) {
793                         ret = -IPSET_ERR_BUSY;
794                         goto out;
795                 }
796                 read_unlock_bh(&ip_set_ref_lock);
797
798                 ip_set_destroy_set(i);
799         }
800         return 0;
801 out:
802         read_unlock_bh(&ip_set_ref_lock);
803         return ret;
804 }
805
806 /* Flush sets */
807
808 static void
809 ip_set_flush_set(struct ip_set *set)
810 {
811         pr_debug("set: %s\n",  set->name);
812
813         write_lock_bh(&set->lock);
814         set->variant->flush(set);
815         write_unlock_bh(&set->lock);
816 }
817
818 static int
819 ip_set_flush(struct sock *ctnl, struct sk_buff *skb,
820              const struct nlmsghdr *nlh,
821              const struct nlattr * const attr[])
822 {
823         ip_set_id_t i;
824
825         if (unlikely(protocol_failed(attr)))
826                 return -IPSET_ERR_PROTOCOL;
827
828         if (!attr[IPSET_ATTR_SETNAME]) {
829                 for (i = 0; i < ip_set_max; i++)
830                         if (ip_set_list[i] != NULL)
831                                 ip_set_flush_set(ip_set_list[i]);
832         } else {
833                 i = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
834                 if (i == IPSET_INVALID_ID)
835                         return -ENOENT;
836
837                 ip_set_flush_set(ip_set_list[i]);
838         }
839
840         return 0;
841 }
842
843 /* Rename a set */
844
845 static const struct nla_policy
846 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
847         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
848         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
849                                     .len = IPSET_MAXNAMELEN - 1 },
850         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
851                                     .len = IPSET_MAXNAMELEN - 1 },
852 };
853
854 static int
855 ip_set_rename(struct sock *ctnl, struct sk_buff *skb,
856               const struct nlmsghdr *nlh,
857               const struct nlattr * const attr[])
858 {
859         struct ip_set *set;
860         const char *name2;
861         ip_set_id_t i;
862         int ret = 0;
863
864         if (unlikely(protocol_failed(attr) ||
865                      attr[IPSET_ATTR_SETNAME] == NULL ||
866                      attr[IPSET_ATTR_SETNAME2] == NULL))
867                 return -IPSET_ERR_PROTOCOL;
868
869         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
870         if (set == NULL)
871                 return -ENOENT;
872
873         read_lock_bh(&ip_set_ref_lock);
874         if (set->ref != 0) {
875                 ret = -IPSET_ERR_REFERENCED;
876                 goto out;
877         }
878
879         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
880         for (i = 0; i < ip_set_max; i++) {
881                 if (ip_set_list[i] != NULL &&
882                     STREQ(ip_set_list[i]->name, name2)) {
883                         ret = -IPSET_ERR_EXIST_SETNAME2;
884                         goto out;
885                 }
886         }
887         strncpy(set->name, name2, IPSET_MAXNAMELEN);
888
889 out:
890         read_unlock_bh(&ip_set_ref_lock);
891         return ret;
892 }
893
894 /* Swap two sets so that name/index points to the other.
895  * References and set names are also swapped.
896  *
897  * The commands are serialized by the nfnl mutex and references are
898  * protected by the ip_set_ref_lock. The kernel interfaces
899  * do not hold the mutex but the pointer settings are atomic
900  * so the ip_set_list always contains valid pointers to the sets.
901  */
902
903 static int
904 ip_set_swap(struct sock *ctnl, struct sk_buff *skb,
905             const struct nlmsghdr *nlh,
906             const struct nlattr * const attr[])
907 {
908         struct ip_set *from, *to;
909         ip_set_id_t from_id, to_id;
910         char from_name[IPSET_MAXNAMELEN];
911
912         if (unlikely(protocol_failed(attr) ||
913                      attr[IPSET_ATTR_SETNAME] == NULL ||
914                      attr[IPSET_ATTR_SETNAME2] == NULL))
915                 return -IPSET_ERR_PROTOCOL;
916
917         from_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
918         if (from_id == IPSET_INVALID_ID)
919                 return -ENOENT;
920
921         to_id = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME2]));
922         if (to_id == IPSET_INVALID_ID)
923                 return -IPSET_ERR_EXIST_SETNAME2;
924
925         from = ip_set_list[from_id];
926         to = ip_set_list[to_id];
927
928         /* Features must not change.
929          * Not an artificial restriction anymore, as we must prevent
930          * possible loops created by swapping in setlist type of sets. */
931         if (!(from->type->features == to->type->features &&
932               from->type->family == to->type->family))
933                 return -IPSET_ERR_TYPE_MISMATCH;
934
935         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
936         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
937         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
938
939         write_lock_bh(&ip_set_ref_lock);
940         swap(from->ref, to->ref);
941         ip_set_list[from_id] = to;
942         ip_set_list[to_id] = from;
943         write_unlock_bh(&ip_set_ref_lock);
944
945         return 0;
946 }
947
948 /* List/save set data */
949
950 #define DUMP_INIT       0
951 #define DUMP_ALL        1
952 #define DUMP_ONE        2
953 #define DUMP_LAST       3
954
955 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
956 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
957
958 static int
959 ip_set_dump_done(struct netlink_callback *cb)
960 {
961         if (cb->args[2]) {
962                 pr_debug("release set %s\n", ip_set_list[cb->args[1]]->name);
963                 ip_set_put_byindex((ip_set_id_t) cb->args[1]);
964         }
965         return 0;
966 }
967
968 static inline void
969 dump_attrs(struct nlmsghdr *nlh)
970 {
971         const struct nlattr *attr;
972         int rem;
973
974         pr_debug("dump nlmsg\n");
975         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
976                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
977         }
978 }
979
980 static int
981 dump_init(struct netlink_callback *cb)
982 {
983         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
984         int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
985         struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
986         struct nlattr *attr = (void *)nlh + min_len;
987         u32 dump_type;
988         ip_set_id_t index;
989
990         /* Second pass, so parser can't fail */
991         nla_parse(cda, IPSET_ATTR_CMD_MAX,
992                   attr, nlh->nlmsg_len - min_len, ip_set_setname_policy);
993
994         /* cb->args[0] : dump single set/all sets
995          *         [1] : set index
996          *         [..]: type specific
997          */
998
999         if (cda[IPSET_ATTR_SETNAME]) {
1000                 index = find_set_id(nla_data(cda[IPSET_ATTR_SETNAME]));
1001                 if (index == IPSET_INVALID_ID)
1002                         return -ENOENT;
1003
1004                 dump_type = DUMP_ONE;
1005                 cb->args[1] = index;
1006         } else
1007                 dump_type = DUMP_ALL;
1008
1009         if (cda[IPSET_ATTR_FLAGS]) {
1010                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1011                 dump_type |= (f << 16);
1012         }
1013         cb->args[0] = dump_type;
1014
1015         return 0;
1016 }
1017
1018 static int
1019 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1020 {
1021         ip_set_id_t index = IPSET_INVALID_ID, max;
1022         struct ip_set *set = NULL;
1023         struct nlmsghdr *nlh = NULL;
1024         unsigned int flags = NETLINK_CB(cb->skb).pid ? NLM_F_MULTI : 0;
1025         u32 dump_type, dump_flags;
1026         int ret = 0;
1027
1028         if (!cb->args[0]) {
1029                 ret = dump_init(cb);
1030                 if (ret < 0) {
1031                         nlh = nlmsg_hdr(cb->skb);
1032                         /* We have to create and send the error message
1033                          * manually :-( */
1034                         if (nlh->nlmsg_flags & NLM_F_ACK)
1035                                 netlink_ack(cb->skb, nlh, ret);
1036                         return ret;
1037                 }
1038         }
1039
1040         if (cb->args[1] >= ip_set_max)
1041                 goto out;
1042
1043         dump_type = DUMP_TYPE(cb->args[0]);
1044         dump_flags = DUMP_FLAGS(cb->args[0]);
1045         max = dump_type == DUMP_ONE ? cb->args[1] + 1 : ip_set_max;
1046 dump_last:
1047         pr_debug("args[0]: %u %u args[1]: %ld\n",
1048                  dump_type, dump_flags, cb->args[1]);
1049         for (; cb->args[1] < max; cb->args[1]++) {
1050                 index = (ip_set_id_t) cb->args[1];
1051                 set = ip_set_list[index];
1052                 if (set == NULL) {
1053                         if (dump_type == DUMP_ONE) {
1054                                 ret = -ENOENT;
1055                                 goto out;
1056                         }
1057                         continue;
1058                 }
1059                 /* When dumping all sets, we must dump "sorted"
1060                  * so that lists (unions of sets) are dumped last.
1061                  */
1062                 if (dump_type != DUMP_ONE &&
1063                     ((dump_type == DUMP_ALL) ==
1064                      !!(set->type->features & IPSET_DUMP_LAST)))
1065                         continue;
1066                 pr_debug("List set: %s\n", set->name);
1067                 if (!cb->args[2]) {
1068                         /* Start listing: make sure set won't be destroyed */
1069                         pr_debug("reference set\n");
1070                         __ip_set_get(index);
1071                 }
1072                 nlh = start_msg(skb, NETLINK_CB(cb->skb).pid,
1073                                 cb->nlh->nlmsg_seq, flags,
1074                                 IPSET_CMD_LIST);
1075                 if (!nlh) {
1076                         ret = -EMSGSIZE;
1077                         goto release_refcount;
1078                 }
1079                 NLA_PUT_U8(skb, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1080                 NLA_PUT_STRING(skb, IPSET_ATTR_SETNAME, set->name);
1081                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1082                         goto next_set;
1083                 switch (cb->args[2]) {
1084                 case 0:
1085                         /* Core header data */
1086                         NLA_PUT_STRING(skb, IPSET_ATTR_TYPENAME,
1087                                        set->type->name);
1088                         NLA_PUT_U8(skb, IPSET_ATTR_FAMILY,
1089                                    set->family);
1090                         NLA_PUT_U8(skb, IPSET_ATTR_REVISION,
1091                                    set->revision);
1092                         ret = set->variant->head(set, skb);
1093                         if (ret < 0)
1094                                 goto release_refcount;
1095                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1096                                 goto next_set;
1097                         /* Fall through and add elements */
1098                 default:
1099                         read_lock_bh(&set->lock);
1100                         ret = set->variant->list(set, skb, cb);
1101                         read_unlock_bh(&set->lock);
1102                         if (!cb->args[2])
1103                                 /* Set is done, proceed with next one */
1104                                 goto next_set;
1105                         goto release_refcount;
1106                 }
1107         }
1108         /* If we dump all sets, continue with dumping last ones */
1109         if (dump_type == DUMP_ALL) {
1110                 dump_type = DUMP_LAST;
1111                 cb->args[0] = dump_type | (dump_flags << 16);
1112                 cb->args[1] = 0;
1113                 goto dump_last;
1114         }
1115         goto out;
1116
1117 nla_put_failure:
1118         ret = -EFAULT;
1119 next_set:
1120         if (dump_type == DUMP_ONE)
1121                 cb->args[1] = IPSET_INVALID_ID;
1122         else
1123                 cb->args[1]++;
1124 release_refcount:
1125         /* If there was an error or set is done, release set */
1126         if (ret || !cb->args[2]) {
1127                 pr_debug("release set %s\n", ip_set_list[index]->name);
1128                 ip_set_put_byindex(index);
1129         }
1130 out:
1131         if (nlh) {
1132                 nlmsg_end(skb, nlh);
1133                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1134                 dump_attrs(nlh);
1135         }
1136
1137         return ret < 0 ? ret : skb->len;
1138 }
1139
1140 static int
1141 ip_set_dump(struct sock *ctnl, struct sk_buff *skb,
1142             const struct nlmsghdr *nlh,
1143             const struct nlattr * const attr[])
1144 {
1145         if (unlikely(protocol_failed(attr)))
1146                 return -IPSET_ERR_PROTOCOL;
1147
1148         return netlink_dump_start(ctnl, skb, nlh,
1149                                   ip_set_dump_start,
1150                                   ip_set_dump_done, 0);
1151 }
1152
1153 /* Add, del and test */
1154
1155 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1156         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1157         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1158                                     .len = IPSET_MAXNAMELEN - 1 },
1159         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1160         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1161         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1162 };
1163
1164 static int
1165 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1166         struct nlattr *tb[], enum ipset_adt adt,
1167         u32 flags, bool use_lineno)
1168 {
1169         int ret;
1170         u32 lineno = 0;
1171         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1172
1173         do {
1174                 write_lock_bh(&set->lock);
1175                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1176                 write_unlock_bh(&set->lock);
1177                 retried = true;
1178         } while (ret == -EAGAIN &&
1179                  set->variant->resize &&
1180                  (ret = set->variant->resize(set, retried)) == 0);
1181
1182         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1183                 return 0;
1184         if (lineno && use_lineno) {
1185                 /* Error in restore/batch mode: send back lineno */
1186                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1187                 struct sk_buff *skb2;
1188                 struct nlmsgerr *errmsg;
1189                 size_t payload = sizeof(*errmsg) + nlmsg_len(nlh);
1190                 int min_len = NLMSG_SPACE(sizeof(struct nfgenmsg));
1191                 struct nlattr *cda[IPSET_ATTR_CMD_MAX+1];
1192                 struct nlattr *cmdattr;
1193                 u32 *errline;
1194
1195                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1196                 if (skb2 == NULL)
1197                         return -ENOMEM;
1198                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).pid,
1199                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1200                 errmsg = nlmsg_data(rep);
1201                 errmsg->error = ret;
1202                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1203                 cmdattr = (void *)&errmsg->msg + min_len;
1204
1205                 nla_parse(cda, IPSET_ATTR_CMD_MAX,
1206                           cmdattr, nlh->nlmsg_len - min_len,
1207                           ip_set_adt_policy);
1208
1209                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1210
1211                 *errline = lineno;
1212
1213                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1214                 /* Signal netlink not to send its ACK/errmsg.  */
1215                 return -EINTR;
1216         }
1217
1218         return ret;
1219 }
1220
1221 static int
1222 ip_set_uadd(struct sock *ctnl, struct sk_buff *skb,
1223             const struct nlmsghdr *nlh,
1224             const struct nlattr * const attr[])
1225 {
1226         struct ip_set *set;
1227         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1228         const struct nlattr *nla;
1229         u32 flags = flag_exist(nlh);
1230         bool use_lineno;
1231         int ret = 0;
1232
1233         if (unlikely(protocol_failed(attr) ||
1234                      attr[IPSET_ATTR_SETNAME] == NULL ||
1235                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1236                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1237                      (attr[IPSET_ATTR_DATA] != NULL &&
1238                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1239                      (attr[IPSET_ATTR_ADT] != NULL &&
1240                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1241                        attr[IPSET_ATTR_LINENO] == NULL))))
1242                 return -IPSET_ERR_PROTOCOL;
1243
1244         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1245         if (set == NULL)
1246                 return -ENOENT;
1247
1248         use_lineno = !!attr[IPSET_ATTR_LINENO];
1249         if (attr[IPSET_ATTR_DATA]) {
1250                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1251                                      attr[IPSET_ATTR_DATA],
1252                                      set->type->adt_policy))
1253                         return -IPSET_ERR_PROTOCOL;
1254                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1255                               use_lineno);
1256         } else {
1257                 int nla_rem;
1258
1259                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1260                         memset(tb, 0, sizeof(tb));
1261                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1262                             !flag_nested(nla) ||
1263                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1264                                              set->type->adt_policy))
1265                                 return -IPSET_ERR_PROTOCOL;
1266                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1267                                       flags, use_lineno);
1268                         if (ret < 0)
1269                                 return ret;
1270                 }
1271         }
1272         return ret;
1273 }
1274
1275 static int
1276 ip_set_udel(struct sock *ctnl, struct sk_buff *skb,
1277             const struct nlmsghdr *nlh,
1278             const struct nlattr * const attr[])
1279 {
1280         struct ip_set *set;
1281         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1282         const struct nlattr *nla;
1283         u32 flags = flag_exist(nlh);
1284         bool use_lineno;
1285         int ret = 0;
1286
1287         if (unlikely(protocol_failed(attr) ||
1288                      attr[IPSET_ATTR_SETNAME] == NULL ||
1289                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1290                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1291                      (attr[IPSET_ATTR_DATA] != NULL &&
1292                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1293                      (attr[IPSET_ATTR_ADT] != NULL &&
1294                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1295                        attr[IPSET_ATTR_LINENO] == NULL))))
1296                 return -IPSET_ERR_PROTOCOL;
1297
1298         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1299         if (set == NULL)
1300                 return -ENOENT;
1301
1302         use_lineno = !!attr[IPSET_ATTR_LINENO];
1303         if (attr[IPSET_ATTR_DATA]) {
1304                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1305                                      attr[IPSET_ATTR_DATA],
1306                                      set->type->adt_policy))
1307                         return -IPSET_ERR_PROTOCOL;
1308                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1309                               use_lineno);
1310         } else {
1311                 int nla_rem;
1312
1313                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1314                         memset(tb, 0, sizeof(*tb));
1315                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1316                             !flag_nested(nla) ||
1317                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1318                                              set->type->adt_policy))
1319                                 return -IPSET_ERR_PROTOCOL;
1320                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1321                                       flags, use_lineno);
1322                         if (ret < 0)
1323                                 return ret;
1324                 }
1325         }
1326         return ret;
1327 }
1328
1329 static int
1330 ip_set_utest(struct sock *ctnl, struct sk_buff *skb,
1331              const struct nlmsghdr *nlh,
1332              const struct nlattr * const attr[])
1333 {
1334         struct ip_set *set;
1335         struct nlattr *tb[IPSET_ATTR_ADT_MAX+1] = {};
1336         int ret = 0;
1337
1338         if (unlikely(protocol_failed(attr) ||
1339                      attr[IPSET_ATTR_SETNAME] == NULL ||
1340                      attr[IPSET_ATTR_DATA] == NULL ||
1341                      !flag_nested(attr[IPSET_ATTR_DATA])))
1342                 return -IPSET_ERR_PROTOCOL;
1343
1344         set = find_set(nla_data(attr[IPSET_ATTR_SETNAME]));
1345         if (set == NULL)
1346                 return -ENOENT;
1347
1348         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1349                              set->type->adt_policy))
1350                 return -IPSET_ERR_PROTOCOL;
1351
1352         read_lock_bh(&set->lock);
1353         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1354         read_unlock_bh(&set->lock);
1355         /* Userspace can't trigger element to be re-added */
1356         if (ret == -EAGAIN)
1357                 ret = 1;
1358
1359         return ret < 0 ? ret : ret > 0 ? 0 : -IPSET_ERR_EXIST;
1360 }
1361
1362 /* Get headed data of a set */
1363
1364 static int
1365 ip_set_header(struct sock *ctnl, struct sk_buff *skb,
1366               const struct nlmsghdr *nlh,
1367               const struct nlattr * const attr[])
1368 {
1369         const struct ip_set *set;
1370         struct sk_buff *skb2;
1371         struct nlmsghdr *nlh2;
1372         ip_set_id_t index;
1373         int ret = 0;
1374
1375         if (unlikely(protocol_failed(attr) ||
1376                      attr[IPSET_ATTR_SETNAME] == NULL))
1377                 return -IPSET_ERR_PROTOCOL;
1378
1379         index = find_set_id(nla_data(attr[IPSET_ATTR_SETNAME]));
1380         if (index == IPSET_INVALID_ID)
1381                 return -ENOENT;
1382         set = ip_set_list[index];
1383
1384         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1385         if (skb2 == NULL)
1386                 return -ENOMEM;
1387
1388         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1389                          IPSET_CMD_HEADER);
1390         if (!nlh2)
1391                 goto nlmsg_failure;
1392         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1393         NLA_PUT_STRING(skb2, IPSET_ATTR_SETNAME, set->name);
1394         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, set->type->name);
1395         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, set->family);
1396         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, set->revision);
1397         nlmsg_end(skb2, nlh2);
1398
1399         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1400         if (ret < 0)
1401                 return ret;
1402
1403         return 0;
1404
1405 nla_put_failure:
1406         nlmsg_cancel(skb2, nlh2);
1407 nlmsg_failure:
1408         kfree_skb(skb2);
1409         return -EMSGSIZE;
1410 }
1411
1412 /* Get type data */
1413
1414 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1415         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1416         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1417                                     .len = IPSET_MAXNAMELEN - 1 },
1418         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1419 };
1420
1421 static int
1422 ip_set_type(struct sock *ctnl, struct sk_buff *skb,
1423             const struct nlmsghdr *nlh,
1424             const struct nlattr * const attr[])
1425 {
1426         struct sk_buff *skb2;
1427         struct nlmsghdr *nlh2;
1428         u8 family, min, max;
1429         const char *typename;
1430         int ret = 0;
1431
1432         if (unlikely(protocol_failed(attr) ||
1433                      attr[IPSET_ATTR_TYPENAME] == NULL ||
1434                      attr[IPSET_ATTR_FAMILY] == NULL))
1435                 return -IPSET_ERR_PROTOCOL;
1436
1437         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1438         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1439         ret = find_set_type_minmax(typename, family, &min, &max);
1440         if (ret)
1441                 return ret;
1442
1443         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1444         if (skb2 == NULL)
1445                 return -ENOMEM;
1446
1447         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1448                          IPSET_CMD_TYPE);
1449         if (!nlh2)
1450                 goto nlmsg_failure;
1451         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1452         NLA_PUT_STRING(skb2, IPSET_ATTR_TYPENAME, typename);
1453         NLA_PUT_U8(skb2, IPSET_ATTR_FAMILY, family);
1454         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION, max);
1455         NLA_PUT_U8(skb2, IPSET_ATTR_REVISION_MIN, min);
1456         nlmsg_end(skb2, nlh2);
1457
1458         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1459         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1460         if (ret < 0)
1461                 return ret;
1462
1463         return 0;
1464
1465 nla_put_failure:
1466         nlmsg_cancel(skb2, nlh2);
1467 nlmsg_failure:
1468         kfree_skb(skb2);
1469         return -EMSGSIZE;
1470 }
1471
1472 /* Get protocol version */
1473
1474 static const struct nla_policy
1475 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1476         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1477 };
1478
1479 static int
1480 ip_set_protocol(struct sock *ctnl, struct sk_buff *skb,
1481                 const struct nlmsghdr *nlh,
1482                 const struct nlattr * const attr[])
1483 {
1484         struct sk_buff *skb2;
1485         struct nlmsghdr *nlh2;
1486         int ret = 0;
1487
1488         if (unlikely(attr[IPSET_ATTR_PROTOCOL] == NULL))
1489                 return -IPSET_ERR_PROTOCOL;
1490
1491         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1492         if (skb2 == NULL)
1493                 return -ENOMEM;
1494
1495         nlh2 = start_msg(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 0,
1496                          IPSET_CMD_PROTOCOL);
1497         if (!nlh2)
1498                 goto nlmsg_failure;
1499         NLA_PUT_U8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
1500         nlmsg_end(skb2, nlh2);
1501
1502         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
1503         if (ret < 0)
1504                 return ret;
1505
1506         return 0;
1507
1508 nla_put_failure:
1509         nlmsg_cancel(skb2, nlh2);
1510 nlmsg_failure:
1511         kfree_skb(skb2);
1512         return -EMSGSIZE;
1513 }
1514
1515 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1516         [IPSET_CMD_CREATE]      = {
1517                 .call           = ip_set_create,
1518                 .attr_count     = IPSET_ATTR_CMD_MAX,
1519                 .policy         = ip_set_create_policy,
1520         },
1521         [IPSET_CMD_DESTROY]     = {
1522                 .call           = ip_set_destroy,
1523                 .attr_count     = IPSET_ATTR_CMD_MAX,
1524                 .policy         = ip_set_setname_policy,
1525         },
1526         [IPSET_CMD_FLUSH]       = {
1527                 .call           = ip_set_flush,
1528                 .attr_count     = IPSET_ATTR_CMD_MAX,
1529                 .policy         = ip_set_setname_policy,
1530         },
1531         [IPSET_CMD_RENAME]      = {
1532                 .call           = ip_set_rename,
1533                 .attr_count     = IPSET_ATTR_CMD_MAX,
1534                 .policy         = ip_set_setname2_policy,
1535         },
1536         [IPSET_CMD_SWAP]        = {
1537                 .call           = ip_set_swap,
1538                 .attr_count     = IPSET_ATTR_CMD_MAX,
1539                 .policy         = ip_set_setname2_policy,
1540         },
1541         [IPSET_CMD_LIST]        = {
1542                 .call           = ip_set_dump,
1543                 .attr_count     = IPSET_ATTR_CMD_MAX,
1544                 .policy         = ip_set_setname_policy,
1545         },
1546         [IPSET_CMD_SAVE]        = {
1547                 .call           = ip_set_dump,
1548                 .attr_count     = IPSET_ATTR_CMD_MAX,
1549                 .policy         = ip_set_setname_policy,
1550         },
1551         [IPSET_CMD_ADD] = {
1552                 .call           = ip_set_uadd,
1553                 .attr_count     = IPSET_ATTR_CMD_MAX,
1554                 .policy         = ip_set_adt_policy,
1555         },
1556         [IPSET_CMD_DEL] = {
1557                 .call           = ip_set_udel,
1558                 .attr_count     = IPSET_ATTR_CMD_MAX,
1559                 .policy         = ip_set_adt_policy,
1560         },
1561         [IPSET_CMD_TEST]        = {
1562                 .call           = ip_set_utest,
1563                 .attr_count     = IPSET_ATTR_CMD_MAX,
1564                 .policy         = ip_set_adt_policy,
1565         },
1566         [IPSET_CMD_HEADER]      = {
1567                 .call           = ip_set_header,
1568                 .attr_count     = IPSET_ATTR_CMD_MAX,
1569                 .policy         = ip_set_setname_policy,
1570         },
1571         [IPSET_CMD_TYPE]        = {
1572                 .call           = ip_set_type,
1573                 .attr_count     = IPSET_ATTR_CMD_MAX,
1574                 .policy         = ip_set_type_policy,
1575         },
1576         [IPSET_CMD_PROTOCOL]    = {
1577                 .call           = ip_set_protocol,
1578                 .attr_count     = IPSET_ATTR_CMD_MAX,
1579                 .policy         = ip_set_protocol_policy,
1580         },
1581 };
1582
1583 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
1584         .name           = "ip_set",
1585         .subsys_id      = NFNL_SUBSYS_IPSET,
1586         .cb_count       = IPSET_MSG_MAX,
1587         .cb             = ip_set_netlink_subsys_cb,
1588 };
1589
1590 /* Interface to iptables/ip6tables */
1591
1592 static int
1593 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
1594 {
1595         unsigned *op;
1596         void *data;
1597         int copylen = *len, ret = 0;
1598
1599         if (!capable(CAP_NET_ADMIN))
1600                 return -EPERM;
1601         if (optval != SO_IP_SET)
1602                 return -EBADF;
1603         if (*len < sizeof(unsigned))
1604                 return -EINVAL;
1605
1606         data = vmalloc(*len);
1607         if (!data)
1608                 return -ENOMEM;
1609         if (copy_from_user(data, user, *len) != 0) {
1610                 ret = -EFAULT;
1611                 goto done;
1612         }
1613         op = (unsigned *) data;
1614
1615         if (*op < IP_SET_OP_VERSION) {
1616                 /* Check the version at the beginning of operations */
1617                 struct ip_set_req_version *req_version = data;
1618
1619                 if (*len < sizeof(struct ip_set_req_version)) {
1620                         ret = -EINVAL;
1621                         goto done;
1622                 }
1623
1624                 if (req_version->version != IPSET_PROTOCOL) {
1625                         ret = -EPROTO;
1626                         goto done;
1627                 }
1628         }
1629
1630         switch (*op) {
1631         case IP_SET_OP_VERSION: {
1632                 struct ip_set_req_version *req_version = data;
1633
1634                 if (*len != sizeof(struct ip_set_req_version)) {
1635                         ret = -EINVAL;
1636                         goto done;
1637                 }
1638
1639                 req_version->version = IPSET_PROTOCOL;
1640                 ret = copy_to_user(user, req_version,
1641                                    sizeof(struct ip_set_req_version));
1642                 goto done;
1643         }
1644         case IP_SET_OP_GET_BYNAME: {
1645                 struct ip_set_req_get_set *req_get = data;
1646
1647                 if (*len != sizeof(struct ip_set_req_get_set)) {
1648                         ret = -EINVAL;
1649                         goto done;
1650                 }
1651                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
1652                 nfnl_lock();
1653                 req_get->set.index = find_set_id(req_get->set.name);
1654                 nfnl_unlock();
1655                 goto copy;
1656         }
1657         case IP_SET_OP_GET_BYINDEX: {
1658                 struct ip_set_req_get_set *req_get = data;
1659
1660                 if (*len != sizeof(struct ip_set_req_get_set) ||
1661                     req_get->set.index >= ip_set_max) {
1662                         ret = -EINVAL;
1663                         goto done;
1664                 }
1665                 nfnl_lock();
1666                 strncpy(req_get->set.name,
1667                         ip_set_list[req_get->set.index]
1668                                 ? ip_set_list[req_get->set.index]->name : "",
1669                         IPSET_MAXNAMELEN);
1670                 nfnl_unlock();
1671                 goto copy;
1672         }
1673         default:
1674                 ret = -EBADMSG;
1675                 goto done;
1676         }       /* end of switch(op) */
1677
1678 copy:
1679         ret = copy_to_user(user, data, copylen);
1680
1681 done:
1682         vfree(data);
1683         if (ret > 0)
1684                 ret = 0;
1685         return ret;
1686 }
1687
1688 static struct nf_sockopt_ops so_set __read_mostly = {
1689         .pf             = PF_INET,
1690         .get_optmin     = SO_IP_SET,
1691         .get_optmax     = SO_IP_SET + 1,
1692         .get            = &ip_set_sockfn_get,
1693         .owner          = THIS_MODULE,
1694 };
1695
1696 static int __init
1697 ip_set_init(void)
1698 {
1699         int ret;
1700
1701         if (max_sets)
1702                 ip_set_max = max_sets;
1703         if (ip_set_max >= IPSET_INVALID_ID)
1704                 ip_set_max = IPSET_INVALID_ID - 1;
1705
1706         ip_set_list = kzalloc(sizeof(struct ip_set *) * ip_set_max,
1707                               GFP_KERNEL);
1708         if (!ip_set_list)
1709                 return -ENOMEM;
1710
1711         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
1712         if (ret != 0) {
1713                 pr_err("ip_set: cannot register with nfnetlink.\n");
1714                 kfree(ip_set_list);
1715                 return ret;
1716         }
1717         ret = nf_register_sockopt(&so_set);
1718         if (ret != 0) {
1719                 pr_err("SO_SET registry failed: %d\n", ret);
1720                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1721                 kfree(ip_set_list);
1722                 return ret;
1723         }
1724
1725         pr_notice("ip_set: protocol %u\n", IPSET_PROTOCOL);
1726         return 0;
1727 }
1728
1729 static void __exit
1730 ip_set_fini(void)
1731 {
1732         /* There can't be any existing set */
1733         nf_unregister_sockopt(&so_set);
1734         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
1735         kfree(ip_set_list);
1736         pr_debug("these are the famous last words\n");
1737 }
1738
1739 module_init(ip_set_init);
1740 module_exit(ip_set_fini);