sctp: use the right sk after waking up from wait_buf sleep
[pandora-kernel.git] / net / sctp / socket.c
index ccbd6d0..f4477d7 100644 (file)
@@ -71,6 +71,7 @@
 #include <linux/crypto.h>
 #include <linux/slab.h>
 #include <linux/compat.h>
+#include <linux/nsproxy.h>
 
 #include <net/ip.h>
 #include <net/icmp.h>
@@ -92,8 +93,8 @@
 /* Forward declarations for internal helper functions. */
 static int sctp_writeable(struct sock *sk);
 static void sctp_wfree(struct sk_buff *skb);
-static int sctp_wait_for_sndbuf(struct sctp_association *, long *timeo_p,
-                               size_t msg_len);
+static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
+                               size_t msg_len, struct sock **orig_sk);
 static int sctp_wait_for_packet(struct sock * sk, int *err, long *timeo_p);
 static int sctp_wait_for_connect(struct sctp_association *, long *timeo_p);
 static int sctp_wait_for_accept(struct sock *sk, long timeo);
@@ -264,7 +265,7 @@ static struct sctp_transport *sctp_addr_id2transport(struct sock *sk,
        if (id_asoc && (id_asoc != addr_asoc))
                return NULL;
 
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk),
                                                (union sctp_addr *)addr);
 
        return transport;
@@ -409,7 +410,7 @@ SCTP_STATIC int sctp_do_bind(struct sock *sk, union sctp_addr *addr, int len)
        /* Copy back into socket for getsockname() use. */
        if (!ret) {
                inet_sk(sk)->inet_sport = htons(inet_sk(sk)->inet_num);
-               af->to_sk_saddr(addr, sk);
+               sp->pf->to_sk_saddr(addr, sk);
        }
 
        return ret;
