patches and low-level development discussion
 help / color / mirror / Atom feed
* [PATCH wayland] Support virtio_wl display sockets
@ 2020-07-08 15:16 Alyssa Ross
  2020-07-08 15:43 ` Cole Helbling
                   ` (2 more replies)
  0 siblings, 3 replies; 5+ messages in thread
From: Alyssa Ross @ 2020-07-08 15:16 UTC (permalink / raw)
  To: devel

This patch adds libvirtio_wl, which exposes reimplementations of the
sendmsg(2) and recvmsg(2) for virtio_wl socket fds.  Tests for as much
libvirtio_wl functionality as is reasonably possible to test without
requiring a working virtio_wl connection are included.  (Testing
further would require running tests in a VM so that they could talk to
the virtio_wl kernel driver, which would be prohibitively complex.

virtio_wl socket fds do not actually point to sockets, but to special
virtio_wl files.  Whenever a display socket operation calls sendmsg()
or recvmsg() with a file descriptor and receives an ENOTSOCK error, it
retries the operation with the equivalent libvirtio_wl function, in
case the file descriptor is a virtio_wl.  This is the least invasive
way to implement virtio_wl support -- if a normal socket is being
used, there will be no change in behaviour.

Because virtio_wl doesn't implement every socket feature, some
workarounds are currently required to accomodate everything Wayland
expects of sockets:

* virtio_wl_recvmsg implements the MSG_DONTWAIT flag by setting
  O_NONBLOCK on the fd, attempting the VIRTWL_IOCTL_RECV operation,
  and then restoring the fd's original flags.  This is obviously
  race-prone, but there is no reasonable alternative at present.

* virtio_wl_sendmsg requires MSG_NOSIGNAL to be set, and ignores it.
  This is because I think from looking at the code that virtio_wl does
  not generate SIGPIPE signals, nor does it ever return EPIPE.  I
  could be wrong about this, though.

* virtio_wl does not support credential passing -- what would it even
  mean, considering the other end of the connection is on
  another (virtual) machine?  So wl_client's ucred member will have
  pid, uid, and gid all set to -1 for a client connected over
  virtio_wl.

* virtio_wl sockets do not support accept(2), so a fallback is used.
  A proxy program on the host accept(2) on a host socket.  When it
  receives a connection, it attaches the connection socket to the VM,
  then sends the name of the connected socket over the Wayland display
  socket.  Wayland then receives this name, looks up the connection
  socket, and uses that as the client connection socket.

Additionally, virtio_wl memfd-like file descriptors don't support
mremap(2), so for virtio_wl sockets Wayland will munmap(2) the memfd,
and then mmap(2) it again.  This should be at least mostly okay
because Wayland only ever calls mremap with MREMAP_MAYMOVE, but it is
still race-prone.  To be able to do this, memfds are no longer closed
after being mmaped, but are kept around in the wl_shm_pool struct, so
that they can be passed to mmap() if required.
---
 src/connection.c       |   4 +
 src/meson.build        |  16 +-
 src/virtio_wl.c        | 344 +++++++++++++++++++++++
 src/virtio_wl.h        |  23 ++
 src/wayland-os.c       |  23 ++
 src/wayland-os.h       |   3 +
 src/wayland-server.c   |  20 +-
 src/wayland-shm.c      |  24 +-
 tests/meson.build      |   1 +
 tests/virtio_wl-test.c | 615 +++++++++++++++++++++++++++++++++++++++++
 10 files changed, 1064 insertions(+), 9 deletions(-)
 create mode 100644 src/virtio_wl.c
 create mode 100644 src/virtio_wl.h
 create mode 100644 tests/virtio_wl-test.c

diff --git a/src/connection.c b/src/connection.c
index d0c7d9f..1bfbd8c 100644
--- a/src/connection.c
+++ b/src/connection.c
@@ -40,6 +40,7 @@
 #include <time.h>
 #include <ffi.h>
 
+#include "virtio_wl.h"
 #include "wayland-util.h"
 #include "wayland-private.h"
 #include "wayland-os.h"
@@ -314,6 +315,9 @@ wl_connection_flush(struct wl_connection *connection)
 		do {
 			len = sendmsg(connection->fd, &msg,
 				      MSG_NOSIGNAL | MSG_DONTWAIT);
+			if (len == -1 && errno == ENOTSOCK)
+				len = virtio_wl_sendmsg(connection->fd, &msg,
+							MSG_NOSIGNAL | MSG_DONTWAIT);
 		} while (len == -1 && errno == EINTR);
 
 		if (len == -1)
diff --git a/src/meson.build b/src/meson.build
index 2d1485c..cca1d93 100644
--- a/src/meson.build
+++ b/src/meson.build
@@ -71,13 +71,26 @@ if get_option('libraries')
 	mathlib_dep = cc.find_library('m', required: false)
 	threads_dep = dependency('threads', required: false)
 
+	virtio_wl = library(
+		'virtio_wl',
+		sources: [
+			'virtio_wl.c',
+		],
+		version: '0.1.0',
+		install: true,
+	)
+
+	virtio_wl_dep = declare_dependency(
+		link_with: virtio_wl,
+	)
+
 	wayland_private = static_library(
 		'wayland-private',
 		sources: [
 			'connection.c',
 			'wayland-os.c'
 		],
-		dependencies: [ ffi_dep, ]
+		dependencies: [ ffi_dep, virtio_wl_dep ]
 	)
 
 	wayland_private_dep = declare_dependency(
@@ -230,5 +243,6 @@ if get_option('libraries')
 		'wayland-server-core.h',
 		'wayland-client.h',
 		'wayland-client-core.h',
+		'virtio_wl.h',
 	])
 endif
diff --git a/src/virtio_wl.c b/src/virtio_wl.c
new file mode 100644
index 0000000..a1ee8ee
--- /dev/null
+++ b/src/virtio_wl.c
@@ -0,0 +1,344 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#define _POSIX_C_SOURCE 200809L
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <linux/virtwl.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include "virtio_wl.h"
+
+// This is essentially vendored reusable library code, so I consider
+// it exempt from the Wayland style guide. :)
+
+#if defined(__GNUC__) && __GNUC__ >= 4
+#define VIRTIO_WL_EXPORT __attribute__ ((visibility("default")))
+#else
+#define VIRTIO_WL_EXPORT
+#endif
+
+static int
+set_nonblocking(int fd)
+{
+	int fl = fcntl(fd, F_GETFL);
+	if (fl == -1)
+		return -1;
+	if (!(fl & O_NONBLOCK))
+		if (fcntl(fd, F_SETFL, fl | O_NONBLOCK) == -1)
+			return -1;
+	return fl;
+}
+
+// Returns the total size of all buffers in an iovec.
+// A return value of -1 means that the total overflowed.
+static ssize_t
+iov_len(const struct iovec iov[], size_t n)
+{
+	size_t len = 0;
+	for (size_t i = 0; i < n; i++) {
+		if (SSIZE_MAX - len < iov[i].iov_len)
+			return -1;
+		len += iov[i].iov_len;
+	}
+	return len;
+}
+
+// Copies from an iovec array into a buffer.
+// The buffer is assumed to be large enough to hold all the data.
+// This length can be calculated with the iov_len function.
+static size_t
+iov_flatten(void *buf, size_t buflen, const struct iovec *iov, size_t iovlen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		const struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy((unsigned char *)buf + off, i->iov_base, len);
+		off += len;
+	}
+
+	return off;
+}
+
+// Copies from a buffer into an iovec array.
+// Returns number of bytes copied.
+static size_t
+iov_fill(struct iovec *iov, size_t iovlen, const void *buf, size_t buflen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy(i->iov_base, (const unsigned char *)buf + off, len);
+		off += len;
+	}
+
+	return off;
+}
+
+static ssize_t
+cmsg_to_fdbuf(int *buf, size_t buflen, const struct msghdr *msg)
+{
+	size_t next_fd = 0;
+
+	for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	     cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
+		size_t size = cmsg->cmsg_len - CMSG_LEN(0);
+
+		// Check the cmsg can be handled.
+		if (cmsg->cmsg_level != SOL_SOCKET ||
+		    cmsg->cmsg_type != SCM_RIGHTS ||
+		    size % sizeof(int) != 0) {
+			errno = EINVAL;
+			return -1;
+		}
+
+		size_t rem = sizeof(int) * (buflen - next_fd);
+		size_t len = size > rem ? rem : size;
+
+		// Copy the fds to the buffer.
+		memcpy(buf + next_fd, CMSG_DATA(cmsg), len);
+		next_fd += len / sizeof(int);
+
+		if (size > rem)
+			break;
+	}
+
+	return next_fd > SSIZE_MAX ? SSIZE_MAX : next_fd;
+}
+
+static size_t
+fdbuf_to_cmsg(struct msghdr *msg, const int *buf, size_t buflen)
+{
+	// Check msg->msg_control is long enough to fit at least one fd.
+	if (msg->msg_controllen < CMSG_SPACE(sizeof(int))) {
+
+		// If there's at least one fd in buf, set MSG_CTRUNC.
+		size_t i = 0;
+		while (i < buflen && !(msg->msg_flags & MSG_CTRUNC))
+			if (buf[i++] != -1)
+				msg->msg_flags |= MSG_CTRUNC;
+
+		return 0;
+	}
+
+	// cmsg(3):
+	// > When initializing a buffer that will contain a series of cmsghdr
+        // > structures (e.g., to be sent with sendmsg(2)), that buffer should
+        // > first be zero-initialized to en‐ sure the correct operation of
+        // > CMSG_NXTHDR().
+	memset(msg->msg_control, 0, msg->msg_controllen);
+
+	// Set up the cmsg.
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+
+	// Copy as many fds as fit into cmsg.
+	size_t len = 0;
+	for (size_t i = 0; i < buflen; i++) {
+		if (buf[i] == -1)
+			continue;
+
+		if (CMSG_LEN((len + 1) * sizeof(int)) > msg->msg_controllen) {
+			msg->msg_flags |= MSG_CTRUNC;
+			break;
+		}
+
+		memcpy(CMSG_DATA(cmsg) + sizeof(int) * len, &buf[i], sizeof(int));
+		len++;
+	}
+
+	cmsg->cmsg_len = CMSG_LEN(len * sizeof(int));
+	return len;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_connect(const char *name, uint32_t flags)
+{
+	static int wl_fd = -1;
+	if (wl_fd < 0)
+		wl_fd = open("/dev/wl0", O_RDWR | O_CLOEXEC);
+	if (wl_fd < 0)
+		return wl_fd;
+
+	struct virtwl_ioctl_new new_ctx = {
+		.type = name ? VIRTWL_IOCTL_NEW_CTX_NAMED : VIRTWL_IOCTL_NEW_CTX,
+		.fd = -1,
+		.flags = flags,
+	};
+	// Device assumes name 32 bytes long if not null terminated.
+	if (name)
+		strncpy(new_ctx.name, name, sizeof(new_ctx.name));
+
+	if (ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx))
+		return -1;
+
+	return new_ctx.fd;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_SEND, ioctl_txn);
+}
+
+static int
+msghdr_to_txn(const struct msghdr *msg, struct virtwl_ioctl_txn **txn)
+{
+	// Make sure that if there's an error and the caller tries to free *txn
+	// anyway it doesn't end up freeing an invalid address.
+	*txn = NULL;
+
+	// Figure out how big the txn needs to be.
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	// Allocate the txn.
+	*txn = malloc(sizeof(**txn) + len);
+	if (!*txn)
+		return -1;
+
+	// Set the len member of the txn.
+	(*txn)->len = len;
+
+	// Copy data from the iovec into the transaction.
+	iov_flatten((*txn)->data, len, msg->msg_iov, msg->msg_iovlen);
+
+	// Copy file descriptors to the txn.
+	ssize_t fd_count = cmsg_to_fdbuf((*txn)->fds, VIRTWL_SEND_MAX_ALLOCS, msg);
+	if (fd_count == -1) {
+		free(*txn);
+		*txn = NULL;
+		return -1;
+	}
+
+	// Fill the rest of the fd buffer with -1.
+	while (fd_count < VIRTWL_SEND_MAX_ALLOCS)
+		(*txn)->fds[fd_count++] = -1;
+
+	return 0;
+}
+
+static void
+txn_to_msghdr(struct msghdr *msg, const struct virtwl_ioctl_txn *txn)
+{
+	// Copy txn data to iovecs.
+	iov_fill(msg->msg_iov, msg->msg_iovlen, txn->data, txn->len);
+
+	// Copy fds to cmsg.
+	fdbuf_to_cmsg(msg, txn->fds, VIRTWL_SEND_MAX_ALLOCS);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_sendmsg(int sockfd, const struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_NOSIGNAL) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn;
+	if (msghdr_to_txn(msg, &txn) == -1)
+		return -1;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	int rv = virtio_wl_send_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK)) {
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	if (rv != -1)
+		rv = txn->len;
+
+	free(txn);
+	return rv;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_RECV, ioctl_txn);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_recvmsg(int sockfd, struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_CMSG_CLOEXEC) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	if (!txn)
+		return -1;
+
+	txn->len = len;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	ssize_t rv = virtio_wl_recv_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK))
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1)
+			rv = -1;
+
+	if (rv == -1) {
+		for (size_t i = 0; i < VIRTWL_SEND_MAX_ALLOCS; i++)
+			if (txn->fds[i] == -1)
+				close(txn->fds[i]);
+	} else {
+		txn_to_msghdr(msg, txn);
+		rv = txn->len;
+	}
+
+	free(txn);
+	return rv;
+}
diff --git a/src/virtio_wl.h b/src/virtio_wl.h
new file mode 100644
index 0000000..80bcd38
--- /dev/null
+++ b/src/virtio_wl.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#ifndef UTIL_VIRTIO_WL_H
+#define UTIL_VIRTIO_WL_H
+
+#include <sys/types.h>
+
+struct virtwl_ioctl_txn;
+struct msghdr;
+
+int virtio_wl_connect(const char *name, uint32_t flags);
+
+ssize_t virtio_wl_sendmsg(int sockfd, const struct msghdr *, int flags);
+int virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+ssize_t virtio_wl_recvmsg(int sockfd, struct msghdr *, int flags);
+int virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+#endif
diff --git a/src/wayland-os.c b/src/wayland-os.c
index 93b6f5f..c92aaae 100644
--- a/src/wayland-os.c
+++ b/src/wayland-os.c
@@ -31,6 +31,9 @@
 #include <fcntl.h>
 #include <errno.h>
 #include <sys/epoll.h>
