ipv4: Avoid reading user iov twice after raw_probe_proto_opt
[pandora-kernel.git] / net / ipv4 / raw.c
index 007e2eb..f250aa7 100644 (file)
 #include <linux/netfilter.h>
 #include <linux/netfilter_ipv4.h>
 #include <linux/compat.h>
+#include <linux/uio.h>
+
+struct raw_frag_vec {
+       struct iovec *iov;
+       union {
+               struct icmphdr icmph;
+               char c[1];
+       } hdr;
+       int hlen;
+};
 
 static struct raw_hashinfo raw_v4_hashinfo = {
        .lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock),
@@ -131,18 +141,20 @@ found:
  *     0 - deliver
  *     1 - block
  */
-static __inline__ int icmp_filter(struct sock *sk, struct sk_buff *skb)
+static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
 {
-       int type;
+       struct icmphdr _hdr;
+       const struct icmphdr *hdr;
 
-       if (!pskb_may_pull(skb, sizeof(struct icmphdr)))
+       hdr = skb_header_pointer(skb, skb_transport_offset(skb),
+                                sizeof(_hdr), &_hdr);
+       if (!hdr)
                return 1;
 
-       type = icmp_hdr(skb)->type;
-       if (type < 32) {
+       if (hdr->type < 32) {
                __u32 data = raw_sk(sk)->filter.data;
 
-               return ((1 << type) & data) != 0;
+               return ((1U << hdr->type) & data) != 0;
        }
 
        /* Do not block unknown ICMP types */
@@ -327,6 +339,7 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4,
        unsigned int iphlen;
        int err;
        struct rtable *rt = *rtp;
+       int hlen, tlen;
 
        if (length > rt->dst.dev->mtu) {
                ip_local_error(sk, EMSGSIZE, fl4->daddr, inet->inet_dport,
@@ -336,12 +349,14 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4,
        if (flags&MSG_PROBE)
                goto out;
 
+       hlen = LL_RESERVED_SPACE(rt->dst.dev);
+       tlen = rt->dst.dev->needed_tailroom;
        skb = sock_alloc_send_skb(sk,
-                                 length + LL_ALLOCATED_SPACE(rt->dst.dev) + 15,
+                                 length + hlen + tlen + 15,
                                  flags & MSG_DONTWAIT, &err);
        if (skb == NULL)
                goto error;
-       skb_reserve(skb, LL_RESERVED_SPACE(rt->dst.dev));
+       skb_reserve(skb, hlen);
 
        skb->priority = sk->sk_priority;
        skb->mark = sk->sk_mark;
@@ -378,7 +393,7 @@ static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4,
                iph->check   = 0;
                iph->tot_len = htons(length);
                if (!iph->id)
-                       ip_select_ident(iph, &rt->dst, NULL);
+                       ip_select_ident(skb, NULL);
 
                iph->check = ip_fast_csum((unsigned char *)iph, iph->ihl);
        }
@@ -404,53 +419,57 @@ error:
        return err;
 }
 
-static int raw_probe_proto_opt(struct flowi4 *fl4, struct msghdr *msg)
+static int raw_probe_proto_opt(struct raw_frag_vec *rfv, struct flowi4 *fl4)
 {
-       struct iovec *iov;
-       u8 __user *type = NULL;
-       u8 __user *code = NULL;
-       int probed = 0;
-       unsigned int i;
+       int err;
 
-       if (!msg->msg_iov)
+       if (fl4->flowi4_proto != IPPROTO_ICMP)
                return 0;
 
-       for (i = 0; i < msg->msg_iovlen; i++) {
-               iov = &msg->msg_iov[i];
-               if (!iov)
-                       continue;
-
-               switch (fl4->flowi4_proto) {
-               case IPPROTO_ICMP:
-                       /* check if one-byte field is readable or not. */
-                       if (iov->iov_base && iov->iov_len < 1)
-                               break;
-
-                       if (!type) {
-                               type = iov->iov_base;
-                               /* check if code field is readable or not. */
-                               if (iov->iov_len > 1)
-                                       code = type + 1;
-                       } else if (!code)
-                               code = iov->iov_base;
-
-                       if (type && code) {
-                               if (get_user(fl4->fl4_icmp_type, type) ||
-                                   get_user(fl4->fl4_icmp_code, code))
-                                       return -EFAULT;
-                               probed = 1;
-                       }
-                       break;
-               default:
-                       probed = 1;
-                       break;
-               }
-               if (probed)
-                       break;
-       }
+       /* We only need the first two bytes. */
+       rfv->hlen = 2;
+
+       err = memcpy_fromiovec(rfv->hdr.c, rfv->iov, rfv->hlen);
+       if (err)
+               return err;
+
+       fl4->fl4_icmp_type = rfv->hdr.icmph.type;
+       fl4->fl4_icmp_code = rfv->hdr.icmph.code;
+
        return 0;
 }
 
+static int raw_getfrag(void *from, char *to, int offset, int len, int odd,
+                      struct sk_buff *skb)
+{
+       struct raw_frag_vec *rfv = from;
+
+       if (offset < rfv->hlen) {
+               int copy = min(rfv->hlen - offset, len);
+
+               if (skb->ip_summed == CHECKSUM_PARTIAL)
+                       memcpy(to, rfv->hdr.c + offset, copy);
+               else
+                       skb->csum = csum_block_add(
+                               skb->csum,
+                               csum_partial_copy_nocheck(rfv->hdr.c + offset,
+                                                         to, copy, 0),
+                               odd);
+
+               odd = 0;
+               offset += copy;
+               to += copy;
+               len -= copy;
+
+               if (!len)
+                       return 0;
+       }
+
+       offset -= rfv->hlen;
+
+       return ip_generic_getfrag(rfv->iov, to, offset, len, odd, skb);
+}
+
 static int raw_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
                       size_t len)
 {
@@ -464,6 +483,7 @@ static int raw_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
        u8  tos;
        int err;
        struct ip_options_data opt_copy;
+       struct raw_frag_vec rfv;
 
        err = -EMSGSIZE;
        if (len > 0xFFFF)
@@ -515,8 +535,10 @@ static int raw_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
 
        if (msg->msg_controllen) {
                err = ip_cmsg_send(sock_net(sk), msg, &ipc);
-               if (err)
+               if (unlikely(err)) {
+                       kfree(ipc.opt);
                        goto out;
+               }
                if (ipc.opt)
                        free = 1;
        }
@@ -568,7 +590,10 @@ static int raw_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
                           daddr, saddr, 0, 0);
 
        if (!inet->hdrincl) {
-               err = raw_probe_proto_opt(&fl4, msg);
+               rfv.iov = msg->msg_iov;
+               rfv.hlen = 0;
+
+               err = raw_probe_proto_opt(&rfv, &fl4);
                if (err)
                        goto done;
        }
@@ -597,8 +622,8 @@ back_from_confirm:
                if (!ipc.addr)
                        ipc.addr = fl4.daddr;
                lock_sock(sk);
-               err = ip_append_data(sk, &fl4, ip_generic_getfrag,
-                                    msg->msg_iov, len, 0,
+               err = ip_append_data(sk, &fl4, raw_getfrag,
+                                    &rfv, len, 0,
                                     &ipc, &rt, msg->msg_flags);
                if (err)
                        ip_flush_pending_frames(sk);
@@ -684,11 +709,8 @@ static int raw_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
        if (flags & MSG_OOB)
                goto out;
 
-       if (addr_len)
-               *addr_len = sizeof(*sin);
-
        if (flags & MSG_ERRQUEUE) {
-               err = ip_recv_error(sk, msg, len);
+               err = ip_recv_error(sk, msg, len, addr_len);
                goto out;
        }
 
@@ -714,6 +736,7 @@ static int raw_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
                sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
                sin->sin_port = 0;
                memset(&sin->sin_zero, 0, sizeof(sin->sin_zero));
+               *addr_len = sizeof(*sin);
        }
        if (inet->cmsg_flags)
                ip_cmsg_recv(msg, skb);