Merge git://git.kernel.org/pub/scm/linux/kernel/git/mason/btrfs-unstable
[pandora-kernel.git] / fs / ncpfs / sock.c
index 8783eb7..e37df8d 100644 (file)
@@ -8,13 +8,13 @@
  *
  */
 
-#include <linux/config.h>
 
 #include <linux/time.h>
 #include <linux/errno.h>
 #include <linux/socket.h>
 #include <linux/fcntl.h>
 #include <linux/stat.h>
+#include <linux/string.h>
 #include <asm/uaccess.h>
 #include <linux/in.h>
 #include <linux/net.h>
@@ -56,10 +56,11 @@ static int _send(struct socket *sock, const void *buff, int len)
 struct ncp_request_reply {
        struct list_head req;
        wait_queue_head_t wq;
-       struct ncp_reply_header* reply_buf;
+       atomic_t refs;
+       unsigned char* reply_buf;
        size_t datalen;
        int result;
-       enum { RQ_DONE, RQ_INPROGRESS, RQ_QUEUED, RQ_IDLE } status;
+       enum { RQ_DONE, RQ_INPROGRESS, RQ_QUEUED, RQ_IDLE, RQ_ABANDONED } status;
        struct kvec* tx_ciov;
        size_t tx_totallen;
        size_t tx_iovlen;
@@ -68,6 +69,32 @@ struct ncp_request_reply {
        u_int32_t sign[6];
 };
 
+static inline struct ncp_request_reply* ncp_alloc_req(void)
+{
+       struct ncp_request_reply *req;
+
+       req = kmalloc(sizeof(struct ncp_request_reply), GFP_KERNEL);
+       if (!req)
+               return NULL;
+
+       init_waitqueue_head(&req->wq);
+       atomic_set(&req->refs, (1));
+       req->status = RQ_IDLE;
+
+       return req;
+}
+
+static void ncp_req_get(struct ncp_request_reply *req)
+{
+       atomic_inc(&req->refs);
+}
+
+static void ncp_req_put(struct ncp_request_reply *req)
+{
+       if (atomic_dec_and_test(&req->refs))
+               kfree(req);
+}
+
 void ncp_tcp_data_ready(struct sock *sk, int len)
 {
        struct ncp_server *server = sk->sk_user_data;
@@ -102,14 +129,17 @@ void ncpdgram_timeout_call(unsigned long v)
        schedule_work(&server->timeout_tq);
 }
 
-static inline void ncp_finish_request(struct ncp_request_reply *req, int result)
+static inline void ncp_finish_request(struct ncp_server *server, struct ncp_request_reply *req, int result)
 {
        req->result = result;
+       if (req->status != RQ_ABANDONED)
+               memcpy(req->reply_buf, server->rxbuf, req->datalen);
        req->status = RQ_DONE;
        wake_up_all(&req->wq);
+       ncp_req_put(req);
 }
 
-static void __abort_ncp_connection(struct ncp_server *server, struct ncp_request_reply *aborted, int err)
+static void __abort_ncp_connection(struct ncp_server *server)
 {
        struct ncp_request_reply *req;
 
@@ -119,31 +149,19 @@ static void __abort_ncp_connection(struct ncp_server *server, struct ncp_request
                req = list_entry(server->tx.requests.next, struct ncp_request_reply, req);
                
                list_del_init(&req->req);
-               if (req == aborted) {
-                       ncp_finish_request(req, err);
-               } else {
-                       ncp_finish_request(req, -EIO);
-               }
+               ncp_finish_request(server, req, -EIO);
        }
        req = server->rcv.creq;
        if (req) {
                server->rcv.creq = NULL;
-               if (req == aborted) {
-                       ncp_finish_request(req, err);
-               } else {
-                       ncp_finish_request(req, -EIO);
-               }
+               ncp_finish_request(server, req, -EIO);
                server->rcv.ptr = NULL;
                server->rcv.state = 0;
        }
        req = server->tx.creq;
        if (req) {
                server->tx.creq = NULL;
-               if (req == aborted) {
-                       ncp_finish_request(req, err);
-               } else {
-                       ncp_finish_request(req, -EIO);
-               }
+               ncp_finish_request(server, req, -EIO);
        }
 }
 
@@ -161,10 +179,12 @@ static inline void __ncp_abort_request(struct ncp_server *server, struct ncp_req
                        break;
                case RQ_QUEUED:
                        list_del_init(&req->req);
-                       ncp_finish_request(req, err);
+                       ncp_finish_request(server, req, err);
                        break;
                case RQ_INPROGRESS:
-                       __abort_ncp_connection(server, req, err);
+                       req->status = RQ_ABANDONED;
+                       break;
+               case RQ_ABANDONED:
                        break;
        }
 }
