use alloc::vec::Vec;
use core::cell::{Cell, RefCell};
use core::slice;
use core::u64;
use crate::private::units::*;
use crate::message;
use crate::message::{Allocator, ReaderSegments};
use crate::{Error, OutputSegments, Result};
pub type SegmentId = u32;
pub struct ReadLimiter {
pub limit: Cell<u64>,
}
impl ReadLimiter {
pub fn new(limit: u64) -> ReadLimiter {
ReadLimiter { limit: Cell::new(limit) }
}
#[inline]
pub fn can_read(&self, amount: u64) -> Result<()> {
let current = self.limit.get();
if amount > current {
Err(Error::failed(format!("read limit exceeded")))
} else {
self.limit.set(current - amount);
Ok(())
}
}
}
pub trait ReaderArena {
fn get_segment(&self, id: u32) -> Result<(*const u8, u32)>;
fn check_offset(&self, segment_id: u32, start: *const u8, offset_in_words: i32) -> Result<*const u8>;
fn contains_interval(&self, segment_id: u32, start: *const u8, size: usize) -> Result<()>;
fn amplified_read(&self, virtual_amount: u64) -> Result<()>;
}
pub struct ReaderArenaImpl<S> {
segments: S,
read_limiter: ReadLimiter,
}
impl <S> ReaderArenaImpl <S> where S: ReaderSegments {
pub fn new(segments: S,
options: message::ReaderOptions)
-> Self
{
let limiter = ReadLimiter::new(options.traversal_limit_in_words);
ReaderArenaImpl {
segments: segments,
read_limiter: limiter,
}
}
pub fn into_segments(self) -> S {
self.segments
}
}
impl <S> ReaderArena for ReaderArenaImpl<S> where S: ReaderSegments {
fn get_segment<'a>(&'a self, id: u32) -> Result<(*const u8, u32)> {
match self.segments.get_segment(id) {
Some(seg) => {
#[cfg(not(feature = "unaligned"))]
{
if seg.as_ptr() as usize % BYTES_PER_WORD != 0 {
return Err(Error::failed(
format!("Detected unaligned segment. You must either ensure all of your \
segments are 8-byte aligned, or you must enable the \"unaligned\" \
feature in the capnp crate")))
}
}
Ok((seg.as_ptr(), (seg.len() / BYTES_PER_WORD) as u32))
}
None => Err(Error::failed(format!("Invalid segment id: {}", id))),
}
}
fn check_offset(&self, segment_id: u32, start: *const u8, offset_in_words: i32) -> Result<*const u8> {
let (segment_start, segment_len) = self.get_segment(segment_id)?;
let this_start: usize = segment_start as usize;
let this_size: usize = segment_len as usize * BYTES_PER_WORD;
let offset: i64 = offset_in_words as i64 * BYTES_PER_WORD as i64;
let start_idx = start as usize;
if start_idx < this_start || ((start_idx - this_start) as i64 + offset) as usize > this_size {
Err(Error::failed(format!("message contained out-of-bounds pointer")))
} else {
unsafe { Ok(start.offset(offset as isize)) }
}
}
fn contains_interval(&self, id: u32, start: *const u8, size_in_words: usize) -> Result<()> {
let (segment_start, segment_len) = self.get_segment(id)?;
let this_start: usize = segment_start as usize;
let this_size: usize = segment_len as usize * BYTES_PER_WORD;
let start = start as usize;
let size = size_in_words * BYTES_PER_WORD;
if !(start >= this_start && start - this_start + size <= this_size) {
Err(Error::failed(format!("message contained out-of-bounds pointer")))
} else {
self.read_limiter.can_read(size_in_words as u64)
}
}
fn amplified_read(&self, virtual_amount: u64) -> Result<()> {
self.read_limiter.can_read(virtual_amount)
}
}
pub trait BuilderArena: ReaderArena {
fn allocate(&self, segment_id: u32, amount: WordCount32) -> Option<u32>;
fn allocate_anywhere(&self, amount: u32) -> (SegmentId, u32);
fn get_segment_mut(&self, id: u32) -> (*mut u8, u32);
fn as_reader<'a>(&'a self) -> &'a dyn ReaderArena;
}
struct BuilderSegment {
ptr: *mut u8,
capacity: u32,
allocated: u32,
}
pub struct BuilderArenaImplInner<A> where A: Allocator {
allocator: Option<A>,
segments: Vec<BuilderSegment>,
}
pub struct BuilderArenaImpl<A> where A: Allocator {
inner: RefCell<BuilderArenaImplInner<A>>
}
impl <A> BuilderArenaImpl<A> where A: Allocator {
pub fn new(allocator: A) -> Self {
BuilderArenaImpl {
inner: RefCell::new(BuilderArenaImplInner {
allocator: Some(allocator),
segments: Vec::new(),
}),
}
}
pub fn allocate_segment(&self, minimum_size: u32) -> Result<()> {
self.inner.borrow_mut().allocate_segment(minimum_size)
}
pub fn get_segments_for_output<'a>(&'a self) -> OutputSegments<'a> {
let reff = self.inner.borrow();
if reff.segments.len() == 1 {
let seg = &reff.segments[0];
let slice = unsafe { slice::from_raw_parts(seg.ptr as *const _, seg.allocated as usize * BYTES_PER_WORD) };
OutputSegments::SingleSegment([slice])
} else {
let mut v = Vec::with_capacity(reff.segments.len());
for ref seg in &reff.segments {
let slice = unsafe { slice::from_raw_parts(seg.ptr as *const _, seg.allocated as usize * BYTES_PER_WORD) };
v.push(slice);
}
OutputSegments::MultiSegment(v)
}
}
pub fn len(&self) -> usize {
self.inner.borrow().segments.len()
}
pub fn into_allocator(self) -> A {
let mut inner = self.inner.into_inner();
inner.deallocate_all();
inner.allocator.take().unwrap()
}
}
impl <A> ReaderArena for BuilderArenaImpl<A> where A: Allocator {
fn get_segment(&self, id: u32) -> Result<(*const u8, u32)> {
let borrow = self.inner.borrow();
let seg = &borrow.segments[id as usize];
Ok((seg.ptr, seg.allocated))
}
fn check_offset(&self, _segment_id: u32, start: *const u8, offset_in_words: i32) -> Result<*const u8> {
unsafe { Ok(start.offset((offset_in_words as i64 * BYTES_PER_WORD as i64) as isize)) }
}
fn contains_interval(&self, _id: u32, _start: *const u8, _size: usize) -> Result<()> {
Ok(())
}
fn amplified_read(&self, _virtual_amount: u64) -> Result<()> {
Ok(())
}
}
impl <A> BuilderArenaImplInner<A> where A: Allocator {
fn allocate_segment(&mut self, minimum_size: WordCount32) -> Result<()> {
let seg = match self.allocator {
Some(ref mut a) => a.allocate_segment(minimum_size),
None => unreachable!(),
};
self.segments.push(BuilderSegment { ptr: seg.0, capacity: seg.1, allocated: 0});
Ok(())
}
fn allocate(&mut self, segment_id: u32, amount: WordCount32) -> Option<u32> {
let ref mut seg = &mut self.segments[segment_id as usize];
if amount > seg.capacity - seg.allocated {
None
} else {
let result = seg.allocated;
seg.allocated += amount;
Some(result)
}
}
fn allocate_anywhere(&mut self, amount: u32) -> (SegmentId, u32) {
let allocated_len = self.segments.len() as u32;
for segment_id in 0.. allocated_len {
match self.allocate(segment_id, amount) {
Some(idx) => return (segment_id, idx),
None => (),
}
}
self.allocate_segment(amount).expect("allocate new segment");
(allocated_len,
self.allocate(allocated_len, amount).expect("use freshly-allocated segment"))
}
fn deallocate_all(&mut self) {
if let Some(ref mut a) = self.allocator {
for ref seg in &self.segments {
a.deallocate_segment(seg.ptr, seg.capacity, seg.allocated);
}
}
}
fn get_segment_mut(&mut self, id: u32) -> (*mut u8, u32) {
let seg = &self.segments[id as usize];
(seg.ptr, seg.capacity)
}
}
impl <A> BuilderArena for BuilderArenaImpl<A> where A: Allocator {
fn allocate(&self, segment_id: u32, amount: WordCount32) -> Option<u32> {
self.inner.borrow_mut().allocate(segment_id, amount)
}
fn allocate_anywhere(&self, amount: u32) -> (SegmentId, u32) {
self.inner.borrow_mut().allocate_anywhere(amount)
}
fn get_segment_mut(&self, id: u32) -> (*mut u8, u32) {
self.inner.borrow_mut().get_segment_mut(id)
}
fn as_reader<'a>(&'a self) -> &'a dyn ReaderArena {
self
}
}
impl <A> Drop for BuilderArenaImplInner<A> where A: Allocator {
fn drop(&mut self) {
self.deallocate_all()
}
}
pub struct NullArena;
impl ReaderArena for NullArena {
fn get_segment(&self, _id: u32) -> Result<(*const u8, u32)> {
Err(Error::failed(format!("tried to read from null arena")))
}
fn check_offset(&self, _segment_id: u32, start: *const u8, offset_in_words: i32) -> Result<*const u8> {
unsafe { Ok(start.offset((offset_in_words as usize * BYTES_PER_WORD)as isize)) }
}
fn contains_interval(&self, _id: u32, _start: *const u8, _size: usize) -> Result<()> {
Ok(())
}
fn amplified_read(&self, _virtual_amount: u64) -> Result<()> {
Ok(())
}
}
impl BuilderArena for NullArena {
fn allocate(&self, _segment_id: u32, _amount: WordCount32) -> Option<u32> {
None
}
fn allocate_anywhere(&self, _amount: u32) -> (SegmentId, u32) {
panic!("tried to allocate from a null arena")
}
fn get_segment_mut(&self, _id: u32) -> (*mut u8, u32) {
(core::ptr::null_mut(), 0)
}
fn as_reader<'a>(&'a self) -> &'a dyn ReaderArena {
self
}
}