#!/usr/bin/python3

# SPDX-License-Identifier: Apache-2.0
# Copyright 2019 IBM Corp.

from argparse import ArgumentParser
from itertools import islice, cycle
from collections import namedtuple
from enum import Enum
from scapy.all import rdpcap
import struct
import json
import sys

RawMessage = namedtuple("RawMessage", "endian, header, data")
FixedHeader = namedtuple("FixedHeader", "endian, type, flags, version, length, cookie")
CookedHeader = namedtuple("CookedHeader", "fixed, fields")
CookedMessage = namedtuple("CookedMessage", "header, body")
TypeProperty = namedtuple("TypeProperty", "field, type, nature")
TypeContainer = namedtuple("TypeContainer", "type, members")
Field = namedtuple("Field", "type, data")

class MessageEndian(Enum):
    LITTLE = ord('l')
    BIG = ord('B')

StructEndianLookup = {
    MessageEndian.LITTLE.value : "<",
    MessageEndian.BIG.value : ">"
}

class MessageType(Enum):
    INVALID = 0
    METHOD_CALL = 1
    METHOD_RETURN = 2
    ERROR = 3
    SIGNAL = 4

class MessageFlags(Enum):
    NO_REPLY_EXPECTED = 0x01
    NO_AUTO_START = 0x02
    ALLOW_INTERACTIVE_AUTHORIZATION = 0x04

class MessageFieldType(Enum):
    INVALID = 0
    PATH = 1
    INTERFACE = 2
    MEMBER = 3
    ERROR_NAME = 4
    REPLY_SERIAL = 5
    DESTINATION = 6
    SENDER = 7
    SIGNATURE = 8
    UNIX_FDS = 9

class DBusType(Enum):
    INVALID = 0
    BYTE = ord('y')
    BOOLEAN = ord('b')
    INT16 = ord('n')
    UINT16 = ord('q')
    INT32 = ord('i')
    UINT32 = ord('u')
    INT64 = ord('x')
    UINT64 = ord('t')
    DOUBLE = ord('d')
    STRING = ord('s')
    OBJECT_PATH = ord('o')
    SIGNATURE = ord('g')
    ARRAY = ord('a')
    STRUCT = ord('(')
    VARIANT = ord('v')
    DICT_ENTRY = ord('{')
    UNIX_FD = ord('h')

DBusContainerTerminatorLookup = {
    chr(DBusType.STRUCT.value) : ')',
    chr(DBusType.DICT_ENTRY.value) : '}',
}

class DBusTypeCategory(Enum):
    FIXED = {
        DBusType.BYTE.value,
        DBusType.BOOLEAN.value,
        DBusType.INT16.value,
        DBusType.UINT16.value,
        DBusType.INT32.value,
        DBusType.UINT32.value,
        DBusType.INT64.value,
        DBusType.UINT64.value,
        DBusType.DOUBLE.value,
        DBusType.UNIX_FD.value
    }
    STRING = {
        DBusType.STRING.value,
        DBusType.OBJECT_PATH.value,
        DBusType.SIGNATURE.value,
    }
    CONTAINER = {
        DBusType.ARRAY.value,
        DBusType.STRUCT.value,
        DBusType.VARIANT.value,
        DBusType.DICT_ENTRY.value,
    }
    RESERVED = {
        DBusType.INVALID.value,
    }

TypePropertyLookup = {
    DBusType.BYTE.value : TypeProperty(DBusType.BYTE, 'B', 1),
    # DBus booleans are 32 bit, with only the LSB used. Extract as 'I'.
    DBusType.BOOLEAN.value : TypeProperty(DBusType.BOOLEAN, 'I', 4),
    DBusType.INT16.value : TypeProperty(DBusType.INT16, 'h', 2),
    DBusType.UINT16.value : TypeProperty(DBusType.UINT16, 'H', 2),
    DBusType.INT32.value : TypeProperty(DBusType.INT32, 'i', 4),
    DBusType.UINT32.value : TypeProperty(DBusType.UINT32, 'I', 4),
    DBusType.INT64.value : TypeProperty(DBusType.INT64, 'q', 8),
    DBusType.UINT64.value : TypeProperty(DBusType.UINT64, 'Q', 8),
    DBusType.DOUBLE.value : TypeProperty(DBusType.DOUBLE, 'd', 8),
    DBusType.STRING.value : TypeProperty(DBusType.STRING, 's', DBusType.UINT32),
    DBusType.OBJECT_PATH.value : TypeProperty(DBusType.OBJECT_PATH, 's', DBusType.UINT32),
    DBusType.SIGNATURE.value : TypeProperty(DBusType.SIGNATURE, 's', DBusType.BYTE),
    DBusType.ARRAY.value : TypeProperty(DBusType.ARRAY, None, DBusType.UINT32),
    DBusType.STRUCT.value : TypeProperty(DBusType.STRUCT, None, 8),
    DBusType.VARIANT.value : TypeProperty(DBusType.VARIANT, None, 1),
    DBusType.DICT_ENTRY.value : TypeProperty(DBusType.DICT_ENTRY, None, 8),
    DBusType.UNIX_FD.value : TypeProperty(DBusType.UINT32, None, 8),
}

