protocol: Add get_flash_info

Change-Id: Iff6c452c1399bc8e8f65219779c6a0f2375c68e6
Signed-off-by: Andrew Jeffery <andrew@aj.id.au>
diff --git a/mbox.h b/mbox.h
index 3676825..5fd074b 100644
--- a/mbox.h
+++ b/mbox.h
@@ -133,7 +133,6 @@
 };
 
 struct mbox_context;
-
 typedef int (*mboxd_mbox_handler)(struct mbox_context *, union mbox_regs *,
 				  struct mbox_msg *);
 
diff --git a/protocol.c b/protocol.c
index b0e9e23..066c5b2 100644
--- a/protocol.c
+++ b/protocol.c
@@ -53,6 +53,15 @@
 	return lpc_map_memory(context);
 }
 
+int protocol_v1_get_flash_info(struct mbox_context *context,
+			       struct protocol_get_flash_info *io)
+{
+	io->resp.v1.flash_size = context->flash_size;
+	io->resp.v1.erase_size = context->mtd_info.erasesize;
+
+	return 0;
+}
+
 /*
  * get_suggested_timeout() - get the suggested timeout value in seconds
  * @context:	The mbox context pointer
@@ -106,14 +115,27 @@
 	return lpc_map_memory(context);
 }
 
+int protocol_v2_get_flash_info(struct mbox_context *context,
+			       struct protocol_get_flash_info *io)
+{
+	io->resp.v2.flash_size =
+		context->flash_size >> context->block_size_shift;
+	io->resp.v2.erase_size =
+		context->mtd_info.erasesize >> context->block_size_shift;
+
+	return 0;
+}
+
 static const struct protocol_ops protocol_ops_v1 = {
 	.reset = protocol_v1_reset,
 	.get_info = protocol_v1_get_info,
+	.get_flash_info = protocol_v1_get_flash_info,
 };
 
 static const struct protocol_ops protocol_ops_v2 = {
 	.reset = protocol_v1_reset,
 	.get_info = protocol_v2_get_info,
+	.get_flash_info = protocol_v2_get_flash_info,
 };
 
 static const struct protocol_ops *protocol_ops_map[] = {
diff --git a/protocol.h b/protocol.h
index d7827c0..9503812 100644
--- a/protocol.h
+++ b/protocol.h
@@ -29,10 +29,27 @@
 	} resp;
 };
 
+struct protocol_get_flash_info {
+	struct {
+		union {
+			struct {
+				uint32_t flash_size;
+				uint32_t erase_size;
+			} v1;
+			struct {
+				uint16_t flash_size;
+				uint16_t erase_size;
+			} v2;
+		};
+	} resp;
+};
+
 struct protocol_ops {
 	int (*reset)(struct mbox_context *context);
 	int (*get_info)(struct mbox_context *context,
 			struct protocol_get_info *io);
+	int (*get_flash_info)(struct mbox_context *context,
+			      struct protocol_get_flash_info *io);
 };
 
 int protocol_init(struct mbox_context *context);
@@ -44,9 +61,13 @@
 int protocol_v1_reset(struct mbox_context *context);
 int protocol_v1_get_info(struct mbox_context *context,
 			 struct protocol_get_info *io);
+int protocol_v1_get_flash_info(struct mbox_context *context,
+			       struct protocol_get_flash_info *io);
 
 /* Protocol v2 */
 int protocol_v2_get_info(struct mbox_context *context,
 			 struct protocol_get_info *io);
+int protocol_v2_get_flash_info(struct mbox_context *context,
+			       struct protocol_get_flash_info *io);
 
 #endif /* PROTOCOL_H */
diff --git a/transport_mbox.c b/transport_mbox.c
index be2eba9..873f952 100644
--- a/transport_mbox.c
+++ b/transport_mbox.c
@@ -238,19 +238,24 @@
 int mbox_handle_flash_info(struct mbox_context *context,
 				  union mbox_regs *req, struct mbox_msg *resp)
 {
+	struct protocol_get_flash_info io;
+	int rc;
+
+	rc = context->protocol->get_flash_info(context, &io);
+	if (rc < 0) {
+		return mbox_xlate_errno(context, rc);
+	}
+
 	switch (context->version) {
 	case API_VERSION_1:
 		/* Both Sizes in Bytes */
-		put_u32(&resp->args[0], context->flash_size);
-		put_u32(&resp->args[4], context->mtd_info.erasesize);
+		put_u32(&resp->args[0], io.resp.v1.flash_size);
+		put_u32(&resp->args[4], io.resp.v1.erase_size);
 		break;
 	case API_VERSION_2:
 		/* Both Sizes in Block Size */
-		put_u16(&resp->args[0],
-			context->flash_size >> context->block_size_shift);
-		put_u16(&resp->args[2],
-			context->mtd_info.erasesize >>
-					context->block_size_shift);
+		put_u16(&resp->args[0], io.resp.v2.flash_size);
+		put_u16(&resp->args[2], io.resp.v2.erase_size);
 		break;
 	default:
 		MSG_ERR("API Version Not Valid - Invalid System State\n");