[CIFS] fix oops on second mount to same server when null auth is used
[pandora-kernel.git] / fs / cifs / asn1.c
1 /*
2  * The ASB.1/BER parsing code is derived from ip_nat_snmp_basic.c which was in
3  * turn derived from the gxsnmp package by Gregory McLean & Jochen Friedrich
4  *
5  * Copyright (c) 2000 RP Internet (www.rpi.net.au).
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
18  */
19
20 #include <linux/module.h>
21 #include <linux/types.h>
22 #include <linux/kernel.h>
23 #include <linux/mm.h>
24 #include <linux/slab.h>
25 #include "cifspdu.h"
26 #include "cifsglob.h"
27 #include "cifs_debug.h"
28 #include "cifsproto.h"
29
30 /*****************************************************************************
31  *
32  * Basic ASN.1 decoding routines (gxsnmp author Dirk Wisse)
33  *
34  *****************************************************************************/
35
36 /* Class */
37 #define ASN1_UNI        0       /* Universal */
38 #define ASN1_APL        1       /* Application */
39 #define ASN1_CTX        2       /* Context */
40 #define ASN1_PRV        3       /* Private */
41
42 /* Tag */
43 #define ASN1_EOC        0       /* End Of Contents or N/A */
44 #define ASN1_BOL        1       /* Boolean */
45 #define ASN1_INT        2       /* Integer */
46 #define ASN1_BTS        3       /* Bit String */
47 #define ASN1_OTS        4       /* Octet String */
48 #define ASN1_NUL        5       /* Null */
49 #define ASN1_OJI        6       /* Object Identifier  */
50 #define ASN1_OJD        7       /* Object Description */
51 #define ASN1_EXT        8       /* External */
52 #define ASN1_SEQ        16      /* Sequence */
53 #define ASN1_SET        17      /* Set */
54 #define ASN1_NUMSTR     18      /* Numerical String */
55 #define ASN1_PRNSTR     19      /* Printable String */
56 #define ASN1_TEXSTR     20      /* Teletext String */
57 #define ASN1_VIDSTR     21      /* Video String */
58 #define ASN1_IA5STR     22      /* IA5 String */
59 #define ASN1_UNITIM     23      /* Universal Time */
60 #define ASN1_GENTIM     24      /* General Time */
61 #define ASN1_GRASTR     25      /* Graphical String */
62 #define ASN1_VISSTR     26      /* Visible String */
63 #define ASN1_GENSTR     27      /* General String */
64
65 /* Primitive / Constructed methods*/
66 #define ASN1_PRI        0       /* Primitive */
67 #define ASN1_CON        1       /* Constructed */
68
69 /*
70  * Error codes.
71  */
72 #define ASN1_ERR_NOERROR                0
73 #define ASN1_ERR_DEC_EMPTY              2
74 #define ASN1_ERR_DEC_EOC_MISMATCH       3
75 #define ASN1_ERR_DEC_LENGTH_MISMATCH    4
76 #define ASN1_ERR_DEC_BADVALUE           5
77
78 #define SPNEGO_OID_LEN 7
79 #define NTLMSSP_OID_LEN  10
80 #define KRB5_OID_LEN  7
81 #define MSKRB5_OID_LEN  7
82 static unsigned long SPNEGO_OID[7] = { 1, 3, 6, 1, 5, 5, 2 };
83 static unsigned long NTLMSSP_OID[10] = { 1, 3, 6, 1, 4, 1, 311, 2, 2, 10 };
84 static unsigned long KRB5_OID[7] = { 1, 2, 840, 113554, 1, 2, 2 };
85 static unsigned long MSKRB5_OID[7] = { 1, 2, 840, 48018, 1, 2, 2 };
86
87 /*
88  * ASN.1 context.
89  */
90 struct asn1_ctx {
91         int error;              /* Error condition */
92         unsigned char *pointer; /* Octet just to be decoded */
93         unsigned char *begin;   /* First octet */
94         unsigned char *end;     /* Octet after last octet */
95 };
96
97 /*
98  * Octet string (not null terminated)
99  */
100 struct asn1_octstr {
101         unsigned char *data;
102         unsigned int len;
103 };
104
105 static void
106 asn1_open(struct asn1_ctx *ctx, unsigned char *buf, unsigned int len)
107 {
108         ctx->begin = buf;
109         ctx->end = buf + len;
110         ctx->pointer = buf;
111         ctx->error = ASN1_ERR_NOERROR;
112 }
113
114 static unsigned char
115 asn1_octet_decode(struct asn1_ctx *ctx, unsigned char *ch)
116 {
117         if (ctx->pointer >= ctx->end) {
118                 ctx->error = ASN1_ERR_DEC_EMPTY;
119                 return 0;
120         }
121         *ch = *(ctx->pointer)++;
122         return 1;
123 }
124
125 static unsigned char
126 asn1_tag_decode(struct asn1_ctx *ctx, unsigned int *tag)
127 {
128         unsigned char ch;
129
130         *tag = 0;
131
132         do {
133                 if (!asn1_octet_decode(ctx, &ch))
134                         return 0;
135                 *tag <<= 7;
136                 *tag |= ch & 0x7F;
137         } while ((ch & 0x80) == 0x80);
138         return 1;
139 }
140
141 static unsigned char
142 asn1_id_decode(struct asn1_ctx *ctx,
143                unsigned int *cls, unsigned int *con, unsigned int *tag)
144 {
145         unsigned char ch;
146
147         if (!asn1_octet_decode(ctx, &ch))
148                 return 0;
149
150         *cls = (ch & 0xC0) >> 6;
151         *con = (ch & 0x20) >> 5;
152         *tag = (ch & 0x1F);
153
154         if (*tag == 0x1F) {
155                 if (!asn1_tag_decode(ctx, tag))
156                         return 0;
157         }
158         return 1;
159 }
160
161 static unsigned char
162 asn1_length_decode(struct asn1_ctx *ctx, unsigned int *def, unsigned int *len)
163 {
164         unsigned char ch, cnt;
165
166         if (!asn1_octet_decode(ctx, &ch))
167                 return 0;
168
169         if (ch == 0x80)
170                 *def = 0;
171         else {
172                 *def = 1;
173
174                 if (ch < 0x80)
175                         *len = ch;
176                 else {
177                         cnt = (unsigned char) (ch & 0x7F);
178                         *len = 0;
179
180                         while (cnt > 0) {
181                                 if (!asn1_octet_decode(ctx, &ch))
182                                         return 0;
183                                 *len <<= 8;
184                                 *len |= ch;
185                                 cnt--;
186                         }
187                 }
188         }
189         return 1;
190 }
191
192 static unsigned char
193 asn1_header_decode(struct asn1_ctx *ctx,
194                    unsigned char **eoc,
195                    unsigned int *cls, unsigned int *con, unsigned int *tag)
196 {
197         unsigned int def = 0;
198         unsigned int len = 0;
199
200         if (!asn1_id_decode(ctx, cls, con, tag))
201                 return 0;
202
203         if (!asn1_length_decode(ctx, &def, &len))
204                 return 0;
205
206         if (def)
207                 *eoc = ctx->pointer + len;
208         else
209                 *eoc = NULL;
210         return 1;
211 }
212
213 static unsigned char
214 asn1_eoc_decode(struct asn1_ctx *ctx, unsigned char *eoc)
215 {
216         unsigned char ch;
217
218         if (eoc == NULL) {
219                 if (!asn1_octet_decode(ctx, &ch))
220                         return 0;
221
222                 if (ch != 0x00) {
223                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
224                         return 0;
225                 }
226
227                 if (!asn1_octet_decode(ctx, &ch))
228                         return 0;
229
230                 if (ch != 0x00) {
231                         ctx->error = ASN1_ERR_DEC_EOC_MISMATCH;
232                         return 0;
233                 }
234                 return 1;
235         } else {
236                 if (ctx->pointer != eoc) {
237                         ctx->error = ASN1_ERR_DEC_LENGTH_MISMATCH;
238                         return 0;
239                 }
240                 return 1;
241         }
242 }
243
244 /* static unsigned char asn1_null_decode(struct asn1_ctx *ctx,
245                                       unsigned char *eoc)
246 {
247         ctx->pointer = eoc;
248         return 1;
249 }
250
251 static unsigned char asn1_long_decode(struct asn1_ctx *ctx,
252                                       unsigned char *eoc, long *integer)
253 {
254         unsigned char ch;
255         unsigned int len;
256
257         if (!asn1_octet_decode(ctx, &ch))
258                 return 0;
259
260         *integer = (signed char) ch;
261         len = 1;
262
263         while (ctx->pointer < eoc) {
264                 if (++len > sizeof(long)) {
265                         ctx->error = ASN1_ERR_DEC_BADVALUE;
266                         return 0;
267                 }
268
269                 if (!asn1_octet_decode(ctx, &ch))
270                         return 0;
271
272                 *integer <<= 8;
273                 *integer |= ch;
274         }
275         return 1;
276 }
277
278 static unsigned char asn1_uint_decode(struct asn1_ctx *ctx,
279                                       unsigned char *eoc,
280                                       unsigned int *integer)
281 {
282         unsigned char ch;
283         unsigned int len;
284
285         if (!asn1_octet_decode(ctx, &ch))
286                 return 0;
287
288         *integer = ch;
289         if (ch == 0)
290                 len = 0;
291         else
292                 len = 1;
293
294         while (ctx->pointer < eoc) {
295                 if (++len > sizeof(unsigned int)) {
296                         ctx->error = ASN1_ERR_DEC_BADVALUE;
297                         return 0;
298                 }
299
300                 if (!asn1_octet_decode(ctx, &ch))
301                         return 0;
302
303                 *integer <<= 8;
304                 *integer |= ch;
305         }
306         return 1;
307 }
308
309 static unsigned char asn1_ulong_decode(struct asn1_ctx *ctx,
310                                        unsigned char *eoc,
311                                        unsigned long *integer)
312 {
313         unsigned char ch;
314         unsigned int len;
315
316         if (!asn1_octet_decode(ctx, &ch))
317                 return 0;
318
319         *integer = ch;
320         if (ch == 0)
321                 len = 0;
322         else
323                 len = 1;
324
325         while (ctx->pointer < eoc) {
326                 if (++len > sizeof(unsigned long)) {
327                         ctx->error = ASN1_ERR_DEC_BADVALUE;
328                         return 0;
329                 }
330
331                 if (!asn1_octet_decode(ctx, &ch))
332                         return 0;
333
334                 *integer <<= 8;
335                 *integer |= ch;
336         }
337         return 1;
338 }
339
340 static unsigned char
341 asn1_octets_decode(struct asn1_ctx *ctx,
342                    unsigned char *eoc,
343                    unsigned char **octets, unsigned int *len)
344 {
345         unsigned char *ptr;
346
347         *len = 0;
348
349         *octets = kmalloc(eoc - ctx->pointer, GFP_ATOMIC);
350         if (*octets == NULL) {
351                 return 0;
352         }
353
354         ptr = *octets;
355         while (ctx->pointer < eoc) {
356                 if (!asn1_octet_decode(ctx, (unsigned char *) ptr++)) {
357                         kfree(*octets);
358                         *octets = NULL;
359                         return 0;
360                 }
361                 (*len)++;
362         }
363         return 1;
364 } */
365
366 static unsigned char
367 asn1_subid_decode(struct asn1_ctx *ctx, unsigned long *subid)
368 {
369         unsigned char ch;
370
371         *subid = 0;
372
373         do {
374                 if (!asn1_octet_decode(ctx, &ch))
375                         return 0;
376
377                 *subid <<= 7;
378                 *subid |= ch & 0x7F;
379         } while ((ch & 0x80) == 0x80);
380         return 1;
381 }
382
383 static int
384 asn1_oid_decode(struct asn1_ctx *ctx,
385                 unsigned char *eoc, unsigned long **oid, unsigned int *len)
386 {
387         unsigned long subid;
388         unsigned int size;
389         unsigned long *optr;
390
391         size = eoc - ctx->pointer + 1;
392         *oid = kmalloc(size * sizeof(unsigned long), GFP_ATOMIC);
393         if (*oid == NULL)
394                 return 0;
395
396         optr = *oid;
397
398         if (!asn1_subid_decode(ctx, &subid)) {
399                 kfree(*oid);
400                 *oid = NULL;
401                 return 0;
402         }
403
404         if (subid < 40) {
405                 optr[0] = 0;
406                 optr[1] = subid;
407         } else if (subid < 80) {
408                 optr[0] = 1;
409                 optr[1] = subid - 40;
410         } else {
411                 optr[0] = 2;
412                 optr[1] = subid - 80;
413         }
414
415         *len = 2;
416         optr += 2;
417
418         while (ctx->pointer < eoc) {
419                 if (++(*len) > size) {
420                         ctx->error = ASN1_ERR_DEC_BADVALUE;
421                         kfree(*oid);
422                         *oid = NULL;
423                         return 0;
424                 }
425
426                 if (!asn1_subid_decode(ctx, optr++)) {
427                         kfree(*oid);
428                         *oid = NULL;
429                         return 0;
430                 }
431         }
432         return 1;
433 }
434
435 static int
436 compare_oid(unsigned long *oid1, unsigned int oid1len,
437             unsigned long *oid2, unsigned int oid2len)
438 {
439         unsigned int i;
440
441         if (oid1len != oid2len)
442                 return 0;
443         else {
444                 for (i = 0; i < oid1len; i++) {
445                         if (oid1[i] != oid2[i])
446                                 return 0;
447                 }
448                 return 1;
449         }
450 }
451
452         /* BB check for endian conversion issues here */
453
454 int
455 decode_negTokenInit(unsigned char *security_blob, int length,
456                     enum securityEnum *secType)
457 {
458         struct asn1_ctx ctx;
459         unsigned char *end;
460         unsigned char *sequence_end;
461         unsigned long *oid = NULL;
462         unsigned int cls, con, tag, oidlen, rc;
463         int use_ntlmssp = FALSE;
464         int use_kerberos = FALSE;
465
466         *secType = NTLM; /* BB eventually make Kerberos or NLTMSSP the default*/
467
468         /* cifs_dump_mem(" Received SecBlob ", security_blob, length); */
469
470         asn1_open(&ctx, security_blob, length);
471
472         if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
473                 cFYI(1, ("Error decoding negTokenInit header"));
474                 return 0;
475         } else if ((cls != ASN1_APL) || (con != ASN1_CON)
476                    || (tag != ASN1_EOC)) {
477                 cFYI(1, ("cls = %d con = %d tag = %d", cls, con, tag));
478                 return 0;
479         } else {
480                 /*      remember to free obj->oid */
481                 rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
482                 if (rc) {
483                         if ((tag == ASN1_OJI) && (cls == ASN1_PRI)) {
484                                 rc = asn1_oid_decode(&ctx, end, &oid, &oidlen);
485                                 if (rc) {
486                                         rc = compare_oid(oid, oidlen,
487                                                          SPNEGO_OID,
488                                                          SPNEGO_OID_LEN);
489                                         kfree(oid);
490                                 }
491                         } else
492                                 rc = 0;
493                 }
494
495                 if (!rc) {
496                         cFYI(1, ("Error decoding negTokenInit header"));
497                         return 0;
498                 }
499
500                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
501                         cFYI(1, ("Error decoding negTokenInit"));
502                         return 0;
503                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
504                            || (tag != ASN1_EOC)) {
505                         cFYI(1,
506                              ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
507                               cls, con, tag, end, *end));
508                         return 0;
509                 }
510
511                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
512                         cFYI(1, ("Error decoding negTokenInit"));
513                         return 0;
514                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
515                            || (tag != ASN1_SEQ)) {
516                         cFYI(1,
517                              ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
518                               cls, con, tag, end, *end));
519                         return 0;
520                 }
521
522                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
523                         cFYI(1, ("Error decoding 2nd part of negTokenInit"));
524                         return 0;
525                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)
526                            || (tag != ASN1_EOC)) {
527                         cFYI(1,
528                              ("cls = %d con = %d tag = %d end = %p (%d) exit 0",
529                               cls, con, tag, end, *end));
530                         return 0;
531                 }
532
533                 if (asn1_header_decode
534                     (&ctx, &sequence_end, &cls, &con, &tag) == 0) {
535                         cFYI(1, ("Error decoding 2nd part of negTokenInit"));
536                         return 0;
537                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
538                            || (tag != ASN1_SEQ)) {
539                         cFYI(1,
540                              ("cls = %d con = %d tag = %d end = %p (%d) exit 1",
541                               cls, con, tag, end, *end));
542                         return 0;
543                 }
544
545                 while (!asn1_eoc_decode(&ctx, sequence_end)) {
546                         rc = asn1_header_decode(&ctx, &end, &cls, &con, &tag);
547                         if (!rc) {
548                                 cFYI(1,
549                                      ("Error decoding negTokenInit hdr exit2"));
550                                 return 0;
551                         }
552                         if ((tag == ASN1_OJI) && (con == ASN1_PRI)) {
553                                 if (asn1_oid_decode(&ctx, end, &oid, &oidlen)) {
554
555                                         cFYI(1,
556                                           ("OID len = %d oid = 0x%lx 0x%lx "
557                                            "0x%lx 0x%lx",
558                                            oidlen, *oid, *(oid + 1),
559                                            *(oid + 2), *(oid + 3)));
560
561                                         if (compare_oid(oid, oidlen,
562                                                         MSKRB5_OID,
563                                                         MSKRB5_OID_LEN))
564                                                 use_kerberos = TRUE;
565                                         else if (compare_oid(oid, oidlen,
566                                                              KRB5_OID,
567                                                              KRB5_OID_LEN))
568                                                 use_kerberos = TRUE;
569                                         else if (compare_oid(oid, oidlen,
570                                                              NTLMSSP_OID,
571                                                              NTLMSSP_OID_LEN))
572                                                 use_ntlmssp = TRUE;
573
574                                         kfree(oid);
575                                 }
576                         } else {
577                                 cFYI(1, ("Should be an oid what is going on?"));
578                         }
579                 }
580
581                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
582                         cFYI(1,
583                              ("Error decoding last part negTokenInit exit3"));
584                         return 0;
585                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
586                         /* tag = 3 indicating mechListMIC */
587                         cFYI(1,
588                              ("Exit 4 cls = %d con = %d tag = %d end = %p (%d)",
589                               cls, con, tag, end, *end));
590                         return 0;
591                 }
592                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
593                         cFYI(1,
594                              ("Error decoding last part negTokenInit exit5"));
595                         return 0;
596                 } else if ((cls != ASN1_UNI) || (con != ASN1_CON)
597                            || (tag != ASN1_SEQ)) {
598                         cFYI(1, ("cls = %d con = %d tag = %d end = %p (%d)",
599                                 cls, con, tag, end, *end));
600                 }
601
602                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
603                         cFYI(1,
604                              ("Error decoding last part negTokenInit exit 7"));
605                         return 0;
606                 } else if ((cls != ASN1_CTX) || (con != ASN1_CON)) {
607                         cFYI(1,
608                              ("Exit 8 cls = %d con = %d tag = %d end = %p (%d)",
609                               cls, con, tag, end, *end));
610                         return 0;
611                 }
612                 if (asn1_header_decode(&ctx, &end, &cls, &con, &tag) == 0) {
613                         cFYI(1,
614                              ("Error decoding last part negTokenInit exit9"));
615                         return 0;
616                 } else if ((cls != ASN1_UNI) || (con != ASN1_PRI)
617                            || (tag != ASN1_GENSTR)) {
618                         cFYI(1,
619                              ("Exit10 cls = %d con = %d tag = %d end = %p (%d)",
620                               cls, con, tag, end, *end));
621                         return 0;
622                 }
623                 cFYI(1, ("Need to call asn1_octets_decode() function for %s",
624                          ctx.pointer)); /* is this UTF-8 or ASCII? */
625         }
626
627         if (use_kerberos)
628                 *secType = Kerberos;
629         else if (use_ntlmssp)
630                 *secType = NTLMSSP;
631
632         return 1;
633 }