+#include <virtio_wl.h>
+#include <string.h>
+#include <stdio.h>
 
 #include "../config.h"
 #include "wayland-os.h"
@@ -126,6 +129,10 @@ wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags)
 	len = recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
 	if (len >= 0)
 		return len;
+
+	if (errno == ENOTSOCK)
+		return virtio_wl_recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
+
 	if (errno != EINVAL)
 		return -1;
 
@@ -165,3 +172,19 @@ wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
 	fd = accept(sockfd, addr, addrlen);
 	return set_cloexec_or_close(fd);
 }
+
+int
+wl_virtio_accept_cloexec(int sockfd)
+{
+	char name[32];
+	size_t offset = 0;
+
+	while (offset < sizeof name) {
+		ssize_t size = read(sockfd, name + offset, sizeof name - offset);
+		if (size < 0)
+			return size;
+		offset += size;
+	}
+
+	return virtio_wl_connect(name, 0);
+}
diff --git a/src/wayland-os.h b/src/wayland-os.h
index f51efaa..899ac8e 100644
--- a/src/wayland-os.h
+++ b/src/wayland-os.h
@@ -41,6 +41,9 @@ wl_os_epoll_create_cloexec(void);
 int
 wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen);
 
+int
+wl_virtio_accept_cloexec(int sockfd);
+
 
 /*
  * The following are for wayland-os.c and the unit tests.
diff --git a/src/wayland-server.c b/src/wayland-server.c
index 3f48dfe..0146ca8 100644
--- a/src/wayland-server.c
+++ b/src/wayland-server.c
@@ -531,8 +531,18 @@ wl_client_create(struct wl_display *display, int fd)
 
 	len = sizeof client->ucred;
 	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED,
-		       &client->ucred, &len) < 0)
-		goto err_source;
+		       &client->ucred, &len) < 0 && errno != ENOTSOCK) {
+		if (errno == ENOTSOCK) {
+			// Probably a virtio_wl socket.
+			// We don't have credential information, so fill with
+			// values that should never match real ones.
+			client->ucred.pid = -1;
+			client->ucred.uid = -1;
+			client->ucred.gid = -1;
+		} else {
+			goto err_source;
+		}
+	}
 
 	client->connection = wl_connection_create(fd);
 	if (client->connection == NULL)
@@ -1419,6 +1429,9 @@ socket_data(int fd, uint32_t mask, void *data)
 	length = sizeof name;
 	client_fd = wl_os_accept_cloexec(fd, (struct sockaddr *) &name,
 					 &length);
+	if (client_fd < 0 && errno == ENOTSOCK)
+		client_fd = wl_virtio_accept_cloexec(fd);
+
 	if (client_fd < 0)
 		wl_log("failed to accept: %s\n", strerror(errno));
 	else
@@ -1603,7 +1616,8 @@ wl_display_add_socket_fd(struct wl_display *display, int sock_fd)
 	struct stat buf;
 
 	/* Require a valid fd or fail */
-	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 || !S_ISSOCK(buf.st_mode)) {
+	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 ||
+	    ((buf.st_mode & S_IFMT) && !S_ISSOCK(buf.st_mode))) {
 		return -1;
 	}
 
diff --git a/src/wayland-shm.c b/src/wayland-shm.c
index b85e5a7..636ee7a 100644
--- a/src/wayland-shm.c
+++ b/src/wayland-shm.c
@@ -61,6 +61,7 @@ struct wl_shm_pool {
 	int internal_refcount;
 	int external_refcount;
 	char *data;
+	int fd;
 	int32_t size;
 	int32_t new_size;
 	bool sigbus_is_impossible;
@@ -91,14 +92,27 @@ shm_pool_finish_resize(struct wl_shm_pool *pool)
 
 	data = mremap(pool->data, pool->size, pool->new_size, MREMAP_MAYMOVE);
 	if (data == MAP_FAILED) {
-		wl_resource_post_error(pool->resource,
-				       WL_SHM_ERROR_INVALID_FD,
-				       "failed mremap");
-		return;
+		if (errno != EFAULT)
+			goto fail;
+
+		if (munmap(pool->data, pool->size) == -1)
+			goto fail;
+
+		data = mmap(pool->data, pool->new_size, PROT_READ | PROT_WRITE,
+			    MAP_SHARED, pool->fd, 0);
+		if (data == MAP_FAILED)
+			goto fail;
 	}
 
 	pool->data = data;
 	pool->size = pool->new_size;
+
+	return;
+
+ fail:
+		wl_resource_post_error(pool->resource,
+				       WL_SHM_ERROR_INVALID_FD,
+				       "failed mremap");
 }
 
 static void
@@ -291,6 +305,7 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 	pool->external_refcount = 0;
 	pool->size = size;
 	pool->new_size = size;
+	pool->fd = fd;
 	pool->data = mmap(NULL, size,
 			  PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
 	if (pool->data == MAP_FAILED) {
@@ -300,7 +315,6 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 				       strerror(errno));
 		goto err_free;
 	}
-	close(fd);
 
 	pool->resource =
 		wl_resource_create(client, &wl_shm_pool_interface, 1, id);
diff --git a/tests/meson.build b/tests/meson.build
index 224f48d..6ee2b49 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -148,6 +148,7 @@ tests = {
 		'headers-protocol-core-test.c',
 	],
 	'os-wrappers-test': [],
+	'virtio_wl-test': [],
 }
 
 foreach test_name, test_extra_sources: tests
