sctp: use the right sk after waking up from wait_buf sleep
[pandora-kernel.git] / net / sctp / socket.c
index 54a7cd2..f4477d7 100644 (file)
@@ -70,6 +70,8 @@
 #include <linux/init.h>
 #include <linux/crypto.h>
 #include <linux/slab.h>
+#include <linux/compat.h>
+#include <linux/nsproxy.h>
 
 #include <net/ip.h>
 #include <net/icmp.h>
@@ -91,8 +93,8 @@
 /* Forward declarations for internal helper functions. */
 static int sctp_writeable(struct sock *sk);
 static void sctp_wfree(struct sk_buff *skb);
-static int sctp_wait_for_sndbuf(struct sctp_association *, long *timeo_p,
-                               size_t msg_len);
+static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
+                               size_t msg_len, struct sock **orig_sk);
 static int sctp_wait_for_packet(struct sock * sk, int *err, long *timeo_p);
 static int sctp_wait_for_connect(struct sctp_association *, long *timeo_p);
 static int sctp_wait_for_accept(struct sock *sk, long timeo);
@@ -263,7 +265,7 @@ static struct sctp_transport *sctp_addr_id2transport(struct sock *sk,
        if (id_asoc && (id_asoc != addr_asoc))
                return NULL;
 
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk),
                                                (union sctp_addr *)addr);
 
        return transport;
@@ -408,7 +410,7 @@ SCTP_STATIC int sctp_do_bind(struct sock *sk, union sctp_addr *addr, int len)
        /* Copy back into socket for getsockname() use. */
        if (!ret) {
                inet_sk(sk)->inet_sport = htons(inet_sk(sk)->inet_num);
-               af->to_sk_saddr(addr, sk);
+               sp->pf->to_sk_saddr(addr, sk);
        }
 
        return ret;
@@ -814,6 +816,9 @@ static int sctp_send_asconf_del_ip(struct sock              *sk,
                        goto skip_mkasconf;
                }
 
