ipv6: tcp: add rcu locking in tcp_v6_send_synack()
[pandora-kernel.git] / net / ipv6 / tcp_ipv6.c
index 2dea4bb..b20f7a0 100644 (file)
@@ -131,6 +131,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        struct ipv6_pinfo *np = inet6_sk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
        struct in6_addr *saddr = NULL, *final_p, final;
+       struct ipv6_txoptions *opt;
        struct rt6_info *rt;
        struct flowi6 fl6;
        struct dst_entry *dst;
@@ -252,7 +253,8 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        fl6.fl6_dport = usin->sin6_port;
        fl6.fl6_sport = inet->inet_sport;
 
-       final_p = fl6_update_dst(&fl6, np->opt, &final);
+       opt = rcu_dereference_protected(np->opt, sock_owned_by_user(sk));
+       final_p = fl6_update_dst(&fl6, opt, &final);
 
        security_sk_classify_flow(sk, flowi6_to_flowi(&fl6));
 
@@ -295,9 +297,9 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        }
 
        icsk->icsk_ext_hdr_len = 0;
-       if (np->opt)
-               icsk->icsk_ext_hdr_len = (np->opt->opt_flen +
-                                         np->opt->opt_nflen);
+       if (opt)
+               icsk->icsk_ext_hdr_len = opt->opt_flen +
+                                        opt->opt_nflen;
 
        tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr);
 