diff --git a/tests/virtio_wl-test.c b/tests/virtio_wl-test.c
new file mode 100644
index 0000000..16c7a93
--- /dev/null
+++ b/tests/virtio_wl-test.c
@@ -0,0 +1,615 @@
+#define _GNU_SOURCE
+#include <assert.h>
+
+#include "test-runner.h"
+#include "virtio_wl.c"
+
+TEST(iov_len_fit)
+{
+	int f1[] = { 1, 2 };
+	int f2[] = { 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == 3 * sizeof(int));
+}
+
+TEST(iov_len_overflow)
+{
+	struct iovec iov[] = {
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == -1);
+}
+
+TEST(iov_flatten_test)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	int buf[2];
+	iov_flatten(buf, sizeof(buf), iov, sizeof(iov) / sizeof(*iov));
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+}
+
+TEST(iov_flatten_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_flatten(buf, 1, NULL, 0) == 0);
+}
+
+TEST(iov_flatten_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_flatten(NULL, 0, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+	assert(buf[3] == 6);
+}
+
+TEST(iov_flatten_exact_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(f1[0] == 1);
+	assert(f2[0] == 2);
+	assert(f2[1] == 3);
+}
+
+TEST(iov_flatten_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+}
+
+TEST(iov_fill_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_fill(NULL, 0, buf, 1) == 0);
+}
+
+TEST(iov_fill_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 0);
+}
+
+TEST(iov_fill_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), NULL, 0) == 0);
+}
+
+TEST(iov_fill_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_exact)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+	assert(f2[2] == 4);
+}
+
+TEST(cmsg_to_fdbuf_zero_buf)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(cmsg_to_fdbuf(NULL, 0, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_zero_cmsg)
+{
+	int fds[2] = { 0 };
+	struct msghdr msg = { 0 };
+	assert(cmsg_to_fdbuf(fds, sizeof fds, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_small_buf)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+}
+
+TEST(cmsg_to_fdbuf_exact)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_small_cmsg)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+	assert(buf[1] == 0xFE);
+}
+
+TEST(cmsg_to_fdbuf_multiple_cmsg)
+{
+	int f1[] = { 0 };
+	int f2[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof f1) + CMSG_SPACE(sizeof f2)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_RIGHTS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof f1);
+	memcpy(CMSG_DATA(cmsg1), f1, sizeof f1);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof f2);
+	memcpy(CMSG_DATA(cmsg2), f2, sizeof f2);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_other_cmsg)
+{
+	struct ucred cred = {
+		.pid = getpid(),
+		.uid = getuid(),
+		.gid = getgid(),
+	};
+
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof cred) + CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_CREDENTIALS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof cred);
+	memcpy(CMSG_DATA(cmsg1), &cred, sizeof cred);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg2), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(cmsg_to_fdbuf_misaligned)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds + 1)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof(fds) + 1);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+	CMSG_DATA(cmsg)[sizeof(fds)] = 0;
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(fdbuf_to_cmsg_null_cmsg)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_empty_both)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { -1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_null_buf)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, NULL, 0) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_tiny_msg)
+{
+	union {
+		char buf[1];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_small_msg)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1, 2, 3 };
+	assert(CMSG_SPACE(sizeof fds) != sizeof u.buf);
+
+	// CMSG_SPACE(sizeof(int)) can equal CMSG_SPACE(sizeof(int) * 2)
+	size_t n = fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds));
+	assert(CMSG_SPACE(n * sizeof(int)) == sizeof u.buf);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_equal)
+{
+	int fds[] = { 1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_negative_one)
+{
+	int fds[] = { 1, -1, -1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+
+	unsigned char *data = CMSG_DATA(CMSG_FIRSTHDR(&msg));
+	assert(!memcmp(data, &fds[0], sizeof(int)));
+	assert(!memcmp(data + sizeof(int), &fds[3], sizeof(int)));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_small_buf)
+{
+	int fds[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 1);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(msghdr_to_txn_test)
+{
+	struct virtwl_ioctl_txn *txn;
+
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	const int fds[] = { 0, 1, 2 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(!msghdr_to_txn(&msg, &txn));
+
+	assert(txn->len == sizeof(f1) + sizeof(f2));
+	assert(!memcmp(txn->fds, fds, sizeof fds));
+	for (size_t i = sizeof(fds) / sizeof(*fds);
+	     i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		assert(txn->fds[i] == -1);
+	assert(!memcmp(txn->data, f1, sizeof f1));
+	assert(!memcmp(txn->data + sizeof f1, f2, sizeof f2));
+}
+
+TEST(txn_to_msghdr_test)
+{
+	const int data[] = { 1, 2, 3 };
+
+	size_t len = sizeof(data);
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	assert(txn);
+
+	txn->len = len;
+
+	const int fds[] = { 0, 1, 2 };
+	memcpy(txn->fds, fds, sizeof fds);
+	for (int i = sizeof(fds) / sizeof(*fds); i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		txn->fds[i] = -1;
+
+	memcpy(txn->data, data, sizeof data);
+
+	int buf1[1], buf2[2];
+
+	struct iovec iov[] = {
+		{ .iov_base = buf1, .iov_len = sizeof buf1 },
+		{ .iov_base = buf2, .iov_len = sizeof buf2 },
+	};
+
+	unsigned char cmsg_buf[VIRTWL_SEND_MAX_ALLOCS];
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = cmsg_buf;
+	msg.msg_controllen = sizeof cmsg_buf;
+
+	txn_to_msghdr(&msg, txn);
+
+	assert(!memcmp(buf1, data, sizeof buf1));
+	assert(!memcmp(buf2, (unsigned char *)data + sizeof buf1, sizeof buf2));
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	assert(cmsg->cmsg_level == SOL_SOCKET);
+	assert(cmsg->cmsg_type == SCM_RIGHTS);
+	assert(cmsg->cmsg_len == CMSG_LEN(3 * sizeof(int)));
+	for (int i = 0; i < 3; i++)
+		assert(!memcmp(CMSG_DATA(cmsg), fds, sizeof fds));
+
+	assert(!CMSG_NXTHDR(&msg, cmsg));
+}
-- 
2.26.2

^ permalink raw reply	[flat|nested] 5+ messages in thread

* Re: [PATCH wayland] Support virtio_wl display sockets
  2020-07-08 15:16 [PATCH wayland] Support virtio_wl display sockets Alyssa Ross
@ 2020-07-08 15:43 ` Cole Helbling
  2020-07-08 21:07 ` [PATCH wayland v2] " Alyssa Ross
  2020-07-26  6:03 ` [PATCH wayland v3] " Alyssa Ross
  2 siblings, 0 replies; 5+ messages in thread
From: Cole Helbling @ 2020-07-08 15:43 UTC (permalink / raw)
  To: Alyssa Ross, devel

Alyssa Ross <hi@alyssa.is> writes:

>  src/connection.c       |   4 +
>  src/meson.build        |  16 +-
>  src/virtio_wl.c        | 344 +++++++++++++++++++++++
>  src/virtio_wl.h        |  23 ++
>  src/wayland-os.c       |  23 ++
>  src/wayland-os.h       |   3 +
>  src/wayland-server.c   |  20 +-
>  src/wayland-shm.c      |  24 +-
>  tests/meson.build      |   1 +
>  tests/virtio_wl-test.c | 615 +++++++++++++++++++++++++++++++++++++++++
>  10 files changed, 1064 insertions(+), 9 deletions(-)
>  create mode 100644 src/virtio_wl.c
>  create mode 100644 src/virtio_wl.h
>  create mode 100644 tests/virtio_wl-test.c
>
> diff --git a/src/virtio_wl.c b/src/virtio_wl.c
> new file mode 100644
> index 0000000..a1ee8ee
> --- /dev/null
> +++ b/src/virtio_wl.c

(snip)

> +static size_t
> +fdbuf_to_cmsg(struct msghdr *msg, const int *buf, size_t buflen)
> +{
> +	// Check msg->msg_control is long enough to fit at least one fd.
> +	if (msg->msg_controllen < CMSG_SPACE(sizeof(int))) {
> +
> +		// If there's at least one fd in buf, set MSG_CTRUNC.
> +		size_t i = 0;
> +		while (i < buflen && !(msg->msg_flags & MSG_CTRUNC))
> +			if (buf[i++] != -1)
> +				msg->msg_flags |= MSG_CTRUNC;
> +
> +		return 0;
> +	}
> +
> +	// cmsg(3):
> +	// > When initializing a buffer that will contain a series of cmsghdr
> +        // > structures (e.g., to be sent with sendmsg(2)), that buffer should
> +        // > first be zero-initialized to en‐ sure the correct operation of
> +        // > CMSG_NXTHDR().

This comment looks weird -- both the misalignment (tabs vs spaces, it
seems) and "en- sure" (probably reflow mistake).

> +	memset(msg->msg_control, 0, msg->msg_controllen);
> +
> +	// Set up the cmsg.
> +	struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
> +	cmsg->cmsg_level = SOL_SOCKET;
> +	cmsg->cmsg_type = SCM_RIGHTS;
> +
> +	// Copy as many fds as fit into cmsg.
> +	size_t len = 0;
> +	for (size_t i = 0; i < buflen; i++) {
> +		if (buf[i] == -1)
> +			continue;
> +
> +		if (CMSG_LEN((len + 1) * sizeof(int)) > msg->msg_controllen) {
> +			msg->msg_flags |= MSG_CTRUNC;
> +			break;
> +		}
> +
> +		memcpy(CMSG_DATA(cmsg) + sizeof(int) * len, &buf[i], sizeof(int));
> +		len++;
> +	}
> +
> +	cmsg->cmsg_len = CMSG_LEN(len * sizeof(int));
> +	return len;
> +}

I can't comment on the code itself, as I'm unfamiliar with both Wayland
and C. :D

Cole

^ permalink raw reply	[flat|nested] 5+ messages in thread

* [PATCH wayland v2] Support virtio_wl display sockets
  2020-07-08 15:16 [PATCH wayland] Support virtio_wl display sockets Alyssa Ross
  2020-07-08 15:43 ` Cole Helbling
@ 2020-07-08 21:07 ` Alyssa Ross
  2020-07-26  6:03 ` [PATCH wayland v3] " Alyssa Ross
  2 siblings, 0 replies; 5+ messages in thread
From: Alyssa Ross @ 2020-07-08 21:07 UTC (permalink / raw)
  To: devel; +Cc: Cole Helbling

This patch adds libvirtio_wl, which exposes reimplementations of the
sendmsg(2) and recvmsg(2) for virtio_wl socket fds.  Tests for as much
libvirtio_wl functionality as is reasonably possible to test without
requiring a working virtio_wl connection are included.  (Testing
further would require running tests in a VM so that they could talk to
the virtio_wl kernel driver, which would be prohibitively complex.

virtio_wl socket fds do not actually point to sockets, but to special
virtio_wl files.  Whenever a display socket operation calls sendmsg()
or recvmsg() with a file descriptor and receives an ENOTSOCK error, it
retries the operation with the equivalent libvirtio_wl function, in
case the file descriptor is a virtio_wl.  This is the least invasive
way to implement virtio_wl support -- if a normal socket is being
used, there will be no change in behaviour.

Because virtio_wl doesn't implement every socket feature, some
workarounds are currently required to accomodate everything Wayland
expects of sockets:

* virtio_wl_recvmsg implements the MSG_DONTWAIT flag by setting
  O_NONBLOCK on the fd, attempting the VIRTWL_IOCTL_RECV operation,
  and then restoring the fd's original flags.  This is obviously
  race-prone, but there is no reasonable alternative at present.

* virtio_wl_sendmsg requires MSG_NOSIGNAL to be set, and ignores it.
  This is because I think from looking at the code that virtio_wl does
  not generate SIGPIPE signals, nor does it ever return EPIPE.  I
  could be wrong about this, though.

* virtio_wl does not support credential passing -- what would it even
  mean, considering the other end of the connection is on
  another (virtual) machine?  So wl_client's ucred member will have
  pid, uid, and gid all set to -1 for a client connected over
  virtio_wl.

* virtio_wl sockets do not support accept(2), so a fallback is used.
  A proxy program on the host accept(2) on a host socket.  When it
  receives a connection, it attaches the connection socket to the VM,
  then sends the name of the connected socket over the Wayland display
  socket.  Wayland then receives this name, looks up the connection
  socket, and uses that as the client connection socket.

Additionally, virtio_wl memfd-like file descriptors don't support
mremap(2), so for virtio_wl sockets Wayland will munmap(2) the memfd,
and then mmap(2) it again.  This should be at least mostly okay
because Wayland only ever calls mremap with MREMAP_MAYMOVE, but it is
still race-prone.  To be able to do this, memfds are no longer closed
after being mmaped, but are kept around in the wl_shm_pool struct, so
that they can be passed to mmap() if required.

This patch appears to be enough to reliably run Alacritty on the host system,
with a Wayland compositor running in a VM.  It is not capable of running
Firefox, which most of the time fails to start, and occassionally will partially
start (sometimes even getting far enough to draw a window) before freezing.
---
Fixed formatting issue pointed out by Cole, added paragraph at the end
of commit message describing application support.

 src/connection.c       |   4 +
 src/meson.build        |  16 +-
 src/virtio_wl.c        | 344 +++++++++++++++++++++++
 src/virtio_wl.h        |  23 ++
 src/wayland-os.c       |  23 ++
 src/wayland-os.h       |   3 +
 src/wayland-server.c   |  20 +-
 src/wayland-shm.c      |  24 +-
 tests/meson.build      |   1 +
 tests/virtio_wl-test.c | 615 +++++++++++++++++++++++++++++++++++++++++
 10 files changed, 1064 insertions(+), 9 deletions(-)
 create mode 100644 src/virtio_wl.c
 create mode 100644 src/virtio_wl.h
 create mode 100644 tests/virtio_wl-test.c

diff --git a/src/connection.c b/src/connection.c
index d0c7d9f..1bfbd8c 100644
--- a/src/connection.c
+++ b/src/connection.c
@@ -40,6 +40,7 @@
 #include <time.h>
 #include <ffi.h>
 
+#include "virtio_wl.h"
 #include "wayland-util.h"
 #include "wayland-private.h"
 #include "wayland-os.h"
@@ -314,6 +315,9 @@ wl_connection_flush(struct wl_connection *connection)
 		do {
 			len = sendmsg(connection->fd, &msg,
 				      MSG_NOSIGNAL | MSG_DONTWAIT);
+			if (len == -1 && errno == ENOTSOCK)
+				len = virtio_wl_sendmsg(connection->fd, &msg,
+							MSG_NOSIGNAL | MSG_DONTWAIT);
 		} while (len == -1 && errno == EINTR);
 
 		if (len == -1)
diff --git a/src/meson.build b/src/meson.build
index 2d1485c..cca1d93 100644
--- a/src/meson.build
+++ b/src/meson.build
@@ -71,13 +71,26 @@ if get_option('libraries')
 	mathlib_dep = cc.find_library('m', required: false)
 	threads_dep = dependency('threads', required: false)
 
+	virtio_wl = library(
+		'virtio_wl',
+		sources: [
+			'virtio_wl.c',
+		],
+		version: '0.1.0',
+		install: true,
+	)
+
+	virtio_wl_dep = declare_dependency(
+		link_with: virtio_wl,
+	)
+
 	wayland_private = static_library(
 		'wayland-private',
 		sources: [
 			'connection.c',
 			'wayland-os.c'
 		],
-		dependencies: [ ffi_dep, ]
+		dependencies: [ ffi_dep, virtio_wl_dep ]
 	)
 
 	wayland_private_dep = declare_dependency(
@@ -230,5 +243,6 @@ if get_option('libraries')
 		'wayland-server-core.h',
 		'wayland-client.h',
 		'wayland-client-core.h',
+		'virtio_wl.h',
 	])
 endif
diff --git a/src/virtio_wl.c b/src/virtio_wl.c
new file mode 100644
index 0000000..3454950
--- /dev/null
+++ b/src/virtio_wl.c
@@ -0,0 +1,344 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#define _POSIX_C_SOURCE 200809L
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <linux/virtwl.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include "virtio_wl.h"
+
+// This is essentially vendored reusable library code, so I consider
+// it exempt from the Wayland style guide. :)
+
+#if defined(__GNUC__) && __GNUC__ >= 4
+#define VIRTIO_WL_EXPORT __attribute__ ((visibility("default")))
+#else
+#define VIRTIO_WL_EXPORT
+#endif
+
+static int
+set_nonblocking(int fd)
+{
+	int fl = fcntl(fd, F_GETFL);
+	if (fl == -1)
+		return -1;
+	if (!(fl & O_NONBLOCK))
+		if (fcntl(fd, F_SETFL, fl | O_NONBLOCK) == -1)
+			return -1;
+	return fl;
+}
+
+// Returns the total size of all buffers in an iovec.
+// A return value of -1 means that the total overflowed.
+static ssize_t
+iov_len(const struct iovec iov[], size_t n)
+{
+	size_t len = 0;
+	for (size_t i = 0; i < n; i++) {
+		if (SSIZE_MAX - len < iov[i].iov_len)
+			return -1;
+		len += iov[i].iov_len;
+	}
+	return len;
+}
+
+// Copies from an iovec array into a buffer.
+// The buffer is assumed to be large enough to hold all the data.
+// This length can be calculated with the iov_len function.
+static size_t
+iov_flatten(void *buf, size_t buflen, const struct iovec *iov, size_t iovlen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		const struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy((unsigned char *)buf + off, i->iov_base, len);
+		off += len;
+	}
+
+	return off;
+}
+
+// Copies from a buffer into an iovec array.
+// Returns number of bytes copied.
+static size_t
+iov_fill(struct iovec *iov, size_t iovlen, const void *buf, size_t buflen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy(i->iov_base, (const unsigned char *)buf + off, len);
+		off += len;
+	}
+
+	return off;
+}
+
+static ssize_t
+cmsg_to_fdbuf(int *buf, size_t buflen, const struct msghdr *msg)
+{
+	size_t next_fd = 0;
+
+	for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	     cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
+		size_t size = cmsg->cmsg_len - CMSG_LEN(0);
+
+		// Check the cmsg can be handled.
+		if (cmsg->cmsg_level != SOL_SOCKET ||
+		    cmsg->cmsg_type != SCM_RIGHTS ||
+		    size % sizeof(int) != 0) {
+			errno = EINVAL;
+			return -1;
+		}
+
+		size_t rem = sizeof(int) * (buflen - next_fd);
+		size_t len = size > rem ? rem : size;
+
+		// Copy the fds to the buffer.
+		memcpy(buf + next_fd, CMSG_DATA(cmsg), len);
+		next_fd += len / sizeof(int);
+
+		if (size > rem)
+			break;
+	}
+
+	return next_fd > SSIZE_MAX ? SSIZE_MAX : next_fd;
+}
+
+static size_t
+fdbuf_to_cmsg(struct msghdr *msg, const int *buf, size_t buflen)
+{
+	// Check msg->msg_control is long enough to fit at least one fd.
+	if (msg->msg_controllen < CMSG_SPACE(sizeof(int))) {
+
+		// If there's at least one fd in buf, set MSG_CTRUNC.
+		size_t i = 0;
+		while (i < buflen && !(msg->msg_flags & MSG_CTRUNC))
+			if (buf[i++] != -1)
+				msg->msg_flags |= MSG_CTRUNC;
+
+		return 0;
+	}
+
+	// cmsg(3):
+	// > When initializing a buffer that will contain a series of cmsghdr
+	// > structures (e.g., to be sent with sendmsg(2)), that buffer should
+	// > first be zero-initialized to ensure the correct operation of
+	// > CMSG_NXTHDR().
+	memset(msg->msg_control, 0, msg->msg_controllen);
+
+	// Set up the cmsg.
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+
+	// Copy as many fds as fit into cmsg.
+	size_t len = 0;
+	for (size_t i = 0; i < buflen; i++) {
+		if (buf[i] == -1)
+			continue;
+
+		if (CMSG_LEN((len + 1) * sizeof(int)) > msg->msg_controllen) {
+			msg->msg_flags |= MSG_CTRUNC;
+			break;
+		}
+
+		memcpy(CMSG_DATA(cmsg) + sizeof(int) * len, &buf[i], sizeof(int));
+		len++;
+	}
+
+	cmsg->cmsg_len = CMSG_LEN(len * sizeof(int));
+	return len;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_connect(const char *name, uint32_t flags)
+{
+	static int wl_fd = -1;
+	if (wl_fd < 0)
+		wl_fd = open("/dev/wl0", O_RDWR | O_CLOEXEC);
+	if (wl_fd < 0)
+		return wl_fd;
+
+	struct virtwl_ioctl_new new_ctx = {
+		.type = name ? VIRTWL_IOCTL_NEW_CTX_NAMED : VIRTWL_IOCTL_NEW_CTX,
+		.fd = -1,
+		.flags = flags,
+	};
+	// Device assumes name 32 bytes long if not null terminated.
+	if (name)
+		strncpy(new_ctx.name, name, sizeof(new_ctx.name));
+
+	if (ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx))
+		return -1;
+
+	return new_ctx.fd;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_SEND, ioctl_txn);
+}
+
+static int
+msghdr_to_txn(const struct msghdr *msg, struct virtwl_ioctl_txn **txn)
+{
+	// Make sure that if there's an error and the caller tries to free *txn
+	// anyway it doesn't end up freeing an invalid address.
+	*txn = NULL;
+
+	// Figure out how big the txn needs to be.
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	// Allocate the txn.
+	*txn = malloc(sizeof(**txn) + len);
+	if (!*txn)
+		return -1;
+
+	// Set the len member of the txn.
+	(*txn)->len = len;
+
+	// Copy data from the iovec into the transaction.
+	iov_flatten((*txn)->data, len, msg->msg_iov, msg->msg_iovlen);
+
+	// Copy file descriptors to the txn.
+	ssize_t fd_count = cmsg_to_fdbuf((*txn)->fds, VIRTWL_SEND_MAX_ALLOCS, msg);
+	if (fd_count == -1) {
+		free(*txn);
+		*txn = NULL;
+		return -1;
+	}
+
+	// Fill the rest of the fd buffer with -1.
+	while (fd_count < VIRTWL_SEND_MAX_ALLOCS)
+		(*txn)->fds[fd_count++] = -1;
+
+	return 0;
+}
+
+static void
+txn_to_msghdr(struct msghdr *msg, const struct virtwl_ioctl_txn *txn)
+{
+	// Copy txn data to iovecs.
+	iov_fill(msg->msg_iov, msg->msg_iovlen, txn->data, txn->len);
+
+	// Copy fds to cmsg.
+	fdbuf_to_cmsg(msg, txn->fds, VIRTWL_SEND_MAX_ALLOCS);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_sendmsg(int sockfd, const struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_NOSIGNAL) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn;
+	if (msghdr_to_txn(msg, &txn) == -1)
+		return -1;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	int rv = virtio_wl_send_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK)) {
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	if (rv != -1)
+		rv = txn->len;
+
+	free(txn);
+	return rv;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_RECV, ioctl_txn);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_recvmsg(int sockfd, struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_CMSG_CLOEXEC) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	if (!txn)
+		return -1;
+
+	txn->len = len;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	ssize_t rv = virtio_wl_recv_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK))
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1)
+			rv = -1;
+
+	if (rv == -1) {
+		for (size_t i = 0; i < VIRTWL_SEND_MAX_ALLOCS; i++)
+			if (txn->fds[i] == -1)
+				close(txn->fds[i]);
+	} else {
+		txn_to_msghdr(msg, txn);
+		rv = txn->len;
+	}
+
+	free(txn);
+	return rv;
+}
diff --git a/src/virtio_wl.h b/src/virtio_wl.h
new file mode 100644
index 0000000..80bcd38
--- /dev/null
+++ b/src/virtio_wl.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#ifndef UTIL_VIRTIO_WL_H
+#define UTIL_VIRTIO_WL_H
+
+#include <sys/types.h>
+
+struct virtwl_ioctl_txn;
+struct msghdr;
+
+int virtio_wl_connect(const char *name, uint32_t flags);
+
+ssize_t virtio_wl_sendmsg(int sockfd, const struct msghdr *, int flags);
+int virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+ssize_t virtio_wl_recvmsg(int sockfd, struct msghdr *, int flags);
+int virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+#endif
diff --git a/src/wayland-os.c b/src/wayland-os.c
index 93b6f5f..c92aaae 100644
--- a/src/wayland-os.c
+++ b/src/wayland-os.c
@@ -31,6 +31,9 @@
 #include <fcntl.h>
 #include <errno.h>
 #include <sys/epoll.h>
