serial: Give write callbacks a consistent behaviour

Require that the write callbacks return either the number of bytes
written or a negative error code. From there, ensure the return value
behaviour is the same for the fd and custom handler paths.

Signed-off-by: Andrew Jeffery <andrew@aj.id.au>
Change-Id: Id9b90cf4e5c5815dc6385c3d493e2bbbd8e47616
diff --git a/serial.c b/serial.c
index 312744b..5e37507 100644
--- a/serial.c
+++ b/serial.c
@@ -1,6 +1,7 @@
 /* SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later */
 
 #include <assert.h>
+#include <errno.h>
 #include <stdbool.h>
 #include <stdlib.h>
 #include <string.h>
@@ -22,18 +23,36 @@
 
 #define pr_fmt(x) "serial: " x
 
-/* Post-condition: All bytes written or an error has occurred */
-#define mctp_write_all(fn, dst, src, len)				\
-({									\
-	ssize_t wrote;							\
-	while (len) {							\
-		wrote = fn(dst, src, len);				\
-		if (wrote < 0)						\
-			break;						\
-		len -= wrote;						\
-	}								\
-	len ? -1 : 0;							\
-})
+/*
+ * @fn: A function that will copy data from the buffer at src into the dst object
+ * @dst: An opaque object to pass as state to fn
+ * @src: A pointer to the buffer of data to copy to dst
+ * @len: The length of the data pointed to by src
+ * @return: 0 on succes, negative error code on failure
+ *
+ * Pre-condition: fn returns a write count or a negative error code
+ * Post-condition: All bytes written or an error has occurred
+ */
+#define mctp_write_all(fn, dst, src, len)                                      \
+	({                                                                     \
+		typeof(src) __src = src;                                       \
+		ssize_t wrote;                                                 \
+		while (len) {                                                  \
+			wrote = fn(dst, __src, len);                           \
+			if (wrote < 0)                                         \
+				break;                                         \
+			__src += wrote;                                        \
+			len -= wrote;                                          \
+		}                                                              \
+		len ? wrote : 0;                                               \
+	})
+
+static ssize_t mctp_serial_write(int fildes, const void *buf, size_t nbyte)
+{
+	ssize_t wrote;
+
+	return ((wrote = write(fildes, buf, nbyte)) < 0) ? -errno : wrote;
+}
 
 #include "libmctp.h"
 #include "libmctp-alloc.h"
@@ -150,9 +169,10 @@
 	len += sizeof(*hdr) + sizeof(*tlr);
 
 	if (!serial->tx_fn)
-		return mctp_write_all(write, serial->fd, serial->txbuf, len);
+		return mctp_write_all(mctp_serial_write, serial->fd,
+				      &serial->txbuf[0], len);
 
-	return mctp_write_all(serial->tx_fn, serial->tx_fn_data, serial->txbuf,
+	return mctp_write_all(serial->tx_fn, serial->tx_fn_data, &serial->txbuf[0],
 			      len);
 }
 
diff --git a/tests/test_serial.c b/tests/test_serial.c
index aa282c9..5bd266f 100644
--- a/tests/test_serial.c
+++ b/tests/test_serial.c
@@ -15,6 +15,7 @@
 #endif
 
 #include <assert.h>
+#include <errno.h>
 #include <fcntl.h>
 #include <stdbool.h>
 #include <stdio.h>