@@ -481,7 +483,6 @@ static int tcp_v6_send_synack(struct sock *sk, struct request_sock *req,
        struct inet6_request_sock *treq = inet6_rsk(req);
        struct ipv6_pinfo *np = inet6_sk(sk);
        struct sk_buff * skb;
-       struct ipv6_txoptions *opt = NULL;
        struct in6_addr * final_p, final;
        struct flowi6 fl6;
        struct dst_entry *dst;
@@ -498,8 +499,7 @@ static int tcp_v6_send_synack(struct sock *sk, struct request_sock *req,
        fl6.fl6_sport = inet_rsk(req)->loc_port;
        security_req_classify_flow(req, flowi6_to_flowi(&fl6));
 
-       opt = np->opt;
-       final_p = fl6_update_dst(&fl6, opt, &final);
+       final_p = fl6_update_dst(&fl6, rcu_dereference(np->opt), &final);
 
        dst = ip6_dst_lookup_flow(sk, &fl6, final_p, false);
        if (IS_ERR(dst)) {
@@ -513,13 +513,14 @@ static int tcp_v6_send_synack(struct sock *sk, struct request_sock *req,
                __tcp_v6_send_check(skb, &treq->loc_addr, &treq->rmt_addr);
 
                ipv6_addr_copy(&fl6.daddr, &treq->rmt_addr);
-               err = ip6_xmit(sk, skb, &fl6, opt, np->tclass);
+               rcu_read_lock();
+               err = ip6_xmit(sk, skb, &fl6, rcu_dereference(np->opt),
+                              np->tclass);
+               rcu_read_unlock();
                err = net_xmit_eval(err);
        }
 
 done:
-       if (opt && opt != np->opt)
-               sock_kfree_s(sk, opt, opt->tot_len);
        dst_release(dst);
        return err;
 }
@@ -592,7 +593,7 @@ static int tcp_v6_md5_do_add(struct sock *sk, const struct in6_addr *peer,
                        sk_nocaps_add(sk, NETIF_F_GSO_MASK);
                }
                if (tp->md5sig_info->entries6 == 0 &&
-                       tcp_alloc_md5sig_pool(sk) == NULL) {
+                   !tcp_alloc_md5sig_pool()) {
                        kfree(newkey);
                        return -ENOMEM;
                }
@@ -602,8 +603,6 @@ static int tcp_v6_md5_do_add(struct sock *sk, const struct in6_addr *peer,
 
                        if (!keys) {
                                kfree(newkey);
-                               if (tp->md5sig_info->entries6 == 0)
-                                       tcp_free_md5sig_pool();
                                return -ENOMEM;
                        }
 
@@ -649,7 +648,6 @@ static int tcp_v6_md5_do_del(struct sock *sk, const struct in6_addr *peer)
                                kfree(tp->md5sig_info->keys6);
                                tp->md5sig_info->keys6 = NULL;
                                tp->md5sig_info->alloced6 = 0;
-                               tcp_free_md5sig_pool();
                        } else {
                                /* shrink the database */
                                if (tp->md5sig_info->entries6 != i)
@@ -673,7 +671,6 @@ static void tcp_v6_clear_md5_list (struct sock *sk)
                for (i = 0; i < tp->md5sig_info->entries6; i++)
                        kfree(tp->md5sig_info->keys6[i].base.key);
                tp->md5sig_info->entries6 = 0;
-               tcp_free_md5sig_pool();
        }
 
        kfree(tp->md5sig_info->keys6);
@@ -684,7 +681,6 @@ static void tcp_v6_clear_md5_list (struct sock *sk)
                for (i = 0; i < tp->md5sig_info->entries4; i++)
                        kfree(tp->md5sig_info->keys4[i].base.key);
                tp->md5sig_info->entries4 = 0;
-               tcp_free_md5sig_pool();
        }
 
        kfree(tp->md5sig_info->keys4);
@@ -1048,7 +1044,8 @@ static void tcp_v6_send_response(struct sk_buff *skb, u32 seq, u32 ack, u32 win,
        __tcp_v6_send_check(buff, &fl6.saddr, &fl6.daddr);
 
        fl6.flowi6_proto = IPPROTO_TCP;
-       fl6.flowi6_oif = inet6_iif(skb);
+       if (ipv6_addr_type(&fl6.daddr) & IPV6_ADDR_LINKLOCAL)
+               fl6.flowi6_oif = inet6_iif(skb);
        fl6.fl6_dport = t1->dest;
        fl6.fl6_sport = t1->source;
        security_skb_classify_flow(skb, flowi6_to_flowi(&fl6));
@@ -1084,7 +1081,7 @@ static void tcp_v6_send_reset(struct sock *sk, struct sk_buff *skb)
 
 #ifdef CONFIG_TCP_MD5SIG
        if (sk)
-               key = tcp_v6_md5_do_lookup(sk, &ipv6_hdr(skb)->daddr);
+               key = tcp_v6_md5_do_lookup(sk, &ipv6_hdr(skb)->saddr);
 #endif
 
        if (th->ack)
@@ -1253,7 +1250,7 @@ static int tcp_v6_conn_request(struct sock *sk, struct sk_buff *skb)
        ipv6_addr_copy(&treq->rmt_addr, &ipv6_hdr(skb)->saddr);
        ipv6_addr_copy(&treq->loc_addr, &ipv6_hdr(skb)->daddr);
        if (!want_cookie || tmp_opt.tstamp_ok)
-               TCP_ECN_create_request(req, tcp_hdr(skb));
+               TCP_ECN_create_request(req, skb);
 
        treq->iif = sk->sk_bound_dev_if;
 
@@ -1412,7 +1409,6 @@ static struct sock * tcp_v6_syn_recv_sock(struct sock *sk, struct sk_buff *skb,
        }
 
        treq = inet6_rsk(req);
-       opt = np->opt;
 
        if (sk_acceptq_is_full(sk))
                goto out_overflow;
@@ -1480,20 +1476,23 @@ static struct sock * tcp_v6_syn_recv_sock(struct sock *sk, struct sk_buff *skb,
           but we make one more one thing there: reattach optmem
           to newsk.
         */
+       opt = rcu_dereference(np->opt);
        if (opt) {
-               newnp->opt = ipv6_dup_options(newsk, opt);
-               if (opt != np->opt)
-                       sock_kfree_s(sk, opt, opt->tot_len);
+               opt = ipv6_dup_options(newsk, opt);
+               RCU_INIT_POINTER(newnp->opt, opt);
        }
-
-       inet_csk(newsk)->icsk_ext_hdr_len = 0;
-       if (newnp->opt)
-               inet_csk(newsk)->icsk_ext_hdr_len = (newnp->opt->opt_nflen +
-                                                    newnp->opt->opt_flen);
+       inet_csk(newsk)->icsk_ext_hdr_len = 0;
+       if (opt)
+               inet_csk(newsk)->icsk_ext_hdr_len = opt->opt_nflen +
+                                                   opt->opt_flen;
 
        tcp_mtup_init(newsk);
        tcp_sync_mss(newsk, dst_mtu(dst));
        newtp->advmss = dst_metric_advmss(dst);
+       if (tcp_sk(sk)->rx_opt.user_mss &&
+           tcp_sk(sk)->rx_opt.user_mss < newtp->advmss)
+               newtp->advmss = tcp_sk(sk)->rx_opt.user_mss;
+
        tcp_initialize_rcv_mss(newsk);
        if (tcp_rsk(req)->snt_synack)
                tcp_valid_rtt_meas(newsk,
@@ -1519,7 +1518,8 @@ static struct sock * tcp_v6_syn_recv_sock(struct sock *sk, struct sk_buff *skb,
 #endif
 
        if (__inet_inherit_port(sk, newsk) < 0) {
-               sock_put(newsk);
+               inet_csk_prepare_forced_close(newsk);
+               tcp_done(newsk);
                goto out;
        }
        __inet6_hash(newsk, NULL);
@@ -1529,8 +1529,6 @@ static struct sock * tcp_v6_syn_recv_sock(struct sock *sk, struct sk_buff *skb,
 out_overflow:
        NET_INC_STATS_BH(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS);
 out_nonewsk:
-       if (opt && opt != np->opt)
-               sock_kfree_s(sk, opt, opt->tot_len);
        dst_release(dst);
 out:
        NET_INC_STATS_BH(sock_net(sk), LINUX_MIB_LISTENDROPS);
@@ -2190,6 +2188,17 @@ void tcp6_proc_exit(struct net *net)
 }
 #endif
 
+static void tcp_v6_clear_sk(struct sock *sk, int size)
+{
+       struct inet_sock *inet = inet_sk(sk);
+
+       /* we do not want to clear pinet6 field, because of RCU lookups */
+       sk_prot_clear_nulls(sk, offsetof(struct inet_sock, pinet6));
+
+       size -= offsetof(struct inet_sock, pinet6) + sizeof(inet->pinet6);
+       memset(&inet->pinet6 + 1, 0, size);
+}
+
 struct proto tcpv6_prot = {
        .name                   = "TCPv6",
        .owner                  = THIS_MODULE,
@@ -2229,6 +2238,7 @@ struct proto tcpv6_prot = {
        .compat_setsockopt      = compat_tcp_setsockopt,
        .compat_getsockopt      = compat_tcp_getsockopt,
 #endif
+       .clear_sk               = tcp_v6_clear_sk,
 };
 
 static const struct inet6_protocol tcpv6_protocol = {