serial: Give write callbacks a consistent behaviour
Require that the write callbacks return either the number of bytes
written or a negative error code. From there, ensure the return value
behaviour is the same for the fd and custom handler paths.
Signed-off-by: Andrew Jeffery <andrew@aj.id.au>
Change-Id: Id9b90cf4e5c5815dc6385c3d493e2bbbd8e47616
diff --git a/serial.c b/serial.c
index 312744b..5e37507 100644
--- a/serial.c
+++ b/serial.c
@@ -1,6 +1,7 @@
/* SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later */
#include <assert.h>
+#include <errno.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
@@ -22,18 +23,36 @@
#define pr_fmt(x) "serial: " x
-/* Post-condition: All bytes written or an error has occurred */
-#define mctp_write_all(fn, dst, src, len) \
-({ \
- ssize_t wrote; \
- while (len) { \
- wrote = fn(dst, src, len); \
- if (wrote < 0) \
- break; \
- len -= wrote; \
- } \
- len ? -1 : 0; \
-})
+/*
+ * @fn: A function that will copy data from the buffer at src into the dst object
+ * @dst: An opaque object to pass as state to fn
+ * @src: A pointer to the buffer of data to copy to dst
+ * @len: The length of the data pointed to by src
+ * @return: 0 on succes, negative error code on failure
+ *
+ * Pre-condition: fn returns a write count or a negative error code
+ * Post-condition: All bytes written or an error has occurred
+ */
+#define mctp_write_all(fn, dst, src, len) \
+ ({ \
+ typeof(src) __src = src; \
+ ssize_t wrote; \
+ while (len) { \
+ wrote = fn(dst, __src, len); \
+ if (wrote < 0) \
+ break; \
+ __src += wrote; \
+ len -= wrote; \
+ } \
+ len ? wrote : 0; \
+ })
+
+static ssize_t mctp_serial_write(int fildes, const void *buf, size_t nbyte)
+{
+ ssize_t wrote;
+
+ return ((wrote = write(fildes, buf, nbyte)) < 0) ? -errno : wrote;
+}
#include "libmctp.h"
#include "libmctp-alloc.h"
@@ -150,9 +169,10 @@
len += sizeof(*hdr) + sizeof(*tlr);
if (!serial->tx_fn)
- return mctp_write_all(write, serial->fd, serial->txbuf, len);
+ return mctp_write_all(mctp_serial_write, serial->fd,
+ &serial->txbuf[0], len);
- return mctp_write_all(serial->tx_fn, serial->tx_fn_data, serial->txbuf,
+ return mctp_write_all(serial->tx_fn, serial->tx_fn_data, &serial->txbuf[0],
len);
}
diff --git a/tests/test_serial.c b/tests/test_serial.c
index aa282c9..5bd266f 100644
--- a/tests/test_serial.c
+++ b/tests/test_serial.c
@@ -15,6 +15,7 @@
#endif
#include <assert.h>
+#include <errno.h>
#include <fcntl.h>
#include <stdbool.h>
#include <stdio.h>