+#include <virtio_wl.h>
+#include <string.h>
+#include <stdio.h>
 
 #include "../config.h"
 #include "wayland-os.h"
@@ -126,6 +129,10 @@ wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags)
 	len = recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
 	if (len >= 0)
 		return len;
+
+	if (errno == ENOTSOCK)
+		return virtio_wl_recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
+
 	if (errno != EINVAL)
 		return -1;
 
@@ -165,3 +172,19 @@ wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
 	fd = accept(sockfd, addr, addrlen);
 	return set_cloexec_or_close(fd);
 }
+
+int
+wl_virtio_accept_cloexec(int sockfd)
+{
+	char name[32];
+	size_t offset = 0;
+
+	while (offset < sizeof name) {
+		ssize_t size = read(sockfd, name + offset, sizeof name - offset);
+		if (size < 0)
+			return size;
+		offset += size;
+	}
+
+	return virtio_wl_connect(name, 0);
+}
diff --git a/src/wayland-os.h b/src/wayland-os.h
index f51efaa..899ac8e 100644
--- a/src/wayland-os.h
+++ b/src/wayland-os.h
@@ -41,6 +41,9 @@ wl_os_epoll_create_cloexec(void);
 int
 wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen);
 
+int
+wl_virtio_accept_cloexec(int sockfd);
+
 
 /*
  * The following are for wayland-os.c and the unit tests.
diff --git a/src/wayland-server.c b/src/wayland-server.c
index 3f48dfe..0146ca8 100644
--- a/src/wayland-server.c
+++ b/src/wayland-server.c
@@ -531,8 +531,18 @@ wl_client_create(struct wl_display *display, int fd)
 
 	len = sizeof client->ucred;
 	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED,
-		       &client->ucred, &len) < 0)
-		goto err_source;
+		       &client->ucred, &len) < 0 && errno != ENOTSOCK) {
+		if (errno == ENOTSOCK) {
+			// Probably a virtio_wl socket.
+			// We don't have credential information, so fill with
+			// values that should never match real ones.
+			client->ucred.pid = -1;
+			client->ucred.uid = -1;
+			client->ucred.gid = -1;
+		} else {
+			goto err_source;
+		}
+	}
 
 	client->connection = wl_connection_create(fd);
 	if (client->connection == NULL)
@@ -1419,6 +1429,9 @@ socket_data(int fd, uint32_t mask, void *data)
 	length = sizeof name;
 	client_fd = wl_os_accept_cloexec(fd, (struct sockaddr *) &name,
 					 &length);
+	if (client_fd < 0 && errno == ENOTSOCK)
+		client_fd = wl_virtio_accept_cloexec(fd);
+
 	if (client_fd < 0)
 		wl_log("failed to accept: %s\n", strerror(errno));
 	else
@@ -1603,7 +1616,8 @@ wl_display_add_socket_fd(struct wl_display *display, int sock_fd)
 	struct stat buf;
 
 	/* Require a valid fd or fail */
