/* Copyright 2020 Alyssa Ross * * This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ #define _POSIX_C_SOURCE 200809L #include #include #include #include #include #include #include #include #include #include #include "virtio_wl.h" // This is essentially vendored reusable library code, so I consider // it exempt from the Wayland style guide. :) #if defined(__GNUC__) && __GNUC__ >= 4 #define VIRTIO_WL_EXPORT __attribute__ ((visibility("default"))) #else #define VIRTIO_WL_EXPORT #endif static int set_nonblocking(int fd) { int fl = fcntl(fd, F_GETFL); if (fl == -1) return -1; if (!(fl & O_NONBLOCK)) if (fcntl(fd, F_SETFL, fl | O_NONBLOCK) == -1) return -1; return fl; } // Returns the total size of all buffers in an iovec. // A return value of -1 means that the total overflowed. static ssize_t iov_len(const struct iovec iov[], size_t n) { size_t len = 0; for (size_t i = 0; i < n; i++) { if (SSIZE_MAX - len < iov[i].iov_len) return -1; len += iov[i].iov_len; } return len; } // Copies from an iovec array into a buffer. // The buffer is assumed to be large enough to hold all the data. // This length can be calculated with the iov_len function. static size_t iov_flatten(void *buf, size_t buflen, const struct iovec *iov, size_t iovlen) { size_t off = 0; for (size_t index = 0; index < iovlen && off < buflen; index++) { const struct iovec *i = &iov[index]; size_t rem = buflen - off; size_t len = i->iov_len < rem ? i->iov_len : rem; memcpy((unsigned char *)buf + off, i->iov_base, len); off += len; } return off; } // Copies from a buffer into an iovec array. // Returns number of bytes copied. static size_t iov_fill(struct iovec *iov, size_t iovlen, const void *buf, size_t buflen) { size_t off = 0; for (size_t index = 0; index < iovlen && off < buflen; index++) { struct iovec *i = &iov[index]; size_t rem = buflen - off; size_t len = i->iov_len < rem ? i->iov_len : rem; memcpy(i->iov_base, (const unsigned char *)buf + off, len); off += len; } return off; } static ssize_t cmsg_to_fdbuf(int *buf, size_t buflen, const struct msghdr *msg) { size_t next_fd = 0; for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) { size_t size = cmsg->cmsg_len - CMSG_LEN(0); // Check the cmsg can be handled. if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS || size % sizeof(int) != 0) { errno = EINVAL; return -1; } size_t rem = sizeof(int) * (buflen - next_fd); size_t len = size > rem ? rem : size; // Copy the fds to the buffer. memcpy(buf + next_fd, CMSG_DATA(cmsg), len); next_fd += len / sizeof(int); if (size > rem) break; } return next_fd > SSIZE_MAX ? SSIZE_MAX : next_fd; } static size_t fdbuf_to_cmsg(struct msghdr *msg, const int *buf, size_t buflen) { // Check msg->msg_control is long enough to fit at least one fd. if (msg->msg_controllen < CMSG_SPACE(sizeof(int))) { // If there's at least one fd in buf, set MSG_CTRUNC. size_t i = 0; while (i < buflen && !(msg->msg_flags & MSG_CTRUNC)) if (buf[i++] != -1) msg->msg_flags |= MSG_CTRUNC; return 0; } // cmsg(3): // > When initializing a buffer that will contain a series of cmsghdr // > structures (e.g., to be sent with sendmsg(2)), that buffer should // > first be zero-initialized to en‐ sure the correct operation of // > CMSG_NXTHDR(). memset(msg->msg_control, 0, msg->msg_controllen); // Set up the cmsg. struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg); cmsg->cmsg_level = SOL_SOCKET; cmsg->cmsg_type = SCM_RIGHTS; // Copy as many fds as fit into cmsg. size_t len = 0; for (size_t i = 0; i < buflen; i++) { if (buf[i] == -1) continue; if (CMSG_LEN((len + 1) * sizeof(int)) > msg->msg_controllen) { msg->msg_flags |= MSG_CTRUNC; break; } memcpy(CMSG_DATA(cmsg) + sizeof(int) * len, &buf[i], sizeof(int)); len++; } cmsg->cmsg_len = CMSG_LEN(len * sizeof(int)); return len; } VIRTIO_WL_EXPORT int virtio_wl_connect(const char *name, uint32_t flags) { static int wl_fd = -1; if (wl_fd < 0) wl_fd = open("/dev/wl0", O_RDWR | O_CLOEXEC); if (wl_fd < 0) return wl_fd; struct virtwl_ioctl_new new_ctx = { .type = name ? VIRTWL_IOCTL_NEW_CTX_NAMED : VIRTWL_IOCTL_NEW_CTX, .fd = -1, .flags = flags, }; // Device assumes name 32 bytes long if not null terminated. if (name) strncpy(new_ctx.name, name, sizeof(new_ctx.name)); if (ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx)) return -1; return new_ctx.fd; } VIRTIO_WL_EXPORT int virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn) { return ioctl(sockfd, VIRTWL_IOCTL_SEND, ioctl_txn); } static int msghdr_to_txn(const struct msghdr *msg, struct virtwl_ioctl_txn **txn) { // Make sure that if there's an error and the caller tries to free *txn // anyway it doesn't end up freeing an invalid address. *txn = NULL; // Figure out how big the txn needs to be. ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen); if (len < 0 || len > UINT32_MAX) { errno = ENOMEM; return -1; } // Allocate the txn. *txn = malloc(sizeof(**txn) + len); if (!*txn) return -1; // Set the len member of the txn. (*txn)->len = len; // Copy data from the iovec into the transaction. iov_flatten((*txn)->data, len, msg->msg_iov, msg->msg_iovlen); // Copy file descriptors to the txn. ssize_t fd_count = cmsg_to_fdbuf((*txn)->fds, VIRTWL_SEND_MAX_ALLOCS, msg); if (fd_count == -1) { free(*txn); *txn = NULL; return -1; } // Fill the rest of the fd buffer with -1. while (fd_count < VIRTWL_SEND_MAX_ALLOCS) (*txn)->fds[fd_count++] = -1; return 0; } static void txn_to_msghdr(struct msghdr *msg, const struct virtwl_ioctl_txn *txn) { // Copy txn data to iovecs. iov_fill(msg->msg_iov, msg->msg_iovlen, txn->data, txn->len); // Copy fds to cmsg. fdbuf_to_cmsg(msg, txn->fds, VIRTWL_SEND_MAX_ALLOCS); } VIRTIO_WL_EXPORT ssize_t virtio_wl_sendmsg(int sockfd, const struct msghdr *msg, int flags) { int sockfl; if ((flags & ~MSG_DONTWAIT) != MSG_NOSIGNAL) { errno = EINVAL; return -1; } struct virtwl_ioctl_txn *txn; if (msghdr_to_txn(msg, &txn) == -1) return -1; if (flags & MSG_DONTWAIT) { sockfl = set_nonblocking(sockfd); if (sockfl == -1) { free(txn); return -1; } } int rv = virtio_wl_send_raw(sockfd, txn); if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK)) { if (fcntl(sockfd, F_SETFL, sockfl) == -1) { free(txn); return -1; } } if (rv != -1) rv = txn->len; free(txn); return rv; } VIRTIO_WL_EXPORT int virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn) { return ioctl(sockfd, VIRTWL_IOCTL_RECV, ioctl_txn); } VIRTIO_WL_EXPORT ssize_t virtio_wl_recvmsg(int sockfd, struct msghdr *msg, int flags) { int sockfl; if ((flags & ~MSG_DONTWAIT) != MSG_CMSG_CLOEXEC) { errno = EINVAL; return -1; } ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen); if (len < 0 || len > UINT32_MAX) { errno = ENOMEM; return -1; } struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len); if (!txn) return -1; txn->len = len; if (flags & MSG_DONTWAIT) { sockfl = set_nonblocking(sockfd); if (sockfl == -1) { free(txn); return -1; } } ssize_t rv = virtio_wl_recv_raw(sockfd, txn); if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK)) if (fcntl(sockfd, F_SETFL, sockfl) == -1) rv = -1; if (rv == -1) { for (size_t i = 0; i < VIRTWL_SEND_MAX_ALLOCS; i++) if (txn->fds[i] == -1) close(txn->fds[i]); } else { txn_to_msghdr(msg, txn); rv = txn->len; } free(txn); return rv; }