@@ -178,7 +198,7 @@ static inline void ncp_abort_request(struct ncp_server *server, struct ncp_reque
 
 static inline void __ncptcp_abort(struct ncp_server *server)
 {
-       __abort_ncp_connection(server, NULL, 0);
+       __abort_ncp_connection(server);
 }
 
 static int ncpdgram_send(struct socket *sock, struct ncp_request_reply *req)
@@ -295,6 +315,11 @@ static void ncptcp_start_request(struct ncp_server *server, struct ncp_request_r
 
 static inline void __ncp_start_request(struct ncp_server *server, struct ncp_request_reply *req)
 {
+       /* we copy the data so that we do not depend on the caller
+          staying alive */
+       memcpy(server->txbuf, req->tx_iov[1].iov_base, req->tx_iov[1].iov_len);
+       req->tx_iov[1].iov_base = server->txbuf;
+
        if (server->ncp_sock->type == SOCK_STREAM)
                ncptcp_start_request(server, req);
        else
@@ -309,6 +334,7 @@ static int ncp_add_request(struct ncp_server *server, struct ncp_request_reply *
                printk(KERN_ERR "ncpfs: tcp: Server died\n");
                return -EIO;
        }
+       ncp_req_get(req);
        if (server->tx.creq || server->rcv.creq) {
                req->status = RQ_QUEUED;
                list_add_tail(&req->req, &server->tx.requests);
@@ -351,9 +377,10 @@ static void info_server(struct ncp_server *server, unsigned int id, const void *
        }
 }
 
-void ncpdgram_rcv_proc(void *s)
+void ncpdgram_rcv_proc(struct work_struct *work)
 {
-       struct ncp_server *server = s;
+       struct ncp_server *server =
+               container_of(work, struct ncp_server, rcv.tq);
        struct socket* sock;
        
        sock = server->ncp_sock;
@@ -409,7 +436,7 @@ void ncpdgram_rcv_proc(void *s)
                                        server->timeout_last = NCP_MAX_RPC_TIMEOUT;
                                        mod_timer(&server->timeout_tm, jiffies + NCP_MAX_RPC_TIMEOUT);
                                } else if (reply.type == NCP_REPLY) {
-                                       result = _recv(sock, (void*)req->reply_buf, req->datalen, MSG_DONTWAIT);
+                                       result = _recv(sock, server->rxbuf, req->datalen, MSG_DONTWAIT);
 #ifdef CONFIG_NCPFS_PACKET_SIGNING
                                        if (result >= 0 && server->sign_active && req->tx_type != NCP_DEALLOC_SLOT_REQUEST) {
                                                if (result < 8 + 8) {
@@ -419,7 +446,7 @@ void ncpdgram_rcv_proc(void *s)
                                                        
                                                        result -= 8;
                                                        hdrl = sock->sk->sk_family == AF_INET ? 8 : 6;
-                                                       if (sign_verify_reply(server, ((char*)req->reply_buf) + hdrl, result - hdrl, cpu_to_le32(result), ((char*)req->reply_buf) + result)) {
+                                                       if (sign_verify_reply(server, server->rxbuf + hdrl, result - hdrl, cpu_to_le32(result), server->rxbuf + result)) {
                                                                printk(KERN_INFO "ncpfs: Signature violation\n");
                                                                result = -EIO;
                                                        }
@@ -428,7 +455,7 @@ void ncpdgram_rcv_proc(void *s)
 #endif
                                        del_timer(&server->timeout_tm);
                                        server->rcv.creq = NULL;
-                                       ncp_finish_request(req, result);
+                                       ncp_finish_request(server, req, result);
                                        __ncp_next_request(server);
                                        mutex_unlock(&server->rcv.creq_mutex);
                                        continue;
@@ -469,20 +496,15 @@ static void __ncpdgram_timeout_proc(struct ncp_server *server)
        }
 }
 
-void ncpdgram_timeout_proc(void *s)
+void ncpdgram_timeout_proc(struct work_struct *work)
 {
-       struct ncp_server *server = s;
+       struct ncp_server *server =
+               container_of(work, struct ncp_server, timeout_tq);
        mutex_lock(&server->rcv.creq_mutex);
        __ncpdgram_timeout_proc(server);
        mutex_unlock(&server->rcv.creq_mutex);
 }
 
-static inline void ncp_init_req(struct ncp_request_reply* req)
-{
-       init_waitqueue_head(&req->wq);
-       req->status = RQ_IDLE;
-}
-
 static int do_tcp_rcv(struct ncp_server *server, void *buffer, size_t len)
 {
        int result;
@@ -600,8 +622,8 @@ skipdata:;
                                        goto skipdata;
                                }
                                req->datalen = datalen - 8;
-                               req->reply_buf->type = NCP_REPLY;
-                               server->rcv.ptr = (unsigned char*)(req->reply_buf) + 2;
+                               ((struct ncp_reply_header*)server->rxbuf)->type = NCP_REPLY;
+                               server->rcv.ptr = server->rxbuf + 2;
                                server->rcv.len = datalen - 10;
                                server->rcv.state = 1;
                                break;
@@ -614,12 +636,12 @@ skipdata:;
                        case 1:
                                req = server->rcv.creq;
                                if (req->tx_type != NCP_ALLOC_SLOT_REQUEST) {
-                                       if (req->reply_buf->sequence != server->sequence) {
+                                       if (((struct ncp_reply_header*)server->rxbuf)->sequence != server->sequence) {
                                                printk(KERN_ERR "ncpfs: tcp: Bad sequence number\n");
                                                __ncp_abort_request(server, req, -EIO);
                                                return -EIO;
                                        }
-                                       if ((req->reply_buf->conn_low | (req->reply_buf->conn_high << 8)) != server->connection) {
+                                       if ((((struct ncp_reply_header*)server->rxbuf)->conn_low | (((struct ncp_reply_header*)server->rxbuf)->conn_high << 8)) != server->connection) {
                                                printk(KERN_ERR "ncpfs: tcp: Connection number mismatch\n");
                                                __ncp_abort_request(server, req, -EIO);
                                                return -EIO;
@@ -627,14 +649,14 @@ skipdata:;
                                }
 #ifdef CONFIG_NCPFS_PACKET_SIGNING                             
                                if (server->sign_active && req->tx_type != NCP_DEALLOC_SLOT_REQUEST) {
-                                       if (sign_verify_reply(server, (unsigned char*)(req->reply_buf) + 6, req->datalen - 6, cpu_to_be32(req->datalen + 16), &server->rcv.buf.type)) {
+                                       if (sign_verify_reply(server, server->rxbuf + 6, req->datalen - 6, cpu_to_be32(req->datalen + 16), &server->rcv.buf.type)) {
                                                printk(KERN_ERR "ncpfs: tcp: Signature violation\n");
                                                __ncp_abort_request(server, req, -EIO);
                                                return -EIO;
                                        }
                                }
 #endif                         
-                               ncp_finish_request(req, req->datalen);
+                               ncp_finish_request(server, req, req->datalen);
                        nextreq:;
                                __ncp_next_request(server);
                        case 2:
@@ -644,7 +666,7 @@ skipdata:;
                                server->rcv.state = 0;
                                break;
                        case 3:
-                               ncp_finish_request(server->rcv.creq, -EIO);
+                               ncp_finish_request(server, server->rcv.creq, -EIO);
                                goto nextreq;
                        case 5:
                                info_server(server, 0, server->unexpected_packet.data, server->unexpected_packet.len);
@@ -653,18 +675,20 @@ skipdata:;
        }
 }
 
-void ncp_tcp_rcv_proc(void *s)
+void ncp_tcp_rcv_proc(struct work_struct *work)
 {
-       struct ncp_server *server = s;
+       struct ncp_server *server =
+               container_of(work, struct ncp_server, rcv.tq);
 
        mutex_lock(&server->rcv.creq_mutex);
        __ncptcp_rcv_proc(server);
        mutex_unlock(&server->rcv.creq_mutex);
 }
 
-void ncp_tcp_tx_proc(void *s)
+void ncp_tcp_tx_proc(struct work_struct *work)
 {
-       struct ncp_server *server = s;
+       struct ncp_server *server =
+               container_of(work, struct ncp_server, tx.tq);
        
        mutex_lock(&server->rcv.creq_mutex);
        __ncptcp_try_send(server);
@@ -672,28 +696,39 @@ void ncp_tcp_tx_proc(void *s)
 }
 
 static int do_ncp_rpc_call(struct ncp_server *server, int size,
-               struct ncp_reply_header* reply_buf, int max_reply_size)
+               unsigned char* reply_buf, int max_reply_size)
 {
        int result;
-       struct ncp_request_reply req;
-
-       ncp_init_req(&req);
-       req.reply_buf = reply_buf;
-       req.datalen = max_reply_size;
-       req.tx_iov[1].iov_base = server->packet;
-       req.tx_iov[1].iov_len = size;
-       req.tx_iovlen = 1;
-       req.tx_totallen = size;
-       req.tx_type = *(u_int16_t*)server->packet;
-
-       result = ncp_add_request(server, &req);
-       if (result < 0) {
-               return result;
-       }
-       if (wait_event_interruptible(req.wq, req.status == RQ_DONE)) {
-               ncp_abort_request(server, &req, -EIO);
+       struct ncp_request_reply *req;
+
+       req = ncp_alloc_req();
+       if (!req)
+               return -ENOMEM;
+
+       req->reply_buf = reply_buf;
+       req->datalen = max_reply_size;
+       req->tx_iov[1].iov_base = server->packet;
+       req->tx_iov[1].iov_len = size;
+       req->tx_iovlen = 1;
+       req->tx_totallen = size;
+       req->tx_type = *(u_int16_t*)server->packet;
+
+       result = ncp_add_request(server, req);
+       if (result < 0)
+               goto out;
+
+       if (wait_event_interruptible(req->wq, req->status == RQ_DONE)) {
+               ncp_abort_request(server, req, -EINTR);
+               result = -EINTR;
+               goto out;
        }
-       return req.result;
+
+       result = req->result;
+
+out:
+       ncp_req_put(req);
+
+       return result;
 }
 
 /*
@@ -748,11 +783,6 @@ static int ncp_do_request(struct ncp_server *server, int size,
 
        DDPRINTK("do_ncp_rpc_call returned %d\n", result);
 
-       if (result < 0) {
-               /* There was a problem with I/O, so the connections is
-                * no longer usable. */
-               ncp_invalidate_conn(server);
-       }
        return result;
 }