+               if (laddr == NULL)
+                       return -EINVAL;
+
                /* We do not need RCU protection throughout this loop
                 * because this is done under a socket lock from the
                 * setsockopt call.
@@ -1056,7 +1061,6 @@ static int __sctp_connect(struct sock* sk,
        struct sctp_association *asoc2;
        struct sctp_transport *transport;
        union sctp_addr to;
-       struct sctp_af *af;
        sctp_scope_t scope;
        long timeo;
        int err = 0;
@@ -1084,6 +1088,8 @@ static int __sctp_connect(struct sock* sk,
        /* Walk through the addrs buffer and count the number of addresses. */
        addr_buf = kaddrs;
        while (walk_size < addrs_size) {
+               struct sctp_af *af;
+
                if (walk_size + sizeof(sa_family_t) > addrs_size) {
                        err = -EINVAL;
                        goto out_free;
@@ -1207,8 +1213,7 @@ static int __sctp_connect(struct sock* sk,
 
        /* Initialize sk's dport and daddr for getpeername() */
        inet_sk(sk)->inet_dport = htons(asoc->peer.port);
-       af = sctp_get_af_specific(sa_addr->sa.sa_family);
-       af->to_sk_daddr(sa_addr, sk);
+       sp->pf->to_sk_daddr(sa_addr, sk);
        sk->sk_err = 0;
 
        /* in-kernel sockets don't generally have a file allocated to them
@@ -1219,9 +1224,12 @@ static int __sctp_connect(struct sock* sk,
 
        timeo = sock_sndtimeo(sk, f_flags & O_NONBLOCK);
 
-       err = sctp_wait_for_connect(asoc, &timeo);
-       if ((err == 0 || err == -EINPROGRESS) && assoc_id)
+       if (assoc_id)
                *assoc_id = asoc->assoc_id;
+       err = sctp_wait_for_connect(asoc, &timeo);
+       /* Note: the asoc may be freed after the return of
+        * sctp_wait_for_connect.
+        */
 
        /* Don't free association on exit. */
        asoc = NULL;
@@ -1231,8 +1239,14 @@ out_free:
        SCTP_DEBUG_PRINTK("About to exit __sctp_connect() free asoc: %p"
                          " kaddrs: %p err: %d\n",
                          asoc, kaddrs, err);
-       if (asoc)
+       if (asoc) {
+               /* sctp_primitive_ASSOCIATE may have added this association
+                * To the hash table, try to unhash it, just in case, its a noop
+                * if it wasn't hashed so we're safe
+                */
+               sctp_unhash_established(asoc);
                sctp_association_free(asoc);
+       }
        return err;
 }
 
@@ -1367,11 +1381,19 @@ SCTP_STATIC int sctp_setsockopt_connectx(struct sock* sk,
 /*
  * New (hopefully final) interface for the API.
  * We use the sctp_getaddrs_old structure so that use-space library
- * can avoid any unnecessary allocations.   The only defferent part
+ * can avoid any unnecessary allocations. The only different part
  * is that we store the actual length of the address buffer into the
- * addrs_num structure member.  That way we can re-use the existing
+ * addrs_num structure member. That way we can re-use the existing
  * code.
  */
+#ifdef CONFIG_COMPAT
+struct compat_sctp_getaddrs_old {
+       sctp_assoc_t    assoc_id;
+       s32             addr_num;
+       compat_uptr_t   addrs;          /* struct sockaddr * */
+};
+#endif
+
 SCTP_STATIC int sctp_getsockopt_connectx3(struct sock* sk, int len,
                                        char __user *optval,
                                        int __user *optlen)
@@ -1380,16 +1402,30 @@ SCTP_STATIC int sctp_getsockopt_connectx3(struct sock* sk, int len,
        sctp_assoc_t assoc_id = 0;
        int err = 0;
 
-       if (len < sizeof(param))
-               return -EINVAL;
+#ifdef CONFIG_COMPAT
+       if (is_compat_task()) {
+               struct compat_sctp_getaddrs_old param32;
 
-       if (copy_from_user(&param, optval, sizeof(param)))
-               return -EFAULT;
+               if (len < sizeof(param32))
+                       return -EINVAL;
+               if (copy_from_user(&param32, optval, sizeof(param32)))
+                       return -EFAULT;
 
-       err = __sctp_setsockopt_connectx(sk,
-                       (struct sockaddr __user *)param.addrs,
-                       param.addr_num, &assoc_id);
+               param.assoc_id = param32.assoc_id;
+               param.addr_num = param32.addr_num;
+               param.addrs = compat_ptr(param32.addrs);
+       } else
+#endif
+       {
+               if (len < sizeof(param))
+                       return -EINVAL;
+               if (copy_from_user(&param, optval, sizeof(param)))
+                       return -EFAULT;
+       }
 
+       err = __sctp_setsockopt_connectx(sk, (struct sockaddr __user *)
+                                        param.addrs, param.addr_num,
+                                        &assoc_id);
        if (err == 0 || err == -EINPROGRESS) {
                if (copy_to_user(optval, &assoc_id, sizeof(assoc_id)))
                        return -EFAULT;
@@ -1492,8 +1528,7 @@ SCTP_STATIC void sctp_close(struct sock *sk, long timeout)
                        struct sctp_chunk *chunk;
 
                        chunk = sctp_make_abort_user(asoc, NULL, 0);
-                       if (chunk)
-                               sctp_primitive_ABORT(asoc, chunk);
+                       sctp_primitive_ABORT(asoc, chunk);
                } else
                        sctp_primitive_SHUTDOWN(asoc, NULL);
        }
@@ -1507,8 +1542,10 @@ SCTP_STATIC void sctp_close(struct sock *sk, long timeout)
 
        /* Supposedly, no process has access to the socket, but
         * the net layers still may.
+        * Also, sctp_destroy_sock() needs to be called with addr_wq_lock
+        * held and that should be grabbed before socket lock.
         */
-       sctp_local_bh_disable();
+       spin_lock_bh(&sctp_globals.addr_wq_lock);
        sctp_bh_lock_sock(sk);
 
        /* Hold the sock, since sk_common_release() will put sock_put()
@@ -1518,7 +1555,7 @@ SCTP_STATIC void sctp_close(struct sock *sk, long timeout)
        sk_common_release(sk);
 
        sctp_bh_unlock_sock(sk);
-       sctp_local_bh_enable();
+       spin_unlock_bh(&sctp_globals.addr_wq_lock);
 
        sock_put(sk);
 
@@ -1579,6 +1616,7 @@ SCTP_STATIC int sctp_sendmsg(struct kiocb *iocb, struct sock *sk,
        sctp_scope_t scope;
        long timeo;
        __u16 sinfo_flags = 0;
+       bool wait_connect = false;
        struct sctp_datamsg *datamsg;
        int msg_flags = msg->msg_flags;
 
@@ -1873,7 +1911,8 @@ SCTP_STATIC int sctp_sendmsg(struct kiocb *iocb, struct sock *sk,
 
        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
        if (!sctp_wspace(asoc)) {
-               err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len);
+               /* sk can be changed by peel off when waiting for buf. */
+               err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len, &sk);
                if (err)
                        goto out_free;
        }
@@ -1897,13 +1936,14 @@ SCTP_STATIC int sctp_sendmsg(struct kiocb *iocb, struct sock *sk,
                err = sctp_primitive_ASSOCIATE(asoc, NULL);
                if (err < 0)
                        goto out_free;
+               wait_connect = true;
                SCTP_DEBUG_PRINTK("We associated primitively.\n");
        }
 
        /* Break the message into multiple chunks of maximum size. */
        datamsg = sctp_datamsg_from_user(asoc, sinfo, msg, msg_len);
-       if (!datamsg) {
-               err = -ENOMEM;
+       if (IS_ERR(datamsg)) {
+               err = PTR_ERR(datamsg);
                goto out_free;
        }
 
@@ -1936,14 +1976,21 @@ SCTP_STATIC int sctp_sendmsg(struct kiocb *iocb, struct sock *sk,
        else
                err = msg_len;
 
+       if (unlikely(wait_connect)) {
+               timeo = sock_sndtimeo(sk, msg_flags & MSG_DONTWAIT);
+               sctp_wait_for_connect(asoc, &timeo);
+       }
+
        /* If we are already past ASSOCIATE, the lower
         * layers are responsible for association cleanup.
         */
        goto out_unlock;
 
 out_free:
-       if (new_asoc)
+       if (new_asoc) {
+               sctp_unhash_established(asoc);
                sctp_association_free(asoc);
+       }
 out_unlock:
        sctp_release_sock(sk);
 
@@ -3367,7 +3414,7 @@ static int sctp_setsockopt_auth_key(struct sock *sk,
 
        ret = sctp_auth_set_key(sctp_sk(sk)->ep, asoc, authkey);
 out:
-       kfree(authkey);
+       kzfree(authkey);
        return ret;
 }
 
@@ -3458,6 +3505,7 @@ static int sctp_setsockopt_auto_asconf(struct sock *sk, char __user *optval,
        if ((val && sp->do_auto_asconf) || (!val && !sp->do_auto_asconf))
                return 0;
 
+       spin_lock_bh(&sctp_globals.addr_wq_lock);
        if (val == 0 && sp->do_auto_asconf) {
                list_del(&sp->auto_asconf_list);
                sp->do_auto_asconf = 0;
@@ -3466,6 +3514,7 @@ static int sctp_setsockopt_auto_asconf(struct sock *sk, char __user *optval,
                    &sctp_auto_asconf_splist);
                sp->do_auto_asconf = 1;
        }
+       spin_unlock_bh(&sctp_globals.addr_wq_lock);
        return 0;
 }
 
@@ -3901,18 +3950,28 @@ SCTP_STATIC int sctp_init_sock(struct sock *sk)
        local_bh_disable();
        percpu_counter_inc(&sctp_sockets_allocated);
        sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
+
+       /* Nothing can fail after this block, otherwise
+        * sctp_destroy_sock() will be called without addr_wq_lock held
+        */
        if (sctp_default_auto_asconf) {
+               spin_lock(&sctp_globals.addr_wq_lock);
                list_add_tail(&sp->auto_asconf_list,
                    &sctp_auto_asconf_splist);
                sp->do_auto_asconf = 1;
-       } else
+               spin_unlock(&sctp_globals.addr_wq_lock);
+       } else {
                sp->do_auto_asconf = 0;
+       }
+
        local_bh_enable();
 
        return 0;
 }
 
-/* Cleanup any SCTP per socket resources.  */
+/* Cleanup any SCTP per socket resources. Must be called with
+ * sctp_globals.addr_wq_lock held if sp->do_auto_asconf is true
+ */
 SCTP_STATIC void sctp_destroy_sock(struct sock *sk)
 {
        struct sctp_sock *sp;
@@ -3921,6 +3980,12 @@ SCTP_STATIC void sctp_destroy_sock(struct sock *sk)
 
        /* Release our hold on the endpoint. */
        sp = sctp_sk(sk);
+       /* This could happen during socket init, thus we bail out
+        * early, since the rest of the below is not setup either.
+        */
+       if (sp->ep == NULL)
+               return;
+
        if (sp->do_auto_asconf) {
                sp->do_auto_asconf = 0;
                list_del(&sp->auto_asconf_list);
@@ -4016,7 +4081,7 @@ static int sctp_getsockopt_sctp_status(struct sock *sk, int len,
        memcpy(&status.sstat_primary.spinfo_address, &transport->ipaddr,
                        transport->af_specific->sockaddr_len);
        /* Map ipv4 address into v4-mapped-on-v6 address.  */
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk),
                (union sctp_addr *)&status.sstat_primary.spinfo_address);
        status.sstat_primary.spinfo_state = transport->state;
        status.sstat_primary.spinfo_cwnd = transport->cwnd;
@@ -4133,9 +4198,10 @@ static int sctp_getsockopt_disable_fragments(struct sock *sk, int len,
 static int sctp_getsockopt_events(struct sock *sk, int len, char __user *optval,
                                  int __user *optlen)
 {
-       if (len < sizeof(struct sctp_event_subscribe))
+       if (len == 0)
                return -EINVAL;
-       len = sizeof(struct sctp_event_subscribe);
+       if (len > sizeof(struct sctp_event_subscribe))
+               len = sizeof(struct sctp_event_subscribe);
        if (put_user(len, optlen))
                return -EFAULT;
        if (copy_to_user(optval, &sctp_sk(sk)->subscribe, len))
@@ -4174,10 +4240,14 @@ SCTP_STATIC int sctp_do_peeloff(struct sctp_association *asoc,
                                struct socket **sockp)
 {
        struct sock *sk = asoc->base.sk;
+       struct sctp_sock *sp = sctp_sk(sk);
        struct socket *sock;
-       struct sctp_af *af;
        int err = 0;
 
+       /* Do not peel off from one netns to another one. */
+       if (!net_eq(current->nsproxy->net_ns, sock_net(sk)))
+               return -EINVAL;
+
        /* An association cannot be branched off from an already peeled-off
         * socket, nor is this supported for tcp style sockets.
         */
@@ -4194,8 +4264,7 @@ SCTP_STATIC int sctp_do_peeloff(struct sctp_association *asoc,
        /* Make peeled-off sockets more like 1-1 accepted sockets.
         * Set the daddr and initialize id to something more random
         */
-       af = sctp_get_af_specific(asoc->peer.primary_addr.sa.sa_family);
-       af->to_sk_daddr(&asoc->peer.primary_addr, sk);
+       sp->pf->to_sk_daddr(&asoc->peer.primary_addr, sk);
 
        /* Populate the fields of the newsk from the oldsk and migrate the
         * asoc to the newsk.
@@ -4570,8 +4639,8 @@ static int sctp_getsockopt_peer_addrs(struct sock *sk, int len,
        list_for_each_entry(from, &asoc->peer.transport_addr_list,
                                transports) {
                memcpy(&temp, &from->ipaddr, sizeof(temp));
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp, &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sp, &temp);
                if (space_left < addrlen)
                        return -ENOMEM;
                if (copy_to_user(to, &temp, addrlen))
@@ -4614,9 +4683,9 @@ static int sctp_copy_laddrs(struct sock *sk, __u16 port, void *to,
                if (!temp.v4.sin_port)
                        temp.v4.sin_port = htons(port);
 
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
-                                                               &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sctp_sk(sk), &temp);
+
                if (space_left < addrlen) {
                        cnt =  -ENOMEM;
                        break;
@@ -4704,8 +4773,8 @@ static int sctp_getsockopt_local_addrs(struct sock *sk, int len,
         */
        list_for_each_entry(addr, &bp->address_list, list) {
                memcpy(&temp, &addr->a, sizeof(temp));
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp, &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sp, &temp);
                if (space_left < addrlen) {
                        err =  -ENOMEM; /*fixme: right error?*/
                        goto out;
@@ -4764,7 +4833,7 @@ static int sctp_getsockopt_primary_addr(struct sock *sk, int len,
        memcpy(&prim.ssp_addr, &asoc->peer.primary_path->ipaddr,
                asoc->peer.primary_path->af_specific->sockaddr_len);
 
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp,
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sp,
                        (union sctp_addr *)&prim.ssp_addr);
 
        if (put_user(len, optlen))
@@ -5247,6 +5316,7 @@ static int sctp_getsockopt_hmac_ident(struct sock *sk, int len,
        struct sctp_hmac_algo_param *hmacs;
        __u16 data_len = 0;
        u32 num_idents;
+       int i;
 
        if (!sctp_auth_enable)
                return -EACCES;
@@ -5264,8 +5334,12 @@ static int sctp_getsockopt_hmac_ident(struct sock *sk, int len,
                return -EFAULT;
        if (put_user(num_idents, &p->shmac_num_idents))
                return -EFAULT;
-       if (copy_to_user(p->shmac_idents, hmacs->hmac_ids, data_len))
-               return -EFAULT;
+       for (i = 0; i < num_idents; i++) {
+               __u16 hmacid = ntohs(hmacs->hmac_ids[i]);
+
+               if (copy_to_user(&p->shmac_idents[i], &hmacid, sizeof(__u16)))
+                       return -EFAULT;
+       }
        return 0;
 }
 
@@ -5520,6 +5594,9 @@ SCTP_STATIC int sctp_getsockopt(struct sock *sk, int level, int optname,
        if (get_user(len, optlen))
                return -EFAULT;
 
+       if (len < 0)
+               return -EINVAL;
+
        sctp_lock_sock(sk);
 
        switch (optname) {
@@ -6179,6 +6256,7 @@ SCTP_STATIC int sctp_msghdr_parse(const struct msghdr *msg,
                        /* Minimally, validate the sinfo_flags. */
                        if (cmsgs->info->sinfo_flags &
                            ~(SCTP_UNORDERED | SCTP_ADDR_OVER |
+                             SCTP_SACK_IMMEDIATELY |
                              SCTP_ABORT | SCTP_EOF))
                                return -EINVAL;
                        break;
@@ -6387,7 +6465,7 @@ void sctp_sock_rfree(struct sk_buff *skb)
 
 /* Helper function to wait for space in the sndbuf.  */
 static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
-                               size_t msg_len)
+                               size_t msg_len, struct sock **orig_sk)
 {
        struct sock *sk = asoc->base.sk;
        int err = 0;
@@ -6419,13 +6497,18 @@ static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
                 */
                sctp_release_sock(sk);
                current_timeo = schedule_timeout(current_timeo);
-               BUG_ON(sk != asoc->base.sk);
                sctp_lock_sock(sk);
+               if (sk != asoc->base.sk) {
+                       release_sock(sk);
+                       sk = asoc->base.sk;
+                       lock_sock(sk);
+               }
 
                *timeo_p = current_timeo;
        }
 
 out:
+       *orig_sk = sk;
        finish_wait(&asoc->wait, &wait);
 
        /* Release the association's refcnt.  */
@@ -6669,6 +6752,22 @@ void sctp_copy_sock(struct sock *newsk, struct sock *sk,
        newinet->mc_ttl = 1;
        newinet->mc_index = 0;
        newinet->mc_list = NULL;
+
+       if (newsk->sk_flags & SK_FLAGS_TIMESTAMP)
+               net_enable_timestamp();
+}
+
+static inline void sctp_copy_descendant(struct sock *sk_to,
+                                       const struct sock *sk_from)
+{
+       int ancestor_size = sizeof(struct inet_sock) +
+                           sizeof(struct sctp_sock) -
+                           offsetof(struct sctp_sock, auto_asconf_list);
+
+       if (sk_from->sk_family == PF_INET6)
+               ancestor_size += sizeof(struct ipv6_pinfo);
+
+       __inet_sk_copy_descendant(sk_to, sk_from, ancestor_size);
 }
 
 /* Populate the fields of the newsk from the oldsk and migrate the assoc
@@ -6685,7 +6784,6 @@ static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk,
        struct sk_buff *skb, *tmp;
        struct sctp_ulpevent *event;
        struct sctp_bind_hashbucket *head;
-       struct list_head tmplist;
 
        /* Migrate socket buffer sizes and all the socket level options to the
         * new socket.
@@ -6693,12 +6791,7 @@ static void sctp_sock_migrate(struct sock *oldsk, struct sock *newsk,
        newsk->sk_sndbuf = oldsk->sk_sndbuf;
        newsk->sk_rcvbuf = oldsk->sk_rcvbuf;
        /* Brute force copy old sctp opt. */
-       if (oldsp->do_auto_asconf) {
-               memcpy(&tmplist, &newsp->auto_asconf_list, sizeof(tmplist));
-               inet_sk_copy_descendant(newsk, oldsk);
-               memcpy(&newsp->auto_asconf_list, &tmplist, sizeof(tmplist));
-       } else
-               inet_sk_copy_descendant(newsk, oldsk);
+       sctp_copy_descendant(newsk, oldsk);
 
        /* Restore the ep value that was overwritten with the above structure
         * copy.
@@ -6841,6 +6934,13 @@ struct proto sctp_prot = {
 
 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
 
+#include <net/transp_v6.h>
+static void sctp_v6_destroy_sock(struct sock *sk)
+{
+       sctp_destroy_sock(sk);
+       inet6_destroy_sock(sk);
+}
+
 struct proto sctpv6_prot = {
        .name           = "SCTPv6",
        .owner          = THIS_MODULE,
@@ -6850,7 +6950,7 @@ struct proto sctpv6_prot = {
        .accept         = sctp_accept,
        .ioctl          = sctp_ioctl,
        .init           = sctp_init_sock,
-       .destroy        = sctp_destroy_sock,
+       .destroy        = sctp_v6_destroy_sock,
        .shutdown       = sctp_shutdown,
        .setsockopt     = sctp_setsockopt,
        .getsockopt     = sctp_getsockopt,