-	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 || !S_ISSOCK(buf.st_mode)) {
+	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 ||
+	    ((buf.st_mode & S_IFMT) && !S_ISSOCK(buf.st_mode))) {
 		return -1;
 	}
 
diff --git a/src/wayland-shm.c b/src/wayland-shm.c
index b85e5a7..636ee7a 100644
--- a/src/wayland-shm.c
+++ b/src/wayland-shm.c
@@ -61,6 +61,7 @@ struct wl_shm_pool {
 	int internal_refcount;
 	int external_refcount;
 	char *data;
+	int fd;
 	int32_t size;
 	int32_t new_size;
 	bool sigbus_is_impossible;
@@ -91,14 +92,27 @@ shm_pool_finish_resize(struct wl_shm_pool *pool)
 
 	data = mremap(pool->data, pool->size, pool->new_size, MREMAP_MAYMOVE);
 	if (data == MAP_FAILED) {
-		wl_resource_post_error(pool->resource,
-				       WL_SHM_ERROR_INVALID_FD,
-				       "failed mremap");
-		return;
+		if (errno != EFAULT)
+			goto fail;
+
+		if (munmap(pool->data, pool->size) == -1)
+			goto fail;
+
+		data = mmap(pool->data, pool->new_size, PROT_READ | PROT_WRITE,
+			    MAP_SHARED, pool->fd, 0);
+		if (data == MAP_FAILED)
+			goto fail;
 	}
 
 	pool->data = data;
 	pool->size = pool->new_size;
+
+	return;
+
+ fail:
+		wl_resource_post_error(pool->resource,
+				       WL_SHM_ERROR_INVALID_FD,
+				       "failed mremap");
 }
 
 static void
@@ -291,6 +305,7 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 	pool->external_refcount = 0;
 	pool->size = size;
 	pool->new_size = size;
+	pool->fd = fd;
 	pool->data = mmap(NULL, size,
 			  PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
 	if (pool->data == MAP_FAILED) {
@@ -300,7 +315,6 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 				       strerror(errno));
 		goto err_free;
 	}
-	close(fd);
 
 	pool->resource =
 		wl_resource_create(client, &wl_shm_pool_interface, 1, id);
diff --git a/tests/meson.build b/tests/meson.build
index 224f48d..6ee2b49 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -148,6 +148,7 @@ tests = {
 		'headers-protocol-core-test.c',
 	],
 	'os-wrappers-test': [],
+	'virtio_wl-test': [],
 }
 
 foreach test_name, test_extra_sources: tests
