Source code for gtirb.byteinterval

import itertools
import typing
from uuid import UUID

from intervaltree import IntervalTree
from sortedcontainers import SortedDict

from .block import ByteBlock, CodeBlock, DataBlock
from .node import Node, _NodeMessage
from .proto import ByteInterval_pb2, SymbolicExpression_pb2
from .symbolicexpression import SymAddrAddr, SymAddrConst, SymbolicExpression
from .util import (
    DictLike,
    DictWrapper,
    SetWrapper,
    _IndexedAttribute,
    _nodes_at_interval_tree,
    _nodes_on_interval_tree,
    _offset_interval,
    get_desired_range,
)

if typing.TYPE_CHECKING:  # pragma: no cover
    # Ignore flake8 "imported but unused" errors.
    from .ir import IR  # noqa: F401
    from .module import Module  # noqa: F401
    from .section import Section  # noqa: F401


SymbolicExpressionElement = typing.Tuple[
    "ByteInterval", int, SymbolicExpression
]


[docs]class ByteInterval(Node): """A contiguous region of bytes in a binary. A ByteInterval defines a relative ordering for a group of :class:`ByteBlock`\\s, optionally at a fixed address in memory. It also stores the bytes associated with these blocks. If two blocks are in two different ByteIntervals, then it should be considered safe (that is, preserving of program semantics) to move one block relative to the other in memory. If two blocks are in the same ByteInterval, then it should be considered unknown if moving the two blocks relative to one another in memory is a safe operation. :ivar ~.address: The fixed address of this interval, if present. If this field is present, it may indicate the original address at which this interval was located at in memory, or it may indicate that this block's address is fixed and must not be changed. If this field is not present, it indicates that the interval is free to be moved around in memory while preserving program semantics. :ivar ~.size: The size of this interval in bytes. If this number is greater than ``initialized_size``, this indicates that the high addresses taken up by this interval consist of uninitialized bytes. This often occurs in BSS sections, where data is zero-initialized rather than stored as zeroes in the binary. :ivar ~.contents: The bytes stored in this interval. :ivar ~.blocks: A set of all :class:`ByteBlock`\\s in this interval. :ivar ~.symbolic_expressions: A mapping, from offset in the interval, to a :class:`SymbolicExpression` in the interval. """ class _BlockSet(SetWrapper[ByteBlock]): def __init__( self, node: "ByteInterval", *args: typing.Iterable[ByteBlock] ): super().__init__() self._node = node self.update(*args) def add(self, v: ByteBlock) -> None: # We're defining add in terms of update so that we can optimize # adding multiple blocks at once. self.update((v,)) def update(self, *iterables: typing.Iterable[ByteBlock]) -> None: node_ir = self._node.ir new_items = set(*iterables) - self._data for v in new_items: if v._byte_interval is not None: v._byte_interval.blocks.discard(v) v._byte_interval = self._node if node_ir is not None: v._add_to_uuid_cache(node_ir._local_uuid_cache) self._node._index_add_multiple(self._data, new_items) self._data.update(new_items) def discard(self, v: ByteBlock) -> None: if v not in self: return self._node._index_discard(v) v._byte_interval = None if self._node.ir is not None: v._remove_from_uuid_cache(self._node.ir._local_uuid_cache) return super().discard(v) class _SymbolicExprDict(DictWrapper[int, SymbolicExpression]): def __init__( self, interval: "ByteInterval", *args: DictLike[int, SymbolicExpression], ): # Do not call the DictWrapper constructor: _SymbolicExprDict needs # to store data in a SortedDict, not a regular dict. self._interval = interval self._data: "SortedDict[int, SymbolicExpression]" = SortedDict() temp: typing.Dict[int, SymbolicExpression] = dict(*args) for i, v in temp.items(): self[i] = v def __repr__(self) -> str: # We can't just return the repr of self._data because it will # create a SortedDict and that isn't part of our public interface. items = ( "{!r}: {!r}".format(key, value) for key, value in self._data.items() ) return "{" + ", ".join(items) + "}" address = _IndexedAttribute[typing.Optional[int]]()( lambda self: self.section ) size = _IndexedAttribute[int]()(lambda self: self.section)
[docs] def __init__( self, *, address: typing.Optional[int] = None, size: typing.Optional[int] = None, initialized_size: typing.Optional[int] = None, contents: typing.ByteString = b"", blocks: typing.Iterable[ByteBlock] = (), symbolic_expressions: DictLike[int, SymbolicExpression] = {}, uuid: typing.Optional[UUID] = None, section: typing.Optional["Section"] = None, ): """ :param address: The fixed address of this interval, if present. :param size: The size of this interval in bytes. :param initialized_size: The number of initialized bytes in this interval. :param contents: The bytes stored in this interval. :param blocks: A set of all :class:`ByteBlock`\\s in this interval. :param symbolic_expressions: A mapping, from offset in the interval, to a :class:`SymbolicExpression` in the interval. :param uuid: The UUID of this ``ByteInterval``, or None if a new UUID needs generated via :func:`uuid.uuid4`. Defaults to None. :param section: The :class:`Section` this interval belongs to. """ if size is None: size = len(contents) if initialized_size is None: initialized_size = len(contents) if initialized_size > size: raise ValueError("initialized_size must be <= size!") super().__init__(uuid=uuid) self._interval_tree: "IntervalTree[int, ByteBlock]" = IntervalTree() self._section: typing.Optional["Section"] = None self.address = address self.size = size self.contents = bytearray(contents) self.initialized_size = initialized_size self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet( self, blocks ) self._symbolic_expressions = ByteInterval._SymbolicExprDict( self, symbolic_expressions ) self._proto_interval: typing.Optional[ ByteInterval_pb2.ByteInterval ] = None # Use the property setter to ensure correct invariants. self.section = section
def _index_add_multiple( self, old_blocks: typing.Collection[ByteBlock], new_blocks: typing.Collection[ByteBlock], ) -> None: if len(old_blocks) < len(new_blocks): self._interval_tree = IntervalTree( _offset_interval(block) for block in itertools.chain(old_blocks, new_blocks) ) else: for block in new_blocks: self._index_add(block) def _index_add(self, block: ByteBlock) -> None: self._interval_tree.add(_offset_interval(block)) def _index_discard(self, block: ByteBlock) -> None: self._interval_tree.discard(_offset_interval(block)) @property def initialized_size(self) -> int: """The number of initialized bytes in this interval. Not all bytes in this interval may correspond to bytes physically stored in the underlying file format. This can occur, for example, in BSS sections, which are zero-initialized at loadtime, but these zeroes are not stored in the file itself. If this number is smaller than ``size``, this indicates that any bytes past this number are unitialized bytes with values determined at loadtime. As such, all bytes past this number in this interval's byte vector are truncated when saving to file. """ return len(self.contents) @initialized_size.setter def initialized_size(self, value: int) -> None: if value > len(self.contents): self.contents += b"\0" * (value - len(self.contents)) elif value < len(self.contents): self.contents = self.contents[:value] @classmethod def _decode_protobuf( cls, proto_interval: _NodeMessage, uuid: UUID, ir: typing.Optional["IR"], ) -> "ByteInterval": assert ir assert isinstance(proto_interval, ByteInterval_pb2.ByteInterval) def decode_block(proto_block: ByteInterval_pb2.Block) -> ByteBlock: block: ByteBlock if proto_block.HasField("code"): block = CodeBlock._from_protobuf(proto_block.code, ir) elif proto_block.HasField("data"): block = DataBlock._from_protobuf(proto_block.data, ir) else: raise TypeError( "Unknown type inside proto block: %s" % proto_block.WhichOneof("value") ) block.offset = proto_block.offset return block # we do not decode symbolic expressions yet, because symbols have # not yet been decoded at this point. result = cls( address=proto_interval.address if proto_interval.has_address else None, size=proto_interval.size, contents=proto_interval.contents, uuid=uuid, blocks=(decode_block(b) for b in proto_interval.blocks), ) result._add_to_uuid_cache(ir._local_uuid_cache) # We store the interval and IR here so we can use it later, when # _decode_symbolic_expressions is called. result._proto_interval = proto_interval # Return the new BI. return result def _decode_symbolic_expressions(self, ir: "IR") -> None: """Called by modules after symbols are decoded, but before the module is done decoding. """ assert self._proto_interval def decode_symbolic_expression( proto_expr: SymbolicExpression_pb2.SymbolicExpression, ) -> SymbolicExpression: if proto_expr.HasField("addr_const"): return SymAddrConst._from_protobuf( proto_expr.addr_const, ir.get_by_uuid ) elif proto_expr.HasField("addr_addr"): return SymAddrAddr._from_protobuf( proto_expr.addr_addr, ir.get_by_uuid ) else: raise TypeError( "Unknown type inside proto sym expr: %s" % proto_expr.WhichOneof("value") ) self.symbolic_expressions = {} for i, v in self._proto_interval.symbolic_expressions.items(): expr = decode_symbolic_expression(v) for f in v.attribute_flags: try: expr.attributes.add(SymbolicExpression.Attribute(f)) except ValueError: expr.attributes.add(f) self.symbolic_expressions[i] = expr del self._proto_interval def _to_protobuf(self) -> ByteInterval_pb2.ByteInterval: proto_interval = ByteInterval_pb2.ByteInterval() proto_interval.uuid = self.uuid.bytes if self.address is None: proto_interval.has_address = False else: proto_interval.has_address = True proto_interval.address = self.address proto_interval.size = self.size proto_interval.contents = bytes(self.contents) # Cannot insert blocks using proto_interval.blocks.append() in a loop # because append() isn't supported in older versions of protobuf. Use a # comprehension and extend() instead. def to_proto_block(block: ByteBlock) -> ByteInterval_pb2.Block: proto_block = ByteInterval_pb2.Block() proto_block.offset = block.offset if isinstance(block, CodeBlock): proto_block.code.CopyFrom(block._to_protobuf()) elif isinstance(block, DataBlock): proto_block.data.CopyFrom(block._to_protobuf()) else: raise TypeError( "Unknown block type in interval: %s" % type(block) ) return proto_block proto_interval.blocks.extend(to_proto_block(b) for b in self.blocks) for k, v in self.symbolic_expressions.items(): sym_exp = SymbolicExpression_pb2.SymbolicExpression() if isinstance(v, SymAddrConst): sym_exp.addr_const.CopyFrom(v._to_protobuf()) elif isinstance(v, SymAddrAddr): sym_exp.addr_addr.CopyFrom(v._to_protobuf()) else: raise ValueError( "Expected sym expr type in interval: %s" % type(v) ) attrs = ( a.value if isinstance(a, SymbolicExpression.Attribute) else a for a in v.attributes ) sym_exp.attribute_flags.extend(attrs) proto_interval.symbolic_expressions[k].CopyFrom(sym_exp) return proto_interval @property def section(self) -> typing.Optional["Section"]: """The :class:`Section` this interval belongs to.""" return self._section @section.setter def section(self, value: typing.Optional["Section"]) -> None: if self._section is not None: self._section.byte_intervals.discard(self) if value is not None: value.byte_intervals.add(self) @property def symbolic_expressions( self, ) -> typing.MutableMapping[int, SymbolicExpression]: return self._symbolic_expressions @symbolic_expressions.setter def symbolic_expressions( self, value: typing.Dict[int, SymbolicExpression] ) -> None: self._symbolic_expressions.clear() self._symbolic_expressions.update(value)
[docs] def deep_eq(self, other: object) -> bool: # Do not move __eq__. See docstring for Node.deep_eq for more info. if not isinstance(other, ByteInterval): return False return ( self.uuid == other.uuid and self.address == other.address and self.contents == other.contents and self.size == other.size and len(self.blocks) == len(other.blocks) and all( self_node.deep_eq(other_node) for self_node, other_node in zip( sorted(self.blocks, key=lambda b: b.uuid), sorted(other.blocks, key=lambda b: b.uuid), ) ) and len(self.symbolic_expressions) == len(other.symbolic_expressions) and all( self_kv[0] == other_kv[0] and self_kv[1].deep_eq(other_kv[1]) for self_kv, other_kv in zip( sorted( self.symbolic_expressions.items(), key=lambda kv: kv[0] ), sorted( other.symbolic_expressions.items(), key=lambda kv: kv[0], ), ) ) )
def __repr__(self) -> str: return ( "ByteInterval(" "uuid={uuid!r}, " "address={address}, " "size={size}, " "contents={contents!r}, " "blocks={blocks!r}, " "symbolic_expressions={symbolic_expressions!r}, " ")".format( uuid=self.uuid, address=self.address, size=self.size, contents=self.contents, blocks=self.blocks, symbolic_expressions=self.symbolic_expressions, ) )
[docs] def byte_blocks_on( self, addrs: typing.Union[int, range] ) -> typing.Iterable[ByteBlock]: """Finds all the byte blocks that overlap an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ if self.address is None: return () return _nodes_on_interval_tree( self._interval_tree, addrs, -self.address )
[docs] def byte_blocks_at( self, addrs: typing.Union[int, range] ) -> typing.Iterable[ByteBlock]: """Finds all the byte blocks that begin at an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ if self.address is None: return () return _nodes_at_interval_tree( self._interval_tree, addrs, -self.address )
[docs] def code_blocks_on( self, addrs: typing.Union[int, range] ) -> typing.Iterable[CodeBlock]: """Finds all the code blocks that overlap an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ return ( b for b in self.byte_blocks_on(addrs) if isinstance(b, CodeBlock) )
[docs] def code_blocks_at( self, addrs: typing.Union[int, range] ) -> typing.Iterable[CodeBlock]: """Finds all the code blocks that begin at an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ return ( b for b in self.byte_blocks_at(addrs) if isinstance(b, CodeBlock) )
[docs] def data_blocks_on( self, addrs: typing.Union[int, range] ) -> typing.Iterable[DataBlock]: """Finds all the data blocks that overlap an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ return ( b for b in self.byte_blocks_on(addrs) if isinstance(b, DataBlock) )
[docs] def data_blocks_at( self, addrs: typing.Union[int, range] ) -> typing.Iterable[DataBlock]: """Finds all the data blocks that begin at an address or range of addresses. :param addrs: Either a ``range`` object or a single address. """ return ( b for b in self.byte_blocks_at(addrs) if isinstance(b, DataBlock) )
[docs] def symbolic_expressions_at( self, addrs: typing.Union[int, range] ) -> typing.Iterable[SymbolicExpressionElement]: """Finds all the symbolic expressions that begin at an address or range of addresses. :param addrs: Either a ``range`` object or a single address. :returns: Yields ``(interval, offset, symexpr)`` tuples for every symbolic expression in the range. """ if self.address is None: return addrs = get_desired_range(addrs) for i in self._symbolic_expressions._data.irange( addrs.start - self.address, addrs.stop - self.address, inclusive=(True, False), ): if self.address + i in addrs: yield (self, i, self.symbolic_expressions[i])
def _add_to_uuid_cache(self, cache: typing.Dict[UUID, Node]) -> None: """Update the UUID cache when this node is added.""" cache[self.uuid] = self for block in self.blocks: block._add_to_uuid_cache(cache) def _remove_from_uuid_cache(self, cache: typing.Dict[UUID, Node]) -> None: """Update the UUID cache when this node is removed.""" del cache[self.uuid] for block in self.blocks: block._remove_from_uuid_cache(cache) @property def module(self) -> typing.Optional["Module"]: """Get the module this node ultimately belongs to.""" if self.section is None: return None return self.section.module @property def ir(self) -> typing.Optional["IR"]: """Get the IR this node ultimately belongs to.""" if self.module is None: return None return self.module.ir