def parse_signature(sigstream):
    sig = ord(next(sigstream))
    assert sig not in DBusTypeCategory.RESERVED.value
    if sig in DBusTypeCategory.FIXED.value:
        ty = TypePropertyLookup[sig].field, None
    elif sig in DBusTypeCategory.STRING.value:
        ty = TypePropertyLookup[sig].field, None
    elif sig in DBusTypeCategory.CONTAINER.value:
        if sig == DBusType.ARRAY.value:
            ty = DBusType.ARRAY, parse_signature(sigstream)
        elif sig == DBusType.STRUCT.value or sig == DBusType.DICT_ENTRY.value:
            collected = list()
            ty = parse_signature(sigstream)
            while ty is not StopIteration:
                collected.append(ty)
                ty = parse_signature(sigstream)
            ty = DBusType.STRUCT, collected
        elif sig == DBusType.VARIANT.value:
            ty = TypePropertyLookup[sig].field, None
        else:
            assert False
    else:
        assert chr(sig) in DBusContainerTerminatorLookup.values()
        return StopIteration

    return TypeContainer._make(ty)

class AlignedStream(object):
    def __init__(self, buf, offset=0):
        self.stash = (buf, offset)
        self.stream = iter(buf)
        self.offset = offset

    def align(self, tc):
        assert tc.type.value in TypePropertyLookup
        prop = TypePropertyLookup[tc.type.value]
        if prop.field.value in DBusTypeCategory.STRING.value:
            prop = TypePropertyLookup[prop.nature.value]
        if prop.nature == DBusType.UINT32:
            prop = TypePropertyLookup[prop.nature.value]
        advance = (prop.nature - (self.offset & (prop.nature - 1))) % prop.nature
        _ = bytes(islice(self.stream, advance))
        self.offset += advance

    def take(self, size):
        val = islice(self.stream, size)
        self.offset += size
        return val

    def autotake(self, tc):
        assert tc.type.value in DBusTypeCategory.FIXED.value
        assert tc.type.value in TypePropertyLookup
        self.align(tc)
        prop = TypePropertyLookup[tc.type.value]
        return self.take(prop.nature)

    def drain(self):
        remaining = bytes(self.stream)
        offset = self.offset
        self.offset += len(remaining)
        if self.offset - self.stash[1] != len(self.stash[0]):
            print("(self.offset - self.stash[1]): %d, len(self.stash[0]): %d"
                % (self.offset - self.stash[1], len(self.stash[0])), file=sys.stderr)
            raise MalformedPacketError
        return remaining, offset

    def dump(self):
        print("AlignedStream: absolute offset: {}".format(self.offset), file=sys.stderr)
        print("AlignedStream: relative offset: {}".format(self.offset - self.stash[1]),
                file=sys.stderr)
        print("AlignedStream: remaining buffer:\n{}".format(self.drain()[0]), file=sys.stderr)
        print("AlignedStream: provided buffer:\n{}".format(self.stash[0]), file=sys.stderr)

    def dump_assert(self, condition):
        if condition:
            return
        self.dump()
        assert condition

def parse_fixed(endian, stream, tc):
    assert tc.type.value in TypePropertyLookup
    prop = TypePropertyLookup[tc.type.value]
    val = bytes(stream.autotake(tc))
    try:
        val = struct.unpack("{}{}".format(endian, prop.type), val)[0]
        return bool(val) if prop.type == DBusType.BOOLEAN else val
    except struct.error as e:
        print(e, file=sys.stderr)
        print("parse_fixed: Error unpacking {}".format(val), file=sys.stderr)
        print("parse_fixed: Attempting to unpack type {} with properties {}".format(tc.type, prop),
              file=sys.stderr)
        stream.dump_assert(False)