diff --git a/tests/virtio_wl-test.c b/tests/virtio_wl-test.c
new file mode 100644
index 0000000..16c7a93
--- /dev/null
+++ b/tests/virtio_wl-test.c
@@ -0,0 +1,615 @@
+#define _GNU_SOURCE
+#include <assert.h>
+
+#include "test-runner.h"
+#include "virtio_wl.c"
+
+TEST(iov_len_fit)
+{
+	int f1[] = { 1, 2 };
+	int f2[] = { 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == 3 * sizeof(int));
+}
+
+TEST(iov_len_overflow)
+{
+	struct iovec iov[] = {
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == -1);
+}
+
+TEST(iov_flatten_test)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	int buf[2];
+	iov_flatten(buf, sizeof(buf), iov, sizeof(iov) / sizeof(*iov));
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+}
+
+TEST(iov_flatten_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_flatten(buf, 1, NULL, 0) == 0);
+}
+
+TEST(iov_flatten_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_flatten(NULL, 0, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+	assert(buf[3] == 6);
+}
+
+TEST(iov_flatten_exact_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(f1[0] == 1);
+	assert(f2[0] == 2);
+	assert(f2[1] == 3);
+}
+
+TEST(iov_flatten_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+}
+
+TEST(iov_fill_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_fill(NULL, 0, buf, 1) == 0);
+}
+
+TEST(iov_fill_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 0);
+}
+
+TEST(iov_fill_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), NULL, 0) == 0);
+}
+
+TEST(iov_fill_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_exact)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+	assert(f2[2] == 4);
+}
+
+TEST(cmsg_to_fdbuf_zero_buf)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(cmsg_to_fdbuf(NULL, 0, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_zero_cmsg)
+{
+	int fds[2] = { 0 };
+	struct msghdr msg = { 0 };
+	assert(cmsg_to_fdbuf(fds, sizeof fds, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_small_buf)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+}
+
+TEST(cmsg_to_fdbuf_exact)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_small_cmsg)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+	assert(buf[1] == 0xFE);
+}
+
+TEST(cmsg_to_fdbuf_multiple_cmsg)
+{
+	int f1[] = { 0 };
+	int f2[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof f1) + CMSG_SPACE(sizeof f2)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_RIGHTS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof f1);
+	memcpy(CMSG_DATA(cmsg1), f1, sizeof f1);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof f2);
+	memcpy(CMSG_DATA(cmsg2), f2, sizeof f2);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_other_cmsg)
+{
+	struct ucred cred = {
+		.pid = getpid(),
+		.uid = getuid(),
+		.gid = getgid(),
+	};
+
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof cred) + CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_CREDENTIALS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof cred);
+	memcpy(CMSG_DATA(cmsg1), &cred, sizeof cred);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg2), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(cmsg_to_fdbuf_misaligned)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds + 1)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof(fds) + 1);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+	CMSG_DATA(cmsg)[sizeof(fds)] = 0;
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(fdbuf_to_cmsg_null_cmsg)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_empty_both)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { -1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_null_buf)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, NULL, 0) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_tiny_msg)
+{
+	union {
+		char buf[1];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_small_msg)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1, 2, 3 };
+	assert(CMSG_SPACE(sizeof fds) != sizeof u.buf);
+
+	// CMSG_SPACE(sizeof(int)) can equal CMSG_SPACE(sizeof(int) * 2)
+	size_t n = fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds));
+	assert(CMSG_SPACE(n * sizeof(int)) == sizeof u.buf);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_equal)
+{
+	int fds[] = { 1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_negative_one)
+{
+	int fds[] = { 1, -1, -1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+
+	unsigned char *data = CMSG_DATA(CMSG_FIRSTHDR(&msg));
+	assert(!memcmp(data, &fds[0], sizeof(int)));
+	assert(!memcmp(data + sizeof(int), &fds[3], sizeof(int)));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_small_buf)
+{
+	int fds[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 1);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(msghdr_to_txn_test)
+{
+	struct virtwl_ioctl_txn *txn;
+
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	const int fds[] = { 0, 1, 2 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(!msghdr_to_txn(&msg, &txn));
+
+	assert(txn->len == sizeof(f1) + sizeof(f2));
+	assert(!memcmp(txn->fds, fds, sizeof fds));
+	for (size_t i = sizeof(fds) / sizeof(*fds);
+	     i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		assert(txn->fds[i] == -1);
+	assert(!memcmp(txn->data, f1, sizeof f1));
+	assert(!memcmp(txn->data + sizeof f1, f2, sizeof f2));
+}
+
+TEST(txn_to_msghdr_test)
+{
+	const int data[] = { 1, 2, 3 };
+
+	size_t len = sizeof(data);
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	assert(txn);
+
+	txn->len = len;
+
+	const int fds[] = { 0, 1, 2 };
+	memcpy(txn->fds, fds, sizeof fds);
+	for (int i = sizeof(fds) / sizeof(*fds); i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		txn->fds[i] = -1;
+
+	memcpy(txn->data, data, sizeof data);
+
+	int buf1[1], buf2[2];
+
+	struct iovec iov[] = {
+		{ .iov_base = buf1, .iov_len = sizeof buf1 },
+		{ .iov_base = buf2, .iov_len = sizeof buf2 },
+	};
+
+	unsigned char cmsg_buf[VIRTWL_SEND_MAX_ALLOCS];
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = cmsg_buf;
+	msg.msg_controllen = sizeof cmsg_buf;
+
+	txn_to_msghdr(&msg, txn);
+
+	assert(!memcmp(buf1, data, sizeof buf1));
+	assert(!memcmp(buf2, (unsigned char *)data + sizeof buf1, sizeof buf2));
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	assert(cmsg->cmsg_level == SOL_SOCKET);
+	assert(cmsg->cmsg_type == SCM_RIGHTS);
+	assert(cmsg->cmsg_len == CMSG_LEN(3 * sizeof(int)));
+	for (int i = 0; i < 3; i++)
+		assert(!memcmp(CMSG_DATA(cmsg), fds, sizeof fds));
+
+	assert(!CMSG_NXTHDR(&msg, cmsg));
+}
-- 
2.26.2

^ permalink raw reply	[flat|nested] 5+ messages in thread

* [PATCH wayland v3] Support virtio_wl display sockets
  2020-07-08 15:16 [PATCH wayland] Support virtio_wl display sockets Alyssa Ross
  2020-07-08 15:43 ` Cole Helbling
  2020-07-08 21:07 ` [PATCH wayland v2] " Alyssa Ross
@ 2020-07-26  6:03 ` Alyssa Ross
  2020-07-27 18:37   ` Cole Helbling
  2 siblings, 1 reply; 5+ messages in thread
From: Alyssa Ross @ 2020-07-26  6:03 UTC (permalink / raw)
  To: devel; +Cc: Cole Helbling

This patch adds libvirtio_wl, which exposes reimplementations of the
sendmsg(2) and recvmsg(2) for virtio_wl socket fds.  Tests for as much
libvirtio_wl functionality as is reasonably possible to test without
requiring a working virtio_wl connection are included.  (Testing
further would require running tests in a VM so that they could talk to
the virtio_wl kernel driver, which would be prohibitively complex.

virtio_wl socket fds do not actually point to sockets, but to special
virtio_wl files.  Whenever a display socket operation calls sendmsg()
or recvmsg() with a file descriptor and receives an ENOTSOCK error, it
retries the operation with the equivalent libvirtio_wl function, in
case the file descriptor is a virtio_wl.  This is the least invasive
way to implement virtio_wl support -- if a normal socket is being
used, there will be no change in behaviour.

Because virtio_wl doesn't implement every socket feature, some
workarounds are currently required to accomodate everything Wayland
expects of sockets:

* virtio_wl_recvmsg implements the MSG_DONTWAIT flag by setting
  O_NONBLOCK on the fd, attempting the VIRTWL_IOCTL_RECV operation,
  and then restoring the fd's original flags.  This is obviously
  race-prone, but there is no reasonable alternative at present.

* virtio_wl_sendmsg requires MSG_NOSIGNAL to be set, and ignores it.
  This is because I think from looking at the code that virtio_wl does
  not generate SIGPIPE signals, nor does it ever return EPIPE.  I
  could be wrong about this, though.

* virtio_wl does not support credential passing -- what would it even
  mean, considering the other end of the connection is on
  another (virtual) machine?  So wl_client's ucred member will have
  pid, uid, and gid all set to -1 for a client connected over
  virtio_wl.

* virtio_wl sockets do not support accept(2), so a fallback is used.
  A proxy program on the host accept(2) on a host socket.  When it
  receives a connection, it attaches the connection socket to the VM,
  then sends the name of the connected socket over the Wayland display
  socket.  Wayland then receives this name, looks up the connection
  socket, and uses that as the client connection socket.

Additionally, virtio_wl memfd-like file descriptors don't support
mremap(2), so for virtio_wl sockets Wayland will munmap(2) the memfd,
and then mmap(2) it again.  This should be at least mostly okay
because Wayland only ever calls mremap with MREMAP_MAYMOVE, but it is
still race-prone.  To be able to do this, memfds are no longer closed
after being mmaped, but are kept around in the wl_shm_pool struct, so
that they can be passed to mmap() if required.

This patch appears to be enough to reliably run Alacritty on the host system,
with a Wayland compositor running in a VM.  It is not capable of running
Firefox, which most of the time fails to start, and occassionally will partially
start (sometimes even getting far enough to draw a window) before freezing.
---
Cole pointed out on IRC that previous versions of this patch
unnecessarily included string.h and stdio.h in src/wayland.os.c.
These includes have now been removed.

 src/connection.c       |   4 +
 src/meson.build        |  16 +-
 src/virtio_wl.c        | 344 +++++++++++++++++++++++
 src/virtio_wl.h        |  23 ++
 src/wayland-os.c       |  21 ++
 src/wayland-os.h       |   3 +
 src/wayland-server.c   |  20 +-
 src/wayland-shm.c      |  24 +-
 tests/meson.build      |   1 +
 tests/virtio_wl-test.c | 615 +++++++++++++++++++++++++++++++++++++++++
 10 files changed, 1062 insertions(+), 9 deletions(-)
 create mode 100644 src/virtio_wl.c
 create mode 100644 src/virtio_wl.h
 create mode 100644 tests/virtio_wl-test.c

diff --git a/src/connection.c b/src/connection.c
index d0c7d9f..1bfbd8c 100644
--- a/src/connection.c
+++ b/src/connection.c
@@ -40,6 +40,7 @@
 #include <time.h>
 #include <ffi.h>
 
+#include "virtio_wl.h"
 #include "wayland-util.h"
 #include "wayland-private.h"
 #include "wayland-os.h"
@@ -314,6 +315,9 @@ wl_connection_flush(struct wl_connection *connection)
 		do {
 			len = sendmsg(connection->fd, &msg,
 				      MSG_NOSIGNAL | MSG_DONTWAIT);
+			if (len == -1 && errno == ENOTSOCK)
+				len = virtio_wl_sendmsg(connection->fd, &msg,
+							MSG_NOSIGNAL | MSG_DONTWAIT);
 		} while (len == -1 && errno == EINTR);
 
 		if (len == -1)
diff --git a/src/meson.build b/src/meson.build
index 2d1485c..cca1d93 100644
--- a/src/meson.build
+++ b/src/meson.build
@@ -71,13 +71,26 @@ if get_option('libraries')
 	mathlib_dep = cc.find_library('m', required: false)
 	threads_dep = dependency('threads', required: false)
 
+	virtio_wl = library(
+		'virtio_wl',
+		sources: [
+			'virtio_wl.c',
+		],
+		version: '0.1.0',
+		install: true,
+	)
+
+	virtio_wl_dep = declare_dependency(
+		link_with: virtio_wl,
+	)
+
 	wayland_private = static_library(
 		'wayland-private',
 		sources: [
 			'connection.c',
 			'wayland-os.c'
 		],
-		dependencies: [ ffi_dep, ]
+		dependencies: [ ffi_dep, virtio_wl_dep ]
 	)
 
 	wayland_private_dep = declare_dependency(
@@ -230,5 +243,6 @@ if get_option('libraries')
 		'wayland-server-core.h',
 		'wayland-client.h',
 		'wayland-client-core.h',
+		'virtio_wl.h',
 	])
 endif
diff --git a/src/virtio_wl.c b/src/virtio_wl.c
new file mode 100644
index 0000000..3454950
--- /dev/null
+++ b/src/virtio_wl.c
@@ -0,0 +1,344 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#define _POSIX_C_SOURCE 200809L
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <linux/virtwl.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include "virtio_wl.h"
+
+// This is essentially vendored reusable library code, so I consider
+// it exempt from the Wayland style guide. :)
+
+#if defined(__GNUC__) && __GNUC__ >= 4
+#define VIRTIO_WL_EXPORT __attribute__ ((visibility("default")))
+#else
+#define VIRTIO_WL_EXPORT
+#endif
+
+static int
+set_nonblocking(int fd)
+{
+	int fl = fcntl(fd, F_GETFL);
+	if (fl == -1)
+		return -1;
+	if (!(fl & O_NONBLOCK))
+		if (fcntl(fd, F_SETFL, fl | O_NONBLOCK) == -1)
+			return -1;
+	return fl;
+}
+
+// Returns the total size of all buffers in an iovec.
+// A return value of -1 means that the total overflowed.
+static ssize_t
+iov_len(const struct iovec iov[], size_t n)
+{
+	size_t len = 0;
+	for (size_t i = 0; i < n; i++) {
+		if (SSIZE_MAX - len < iov[i].iov_len)
+			return -1;
+		len += iov[i].iov_len;
+	}
+	return len;
+}
+
+// Copies from an iovec array into a buffer.
+// The buffer is assumed to be large enough to hold all the data.
+// This length can be calculated with the iov_len function.
+static size_t
+iov_flatten(void *buf, size_t buflen, const struct iovec *iov, size_t iovlen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		const struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy((unsigned char *)buf + off, i->iov_base, len);
+		off += len;
+	}
+
+	return off;
+}
+
+// Copies from a buffer into an iovec array.
+// Returns number of bytes copied.
+static size_t
+iov_fill(struct iovec *iov, size_t iovlen, const void *buf, size_t buflen)
+{
+	size_t off = 0;
+
+	for (size_t index = 0; index < iovlen && off < buflen; index++) {
+		struct iovec *i = &iov[index];
+		size_t rem = buflen - off;
+		size_t len = i->iov_len < rem ? i->iov_len : rem;
+
+		memcpy(i->iov_base, (const unsigned char *)buf + off, len);
+		off += len;
+	}
+
+	return off;
+}
+
+static ssize_t
+cmsg_to_fdbuf(int *buf, size_t buflen, const struct msghdr *msg)
+{
+	size_t next_fd = 0;
+
+	for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	     cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
+		size_t size = cmsg->cmsg_len - CMSG_LEN(0);
+
+		// Check the cmsg can be handled.
+		if (cmsg->cmsg_level != SOL_SOCKET ||
+		    cmsg->cmsg_type != SCM_RIGHTS ||
+		    size % sizeof(int) != 0) {
+			errno = EINVAL;
+			return -1;
+		}
+
+		size_t rem = sizeof(int) * (buflen - next_fd);
+		size_t len = size > rem ? rem : size;
+
+		// Copy the fds to the buffer.
+		memcpy(buf + next_fd, CMSG_DATA(cmsg), len);
+		next_fd += len / sizeof(int);
+
+		if (size > rem)
+			break;
+	}
+
+	return next_fd > SSIZE_MAX ? SSIZE_MAX : next_fd;
+}
+
+static size_t
+fdbuf_to_cmsg(struct msghdr *msg, const int *buf, size_t buflen)
+{
+	// Check msg->msg_control is long enough to fit at least one fd.
+	if (msg->msg_controllen < CMSG_SPACE(sizeof(int))) {
+
+		// If there's at least one fd in buf, set MSG_CTRUNC.
+		size_t i = 0;
+		while (i < buflen && !(msg->msg_flags & MSG_CTRUNC))
+			if (buf[i++] != -1)
+				msg->msg_flags |= MSG_CTRUNC;
+
+		return 0;
+	}
+
+	// cmsg(3):
+	// > When initializing a buffer that will contain a series of cmsghdr
+	// > structures (e.g., to be sent with sendmsg(2)), that buffer should
+	// > first be zero-initialized to ensure the correct operation of
+	// > CMSG_NXTHDR().
+	memset(msg->msg_control, 0, msg->msg_controllen);
+
+	// Set up the cmsg.
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+
+	// Copy as many fds as fit into cmsg.
+	size_t len = 0;
+	for (size_t i = 0; i < buflen; i++) {
+		if (buf[i] == -1)
+			continue;
+
+		if (CMSG_LEN((len + 1) * sizeof(int)) > msg->msg_controllen) {
+			msg->msg_flags |= MSG_CTRUNC;
+			break;
+		}
+
+		memcpy(CMSG_DATA(cmsg) + sizeof(int) * len, &buf[i], sizeof(int));
+		len++;
+	}
+
+	cmsg->cmsg_len = CMSG_LEN(len * sizeof(int));
+	return len;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_connect(const char *name, uint32_t flags)
+{
+	static int wl_fd = -1;
+	if (wl_fd < 0)
+		wl_fd = open("/dev/wl0", O_RDWR | O_CLOEXEC);
+	if (wl_fd < 0)
+		return wl_fd;
+
+	struct virtwl_ioctl_new new_ctx = {
+		.type = name ? VIRTWL_IOCTL_NEW_CTX_NAMED : VIRTWL_IOCTL_NEW_CTX,
+		.fd = -1,
+		.flags = flags,
+	};
+	// Device assumes name 32 bytes long if not null terminated.
+	if (name)
+		strncpy(new_ctx.name, name, sizeof(new_ctx.name));
+
+	if (ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx))
+		return -1;
+
+	return new_ctx.fd;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_SEND, ioctl_txn);
+}
+
+static int
+msghdr_to_txn(const struct msghdr *msg, struct virtwl_ioctl_txn **txn)
+{
+	// Make sure that if there's an error and the caller tries to free *txn
+	// anyway it doesn't end up freeing an invalid address.
+	*txn = NULL;
+
+	// Figure out how big the txn needs to be.
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	// Allocate the txn.
+	*txn = malloc(sizeof(**txn) + len);
+	if (!*txn)
+		return -1;
+
+	// Set the len member of the txn.
+	(*txn)->len = len;
+
+	// Copy data from the iovec into the transaction.
+	iov_flatten((*txn)->data, len, msg->msg_iov, msg->msg_iovlen);
+
+	// Copy file descriptors to the txn.
+	ssize_t fd_count = cmsg_to_fdbuf((*txn)->fds, VIRTWL_SEND_MAX_ALLOCS, msg);
+	if (fd_count == -1) {
+		free(*txn);
+		*txn = NULL;
+		return -1;
+	}
+
+	// Fill the rest of the fd buffer with -1.
+	while (fd_count < VIRTWL_SEND_MAX_ALLOCS)
+		(*txn)->fds[fd_count++] = -1;
+
+	return 0;
+}
+
+static void
+txn_to_msghdr(struct msghdr *msg, const struct virtwl_ioctl_txn *txn)
+{
+	// Copy txn data to iovecs.
+	iov_fill(msg->msg_iov, msg->msg_iovlen, txn->data, txn->len);
+
+	// Copy fds to cmsg.
+	fdbuf_to_cmsg(msg, txn->fds, VIRTWL_SEND_MAX_ALLOCS);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_sendmsg(int sockfd, const struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_NOSIGNAL) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn;
+	if (msghdr_to_txn(msg, &txn) == -1)
+		return -1;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	int rv = virtio_wl_send_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK)) {
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	if (rv != -1)
+		rv = txn->len;
+
+	free(txn);
+	return rv;
+}
+
+VIRTIO_WL_EXPORT int
+virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *ioctl_txn)
+{
+	return ioctl(sockfd, VIRTWL_IOCTL_RECV, ioctl_txn);
+}
+
+VIRTIO_WL_EXPORT ssize_t
+virtio_wl_recvmsg(int sockfd, struct msghdr *msg, int flags)
+{
+	int sockfl;
+
+	if ((flags & ~MSG_DONTWAIT) != MSG_CMSG_CLOEXEC) {
+		errno = EINVAL;
+		return -1;
+	}
+
+	ssize_t len = iov_len(msg->msg_iov, msg->msg_iovlen);
+	if (len < 0 || len > UINT32_MAX) {
+		errno = ENOMEM;
+		return -1;
+	}
+
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	if (!txn)
+		return -1;
+
+	txn->len = len;
+
+	if (flags & MSG_DONTWAIT) {
+		sockfl = set_nonblocking(sockfd);
+		if (sockfl == -1) {
+			free(txn);
+			return -1;
+		}
+	}
+
+	ssize_t rv = virtio_wl_recv_raw(sockfd, txn);
+
+	if ((flags & MSG_DONTWAIT) && !(sockfl & O_NONBLOCK))
+		if (fcntl(sockfd, F_SETFL, sockfl) == -1)
+			rv = -1;
+
+	if (rv == -1) {
+		for (size_t i = 0; i < VIRTWL_SEND_MAX_ALLOCS; i++)
+			if (txn->fds[i] == -1)
+				close(txn->fds[i]);
+	} else {
+		txn_to_msghdr(msg, txn);
+		rv = txn->len;
+	}
+
+	free(txn);
+	return rv;
+}
diff --git a/src/virtio_wl.h b/src/virtio_wl.h
new file mode 100644
index 0000000..80bcd38
--- /dev/null
+++ b/src/virtio_wl.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Alyssa Ross
+ *
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+#ifndef UTIL_VIRTIO_WL_H
+#define UTIL_VIRTIO_WL_H
+
+#include <sys/types.h>
+
+struct virtwl_ioctl_txn;
+struct msghdr;
+
+int virtio_wl_connect(const char *name, uint32_t flags);
+
+ssize_t virtio_wl_sendmsg(int sockfd, const struct msghdr *, int flags);
+int virtio_wl_send_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+ssize_t virtio_wl_recvmsg(int sockfd, struct msghdr *, int flags);
+int virtio_wl_recv_raw(int sockfd, struct virtwl_ioctl_txn *);
+
+#endif
diff --git a/src/wayland-os.c b/src/wayland-os.c
index 93b6f5f..18c59b3 100644
--- a/src/wayland-os.c
+++ b/src/wayland-os.c
@@ -31,6 +31,7 @@
 #include <fcntl.h>
 #include <errno.h>
 #include <sys/epoll.h>
+#include <virtio_wl.h>
 
 #include "../config.h"
 #include "wayland-os.h"
@@ -126,6 +127,10 @@ wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags)
 	len = recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
 	if (len >= 0)
 		return len;
