blob: 67d163abb7790672d5dba5721162f00177c16130 [file] [log] [blame] [edit]
/**
* Copyright © 2016 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <err.h>
#include <errno.h>
#include <getopt.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <termios.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include "console-server.h"
#include "config.h"
#define EXIT_ESCAPE 2
enum process_rc {
PROCESS_OK = 0,
PROCESS_ERR,
PROCESS_EXIT,
PROCESS_ESC,
};
enum esc_type {
ESC_TYPE_SSH,
ESC_TYPE_STR,
};
struct ssh_esc_state {
uint8_t state;
};
struct str_esc_state {
const uint8_t *str;
size_t pos;
};
struct console_client {
int console_sd;
int fd_in;
int fd_out;
bool is_tty;
struct termios orig_termios;
enum esc_type esc_type;
union {
struct ssh_esc_state ssh;
struct str_esc_state str;
} esc_state;
};
static enum process_rc process_ssh_tty(struct console_client *client,
const uint8_t *buf, size_t len)
{
struct ssh_esc_state *esc_state = &client->esc_state.ssh;
const uint8_t *out_buf = buf;
int rc;
for (size_t i = 0; i < len; ++i) {
switch (buf[i]) {
case '.':
if (esc_state->state != '~') {
esc_state->state = '\0';
break;
}
return PROCESS_ESC;
case '~':
if (esc_state->state != '\r') {
esc_state->state = '\0';
break;
}
esc_state->state = '~';
/* We need to print everything to skip the tilde */
rc = write_buf_to_fd(client->console_sd, out_buf,
i - (out_buf - buf));
if (rc < 0) {
return PROCESS_ERR;
}
out_buf = &buf[i + 1];
break;
case '\r':
esc_state->state = '\r';
break;
default:
esc_state->state = '\0';
}
}
rc = write_buf_to_fd(client->console_sd, out_buf,
len - (out_buf - buf));
return rc < 0 ? PROCESS_ERR : PROCESS_OK;
}
static enum process_rc process_str_tty(struct console_client *client,
const uint8_t *buf, size_t len)
{
struct str_esc_state *esc_state = &client->esc_state.str;
enum process_rc prc = PROCESS_OK;
size_t i;
for (i = 0; i < len; ++i) {
if (buf[i] == esc_state->str[esc_state->pos]) {
esc_state->pos++;
} else {
esc_state->pos = 0;
}
if (esc_state->str[esc_state->pos] == '\0') {
prc = PROCESS_ESC;
++i;
break;
}
}
if (write_buf_to_fd(client->console_sd, buf, i) < 0) {
return PROCESS_ERR;
}
return prc;
}
static enum process_rc process_tty(struct console_client *client)
{
uint8_t buf[4096];
ssize_t len;
len = read(client->fd_in, buf, sizeof(buf));
if (len < 0) {
return PROCESS_ERR;
}
if (len == 0) {
return PROCESS_EXIT;
}
switch (client->esc_type) {
case ESC_TYPE_SSH:
return process_ssh_tty(client, buf, len);
case ESC_TYPE_STR:
return process_str_tty(client, buf, len);
default:
return PROCESS_ERR;
}
}
static int process_console(struct console_client *client)
{
uint8_t buf[4096];
ssize_t len;
int rc;
len = read(client->console_sd, buf, sizeof(buf));
if (len < 0) {
warn("Can't read from server");
return PROCESS_ERR;
}
if (len == 0) {
fprintf(stderr, "Connection closed\n");
return PROCESS_EXIT;
}
rc = write_buf_to_fd(client->fd_out, buf, len);
return rc ? PROCESS_ERR : PROCESS_OK;
}
/*
* Setup our local file descriptors for IO: use stdin/stdout, and if we're on a
* TTY, put it in canonical mode
*/
static int client_tty_init(struct console_client *client)
{
struct termios termios;
int rc;
client->fd_in = STDIN_FILENO;
client->fd_out = STDOUT_FILENO;
client->is_tty = isatty(client->fd_in);
if (!client->is_tty) {
return 0;
}
rc = tcgetattr(client->fd_in, &termios);
if (rc) {
warn("Can't get terminal attributes for console");
return -1;
}
memcpy(&client->orig_termios, &termios, sizeof(client->orig_termios));
cfmakeraw(&termios);
rc = tcsetattr(client->fd_in, TCSANOW, &termios);
if (rc) {
warn("Can't set terminal attributes for console");
return -1;
}
return 0;
}
static int client_init(struct console_client *client, struct config *config,
const char *console_id)
{
const char *resolved_id = NULL;
struct sockaddr_un addr;
socket_path_t path;
ssize_t len;
int rc;
client->console_sd = socket(AF_UNIX, SOCK_STREAM, 0);
if (!client->console_sd) {
warn("Can't open socket");
return -1;
}
/* Get the console id */
resolved_id = config_resolve_console_id(config, console_id);
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
len = console_socket_path(addr.sun_path, resolved_id);
if (len < 0) {
if (errno) {
warn("Failed to configure socket: %s", strerror(errno));
} else {
warn("Socket name length exceeds buffer limits");
}
goto cleanup;
}
rc = connect(client->console_sd, (struct sockaddr *)&addr,
sizeof(addr) - sizeof(addr.sun_path) + len);
if (!rc) {
return 0;
}
console_socket_path_readable(&addr, len, path);
warn("Can't connect to console server '@%s'", path);
cleanup:
close(client->console_sd);
return -1;
}
static void client_fini(struct console_client *client)
{
if (client->is_tty) {
tcsetattr(client->fd_in, TCSANOW, &client->orig_termios);
}
close(client->console_sd);
}
int main(int argc, char *argv[])
{
struct console_client _client;
struct console_client *client;
struct pollfd pollfds[2];
enum process_rc prc = PROCESS_OK;
const char *config_path = NULL;
struct config *config = NULL;
const char *console_id = NULL;
const uint8_t *esc = NULL;
int rc;
client = &_client;
memset(client, 0, sizeof(*client));
client->esc_type = ESC_TYPE_SSH;
for (;;) {
rc = getopt(argc, argv, "c:e:i:");
if (rc == -1) {
break;
}
switch (rc) {
case 'c':
if (optarg[0] == '\0') {
fprintf(stderr, "Config str cannot be empty\n");
return EXIT_FAILURE;
}
config_path = optarg;
break;
case 'e':
if (optarg[0] == '\0') {
fprintf(stderr, "Escape str cannot be empty\n");
return EXIT_FAILURE;
}
esc = (const uint8_t *)optarg;
break;
case 'i':
if (optarg[0] == '\0') {
fprintf(stderr,
"Socket ID str cannot be empty\n");
return EXIT_FAILURE;
}
console_id = optarg;
break;
default:
fprintf(stderr,
"Usage: %s "
"[-e <escape sequence>]"
"[-i <console ID>]"
"[-c <config>]\n",
argv[0]);
return EXIT_FAILURE;
}
}
if (config_path) {
config = config_init(config_path);
if (!config) {
warnx("Can't read configuration, exiting.");
return EXIT_FAILURE;
}
if (!esc) {
esc = (const uint8_t *)config_get_value(
config, "escape-sequence");
}
}
if (esc) {
client->esc_type = ESC_TYPE_STR;
client->esc_state.str.str = esc;
}
rc = client_init(client, config, console_id);
if (rc) {
goto out_config_fini;
}
rc = client_tty_init(client);
if (rc) {
goto out_client_fini;
}
for (;;) {
pollfds[0].fd = client->fd_in;
pollfds[0].events = POLLIN;
pollfds[1].fd = client->console_sd;
pollfds[1].events = POLLIN;
rc = poll(pollfds, 2, -1);
if (rc < 0) {
warn("Poll failure");
break;
}
if (pollfds[0].revents) {
prc = process_tty(client);
}
if (prc == PROCESS_OK && pollfds[1].revents) {
prc = process_console(client);
}
rc = (prc == PROCESS_ERR) ? -1 : 0;
if (prc != PROCESS_OK) {
break;
}
}
out_client_fini:
client_fini(client);
out_config_fini:
if (config_path) {
config_fini(config);
}
if (prc == PROCESS_ESC) {
return EXIT_ESCAPE;
}
return rc ? EXIT_FAILURE : EXIT_SUCCESS;
}