@@ -1060,7 +1061,6 @@ static int __sctp_connect(struct sock* sk,
        struct sctp_association *asoc2;
        struct sctp_transport *transport;
        union sctp_addr to;
-       struct sctp_af *af;
        sctp_scope_t scope;
        long timeo;
        int err = 0;
@@ -1088,6 +1088,8 @@ static int __sctp_connect(struct sock* sk,
        /* Walk through the addrs buffer and count the number of addresses. */
        addr_buf = kaddrs;
        while (walk_size < addrs_size) {
+               struct sctp_af *af;
+
                if (walk_size + sizeof(sa_family_t) > addrs_size) {
                        err = -EINVAL;
                        goto out_free;
@@ -1211,8 +1213,7 @@ static int __sctp_connect(struct sock* sk,
 
        /* Initialize sk's dport and daddr for getpeername() */
        inet_sk(sk)->inet_dport = htons(asoc->peer.port);
-       af = sctp_get_af_specific(sa_addr->sa.sa_family);
-       af->to_sk_daddr(sa_addr, sk);
+       sp->pf->to_sk_daddr(sa_addr, sk);
        sk->sk_err = 0;
 
        /* in-kernel sockets don't generally have a file allocated to them
@@ -1910,7 +1911,8 @@ SCTP_STATIC int sctp_sendmsg(struct kiocb *iocb, struct sock *sk,
 
        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
        if (!sctp_wspace(asoc)) {
-               err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len);
+               /* sk can be changed by peel off when waiting for buf. */
+               err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len, &sk);
                if (err)
                        goto out_free;
        }
@@ -4079,7 +4081,7 @@ static int sctp_getsockopt_sctp_status(struct sock *sk, int len,
        memcpy(&status.sstat_primary.spinfo_address, &transport->ipaddr,
                        transport->af_specific->sockaddr_len);
        /* Map ipv4 address into v4-mapped-on-v6 address.  */
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk),
                (union sctp_addr *)&status.sstat_primary.spinfo_address);
        status.sstat_primary.spinfo_state = transport->state;
        status.sstat_primary.spinfo_cwnd = transport->cwnd;
@@ -4238,10 +4240,14 @@ SCTP_STATIC int sctp_do_peeloff(struct sctp_association *asoc,
                                struct socket **sockp)
 {
        struct sock *sk = asoc->base.sk;
+       struct sctp_sock *sp = sctp_sk(sk);
        struct socket *sock;
-       struct sctp_af *af;
        int err = 0;
 
+       /* Do not peel off from one netns to another one. */
+       if (!net_eq(current->nsproxy->net_ns, sock_net(sk)))
+               return -EINVAL;
+
        /* An association cannot be branched off from an already peeled-off
         * socket, nor is this supported for tcp style sockets.
         */
@@ -4258,8 +4264,7 @@ SCTP_STATIC int sctp_do_peeloff(struct sctp_association *asoc,
        /* Make peeled-off sockets more like 1-1 accepted sockets.
         * Set the daddr and initialize id to something more random
         */
-       af = sctp_get_af_specific(asoc->peer.primary_addr.sa.sa_family);
-       af->to_sk_daddr(&asoc->peer.primary_addr, sk);
+       sp->pf->to_sk_daddr(&asoc->peer.primary_addr, sk);
 
        /* Populate the fields of the newsk from the oldsk and migrate the
         * asoc to the newsk.
@@ -4634,8 +4639,8 @@ static int sctp_getsockopt_peer_addrs(struct sock *sk, int len,
        list_for_each_entry(from, &asoc->peer.transport_addr_list,
                                transports) {
                memcpy(&temp, &from->ipaddr, sizeof(temp));
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp, &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sp, &temp);
                if (space_left < addrlen)
                        return -ENOMEM;
                if (copy_to_user(to, &temp, addrlen))
@@ -4678,9 +4683,9 @@ static int sctp_copy_laddrs(struct sock *sk, __u16 port, void *to,
                if (!temp.v4.sin_port)
                        temp.v4.sin_port = htons(port);
 
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sctp_sk(sk),
-                                                               &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sctp_sk(sk), &temp);
+
                if (space_left < addrlen) {
                        cnt =  -ENOMEM;
                        break;
@@ -4768,8 +4773,8 @@ static int sctp_getsockopt_local_addrs(struct sock *sk, int len,
         */
        list_for_each_entry(addr, &bp->address_list, list) {
                memcpy(&temp, &addr->a, sizeof(temp));
-               sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp, &temp);
-               addrlen = sctp_get_af_specific(temp.sa.sa_family)->sockaddr_len;
+               addrlen = sctp_get_pf_specific(sk->sk_family)
+                             ->addr_to_user(sp, &temp);
                if (space_left < addrlen) {
                        err =  -ENOMEM; /*fixme: right error?*/
                        goto out;
@@ -4828,7 +4833,7 @@ static int sctp_getsockopt_primary_addr(struct sock *sk, int len,
        memcpy(&prim.ssp_addr, &asoc->peer.primary_path->ipaddr,
                asoc->peer.primary_path->af_specific->sockaddr_len);
 
-       sctp_get_pf_specific(sk->sk_family)->addr_v4map(sp,
+       sctp_get_pf_specific(sk->sk_family)->addr_to_user(sp,
                        (union sctp_addr *)&prim.ssp_addr);
 
        if (put_user(len, optlen))
@@ -6460,7 +6465,7 @@ void sctp_sock_rfree(struct sk_buff *skb)
 
 /* Helper function to wait for space in the sndbuf.  */
 static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
-                               size_t msg_len)
+                               size_t msg_len, struct sock **orig_sk)
 {
        struct sock *sk = asoc->base.sk;
        int err = 0;
@@ -6492,14 +6497,18 @@ static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p,
                 */
                sctp_release_sock(sk);
                current_timeo = schedule_timeout(current_timeo);
-               if (sk != asoc->base.sk)
-                       goto do_error;
                sctp_lock_sock(sk);
+               if (sk != asoc->base.sk) {
+                       release_sock(sk);
+                       sk = asoc->base.sk;
+                       lock_sock(sk);
+               }
 
                *timeo_p = current_timeo;
        }
 
 out:
+       *orig_sk = sk;
        finish_wait(&asoc->wait, &wait);
 
        /* Release the association's refcnt.  */