+
+	if (errno == ENOTSOCK)
+		return virtio_wl_recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
+
 	if (errno != EINVAL)
 		return -1;
 
@@ -165,3 +170,19 @@ wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
 	fd = accept(sockfd, addr, addrlen);
 	return set_cloexec_or_close(fd);
 }
+
+int
+wl_virtio_accept_cloexec(int sockfd)
+{
+	char name[32];
+	size_t offset = 0;
+
+	while (offset < sizeof name) {
+		ssize_t size = read(sockfd, name + offset, sizeof name - offset);
+		if (size < 0)
+			return size;
+		offset += size;
+	}
+
+	return virtio_wl_connect(name, 0);
+}
diff --git a/src/wayland-os.h b/src/wayland-os.h
index f51efaa..899ac8e 100644
--- a/src/wayland-os.h
+++ b/src/wayland-os.h
@@ -41,6 +41,9 @@ wl_os_epoll_create_cloexec(void);
 int
 wl_os_accept_cloexec(int sockfd, struct sockaddr *addr, socklen_t *addrlen);
 
+int
+wl_virtio_accept_cloexec(int sockfd);
+
 
 /*
  * The following are for wayland-os.c and the unit tests.
diff --git a/src/wayland-server.c b/src/wayland-server.c
index 3f48dfe..0146ca8 100644
--- a/src/wayland-server.c
+++ b/src/wayland-server.c
@@ -531,8 +531,18 @@ wl_client_create(struct wl_display *display, int fd)
 
 	len = sizeof client->ucred;
 	if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED,
-		       &client->ucred, &len) < 0)
-		goto err_source;
+		       &client->ucred, &len) < 0 && errno != ENOTSOCK) {
+		if (errno == ENOTSOCK) {
+			// Probably a virtio_wl socket.
+			// We don't have credential information, so fill with
+			// values that should never match real ones.
+			client->ucred.pid = -1;
+			client->ucred.uid = -1;
+			client->ucred.gid = -1;
+		} else {
+			goto err_source;
+		}
+	}
 
 	client->connection = wl_connection_create(fd);
 	if (client->connection == NULL)
@@ -1419,6 +1429,9 @@ socket_data(int fd, uint32_t mask, void *data)
 	length = sizeof name;
 	client_fd = wl_os_accept_cloexec(fd, (struct sockaddr *) &name,
 					 &length);
+	if (client_fd < 0 && errno == ENOTSOCK)
+		client_fd = wl_virtio_accept_cloexec(fd);
+
 	if (client_fd < 0)
 		wl_log("failed to accept: %s\n", strerror(errno));
 	else
@@ -1603,7 +1616,8 @@ wl_display_add_socket_fd(struct wl_display *display, int sock_fd)
 	struct stat buf;
 
 	/* Require a valid fd or fail */
