f4a46c0d25f34e2d1d07bc5d41657865b8e9b21f
[pandora-kernel.git] / net / netfilter / ipset / ip_set_list_set.c
1 /* Copyright (C) 2008-2011 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.
6  */
7
8 /* Kernel module implementing an IP set type: the list:set type */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/skbuff.h>
13 #include <linux/errno.h>
14
15 #include <linux/netfilter/ipset/ip_set.h>
16 #include <linux/netfilter/ipset/ip_set_timeout.h>
17 #include <linux/netfilter/ipset/ip_set_list.h>
18
19 MODULE_LICENSE("GPL");
20 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
21 MODULE_DESCRIPTION("list:set type of IP sets");
22 MODULE_ALIAS("ip_set_list:set");
23
24 /* Member elements without and with timeout */
25 struct set_elem {
26         ip_set_id_t id;
27 };
28
29 struct set_telem {
30         ip_set_id_t id;
31         unsigned long timeout;
32 };
33
34 /* Type structure */
35 struct list_set {
36         size_t dsize;           /* element size */
37         u32 size;               /* size of set list array */
38         u32 timeout;            /* timeout value */
39         struct timer_list gc;   /* garbage collection */
40         struct set_elem members[0]; /* the set members */
41 };
42
43 static inline struct set_elem *
44 list_set_elem(const struct list_set *map, u32 id)
45 {
46         return (struct set_elem *)((void *)map->members + id * map->dsize);
47 }
48
49 static inline struct set_telem *
50 list_set_telem(const struct list_set *map, u32 id)
51 {
52         return (struct set_telem *)((void *)map->members + id * map->dsize);
53 }
54
55 static inline bool
56 list_set_timeout(const struct list_set *map, u32 id)
57 {
58         const struct set_telem *elem = list_set_telem(map, id);
59
60         return ip_set_timeout_test(elem->timeout);
61 }
62
63 static inline bool
64 list_set_expired(const struct list_set *map, u32 id)
65 {
66         const struct set_telem *elem = list_set_telem(map, id);
67
68         return ip_set_timeout_expired(elem->timeout);
69 }
70
71 /* Set list without and with timeout */
72
73 static int
74 list_set_kadt(struct ip_set *set, const struct sk_buff *skb,
75               enum ipset_adt adt, u8 pf, u8 dim, u8 flags)
76 {
77         struct list_set *map = set->data;
78         struct set_elem *elem;
79         u32 i;
80         int ret;
81
82         for (i = 0; i < map->size; i++) {
83                 elem = list_set_elem(map, i);
84                 if (elem->id == IPSET_INVALID_ID)
85                         return 0;
86                 if (with_timeout(map->timeout) && list_set_expired(map, i))
87                         continue;
88                 switch (adt) {
89                 case IPSET_TEST:
90                         ret = ip_set_test(elem->id, skb, pf, dim, flags);
91                         if (ret > 0)
92                                 return ret;
93                         break;
94                 case IPSET_ADD:
95                         ret = ip_set_add(elem->id, skb, pf, dim, flags);
96                         if (ret == 0)
97                                 return ret;
98                         break;
99                 case IPSET_DEL:
100                         ret = ip_set_del(elem->id, skb, pf, dim, flags);
101                         if (ret == 0)
102                                 return ret;
103                         break;
104                 default:
105                         break;
106                 }
107         }
108         return -EINVAL;
109 }
110
111 static bool
112 next_id_eq(const struct list_set *map, u32 i, ip_set_id_t id)
113 {
114         const struct set_elem *elem;
115
116         if (i + 1 < map->size) {
117                 elem = list_set_elem(map, i + 1);
118                 return !!(elem->id == id &&
119                           !(with_timeout(map->timeout) &&
120                             list_set_expired(map, i + 1)));
121         }
122
123         return 0;
124 }
125
126 static void
127 list_elem_add(struct list_set *map, u32 i, ip_set_id_t id)
128 {
129         struct set_elem *e;
130
131         for (; i < map->size; i++) {
132                 e = list_set_elem(map, i);
133                 swap(e->id, id);
134                 if (e->id == IPSET_INVALID_ID)
135                         break;
136         }
137 }
138
139 static void
140 list_elem_tadd(struct list_set *map, u32 i, ip_set_id_t id,
141                unsigned long timeout)
142 {
143         struct set_telem *e;
144
145         for (; i < map->size; i++) {
146                 e = list_set_telem(map, i);
147                 swap(e->id, id);
148                 swap(e->timeout, timeout);
149                 if (e->id == IPSET_INVALID_ID)
150                         break;
151         }
152 }
153
154 static int
155 list_set_add(struct list_set *map, u32 i, ip_set_id_t id,
156              unsigned long timeout)
157 {
158         const struct set_elem *e = list_set_elem(map, i);
159
160         if (i == map->size - 1 && e->id != IPSET_INVALID_ID)
161                 /* Last element replaced: e.g. add new,before,last */
162                 ip_set_put_byindex(e->id);
163         if (with_timeout(map->timeout))
164                 list_elem_tadd(map, i, id, ip_set_timeout_set(timeout));
165         else
166                 list_elem_add(map, i, id);
167
168         return 0;
169 }
170
171 static int
172 list_set_del(struct list_set *map, u32 i)
173 {
174         struct set_elem *a = list_set_elem(map, i), *b;
175
176         ip_set_put_byindex(a->id);
177
178         for (; i < map->size - 1; i++) {
179                 b = list_set_elem(map, i + 1);
180                 a->id = b->id;
181                 if (with_timeout(map->timeout))
182                         ((struct set_telem *)a)->timeout =
183                                 ((struct set_telem *)b)->timeout;
184                 a = b;
185                 if (a->id == IPSET_INVALID_ID)
186                         break;
187         }
188         /* Last element */
189         a->id = IPSET_INVALID_ID;
190         return 0;
191 }
192
193 static int
194 list_set_uadt(struct ip_set *set, struct nlattr *tb[],
195               enum ipset_adt adt, u32 *lineno, u32 flags)
196 {
197         struct list_set *map = set->data;
198         bool with_timeout = with_timeout(map->timeout);
199         int before = 0;
200         u32 timeout = map->timeout;
201         ip_set_id_t id, refid = IPSET_INVALID_ID;
202         const struct set_elem *elem;
203         struct ip_set *s;
204         u32 i;
205         int ret = 0;
206
207         if (unlikely(!tb[IPSET_ATTR_NAME] ||
208                      !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
209                      !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS)))
210                 return -IPSET_ERR_PROTOCOL;
211
212         if (tb[IPSET_ATTR_LINENO])
213                 *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]);
214
215         id = ip_set_get_byname(nla_data(tb[IPSET_ATTR_NAME]), &s);
216         if (id == IPSET_INVALID_ID)
217                 return -IPSET_ERR_NAME;
218         /* "Loop detection" */
219         if (s->type->features & IPSET_TYPE_NAME) {
220                 ret = -IPSET_ERR_LOOP;
221                 goto finish;
222         }
223
224         if (tb[IPSET_ATTR_CADT_FLAGS]) {
225                 u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
226                 before = f & IPSET_FLAG_BEFORE;
227         }
228
229         if (before && !tb[IPSET_ATTR_NAMEREF]) {
230                 ret = -IPSET_ERR_BEFORE;
231                 goto finish;
232         }
233
234         if (tb[IPSET_ATTR_NAMEREF]) {
235                 refid = ip_set_get_byname(nla_data(tb[IPSET_ATTR_NAMEREF]),
236                                           &s);
237                 if (refid == IPSET_INVALID_ID) {
238                         ret = -IPSET_ERR_NAMEREF;
239                         goto finish;
240                 }
241                 if (!before)
242                         before = -1;
243         }
244         if (tb[IPSET_ATTR_TIMEOUT]) {
245                 if (!with_timeout) {
246                         ret = -IPSET_ERR_TIMEOUT;
247                         goto finish;
248                 }
249                 timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
250         }
251
252         switch (adt) {
253         case IPSET_TEST:
254                 for (i = 0; i < map->size && !ret; i++) {
255                         elem = list_set_elem(map, i);
256                         if (elem->id == IPSET_INVALID_ID ||
257                             (before != 0 && i + 1 >= map->size))
258                                 break;
259                         else if (with_timeout && list_set_expired(map, i))
260                                 continue;
261                         else if (before > 0 && elem->id == id)
262                                 ret = next_id_eq(map, i, refid);
263                         else if (before < 0 && elem->id == refid)
264                                 ret = next_id_eq(map, i, id);
265                         else if (before == 0 && elem->id == id)
266                                 ret = 1;
267                 }
268                 break;
269         case IPSET_ADD:
270                 for (i = 0; i < map->size && !ret; i++) {
271                         elem = list_set_elem(map, i);
272                         if (elem->id == id &&
273                             !(with_timeout && list_set_expired(map, i)))
274                                 ret = -IPSET_ERR_EXIST;
275                 }
276                 if (ret == -IPSET_ERR_EXIST)
277                         break;
278                 ret = -IPSET_ERR_LIST_FULL;
279                 for (i = 0; i < map->size && ret == -IPSET_ERR_LIST_FULL; i++) {
280                         elem = list_set_elem(map, i);
281                         if (elem->id == IPSET_INVALID_ID)
282                                 ret = before != 0 ? -IPSET_ERR_REF_EXIST
283                                         : list_set_add(map, i, id, timeout);
284                         else if (elem->id != refid)
285                                 continue;
286                         else if (with_timeout && list_set_expired(map, i))
287                                 ret = -IPSET_ERR_REF_EXIST;
288                         else if (before)
289                                 ret = list_set_add(map, i, id, timeout);
290                         else if (i + 1 < map->size)
291                                 ret = list_set_add(map, i + 1, id, timeout);
292                 }
293                 break;
294         case IPSET_DEL:
295                 ret = -IPSET_ERR_EXIST;
296                 for (i = 0; i < map->size && ret == -IPSET_ERR_EXIST; i++) {
297                         elem = list_set_elem(map, i);
298                         if (elem->id == IPSET_INVALID_ID) {
299                                 ret = before != 0 ? -IPSET_ERR_REF_EXIST
300                                                   : -IPSET_ERR_EXIST;
301                                 break;
302                         } else if (with_timeout && list_set_expired(map, i))
303                                 continue;
304                         else if (elem->id == id &&
305                                  (before == 0 ||
306                                   (before > 0 &&
307                                    next_id_eq(map, i, refid))))
308                                 ret = list_set_del(map, i);
309                         else if (before < 0 &&
310                                  elem->id == refid &&
311                                  next_id_eq(map, i, id))
312                                 ret = list_set_del(map, i + 1);
313                 }
314                 break;
315         default:
316                 break;
317         }
318
319 finish:
320         if (refid != IPSET_INVALID_ID)
321                 ip_set_put_byindex(refid);
322         if (adt != IPSET_ADD || ret)
323                 ip_set_put_byindex(id);
324
325         return ip_set_eexist(ret, flags) ? 0 : ret;
326 }
327
328 static void
329 list_set_flush(struct ip_set *set)
330 {
331         struct list_set *map = set->data;
332         struct set_elem *elem;
333         u32 i;
334
335         for (i = 0; i < map->size; i++) {
336                 elem = list_set_elem(map, i);
337                 if (elem->id != IPSET_INVALID_ID) {
338                         ip_set_put_byindex(elem->id);
339                         elem->id = IPSET_INVALID_ID;
340                 }
341         }
342 }
343
344 static void
345 list_set_destroy(struct ip_set *set)
346 {
347         struct list_set *map = set->data;
348
349         if (with_timeout(map->timeout))
350                 del_timer_sync(&map->gc);
351         list_set_flush(set);
352         kfree(map);
353
354         set->data = NULL;
355 }
356
357 static int
358 list_set_head(struct ip_set *set, struct sk_buff *skb)
359 {
360         const struct list_set *map = set->data;
361         struct nlattr *nested;
362
363         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
364         if (!nested)
365                 goto nla_put_failure;
366         NLA_PUT_NET32(skb, IPSET_ATTR_SIZE, htonl(map->size));
367         if (with_timeout(map->timeout))
368                 NLA_PUT_NET32(skb, IPSET_ATTR_TIMEOUT, htonl(map->timeout));
369         NLA_PUT_NET32(skb, IPSET_ATTR_REFERENCES,
370                       htonl(atomic_read(&set->ref) - 1));
371         NLA_PUT_NET32(skb, IPSET_ATTR_MEMSIZE,
372                       htonl(sizeof(*map) + map->size * map->dsize));
373         ipset_nest_end(skb, nested);
374
375         return 0;
376 nla_put_failure:
377         return -EMSGSIZE;
378 }
379
380 static int
381 list_set_list(const struct ip_set *set,
382               struct sk_buff *skb, struct netlink_callback *cb)
383 {
384         const struct list_set *map = set->data;
385         struct nlattr *atd, *nested;
386         u32 i, first = cb->args[2];
387         const struct set_elem *e;
388
389         atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
390         if (!atd)
391                 return -EMSGSIZE;
392         for (; cb->args[2] < map->size; cb->args[2]++) {
393                 i = cb->args[2];
394                 e = list_set_elem(map, i);
395                 if (e->id == IPSET_INVALID_ID)
396                         goto finish;
397                 if (with_timeout(map->timeout) && list_set_expired(map, i))
398                         continue;
399                 nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
400                 if (!nested) {
401                         if (i == first) {
402                                 nla_nest_cancel(skb, atd);
403                                 return -EMSGSIZE;
404                         } else
405                                 goto nla_put_failure;
406                 }
407                 NLA_PUT_STRING(skb, IPSET_ATTR_NAME,
408                                ip_set_name_byindex(e->id));
409                 if (with_timeout(map->timeout)) {
410                         const struct set_telem *te =
411                                 (const struct set_telem *) e;
412                         NLA_PUT_NET32(skb, IPSET_ATTR_TIMEOUT,
413                                       htonl(ip_set_timeout_get(te->timeout)));
414                 }
415                 ipset_nest_end(skb, nested);
416         }
417 finish:
418         ipset_nest_end(skb, atd);
419         /* Set listing finished */
420         cb->args[2] = 0;
421         return 0;
422
423 nla_put_failure:
424         nla_nest_cancel(skb, nested);
425         ipset_nest_end(skb, atd);
426         if (unlikely(i == first)) {
427                 cb->args[2] = 0;
428                 return -EMSGSIZE;
429         }
430         return 0;
431 }
432
433 static bool
434 list_set_same_set(const struct ip_set *a, const struct ip_set *b)
435 {
436         const struct list_set *x = a->data;
437         const struct list_set *y = b->data;
438
439         return x->size == y->size &&
440                x->timeout == y->timeout;
441 }
442
443 static const struct ip_set_type_variant list_set = {
444         .kadt   = list_set_kadt,
445         .uadt   = list_set_uadt,
446         .destroy = list_set_destroy,
447         .flush  = list_set_flush,
448         .head   = list_set_head,
449         .list   = list_set_list,
450         .same_set = list_set_same_set,
451 };
452
453 static void
454 list_set_gc(unsigned long ul_set)
455 {
456         struct ip_set *set = (struct ip_set *) ul_set;
457         struct list_set *map = set->data;
458         struct set_telem *e;
459         u32 i;
460         
461         /* nfnl_lock should be called */
462         write_lock_bh(&set->lock);
463         for (i = 0; i < map->size; i++) {
464                 e = list_set_telem(map, i);
465                 if (e->id != IPSET_INVALID_ID && list_set_expired(map, i))
466                         list_set_del(map, i);
467         }
468         write_unlock_bh(&set->lock);
469
470         map->gc.expires = jiffies + IPSET_GC_PERIOD(map->timeout) * HZ;
471         add_timer(&map->gc);
472 }
473
474 static void
475 list_set_gc_init(struct ip_set *set)
476 {
477         struct list_set *map = set->data;
478
479         init_timer(&map->gc);
480         map->gc.data = (unsigned long) set;
481         map->gc.function = list_set_gc;
482         map->gc.expires = jiffies + IPSET_GC_PERIOD(map->timeout) * HZ;
483         add_timer(&map->gc);
484 }
485
486 /* Create list:set type of sets */
487
488 static bool
489 init_list_set(struct ip_set *set, u32 size, size_t dsize,
490               unsigned long timeout)
491 {
492         struct list_set *map;
493         struct set_elem *e;
494         u32 i;
495
496         map = kzalloc(sizeof(*map) + size * dsize, GFP_KERNEL);
497         if (!map)
498                 return false;
499
500         map->size = size;
501         map->dsize = dsize;
502         map->timeout = timeout;
503         set->data = map;
504
505         for (i = 0; i < size; i++) {
506                 e = list_set_elem(map, i);
507                 e->id = IPSET_INVALID_ID;
508         }
509
510         return true;
511 }
512
513 static int
514 list_set_create(struct ip_set *set, struct nlattr *tb[], u32 flags)
515 {
516         u32 size = IP_SET_LIST_DEFAULT_SIZE;
517
518         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_SIZE) ||
519                      !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT)))
520                 return -IPSET_ERR_PROTOCOL;
521
522         if (tb[IPSET_ATTR_SIZE])
523                 size = ip_set_get_h32(tb[IPSET_ATTR_SIZE]);
524         if (size < IP_SET_LIST_MIN_SIZE)
525                 size = IP_SET_LIST_MIN_SIZE;
526
527         if (tb[IPSET_ATTR_TIMEOUT]) {
528                 if (!init_list_set(set, size, sizeof(struct set_telem),
529                                    ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT])))
530                         return -ENOMEM;
531
532                 list_set_gc_init(set);
533         } else {
534                 if (!init_list_set(set, size, sizeof(struct set_elem),
535                                    IPSET_NO_TIMEOUT))
536                         return -ENOMEM;
537         }
538         set->variant = &list_set;
539         return 0;
540 }
541
542 static struct ip_set_type list_set_type __read_mostly = {
543         .name           = "list:set",
544         .protocol       = IPSET_PROTOCOL,
545         .features       = IPSET_TYPE_NAME | IPSET_DUMP_LAST,
546         .dimension      = IPSET_DIM_ONE,
547         .family         = AF_UNSPEC,
548         .revision       = 0,
549         .create         = list_set_create,
550         .create_policy  = {
551                 [IPSET_ATTR_SIZE]       = { .type = NLA_U32 },
552                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
553         },
554         .adt_policy     = {
555                 [IPSET_ATTR_NAME]       = { .type = NLA_STRING,
556                                             .len = IPSET_MAXNAMELEN },
557                 [IPSET_ATTR_NAMEREF]    = { .type = NLA_STRING,
558                                             .len = IPSET_MAXNAMELEN },
559                 [IPSET_ATTR_TIMEOUT]    = { .type = NLA_U32 },
560                 [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
561                 [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 },
562         },
563         .me             = THIS_MODULE,
564 };
565
566 static int __init
567 list_set_init(void)
568 {
569         return ip_set_type_register(&list_set_type);
570 }
571
572 static void __exit
573 list_set_fini(void)
574 {
575         ip_set_type_unregister(&list_set_type);
576 }
577
578 module_init(list_set_init);
579 module_exit(list_set_fini);