def parse_string(endian, stream, tc):
    assert tc.type.value in TypePropertyLookup
    prop = TypePropertyLookup[tc.type.value]
    size = parse_fixed(endian, stream, TypeContainer(prop.nature, None))
    # Empty DBus strings have no NUL-terminator
    if size == 0:
        return ""
    # stream.dump_assert(size > 0)
    val = bytes(stream.take(size + 1))
    try:
        stream.dump_assert(len(val) == size + 1)
        try:
            return struct.unpack("{}{}".format(size, prop.type), val[:size])[0].decode()
        except struct.error as e:
            stream.dump()
            raise AssertionError(e)
    except AssertionError as e:
        print("parse_string: Error unpacking string of length {} from {}".format(size, val),
                file=sys.stderr)
        raise e

def parse_type(endian, stream, tc):
    if tc.type.value in DBusTypeCategory.FIXED.value:
        val = parse_fixed(endian, stream, tc)
    elif tc.type.value in DBusTypeCategory.STRING.value:
        val = parse_string(endian, stream, tc)
    elif tc.type.value in DBusTypeCategory.CONTAINER.value:
        val = parse_container(endian, stream, tc)
    else:
        stream.dump_assert(False)

    return val

def parse_array(endian, stream, tc):
    arr = list()
    length = parse_fixed(endian, stream, TypeContainer(DBusType.UINT32, None))
    stream.align(tc)
    offset = stream.offset
    while (stream.offset - offset) < length:
        elem = parse_type(endian, stream, tc)
        arr.append(elem)
        if (stream.offset - offset) < length:
            stream.align(tc)
    return arr

def parse_struct(endian, stream, tcs):
    arr = list()
    stream.align(TypeContainer(DBusType.STRUCT, None))
    for tc in tcs:
        arr.append(parse_type(endian, stream, tc))
    return arr

def parse_variant(endian, stream):
    sig = parse_string(endian, stream, TypeContainer(DBusType.SIGNATURE, None))
    tc = parse_signature(iter(sig))
    return parse_type(endian, stream, tc)

def parse_container(endian, stream, tc):
    if tc.type == DBusType.ARRAY:
        return parse_array(endian, stream, tc.members)
    elif tc.type in (DBusType.STRUCT, DBusType.DICT_ENTRY):
        return parse_struct(endian, stream, tc.members)
    elif tc.type == DBusType.VARIANT:
        return parse_variant(endian, stream)
    else:
        stream.dump_assert(False)

def parse_fields(endian, stream):
    sig = parse_signature(iter("a(yv)"))
    fields = parse_container(endian, stream, sig)
    # The header ends after its alignment padding to an 8-boundary.
    # https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-messages
    stream.align(TypeContainer(DBusType.STRUCT, None))
    return list(map(lambda v: Field(MessageFieldType(v[0]), v[1]), fields))

class MalformedPacketError(Exception):
    pass

def parse_header(raw, ignore_error):
    assert raw.endian in StructEndianLookup.keys()
    endian = StructEndianLookup[raw.endian]
    fixed = FixedHeader._make(struct.unpack("{}BBBBLL".format(endian), raw.header))
    astream = AlignedStream(raw.data, len(raw.header))
    fields = parse_fields(endian, astream)
    data, offset = astream.drain()
    if ignore_error == False and fixed.length > len(data):
        raise MalformedPacketError
    return CookedHeader(fixed, fields), AlignedStream(data, offset)

def parse_body(header, stream):
    assert header.fixed.endian in StructEndianLookup
    endian = StructEndianLookup[header.fixed.endian]
    body = list()
    for field in header.fields:
        if field.type == MessageFieldType.SIGNATURE:
            sigstream = iter(field.data)
            try:
                while True:
                    tc = parse_signature(sigstream)
                    val = parse_type(endian, stream, tc)
                    body.append(val)
            except StopIteration:
                pass
            break
    return body

def parse_message(raw):
    try:
        header, data = parse_header(raw, False)
        try:
            body = parse_body(header, data)
            return CookedMessage(header, body)
        except AssertionError as e:
            print(header, file=sys.stderr)
            raise e
    except AssertionError as e:
        print(raw, file=sys.stderr)
        raise e