-	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 || !S_ISSOCK(buf.st_mode)) {
+	if (sock_fd < 0 || fstat(sock_fd, &buf) < 0 ||
+	    ((buf.st_mode & S_IFMT) && !S_ISSOCK(buf.st_mode))) {
 		return -1;
 	}
 
diff --git a/src/wayland-shm.c b/src/wayland-shm.c
index b85e5a7..636ee7a 100644
--- a/src/wayland-shm.c
+++ b/src/wayland-shm.c
@@ -61,6 +61,7 @@ struct wl_shm_pool {
 	int internal_refcount;
 	int external_refcount;
 	char *data;
+	int fd;
 	int32_t size;
 	int32_t new_size;
 	bool sigbus_is_impossible;
@@ -91,14 +92,27 @@ shm_pool_finish_resize(struct wl_shm_pool *pool)
 
 	data = mremap(pool->data, pool->size, pool->new_size, MREMAP_MAYMOVE);
 	if (data == MAP_FAILED) {
-		wl_resource_post_error(pool->resource,
-				       WL_SHM_ERROR_INVALID_FD,
-				       "failed mremap");
-		return;
+		if (errno != EFAULT)
+			goto fail;
+
+		if (munmap(pool->data, pool->size) == -1)
+			goto fail;
+
+		data = mmap(pool->data, pool->new_size, PROT_READ | PROT_WRITE,
+			    MAP_SHARED, pool->fd, 0);
+		if (data == MAP_FAILED)
+			goto fail;
 	}
 
 	pool->data = data;
 	pool->size = pool->new_size;
+
+	return;
+
+ fail:
+		wl_resource_post_error(pool->resource,
+				       WL_SHM_ERROR_INVALID_FD,
+				       "failed mremap");
 }
 
 static void
@@ -291,6 +305,7 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 	pool->external_refcount = 0;
 	pool->size = size;
 	pool->new_size = size;
+	pool->fd = fd;
 	pool->data = mmap(NULL, size,
 			  PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
 	if (pool->data == MAP_FAILED) {
@@ -300,7 +315,6 @@ shm_create_pool(struct wl_client *client, struct wl_resource *resource,
 				       strerror(errno));
 		goto err_free;
 	}
-	close(fd);
 
 	pool->resource =
 		wl_resource_create(client, &wl_shm_pool_interface, 1, id);
diff --git a/tests/meson.build b/tests/meson.build
index 224f48d..6ee2b49 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -148,6 +148,7 @@ tests = {
 		'headers-protocol-core-test.c',
 	],
 	'os-wrappers-test': [],
+	'virtio_wl-test': [],
 }
 
 foreach test_name, test_extra_sources: tests
diff --git a/tests/virtio_wl-test.c b/tests/virtio_wl-test.c
new file mode 100644
index 0000000..16c7a93
--- /dev/null
+++ b/tests/virtio_wl-test.c
@@ -0,0 +1,615 @@
+#define _GNU_SOURCE
+#include <assert.h>
+
+#include "test-runner.h"
+#include "virtio_wl.c"
+
+TEST(iov_len_fit)
+{
+	int f1[] = { 1, 2 };
+	int f2[] = { 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == 3 * sizeof(int));
+}
+
+TEST(iov_len_overflow)
+{
+	struct iovec iov[] = {
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+		{ .iov_base = NULL, .iov_len = SIZE_MAX },
+	};
+
+	assert(iov_len(iov, sizeof(iov) / sizeof(*iov)) == -1);
+}
+
+TEST(iov_flatten_test)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof(f1) },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof(f2) },
+	};
+
+	int buf[2];
+	iov_flatten(buf, sizeof(buf), iov, sizeof(iov) / sizeof(*iov));
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+}
+
+TEST(iov_flatten_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_flatten(buf, 1, NULL, 0) == 0);
+}
+
+TEST(iov_flatten_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_flatten(NULL, 0, iov, sizeof(iov) / sizeof(*iov)) == 0);
+}
+
+TEST(iov_flatten_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+	assert(buf[3] == 6);
+}
+
+TEST(iov_flatten_exact_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(f1[0] == 1);
+	assert(f2[0] == 2);
+	assert(f2[1] == 3);
+}
+
+TEST(iov_flatten_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_flatten(buf, sizeof buf, iov, sizeof(iov) / sizeof(*iov)) == 12);
+	assert(buf[0] == 1);
+	assert(buf[1] == 2);
+	assert(buf[2] == 3);
+}
+
+TEST(iov_fill_zero_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	assert(iov_fill(NULL, 0, buf, 1) == 0);
+}
+
+TEST(iov_fill_null_iov)
+{
+	unsigned char buf[] = { 0xFF };
+	struct iovec iov[] = { { .iov_base = NULL, .iov_len = 0 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 0);
+}
+
+TEST(iov_fill_zero_buf)
+{
+	unsigned char f1[] = { 0 };
+	struct iovec iov[] = { { .iov_base = f1, .iov_len = sizeof f1 } };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), NULL, 0) == 0);
+}
+
+TEST(iov_fill_short_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7, 6 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_exact)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+}
+
+TEST(iov_fill_long_iov)
+{
+	int f1[] = { 1 };
+	int f2[] = { 2, 3, 4 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = NULL, .iov_len = 0 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	int buf[] = { 9, 8, 7 };
+
+	assert(iov_fill(iov, sizeof(iov) / sizeof(*iov), buf, sizeof buf) == 12);
+	assert(f1[0] == 9);
+	assert(f2[0] == 8);
+	assert(f2[1] == 7);
+	assert(f2[2] == 4);
+}
+
+TEST(cmsg_to_fdbuf_zero_buf)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(cmsg_to_fdbuf(NULL, 0, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_zero_cmsg)
+{
+	int fds[2] = { 0 };
+	struct msghdr msg = { 0 };
+	assert(cmsg_to_fdbuf(fds, sizeof fds, &msg) == 0);
+}
+
+TEST(cmsg_to_fdbuf_small_buf)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+}
+
+TEST(cmsg_to_fdbuf_exact)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_small_cmsg)
+{
+	int fds[] = { 0 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 1);
+	assert(buf[0] == 0);
+	assert(buf[1] == 0xFE);
+}
+
+TEST(cmsg_to_fdbuf_multiple_cmsg)
+{
+	int f1[] = { 0 };
+	int f2[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof f1) + CMSG_SPACE(sizeof f2)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_RIGHTS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof f1);
+	memcpy(CMSG_DATA(cmsg1), f1, sizeof f1);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof f2);
+	memcpy(CMSG_DATA(cmsg2), f2, sizeof f2);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == 2);
+	assert(buf[0] == 0);
+	assert(buf[1] == 1);
+}
+
+TEST(cmsg_to_fdbuf_other_cmsg)
+{
+	struct ucred cred = {
+		.pid = getpid(),
+		.uid = getuid(),
+		.gid = getgid(),
+	};
+
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof cred) + CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg1 = CMSG_FIRSTHDR(&msg);
+	cmsg1->cmsg_level = SOL_SOCKET;
+	cmsg1->cmsg_type = SCM_CREDENTIALS;
+	cmsg1->cmsg_len = CMSG_LEN(sizeof cred);
+	memcpy(CMSG_DATA(cmsg1), &cred, sizeof cred);
+
+	struct cmsghdr *cmsg2 = CMSG_NXTHDR(&msg, cmsg1);
+	cmsg2->cmsg_level = SOL_SOCKET;
+	cmsg2->cmsg_type = SCM_RIGHTS;
+	cmsg2->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg2), fds, sizeof fds);
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(cmsg_to_fdbuf_misaligned)
+{
+	int fds[] = { 0, 1 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds + 1)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof(fds) + 1);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+	CMSG_DATA(cmsg)[sizeof(fds)] = 0;
+
+	int buf[] = { 0xFF, 0xFE };
+	assert(cmsg_to_fdbuf(buf, sizeof(buf) / sizeof(*buf), &msg) == -1);
+	assert(errno == EINVAL);
+}
+
+TEST(fdbuf_to_cmsg_null_cmsg)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_empty_both)
+{
+	struct msghdr msg = { 0 };
+	int fds[] = { -1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_null_buf)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, NULL, 0) == 0);
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_tiny_msg)
+{
+	union {
+		char buf[1];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1 };
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 0);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_small_msg)
+{
+	union {
+		char buf[CMSG_SPACE(sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	int fds[] = { 1, 2, 3 };
+	assert(CMSG_SPACE(sizeof fds) != sizeof u.buf);
+
+	// CMSG_SPACE(sizeof(int)) can equal CMSG_SPACE(sizeof(int) * 2)
+	size_t n = fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds));
+	assert(CMSG_SPACE(n * sizeof(int)) == sizeof u.buf);
+	assert(msg.msg_flags & MSG_CTRUNC);
+}
+
+TEST(fdbuf_to_cmsg_equal)
+{
+	int fds[] = { 1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_negative_one)
+{
+	int fds[] = { 1, -1, -1, 2 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof(int))];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 2);
+
+	unsigned char *data = CMSG_DATA(CMSG_FIRSTHDR(&msg));
+	assert(!memcmp(data, &fds[0], sizeof(int)));
+	assert(!memcmp(data + sizeof(int), &fds[3], sizeof(int)));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(fdbuf_to_cmsg_small_buf)
+{
+	int fds[] = { 1 };
+
+	union {
+		char buf[CMSG_SPACE(2 * sizeof fds)];
+		struct cmsghdr align;
+	} u = { 0 };
+
+	struct msghdr msg = { 0 };
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	assert(fdbuf_to_cmsg(&msg, fds, sizeof(fds) / sizeof(*fds)) == 1);
+	assert(!memcmp(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, sizeof fds));
+	assert(!(msg.msg_flags & MSG_CTRUNC));
+}
+
+TEST(msghdr_to_txn_test)
+{
+	struct virtwl_ioctl_txn *txn;
+
+	int f1[] = { 1 };
+	int f2[] = { 2, 3 };
+
+	const int fds[] = { 0, 1, 2 };
+
+	struct iovec iov[] = {
+		{ .iov_base = f1, .iov_len = sizeof f1 },
+		{ .iov_base = f2, .iov_len = sizeof f2 },
+	};
+
+	union {
+		char buf[CMSG_SPACE(sizeof fds)];
+		struct cmsghdr align;
+	} u;
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = u.buf;
+	msg.msg_controllen = sizeof u.buf;
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	cmsg->cmsg_level = SOL_SOCKET;
+	cmsg->cmsg_type = SCM_RIGHTS;
+	cmsg->cmsg_len = CMSG_LEN(sizeof fds);
+	memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
+
+	assert(!msghdr_to_txn(&msg, &txn));
+
+	assert(txn->len == sizeof(f1) + sizeof(f2));
+	assert(!memcmp(txn->fds, fds, sizeof fds));
+	for (size_t i = sizeof(fds) / sizeof(*fds);
+	     i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		assert(txn->fds[i] == -1);
+	assert(!memcmp(txn->data, f1, sizeof f1));
+	assert(!memcmp(txn->data + sizeof f1, f2, sizeof f2));
+}
+
+TEST(txn_to_msghdr_test)
+{
+	const int data[] = { 1, 2, 3 };
+
+	size_t len = sizeof(data);
+	struct virtwl_ioctl_txn *txn = malloc(sizeof(*txn) + len);
+	assert(txn);
+
+	txn->len = len;
+
+	const int fds[] = { 0, 1, 2 };
+	memcpy(txn->fds, fds, sizeof fds);
+	for (int i = sizeof(fds) / sizeof(*fds); i < VIRTWL_SEND_MAX_ALLOCS; i++)
+		txn->fds[i] = -1;
+
+	memcpy(txn->data, data, sizeof data);
+
+	int buf1[1], buf2[2];
+
+	struct iovec iov[] = {
+		{ .iov_base = buf1, .iov_len = sizeof buf1 },
+		{ .iov_base = buf2, .iov_len = sizeof buf2 },
+	};
+
+	unsigned char cmsg_buf[VIRTWL_SEND_MAX_ALLOCS];
+
+	struct msghdr msg = { 0 };
+	msg.msg_iov = iov;
+	msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+	msg.msg_control = cmsg_buf;
+	msg.msg_controllen = sizeof cmsg_buf;
+
+	txn_to_msghdr(&msg, txn);
+
+	assert(!memcmp(buf1, data, sizeof buf1));
+	assert(!memcmp(buf2, (unsigned char *)data + sizeof buf1, sizeof buf2));
+
+	struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+	assert(cmsg->cmsg_level == SOL_SOCKET);
+	assert(cmsg->cmsg_type == SCM_RIGHTS);
+	assert(cmsg->cmsg_len == CMSG_LEN(3 * sizeof(int)));
+	for (int i = 0; i < 3; i++)
+		assert(!memcmp(CMSG_DATA(cmsg), fds, sizeof fds));
+
+	assert(!CMSG_NXTHDR(&msg, cmsg));
+}
-- 
2.27.0

^ permalink raw reply	[flat|nested] 5+ messages in thread

* Re: [PATCH wayland v3] Support virtio_wl display sockets
  2020-07-26  6:03 ` [PATCH wayland v3] " Alyssa Ross
@ 2020-07-27 18:37   ` Cole Helbling
  0 siblings, 0 replies; 5+ messages in thread
From: Cole Helbling @ 2020-07-27 18:37 UTC (permalink / raw)
  To: Alyssa Ross, devel; +Cc: Cole Helbling

On Sat Jul 25, 2020 at 11:03 PM PDT, Alyssa Ross wrote:
> This patch adds libvirtio_wl, which exposes reimplementations of the
> sendmsg(2) and recvmsg(2) for virtio_wl socket fds. Tests for as much
> libvirtio_wl functionality as is reasonably possible to test without
> requiring a working virtio_wl connection are included. (Testing
> further would require running tests in a VM so that they could talk to
> the virtio_wl kernel driver, which would be prohibitively complex.
>
> virtio_wl socket fds do not actually point to sockets, but to special
> virtio_wl files. Whenever a display socket operation calls sendmsg()
> or recvmsg() with a file descriptor and receives an ENOTSOCK error, it
> retries the operation with the equivalent libvirtio_wl function, in
> case the file descriptor is a virtio_wl. This is the least invasive
> way to implement virtio_wl support -- if a normal socket is being
> used, there will be no change in behaviour.
>
> Because virtio_wl doesn't implement every socket feature, some
> workarounds are currently required to accomodate everything Wayland
> expects of sockets:
>
> * virtio_wl_recvmsg implements the MSG_DONTWAIT flag by setting
> O_NONBLOCK on the fd, attempting the VIRTWL_IOCTL_RECV operation,
> and then restoring the fd's original flags. This is obviously
> race-prone, but there is no reasonable alternative at present.
>
> * virtio_wl_sendmsg requires MSG_NOSIGNAL to be set, and ignores it.
> This is because I think from looking at the code that virtio_wl does
> not generate SIGPIPE signals, nor does it ever return EPIPE. I
> could be wrong about this, though.
>
> * virtio_wl does not support credential passing -- what would it even
> mean, considering the other end of the connection is on
> another (virtual) machine? So wl_client's ucred member will have
> pid, uid, and gid all set to -1 for a client connected over
> virtio_wl.
>
> * virtio_wl sockets do not support accept(2), so a fallback is used.
> A proxy program on the host accept(2) on a host socket. When it
> receives a connection, it attaches the connection socket to the VM,
> then sends the name of the connected socket over the Wayland display
> socket. Wayland then receives this name, looks up the connection
> socket, and uses that as the client connection socket.
>
> Additionally, virtio_wl memfd-like file descriptors don't support
> mremap(2), so for virtio_wl sockets Wayland will munmap(2) the memfd,
> and then mmap(2) it again. This should be at least mostly okay
> because Wayland only ever calls mremap with MREMAP_MAYMOVE, but it is
> still race-prone. To be able to do this, memfds are no longer closed
> after being mmaped, but are kept around in the wl_shm_pool struct, so
> that they can be passed to mmap() if required.
>
> This patch appears to be enough to reliably run Alacritty on the host
> system,
> with a Wayland compositor running in a VM. It is not capable of running
> Firefox, which most of the time fails to start, and occassionally will
> partially
> start (sometimes even getting far enough to draw a window) before
> freezing.
> ---
> Cole pointed out on IRC that previous versions of this patch
> unnecessarily included string.h and stdio.h in src/wayland.os.c.
> These includes have now been removed.
>
> src/connection.c | 4 +
> src/meson.build | 16 +-
> src/virtio_wl.c | 344 +++++++++++++++++++++++
> src/virtio_wl.h | 23 ++
> src/wayland-os.c | 21 ++
> src/wayland-os.h | 3 +
> src/wayland-server.c | 20 +-
> src/wayland-shm.c | 24 +-
> tests/meson.build | 1 +
> tests/virtio_wl-test.c | 615 +++++++++++++++++++++++++++++++++++++++++
> 10 files changed, 1062 insertions(+), 9 deletions(-)
> create mode 100644 src/virtio_wl.c
> create mode 100644 src/virtio_wl.h
> create mode 100644 tests/virtio_wl-test.c

Style and friends LGTM. As before, I'm not that great at grokking C, so
I'll just hope you know what you are doing. :P

Reviewed-by: Cole Helbling <cole.e.helbling@outlook.com>

^ permalink raw reply	[flat|nested] 5+ messages in thread

end of thread, other threads:[~2020-07-27 18:39 UTC | newest]

Thread overview: 5+ messages (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2020-07-08 15:16 [PATCH wayland] Support virtio_wl display sockets Alyssa Ross
2020-07-08 15:43 ` Cole Helbling
2020-07-08 21:07 ` [PATCH wayland v2] " Alyssa Ross
2020-07-26  6:03 ` [PATCH wayland v3] " Alyssa Ross
2020-07-27 18:37   ` Cole Helbling

patches and low-level development discussion

This inbox may be cloned and mirrored by anyone:

	git clone --mirror https://spectrum-os.org/lists/archives/spectrum-devel/0 spectrum-devel/git/0.git

	# If you have public-inbox 1.1+ installed, you may
	# initialize and index your mirror using the following commands:
	public-inbox-init -V2 spectrum-devel spectrum-devel/ https://spectrum-os.org/lists/archives/spectrum-devel \
		public-inbox+spectrum-devel@spectrum-os.org devel@spectrum-os.org
	public-inbox-index spectrum-devel

Example config snippet for mirrors.
Newsgroups are available over NNTP:
	nntps://spectrum-os.org/inbox.comp.spectrum.devel
	nntp://spectrum-os.org/inbox.comp.spectrum.devel


code repositories for the project(s) associated with this inbox:

	https://spectrum-os.org/git/nixpkgs
	https://spectrum-os.org/git/ucspi-vsock
	https://spectrum-os.org/git/www

AGPL code for this site: https://ftp.qyliss.net/public-inbox/public-inbox-1.6.1-qyliss-yjdmfarc701widk5ma6pcr7jrb69jyb3.tar.gz