pandora: defconfig: update
[pandora-kernel.git] / net / packet / af_packet.c
index 3a5a58f..34cc788 100644 (file)
@@ -1294,10 +1294,6 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
 
        mutex_lock(&fanout_mutex);
 
-       err = -EINVAL;
-       if (!po->running)
-               goto out;
-
        err = -EALREADY;
        if (po->fanout)
                goto out;
@@ -1334,7 +1330,10 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                list_add(&match->list, &fanout_list);
        }
        err = -EINVAL;
-       if (match->type == type &&
+
+       spin_lock(&po->bind_lock);
+       if (po->running &&
+           match->type == type &&
            match->prot_hook.type == po->prot_hook.type &&
            match->prot_hook.dev == po->prot_hook.dev) {
                err = -ENOSPC;
@@ -1346,6 +1345,13 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                        err = 0;
                }
        }
+       spin_unlock(&po->bind_lock);
+
+       if (err && !atomic_read(&match->sk_ref)) {
+               list_del(&match->list);
+               kfree(match);
+       }
+
 out:
        mutex_unlock(&fanout_mutex);
        return err;
@@ -2225,6 +2231,7 @@ static int packet_snd(struct socket *sock,
        int offset = 0;
        int vnet_hdr_len;
        struct packet_sock *po = pkt_sk(sk);
+       bool has_vnet_hdr = false;
        unsigned short gso_type = 0;
 
        /*
@@ -2257,6 +2264,7 @@ static int packet_snd(struct socket *sock,
                reserve = dev->hard_header_len;
        if (po->has_vnet_hdr) {
                vnet_hdr_len = sizeof(vnet_hdr);
+               has_vnet_hdr = true;
 
                err = -EINVAL;
                if (len < vnet_hdr_len)
@@ -2348,7 +2356,7 @@ static int packet_snd(struct socket *sock,
        skb->priority = sk->sk_priority;
        skb->mark = sk->sk_mark;
 
-       if (po->has_vnet_hdr) {
+       if (has_vnet_hdr) {
                if (vnet_hdr.flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) {
                        if (!skb_partial_csum_set(skb, vnet_hdr.csum_start,
                                                  vnet_hdr.csum_offset)) {
@@ -2468,33 +2476,77 @@ static int packet_release(struct socket *sock)
  *     Attach a packet hook.
  */
 
-static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protocol)
+static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
+                         __be16 protocol)
 {
        struct packet_sock *po = pkt_sk(sk);
+       struct net_device *dev_curr;
+       struct net_device *dev = NULL;
+       int ret = 0;
+       bool unlisted = false;
+
+       lock_sock(sk);
+
+       spin_lock(&po->bind_lock);
+       rcu_read_lock();
 
        if (po->fanout) {
-               if (dev)
-                       dev_put(dev);
+               ret = -EINVAL;
+               goto out_unlock;
+       }
 
-               return -EINVAL;
+       if (name) {
+               dev = dev_get_by_name_rcu(sock_net(sk), name);
+               if (!dev) {
+                       ret = -ENODEV;
+                       goto out_unlock;
+               }
+       } else if (ifindex) {
+               dev = dev_get_by_index_rcu(sock_net(sk), ifindex);
+               if (!dev) {
+                       ret = -ENODEV;
+                       goto out_unlock;
+               }
        }
 
-       lock_sock(sk);
+       if (dev)
+               dev_hold(dev);
 
-       spin_lock(&po->bind_lock);
-       unregister_prot_hook(sk, true);
+       dev_curr = po->prot_hook.dev;
+
+       if (po->running) {
+               rcu_read_unlock();
+               /* prevents packet_notifier() from calling
+                * register_prot_hook()
+                */
+               po->num = 0;
+               __unregister_prot_hook(sk, true);
+               rcu_read_lock();
+               dev_curr = po->prot_hook.dev;
+               if (dev)
+                       unlisted = !dev_get_by_index_rcu(sock_net(sk),
+                                                        dev->ifindex);
+       }
+       BUG_ON(po->running);
        po->num = protocol;
        po->prot_hook.type = protocol;
-       if (po->prot_hook.dev)
-               dev_put(po->prot_hook.dev);
-       po->prot_hook.dev = dev;
 
-       po->ifindex = dev ? dev->ifindex : 0;
+       if (unlikely(unlisted)) {
+               dev_put(dev);
+               po->prot_hook.dev = NULL;
+               po->ifindex = -1;
+       } else {
+               po->prot_hook.dev = dev;
+               po->ifindex = dev ? dev->ifindex : 0;
+       }
+
+       if (dev_curr)
+               dev_put(dev_curr);
 
        if (protocol == 0)
                goto out_unlock;
 
-       if (!dev || (dev->flags & IFF_UP)) {
+       if (!unlisted && (!dev || (dev->flags & IFF_UP))) {
                register_prot_hook(sk);
        } else {
                sk->sk_err = ENETDOWN;
@@ -2503,9 +2555,10 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
        }
 
 out_unlock:
+       rcu_read_unlock();
        spin_unlock(&po->bind_lock);
        release_sock(sk);
-       return 0;
+       return ret;
 }
 
 /*
@@ -2517,8 +2570,6 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
 {
        struct sock *sk = sock->sk;
        char name[15];
-       struct net_device *dev;
-       int err = -ENODEV;
 
        /*
         *      Check legality
@@ -2528,19 +2579,13 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
                return -EINVAL;
        strlcpy(name, uaddr->sa_data, sizeof(name));
 
-       dev = dev_get_by_name(sock_net(sk), name);
-       if (dev)
-               err = packet_do_bind(sk, dev, pkt_sk(sk)->num);
-       return err;
+       return packet_do_bind(sk, name, 0, pkt_sk(sk)->num);
 }
 
 static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
 {
        struct sockaddr_ll *sll = (struct sockaddr_ll *)uaddr;
        struct sock *sk = sock->sk;
-       struct net_device *dev = NULL;
-       int err;
-
 
        /*
         *      Check legality
@@ -2551,16 +2596,8 @@ static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len
        if (sll->sll_family != AF_PACKET)
                return -EINVAL;
 
-       if (sll->sll_ifindex) {
-               err = -ENODEV;
-               dev = dev_get_by_index(sock_net(sk), sll->sll_ifindex);
-               if (dev == NULL)
-                       goto out;
-       }
-       err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num);
-
-out:
-       return err;
+       return packet_do_bind(sk, NULL, sll->sll_ifindex,
+                             sll->sll_protocol ? : pkt_sk(sk)->num);
 }
 
 static struct proto packet_proto = {
@@ -3132,12 +3169,19 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
 
                if (optlen != sizeof(val))
                        return -EINVAL;
-               if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
-                       return -EBUSY;
                if (copy_from_user(&val, optval, sizeof(val)))
                        return -EFAULT;
-               po->tp_reserve = val;
-               return 0;
+               if (val > INT_MAX)
+                       return -EINVAL;
+               lock_sock(sk);
+               if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
+                       ret = -EBUSY;
+               } else {
+                       po->tp_reserve = val;
+                       ret = 0;
+               }
+               release_sock(sk);
+               return ret;
        }
        case PACKET_LOSS:
        {
@@ -3649,6 +3693,8 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
                rb->frames_per_block = req->tp_block_size/req->tp_frame_size;
                if (unlikely(rb->frames_per_block <= 0))
                        goto out;
+               if (unlikely(req->tp_block_size > UINT_MAX / req->tp_block_nr))
+                       goto out;
                if (unlikely((rb->frames_per_block * req->tp_block_nr) !=
                                        req->tp_frame_nr))
                        goto out;