packet: Add fanout support.
[pandora-kernel.git] / net / packet / af_packet.c
index 925f715..3350f1d 100644 (file)
@@ -187,9 +187,11 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg);
 
 static void packet_flush_mclist(struct sock *sk);
 
+struct packet_fanout;
 struct packet_sock {
        /* struct sock has to be the first member of packet_sock */
        struct sock             sk;
+       struct packet_fanout    *fanout;
        struct tpacket_stats    stats;
        struct packet_ring_buffer       rx_ring;
        struct packet_ring_buffer       tx_ring;
@@ -212,6 +214,24 @@ struct packet_sock {
        struct packet_type      prot_hook ____cacheline_aligned_in_smp;
 };
 
+#define PACKET_FANOUT_MAX      256
+
+struct packet_fanout {
+#ifdef CONFIG_NET_NS
+       struct net              *net;
+#endif
+       unsigned int            num_members;
+       u16                     id;
+       u8                      type;
+       u8                      pad;
+       atomic_t                rr_cur;
+       struct list_head        list;
+       struct sock             *arr[PACKET_FANOUT_MAX];
+       spinlock_t              lock;
+       atomic_t                sk_ref;
+       struct packet_type      prot_hook ____cacheline_aligned_in_smp;
+};
+
 struct packet_skb_cb {
        unsigned int origlen;
        union {
@@ -222,6 +242,64 @@ struct packet_skb_cb {
 
 #define PACKET_SKB_CB(__skb)   ((struct packet_skb_cb *)((__skb)->cb))
 
+static inline struct packet_sock *pkt_sk(struct sock *sk)
+{
+       return (struct packet_sock *)sk;
+}
+
+static void __fanout_unlink(struct sock *sk, struct packet_sock *po);
+static void __fanout_link(struct sock *sk, struct packet_sock *po);
+
+/* register_prot_hook must be invoked with the po->bind_lock held,
+ * or from a context in which asynchronous accesses to the packet
+ * socket is not possible (packet_create()).
+ */
+static void register_prot_hook(struct sock *sk)
+{
+       struct packet_sock *po = pkt_sk(sk);
+       if (!po->running) {
+               if (po->fanout)
+                       __fanout_link(sk, po);
+               else
+                       dev_add_pack(&po->prot_hook);
+               sock_hold(sk);
+               po->running = 1;
+       }
+}
+
+/* {,__}unregister_prot_hook() must be invoked with the po->bind_lock
+ * held.   If the sync parameter is true, we will temporarily drop
+ * the po->bind_lock and do a synchronize_net to make sure no
+ * asynchronous packet processing paths still refer to the elements
+ * of po->prot_hook.  If the sync parameter is false, it is the
+ * callers responsibility to take care of this.
+ */
+static void __unregister_prot_hook(struct sock *sk, bool sync)
+{
+       struct packet_sock *po = pkt_sk(sk);
+
+       po->running = 0;
+       if (po->fanout)
+               __fanout_unlink(sk, po);
+       else
+               __dev_remove_pack(&po->prot_hook);
+       __sock_put(sk);
+
+       if (sync) {
+               spin_unlock(&po->bind_lock);
+               synchronize_net();
+               spin_lock(&po->bind_lock);
+       }
+}
+
+static void unregister_prot_hook(struct sock *sk, bool sync)
+{
+       struct packet_sock *po = pkt_sk(sk);
+
+       if (po->running)
+               __unregister_prot_hook(sk, sync);
+}
+
 static inline __pure struct page *pgv_to_page(void *addr)
 {
        if (is_vmalloc_addr(addr))
@@ -324,11 +402,6 @@ static inline void packet_increment_head(struct packet_ring_buffer *buff)
        buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;
 }
 
-static inline struct packet_sock *pkt_sk(struct sock *sk)
-{
-       return (struct packet_sock *)sk;
-}
-
 static void packet_sock_destruct(struct sock *sk)
 {
        skb_queue_purge(&sk->sk_error_queue);
@@ -344,6 +417,201 @@ static void packet_sock_destruct(struct sock *sk)
        sk_refcnt_debug_dec(sk);
 }
 
+static int fanout_rr_next(struct packet_fanout *f, unsigned int num)
+{
+       int x = atomic_read(&f->rr_cur) + 1;
+
+       if (x >= num)
+               x = 0;
+
+       return x;
+}
+
+static struct sock *fanout_demux_hash(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+{
+       u32 idx, hash = skb->rxhash;
+
+       idx = ((u64)hash * num) >> 32;
+
+       return f->arr[idx];
+}
+
+static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+{
+       int cur, old;
+
+       cur = atomic_read(&f->rr_cur);
+       while ((old = atomic_cmpxchg(&f->rr_cur, cur,
+                                    fanout_rr_next(f, num))) != cur)
+               cur = old;
+       return f->arr[cur];
+}
+
+static int packet_rcv_fanout_hash(struct sk_buff *skb, struct net_device *dev,
+                                 struct packet_type *pt, struct net_device *orig_dev)
+{
+       struct packet_fanout *f = pt->af_packet_priv;
+       unsigned int num = f->num_members;
+       struct packet_sock *po;
+       struct sock *sk;
+
+       if (!net_eq(dev_net(dev), read_pnet(&f->net)) ||
+           !num) {
+               kfree_skb(skb);
+               return 0;
+       }
+
+       skb_get_rxhash(skb);
+
+       sk = fanout_demux_hash(f, skb, num);
+       po = pkt_sk(sk);
+
+       return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
+}
+
+static int packet_rcv_fanout_lb(struct sk_buff *skb, struct net_device *dev,
+                               struct packet_type *pt, struct net_device *orig_dev)
+{
+       struct packet_fanout *f = pt->af_packet_priv;
+       unsigned int num = f->num_members;
+       struct packet_sock *po;
+       struct sock *sk;
+
+       if (!net_eq(dev_net(dev), read_pnet(&f->net)) ||
+           !num) {
+               kfree_skb(skb);
+               return 0;
+       }
+
+       sk = fanout_demux_lb(f, skb, num);
+       po = pkt_sk(sk);
+
+       return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
+}
+
+static DEFINE_MUTEX(fanout_mutex);
+static LIST_HEAD(fanout_list);
+
+static void __fanout_link(struct sock *sk, struct packet_sock *po)
+{
+       struct packet_fanout *f = po->fanout;
+
+       spin_lock(&f->lock);
+       f->arr[f->num_members] = sk;
+       smp_wmb();
+       f->num_members++;
+       spin_unlock(&f->lock);
+}
+
+static void __fanout_unlink(struct sock *sk, struct packet_sock *po)
+{
+       struct packet_fanout *f = po->fanout;
+       int i;
+
+       spin_lock(&f->lock);
+       for (i = 0; i < f->num_members; i++) {
+               if (f->arr[i] == sk)
+                       break;
+       }
+       BUG_ON(i >= f->num_members);
+       f->arr[i] = f->arr[f->num_members - 1];
+       f->num_members--;
+       spin_unlock(&f->lock);
+}
+
+static int fanout_add(struct sock *sk, u16 id, u8 type)
+{
+       struct packet_sock *po = pkt_sk(sk);
+       struct packet_fanout *f, *match;
+       int err;
+
+       switch (type) {
+       case PACKET_FANOUT_HASH:
+       case PACKET_FANOUT_LB:
+               break;
+       default:
+               return -EINVAL;
+       }
+
+       if (!po->running)
+               return -EINVAL;
+
+       if (po->fanout)
+               return -EALREADY;
+
+       mutex_lock(&fanout_mutex);
+       match = NULL;
+       list_for_each_entry(f, &fanout_list, list) {
+               if (f->id == id &&
+                   read_pnet(&f->net) == sock_net(sk)) {
+                       match = f;
+                       break;
+               }
+       }
+       if (!match) {
+               match = kzalloc(sizeof(*match), GFP_KERNEL);
+               if (match) {
+                       write_pnet(&match->net, sock_net(sk));
+                       match->id = id;
+                       match->type = type;
+                       atomic_set(&match->rr_cur, 0);
+                       INIT_LIST_HEAD(&match->list);
+                       spin_lock_init(&match->lock);
+                       atomic_set(&match->sk_ref, 0);
+                       match->prot_hook.type = po->prot_hook.type;
+                       match->prot_hook.dev = po->prot_hook.dev;
+                       switch (type) {
+                       case PACKET_FANOUT_HASH:
+                               match->prot_hook.func = packet_rcv_fanout_hash;
+                               break;
+                       case PACKET_FANOUT_LB:
+                               match->prot_hook.func = packet_rcv_fanout_lb;
+                               break;
+                       }
+                       match->prot_hook.af_packet_priv = match;
+                       dev_add_pack(&match->prot_hook);
+                       list_add(&match->list, &fanout_list);
+               }
+       }
+       err = -ENOMEM;
+       if (match) {
+               err = -EINVAL;
+               if (match->type == type &&
+                   match->prot_hook.type == po->prot_hook.type &&
+                   match->prot_hook.dev == po->prot_hook.dev) {
+                       err = -ENOSPC;
+                       if (atomic_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
+                               __dev_remove_pack(&po->prot_hook);
+                               po->fanout = match;
+                               atomic_inc(&match->sk_ref);
+                               __fanout_link(sk, po);
+                               err = 0;
+                       }
+               }
+       }
+       mutex_unlock(&fanout_mutex);
+       return err;
+}
+
+static void fanout_release(struct sock *sk)
+{
+       struct packet_sock *po = pkt_sk(sk);
+       struct packet_fanout *f;
+
+       f = po->fanout;
+       if (!f)
+               return;
+
+       po->fanout = NULL;
+
+       mutex_lock(&fanout_mutex);
+       if (atomic_dec_and_test(&f->sk_ref)) {
+               list_del(&f->list);
+               dev_remove_pack(&f->prot_hook);
+               kfree(f);
+       }
+       mutex_unlock(&fanout_mutex);
+}
 
 static const struct proto_ops packet_ops;
 
@@ -798,7 +1066,13 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                        getnstimeofday(&ts);
                h.h2->tp_sec = ts.tv_sec;
                h.h2->tp_nsec = ts.tv_nsec;
-               h.h2->tp_vlan_tci = vlan_tx_tag_get(skb);
+               if (vlan_tx_tag_present(skb)) {
+                       h.h2->tp_vlan_tci = vlan_tx_tag_get(skb);
+                       status |= TP_STATUS_VLAN_VALID;
+               } else {
+                       h.h2->tp_vlan_tci = 0;
+               }
+               h.h2->tp_padding = 0;
                hdrlen = sizeof(*h.h2);
                break;
        default:
@@ -969,7 +1243,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        struct sk_buff *skb;
        struct net_device *dev;
        __be16 proto;
-       int ifindex, err, reserve = 0;
+       bool need_rls_dev = false;
+       int err, reserve = 0;
        void *ph;
        struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
        int tp_len, size_max;
@@ -981,7 +1256,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 
        err = -EBUSY;
        if (saddr == NULL) {
-               ifindex = po->ifindex;
+               dev = po->prot_hook.dev;
                proto   = po->num;
                addr    = NULL;
        } else {
@@ -992,12 +1267,12 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                                        + offsetof(struct sockaddr_ll,
                                                sll_addr)))
                        goto out;
-               ifindex = saddr->sll_ifindex;
                proto   = saddr->sll_protocol;
                addr    = saddr->sll_addr;
+               dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
+               need_rls_dev = true;
        }
 
-       dev = dev_get_by_index(sock_net(&po->sk), ifindex);
        err = -ENXIO;
        if (unlikely(dev == NULL))
                goto out;
@@ -1083,7 +1358,8 @@ out_status:
        __packet_set_status(po, ph, status);
        kfree_skb(skb);
 out_put:
-       dev_put(dev);
+       if (need_rls_dev)
+               dev_put(dev);
 out:
        mutex_unlock(&po->pg_vec_lock);
        return err;
@@ -1121,8 +1397,9 @@ static int packet_snd(struct socket *sock,
        struct sk_buff *skb;
        struct net_device *dev;
        __be16 proto;
+       bool need_rls_dev = false;
        unsigned char *addr;
-       int ifindex, err, reserve = 0;
+       int err, reserve = 0;
        struct virtio_net_hdr vnet_hdr = { 0 };
        int offset = 0;
        int vnet_hdr_len;
@@ -1134,7 +1411,7 @@ static int packet_snd(struct socket *sock,
         */
 
        if (saddr == NULL) {
-               ifindex = po->ifindex;
+               dev = po->prot_hook.dev;
                proto   = po->num;
                addr    = NULL;
        } else {
@@ -1143,13 +1420,12 @@ static int packet_snd(struct socket *sock,
                        goto out;
                if (msg->msg_namelen < (saddr->sll_halen + offsetof(struct sockaddr_ll, sll_addr)))
                        goto out;
-               ifindex = saddr->sll_ifindex;
                proto   = saddr->sll_protocol;
                addr    = saddr->sll_addr;
+               dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
+               need_rls_dev = true;
        }
 
-
-       dev = dev_get_by_index(sock_net(sk), ifindex);
        err = -ENXIO;
        if (dev == NULL)
                goto out_unlock;
@@ -1280,14 +1556,15 @@ static int packet_snd(struct socket *sock,
        if (err > 0 && (err = net_xmit_errno(err)) != 0)
                goto out_unlock;
 
-       dev_put(dev);
+       if (need_rls_dev)
+               dev_put(dev);
 
        return len;
 
 out_free:
        kfree_skb(skb);
 out_unlock:
-       if (dev)
+       if (dev && need_rls_dev)
                dev_put(dev);
 out:
        return err;
@@ -1328,14 +1605,10 @@ static int packet_release(struct socket *sock)
        spin_unlock_bh(&net->packet.sklist_lock);
 
        spin_lock(&po->bind_lock);
-       if (po->running) {
-               /*
-                * Remove from protocol table
-                */
-               po->running = 0;
-               po->num = 0;
-               __dev_remove_pack(&po->prot_hook);
-               __sock_put(sk);
+       unregister_prot_hook(sk, false);
+       if (po->prot_hook.dev) {
+               dev_put(po->prot_hook.dev);
+               po->prot_hook.dev = NULL;
        }
        spin_unlock(&po->bind_lock);
 
@@ -1349,6 +1622,8 @@ static int packet_release(struct socket *sock)
        if (po->tx_ring.pg_vec)
                packet_set_ring(sk, &req, 1, 1);
 
+       fanout_release(sk);
+
        synchronize_net();
        /*
         *      Now the socket is dead. No more input will appear.
@@ -1372,24 +1647,18 @@ static int packet_release(struct socket *sock)
 static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protocol)
 {
        struct packet_sock *po = pkt_sk(sk);
-       /*
-        *      Detach an existing hook if present.
-        */
+
+       if (po->fanout)
+               return -EINVAL;
 
        lock_sock(sk);
 
        spin_lock(&po->bind_lock);
-       if (po->running) {
-               __sock_put(sk);
-               po->running = 0;
-               po->num = 0;
-               spin_unlock(&po->bind_lock);
-               dev_remove_pack(&po->prot_hook);
-               spin_lock(&po->bind_lock);
-       }
-
+       unregister_prot_hook(sk, true);
        po->num = protocol;
        po->prot_hook.type = protocol;
+       if (po->prot_hook.dev)
+               dev_put(po->prot_hook.dev);
        po->prot_hook.dev = dev;
 
        po->ifindex = dev ? dev->ifindex : 0;
@@ -1398,9 +1667,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
                goto out_unlock;
 
        if (!dev || (dev->flags & IFF_UP)) {
-               dev_add_pack(&po->prot_hook);
-               sock_hold(sk);
-               po->running = 1;
+               register_prot_hook(sk);
        } else {
                sk->sk_err = ENETDOWN;
                if (!sock_flag(sk, SOCK_DEAD))
@@ -1434,10 +1701,8 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
        strlcpy(name, uaddr->sa_data, sizeof(name));
 
        dev = dev_get_by_name(sock_net(sk), name);
-       if (dev) {
+       if (dev)
                err = packet_do_bind(sk, dev, pkt_sk(sk)->num);
-               dev_put(dev);
-       }
        return err;
 }
 
@@ -1465,8 +1730,6 @@ static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len
                        goto out;
        }
        err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num);
-       if (dev)
-               dev_put(dev);
 
 out:
        return err;
@@ -1531,9 +1794,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
 
        if (proto) {
                po->prot_hook.type = proto;
-               dev_add_pack(&po->prot_hook);
-               sock_hold(sk);
-               po->running = 1;
+               register_prot_hook(sk);
        }
 
        spin_lock_bh(&net->packet.sklist_lock);
@@ -1675,6 +1936,8 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                        vnet_hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
                        vnet_hdr.csum_start = skb_checksum_start_offset(skb);
                        vnet_hdr.csum_offset = skb->csum_offset;
+               } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
+                       vnet_hdr.flags = VIRTIO_NET_HDR_F_DATA_VALID;
                } /* else everything is zero */
 
                err = memcpy_toiovec(msg->msg_iov, (void *)&vnet_hdr,
@@ -1725,8 +1988,13 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                aux.tp_snaplen = skb->len;
                aux.tp_mac = 0;
                aux.tp_net = skb_network_offset(skb);
-               aux.tp_vlan_tci = vlan_tx_tag_get(skb);
-
+               if (vlan_tx_tag_present(skb)) {
+                       aux.tp_vlan_tci = vlan_tx_tag_get(skb);
+                       aux.tp_status |= TP_STATUS_VLAN_VALID;
+               } else {
+                       aux.tp_vlan_tci = 0;
+               }
+               aux.tp_padding = 0;
                put_cmsg(msg, SOL_PACKET, PACKET_AUXDATA, sizeof(aux), &aux);
        }
 
@@ -2091,6 +2359,17 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
                po->tp_tstamp = val;
                return 0;
        }
+       case PACKET_FANOUT:
+       {
+               int val;
+
+               if (optlen != sizeof(val))
+                       return -EINVAL;
+               if (copy_from_user(&val, optval, sizeof(val)))
+                       return -EFAULT;
+
+               return fanout_add(sk, val & 0xffff, val >> 16);
+       }
        default:
                return -ENOPROTOOPT;
        }
@@ -2189,6 +2468,15 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
                val = po->tp_tstamp;
                data = &val;
                break;
+       case PACKET_FANOUT:
+               if (len > sizeof(int))
+                       len = sizeof(int);
+               val = (po->fanout ?
+                      ((u32)po->fanout->id |
+                       ((u32)po->fanout->type << 16)) :
+                      0);
+               data = &val;
+               break;
        default:
                return -ENOPROTOOPT;
        }
@@ -2222,15 +2510,15 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
                        if (dev->ifindex == po->ifindex) {
                                spin_lock(&po->bind_lock);
                                if (po->running) {
-                                       __dev_remove_pack(&po->prot_hook);
-                                       __sock_put(sk);
-                                       po->running = 0;
+                                       __unregister_prot_hook(sk, false);
                                        sk->sk_err = ENETDOWN;
                                        if (!sock_flag(sk, SOCK_DEAD))
                                                sk->sk_error_report(sk);
                                }
                                if (msg == NETDEV_UNREGISTER) {
                                        po->ifindex = -1;
+                                       if (po->prot_hook.dev)
+                                               dev_put(po->prot_hook.dev);
                                        po->prot_hook.dev = NULL;
                                }
                                spin_unlock(&po->bind_lock);
@@ -2239,11 +2527,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
                case NETDEV_UP:
                        if (dev->ifindex == po->ifindex) {
                                spin_lock(&po->bind_lock);
-                               if (po->num && !po->running) {
-                                       dev_add_pack(&po->prot_hook);
-                                       sock_hold(sk);
-                                       po->running = 1;
-                               }
+                               if (po->num)
+                                       register_prot_hook(sk);
                                spin_unlock(&po->bind_lock);
                        }
                        break;
@@ -2510,10 +2795,8 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
        was_running = po->running;
        num = po->num;
        if (was_running) {
-               __dev_remove_pack(&po->prot_hook);
                po->num = 0;
-               po->running = 0;
-               __sock_put(sk);
+               __unregister_prot_hook(sk, false);
        }
        spin_unlock(&po->bind_lock);
 
@@ -2544,11 +2827,9 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
        mutex_unlock(&po->pg_vec_lock);
 
        spin_lock(&po->bind_lock);
-       if (was_running && !po->running) {
-               sock_hold(sk);
-               po->running = 1;
+       if (was_running) {
                po->num = num;
-               dev_add_pack(&po->prot_hook);
+               register_prot_hook(sk);
        }
        spin_unlock(&po->bind_lock);