def parse_packet(packet):
    data = bytes(packet)
    raw = RawMessage(data[0], data[:12], data[12:])
    try:
        msg = parse_message(raw)
    except MalformedPacketError as e:
        print("Got malformed packet: {}".format(raw), file=sys.stderr)
        # For a message that is so large that its payload data could not be parsed,
        # just parse its header, then set its data field to empty.
        header, data = parse_header(raw, True)
        msg = CookedMessage(header, [])
    return msg

CallEnvelope = namedtuple("CallEnvelope", "cookie, origin")
def parse_session(session, matchers, track_calls):
    calls = set()
    for packet in session:
        try:
            cooked = parse_packet(packet)
            if not matchers:
                yield packet.time, cooked
            elif any(all(r(cooked) for r in m) for m in matchers):
                if cooked.header.fixed.type == MessageType.METHOD_CALL.value:
                    s = [f for f in cooked.header.fields
                            if f.type == MessageFieldType.SENDER][0]
                    calls.add(CallEnvelope(cooked.header.fixed.cookie, s.data))
                yield packet.time, cooked
            elif track_calls:
                if cooked.header.fixed.type != MessageType.METHOD_RETURN.value:
                    continue
                rs = [f for f in cooked.header.fields
                        if f.type == MessageFieldType.REPLY_SERIAL][0]
                d = [f for f in cooked.header.fields
                        if f.type == MessageFieldType.DESTINATION][0]
                ce = CallEnvelope(rs.data, d.data)
                if ce in calls:
                    calls.remove(ce)
                    yield packet.time, cooked
        except MalformedPacketError as e:
            pass

def gen_match_type(rule):
    mt = MessageType.__members__[rule.value.upper()]
    return lambda p: p.header.fixed.type == mt.value

def gen_match_sender(rule):
    mf = Field(MessageFieldType.SENDER, rule.value)
    return lambda p: any(f == mf for f in p.header.fields)

def gen_match_interface(rule):
    mf = Field(MessageFieldType.INTERFACE, rule.value)
    return lambda p: any(f == mf for f in p.header.fields)

def gen_match_member(rule):
    mf = Field(MessageFieldType.MEMBER, rule.value)
    return lambda p: any(f == mf for f in p.header.fields)

def gen_match_path(rule):
    mf = Field(MessageFieldType.PATH, rule.value)
    return lambda p: any(f == mf for f in p.header.fields)

def gen_match_destination(rule):
    mf = Field(MessageFieldType.DESTINATION, rule.value)
    return lambda p: any(f == mf for f in p.header.fields)

ValidMatchKeys = {
        "type", "sender", "interface", "member", "path", "destination"
}
MatchRule = namedtuple("MatchExpression", "key, value")

# https://dbus.freedesktop.org/doc/dbus-specification.html#message-bus-routing-match-rules
def parse_match_rules(exprs):
    matchers = list()
    for mexpr in exprs:
        rules = list()
        for rexpr in mexpr.split(","):
            rule = MatchRule._make(map(lambda s: str.strip(s, "'"), rexpr.split("=")))
            assert rule.key in ValidMatchKeys, "Invalid expression: %" % rule
            rules.append(globals()["gen_match_{}".format(rule.key)](rule))
        matchers.append(rules)
    return matchers

def packetconv(obj):
    if isinstance(obj, Enum):
        return obj.value
    raise TypeError

def main():
    parser = ArgumentParser()
    parser.add_argument("--json", action="store_true",
            help="Emit a JSON representation of the messages")
    parser.add_argument("--no-track-calls", action="store_true", default=False,
            help="Make a call response pass filters")
    parser.add_argument("file", help="The pcap file")
    parser.add_argument("expressions", nargs="*",
            help="DBus message match expressions")
    args = parser.parse_args()
    stream = rdpcap(args.file)
    matchers = parse_match_rules(args.expressions)
    try:
        if args.json:
            for (_, msg) in parse_session(stream, matchers, not args.no_track_calls):
                print("{}".format(json.dumps(msg, default=packetconv)))
        else:
            for (time, msg) in parse_session(stream, matchers, not args.no_track_calls):
                print("{}: {}".format(time, msg))
                print()
    except BrokenPipeError:
        pass

if __name__ == "__main__":
    main()
