crypto: algif_skcipher - Require setkey before accept(2)
[pandora-kernel.git] / crypto / algif_skcipher.c
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/mm.h>
22 #include <linux/module.h>
23 #include <linux/net.h>
24 #include <net/sock.h>
25
26 struct skcipher_sg_list {
27         struct list_head list;
28
29         int cur;
30
31         struct scatterlist sg[0];
32 };
33
34 struct skcipher_tfm {
35         struct crypto_ablkcipher *skcipher;
36         bool has_key;
37 };
38
39 struct skcipher_ctx {
40         struct list_head tsgl;
41         struct af_alg_sgl rsgl;
42
43         void *iv;
44
45         struct af_alg_completion completion;
46
47         unsigned used;
48
49         unsigned int len;
50         bool more;
51         bool merge;
52         bool enc;
53
54         struct ablkcipher_request req;
55 };
56
57 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
58                       sizeof(struct scatterlist) - 1)
59
60 static inline int skcipher_sndbuf(struct sock *sk)
61 {
62         struct alg_sock *ask = alg_sk(sk);
63         struct skcipher_ctx *ctx = ask->private;
64
65         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
66                           ctx->used, 0);
67 }
68
69 static inline bool skcipher_writable(struct sock *sk)
70 {
71         return PAGE_SIZE <= skcipher_sndbuf(sk);
72 }
73
74 static int skcipher_alloc_sgl(struct sock *sk)
75 {
76         struct alg_sock *ask = alg_sk(sk);
77         struct skcipher_ctx *ctx = ask->private;
78         struct skcipher_sg_list *sgl;
79         struct scatterlist *sg = NULL;
80
81         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
82         if (!list_empty(&ctx->tsgl))
83                 sg = sgl->sg;
84
85         if (!sg || sgl->cur >= MAX_SGL_ENTS) {
86                 sgl = sock_kmalloc(sk, sizeof(*sgl) +
87                                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
88                                    GFP_KERNEL);
89                 if (!sgl)
90                         return -ENOMEM;
91
92                 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
93                 sgl->cur = 0;
94
95                 if (sg)
96                         scatterwalk_sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
97
98                 list_add_tail(&sgl->list, &ctx->tsgl);
99         }
100
101         return 0;
102 }
103
104 static void skcipher_pull_sgl(struct sock *sk, int used)
105 {
106         struct alg_sock *ask = alg_sk(sk);
107         struct skcipher_ctx *ctx = ask->private;
108         struct skcipher_sg_list *sgl;
109         struct scatterlist *sg;
110         int i;
111
112         while (!list_empty(&ctx->tsgl)) {
113                 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
114                                        list);
115                 sg = sgl->sg;
116
117                 for (i = 0; i < sgl->cur; i++) {
118                         int plen = min_t(int, used, sg[i].length);
119
120                         if (!sg_page(sg + i))
121                                 continue;
122
123                         sg[i].length -= plen;
124                         sg[i].offset += plen;
125
126                         used -= plen;
127                         ctx->used -= plen;
128
129                         if (sg[i].length)
130                                 return;
131
132                         put_page(sg_page(sg + i));
133                         sg_assign_page(sg + i, NULL);
134                 }
135
136                 list_del(&sgl->list);
137                 sock_kfree_s(sk, sgl,
138                              sizeof(*sgl) + sizeof(sgl->sg[0]) *
139                                             (MAX_SGL_ENTS + 1));
140         }
141
142         if (!ctx->used)
143                 ctx->merge = 0;
144 }
145
146 static void skcipher_free_sgl(struct sock *sk)
147 {
148         struct alg_sock *ask = alg_sk(sk);
149         struct skcipher_ctx *ctx = ask->private;
150
151         skcipher_pull_sgl(sk, ctx->used);
152 }
153
154 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
155 {
156         long timeout;
157         DEFINE_WAIT(wait);
158         int err = -ERESTARTSYS;
159
160         if (flags & MSG_DONTWAIT)
161                 return -EAGAIN;
162
163         set_bit(SOCK_ASYNC_NOSPACE, &sk->sk_socket->flags);
164
165         for (;;) {
166                 if (signal_pending(current))
167                         break;
168                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
169                 timeout = MAX_SCHEDULE_TIMEOUT;
170                 if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
171                         err = 0;
172                         break;
173                 }
174         }
175         finish_wait(sk_sleep(sk), &wait);
176
177         return err;
178 }
179
180 static void skcipher_wmem_wakeup(struct sock *sk)
181 {
182         struct socket_wq *wq;
183
184         if (!skcipher_writable(sk))
185                 return;
186
187         rcu_read_lock();
188         wq = rcu_dereference(sk->sk_wq);
189         if (wq_has_sleeper(wq))
190                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
191                                                            POLLRDNORM |
192                                                            POLLRDBAND);
193         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
194         rcu_read_unlock();
195 }
196
197 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
198 {
199         struct alg_sock *ask = alg_sk(sk);
200         struct skcipher_ctx *ctx = ask->private;
201         long timeout;
202         DEFINE_WAIT(wait);
203         int err = -ERESTARTSYS;
204
205         if (flags & MSG_DONTWAIT) {
206                 return -EAGAIN;
207         }
208
209         set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
210
211         for (;;) {
212                 if (signal_pending(current))
213                         break;
214                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
215                 timeout = MAX_SCHEDULE_TIMEOUT;
216                 if (sk_wait_event(sk, &timeout, ctx->used)) {
217                         err = 0;
218                         break;
219                 }
220         }
221         finish_wait(sk_sleep(sk), &wait);
222
223         clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
224
225         return err;
226 }
227
228 static void skcipher_data_wakeup(struct sock *sk)
229 {
230         struct alg_sock *ask = alg_sk(sk);
231         struct skcipher_ctx *ctx = ask->private;
232         struct socket_wq *wq;
233
234         if (!ctx->used)
235                 return;
236
237         rcu_read_lock();
238         wq = rcu_dereference(sk->sk_wq);
239         if (wq_has_sleeper(wq))
240                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
241                                                            POLLRDNORM |
242                                                            POLLRDBAND);
243         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
244         rcu_read_unlock();
245 }
246
247 static int skcipher_sendmsg(struct kiocb *unused, struct socket *sock,
248                             struct msghdr *msg, size_t size)
249 {
250         struct sock *sk = sock->sk;
251         struct alg_sock *ask = alg_sk(sk);
252         struct skcipher_ctx *ctx = ask->private;
253         struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
254         unsigned ivsize = crypto_ablkcipher_ivsize(tfm);
255         struct skcipher_sg_list *sgl;
256         struct af_alg_control con = {};
257         long copied = 0;
258         bool enc = 0;
259         int err;
260         int i;
261
262         if (msg->msg_controllen) {
263                 err = af_alg_cmsg_send(msg, &con);
264                 if (err)
265                         return err;
266
267                 switch (con.op) {
268                 case ALG_OP_ENCRYPT:
269                         enc = 1;
270                         break;
271                 case ALG_OP_DECRYPT:
272                         enc = 0;
273                         break;
274                 default:
275                         return -EINVAL;
276                 }
277
278                 if (con.iv && con.iv->ivlen != ivsize)
279                         return -EINVAL;
280         }
281
282         err = -EINVAL;
283
284         lock_sock(sk);
285         if (!ctx->more && ctx->used)
286                 goto unlock;
287
288         if (!ctx->used) {
289                 ctx->enc = enc;
290                 if (con.iv)
291                         memcpy(ctx->iv, con.iv->iv, ivsize);
292         }
293
294         while (size) {
295                 struct scatterlist *sg;
296                 unsigned long len = size;
297                 int plen;
298
299                 if (ctx->merge) {
300                         sgl = list_entry(ctx->tsgl.prev,
301                                          struct skcipher_sg_list, list);
302                         sg = sgl->sg + sgl->cur - 1;
303                         len = min_t(unsigned long, len,
304                                     PAGE_SIZE - sg->offset - sg->length);
305
306                         err = memcpy_fromiovec(page_address(sg_page(sg)) +
307                                                sg->offset + sg->length,
308                                                msg->msg_iov, len);
309                         if (err)
310                                 goto unlock;
311
312                         sg->length += len;
313                         ctx->merge = (sg->offset + sg->length) &
314                                      (PAGE_SIZE - 1);
315
316                         ctx->used += len;
317                         copied += len;
318                         size -= len;
319                         continue;
320                 }
321
322                 if (!skcipher_writable(sk)) {
323                         err = skcipher_wait_for_wmem(sk, msg->msg_flags);
324                         if (err)
325                                 goto unlock;
326                 }
327
328                 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
329
330                 err = skcipher_alloc_sgl(sk);
331                 if (err)
332                         goto unlock;
333
334                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
335                 sg = sgl->sg;
336                 do {
337                         i = sgl->cur;
338                         plen = min_t(int, len, PAGE_SIZE);
339
340                         sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
341                         err = -ENOMEM;
342                         if (!sg_page(sg + i))
343                                 goto unlock;
344
345                         err = memcpy_fromiovec(page_address(sg_page(sg + i)),
346                                                msg->msg_iov, plen);
347                         if (err) {
348                                 __free_page(sg_page(sg + i));
349                                 sg_assign_page(sg + i, NULL);
350                                 goto unlock;
351                         }
352
353                         sg[i].length = plen;
354                         len -= plen;
355                         ctx->used += plen;
356                         copied += plen;
357                         size -= plen;
358                         sgl->cur++;
359                 } while (len && sgl->cur < MAX_SGL_ENTS);
360
361                 ctx->merge = plen & (PAGE_SIZE - 1);
362         }
363
364         err = 0;
365
366         ctx->more = msg->msg_flags & MSG_MORE;
367         if (!ctx->more && !list_empty(&ctx->tsgl))
368                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
369
370 unlock:
371         skcipher_data_wakeup(sk);
372         release_sock(sk);
373
374         return copied ?: err;
375 }
376
377 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
378                                  int offset, size_t size, int flags)
379 {
380         struct sock *sk = sock->sk;
381         struct alg_sock *ask = alg_sk(sk);
382         struct skcipher_ctx *ctx = ask->private;
383         struct skcipher_sg_list *sgl;
384         int err = -EINVAL;
385
386         if (flags & MSG_SENDPAGE_NOTLAST)
387                 flags |= MSG_MORE;
388
389         lock_sock(sk);
390         if (!ctx->more && ctx->used)
391                 goto unlock;
392
393         if (!size)
394                 goto done;
395
396         if (!skcipher_writable(sk)) {
397                 err = skcipher_wait_for_wmem(sk, flags);
398                 if (err)
399                         goto unlock;
400         }
401
402         err = skcipher_alloc_sgl(sk);
403         if (err)
404                 goto unlock;
405
406         ctx->merge = 0;
407         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
408
409         get_page(page);
410         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
411         sgl->cur++;
412         ctx->used += size;
413
414 done:
415         ctx->more = flags & MSG_MORE;
416         if (!ctx->more && !list_empty(&ctx->tsgl))
417                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
418
419 unlock:
420         skcipher_data_wakeup(sk);
421         release_sock(sk);
422
423         return err ?: size;
424 }
425
426 static int skcipher_recvmsg(struct kiocb *unused, struct socket *sock,
427                             struct msghdr *msg, size_t ignored, int flags)
428 {
429         struct sock *sk = sock->sk;
430         struct alg_sock *ask = alg_sk(sk);
431         struct skcipher_ctx *ctx = ask->private;
432         unsigned bs = crypto_ablkcipher_blocksize(crypto_ablkcipher_reqtfm(
433                 &ctx->req));
434         struct skcipher_sg_list *sgl;
435         struct scatterlist *sg;
436         unsigned long iovlen;
437         struct iovec *iov;
438         int err = -EAGAIN;
439         int used;
440         long copied = 0;
441
442         lock_sock(sk);
443         for (iov = msg->msg_iov, iovlen = msg->msg_iovlen; iovlen > 0;
444              iovlen--, iov++) {
445                 unsigned long seglen = iov->iov_len;
446                 char __user *from = iov->iov_base;
447
448                 while (seglen) {
449                         sgl = list_first_entry(&ctx->tsgl,
450                                                struct skcipher_sg_list, list);
451                         sg = sgl->sg;
452
453                         while (!sg->length)
454                                 sg++;
455
456                         used = ctx->used;
457                         if (!used) {
458                                 err = skcipher_wait_for_data(sk, flags);
459                                 if (err)
460                                         goto unlock;
461                         }
462
463                         used = min_t(unsigned long, used, seglen);
464
465                         used = af_alg_make_sg(&ctx->rsgl, from, used, 1);
466                         err = used;
467                         if (err < 0)
468                                 goto unlock;
469
470                         if (ctx->more || used < ctx->used)
471                                 used -= used % bs;
472
473                         err = -EINVAL;
474                         if (!used)
475                                 goto free;
476
477                         ablkcipher_request_set_crypt(&ctx->req, sg,
478                                                      ctx->rsgl.sg, used,
479                                                      ctx->iv);
480
481                         err = af_alg_wait_for_completion(
482                                 ctx->enc ?
483                                         crypto_ablkcipher_encrypt(&ctx->req) :
484                                         crypto_ablkcipher_decrypt(&ctx->req),
485                                 &ctx->completion);
486
487 free:
488                         af_alg_free_sg(&ctx->rsgl);
489
490                         if (err)
491                                 goto unlock;
492
493                         copied += used;
494                         from += used;
495                         seglen -= used;
496                         skcipher_pull_sgl(sk, used);
497                 }
498         }
499
500         err = 0;
501
502 unlock:
503         skcipher_wmem_wakeup(sk);
504         release_sock(sk);
505
506         return copied ?: err;
507 }
508
509
510 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
511                                   poll_table *wait)
512 {
513         struct sock *sk = sock->sk;
514         struct alg_sock *ask = alg_sk(sk);
515         struct skcipher_ctx *ctx = ask->private;
516         unsigned int mask;
517
518         sock_poll_wait(file, sk_sleep(sk), wait);
519         mask = 0;
520
521         if (ctx->used)
522                 mask |= POLLIN | POLLRDNORM;
523
524         if (skcipher_writable(sk))
525                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
526
527         return mask;
528 }
529
530 static struct proto_ops algif_skcipher_ops = {
531         .family         =       PF_ALG,
532
533         .connect        =       sock_no_connect,
534         .socketpair     =       sock_no_socketpair,
535         .getname        =       sock_no_getname,
536         .ioctl          =       sock_no_ioctl,
537         .listen         =       sock_no_listen,
538         .shutdown       =       sock_no_shutdown,
539         .getsockopt     =       sock_no_getsockopt,
540         .mmap           =       sock_no_mmap,
541         .bind           =       sock_no_bind,
542         .accept         =       sock_no_accept,
543         .setsockopt     =       sock_no_setsockopt,
544
545         .release        =       af_alg_release,
546         .sendmsg        =       skcipher_sendmsg,
547         .sendpage       =       skcipher_sendpage,
548         .recvmsg        =       skcipher_recvmsg,
549         .poll           =       skcipher_poll,
550 };
551
552 static void *skcipher_bind(const char *name, u32 type, u32 mask)
553 {
554         struct skcipher_tfm *tfm;
555         struct crypto_ablkcipher *skcipher;
556
557         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
558         if (!tfm)
559                 return ERR_PTR(-ENOMEM);
560
561         skcipher = crypto_alloc_ablkcipher(name, type, mask);
562         if (IS_ERR(skcipher)) {
563                 kfree(tfm);
564                 return ERR_CAST(skcipher);
565         }
566
567         tfm->skcipher = skcipher;
568
569         return tfm;
570 }
571
572 static void skcipher_release(void *private)
573 {
574         struct skcipher_tfm *tfm = private;
575
576         crypto_free_ablkcipher(tfm->skcipher);
577         kfree(tfm);
578 }
579
580 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
581 {
582         struct skcipher_tfm *tfm = private;
583         int err;
584
585         err = crypto_ablkcipher_setkey(tfm->skcipher, key, keylen);
586         tfm->has_key = !err;
587
588         return err;
589 }
590
591 static void skcipher_sock_destruct(struct sock *sk)
592 {
593         struct alg_sock *ask = alg_sk(sk);
594         struct skcipher_ctx *ctx = ask->private;
595         struct crypto_ablkcipher *tfm = crypto_ablkcipher_reqtfm(&ctx->req);
596
597         skcipher_free_sgl(sk);
598         sock_kfree_s(sk, ctx->iv, crypto_ablkcipher_ivsize(tfm));
599         sock_kfree_s(sk, ctx, ctx->len);
600         af_alg_release_parent(sk);
601 }
602
603 static int skcipher_accept_parent(void *private, struct sock *sk)
604 {
605         struct skcipher_ctx *ctx;
606         struct alg_sock *ask = alg_sk(sk);
607         struct skcipher_tfm *tfm = private;
608         struct crypto_ablkcipher *skcipher = tfm->skcipher;
609         unsigned int len = sizeof(*ctx) + crypto_ablkcipher_reqsize(skcipher);
610
611         if (!tfm->has_key)
612                 return -ENOKEY;
613
614         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
615         if (!ctx)
616                 return -ENOMEM;
617
618         ctx->iv = sock_kmalloc(sk, crypto_ablkcipher_ivsize(skcipher),
619                                GFP_KERNEL);
620         if (!ctx->iv) {
621                 sock_kfree_s(sk, ctx, len);
622                 return -ENOMEM;
623         }
624
625         memset(ctx->iv, 0, crypto_ablkcipher_ivsize(skcipher));
626
627         INIT_LIST_HEAD(&ctx->tsgl);
628         ctx->len = len;
629         ctx->used = 0;
630         ctx->more = 0;
631         ctx->merge = 0;
632         ctx->enc = 0;
633         af_alg_init_completion(&ctx->completion);
634
635         ask->private = ctx;
636
637         ablkcipher_request_set_tfm(&ctx->req, skcipher);
638         ablkcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
639                                         af_alg_complete, &ctx->completion);
640
641         sk->sk_destruct = skcipher_sock_destruct;
642
643         return 0;
644 }
645
646 static const struct af_alg_type algif_type_skcipher = {
647         .bind           =       skcipher_bind,
648         .release        =       skcipher_release,
649         .setkey         =       skcipher_setkey,
650         .accept         =       skcipher_accept_parent,
651         .ops            =       &algif_skcipher_ops,
652         .name           =       "skcipher",
653         .owner          =       THIS_MODULE
654 };
655
656 static int __init algif_skcipher_init(void)
657 {
658         return af_alg_register_type(&algif_type_skcipher);
659 }
660
661 static void __exit algif_skcipher_exit(void)
662 {
663         int err = af_alg_unregister_type(&algif_type_skcipher);
664         BUG_ON(err);
665 }
666
667 module_init(algif_skcipher_init);
668 module_exit(algif_skcipher_exit);
669 MODULE_LICENSE("GPL");