use alloc::string::ToString;
use alloc::vec::Vec;
use core::convert::TryInto;
use crate::io::{Read, Write};
use crate::message;
use crate::private::units::BYTES_PER_WORD;
use crate::{Error, Result};
pub struct SliceSegments<'a> {
words: &'a [u8],
segment_indices : Vec<(usize, usize)>,
}
impl <'a> message::ReaderSegments for SliceSegments<'a> {
fn get_segment<'b>(&'b self, id: u32) -> Option<&'b [u8]> {
if id < self.segment_indices.len() as u32 {
let (a, b) = self.segment_indices[id as usize];
Some(&self.words[(a * BYTES_PER_WORD)..(b * BYTES_PER_WORD)])
} else {
None
}
}
fn len(&self) -> usize {
self.segment_indices.len()
}
}
pub fn read_message_from_flat_slice<'a>(slice: &mut &'a [u8],
options: message::ReaderOptions)
-> Result<message::Reader<SliceSegments<'a>>> {
let all_bytes = *slice;
let mut bytes = *slice;
let orig_bytes_len = bytes.len();
let segment_lengths_builder = match read_segment_table(&mut bytes, options)? {
Some(b) => b,
None => return Err(Error::failed("empty slice".to_string())),
};
let segment_table_bytes_len = orig_bytes_len - bytes.len();
assert_eq!(segment_table_bytes_len % BYTES_PER_WORD, 0);
let num_words = segment_lengths_builder.total_words();
let body_bytes = &all_bytes[segment_table_bytes_len..];
if num_words > (body_bytes.len() / BYTES_PER_WORD) {
Err(Error::failed(
format!("Message ends prematurely. Header claimed {} words, but message only has {} words.",
num_words, body_bytes.len() / BYTES_PER_WORD)))
} else {
*slice = &body_bytes[(num_words * BYTES_PER_WORD)..];
Ok(message::Reader::new(segment_lengths_builder.into_slice_segments(body_bytes), options))
}
}
pub struct OwnedSegments {
segment_indices : Vec<(usize, usize)>,
owned_space: Vec<crate::Word>,
}
impl core::ops::Deref for OwnedSegments {
type Target = [u8];
fn deref(&self) -> &[u8] {
crate::Word::words_to_bytes(&self.owned_space[..])
}
}
impl core::ops::DerefMut for OwnedSegments {
fn deref_mut(&mut self) -> &mut [u8] {
crate::Word::words_to_bytes_mut(&mut self.owned_space[..])
}
}
impl crate::message::ReaderSegments for OwnedSegments {
fn get_segment<'a>(&'a self, id: u32) -> Option<&'a [u8]> {
if id < self.segment_indices.len() as u32 {
let (a, b) = self.segment_indices[id as usize];
Some(&self[(a * BYTES_PER_WORD)..(b * BYTES_PER_WORD)])
} else {
None
}
}
fn len(&self) -> usize {
self.segment_indices.len()
}
}
pub struct SegmentLengthsBuilder {
segment_indices: Vec<(usize, usize)>,
total_words: usize,
}
impl SegmentLengthsBuilder {
pub fn with_capacity(capacity: usize) -> Self {
Self {
segment_indices: Vec::with_capacity(capacity),
total_words: 0,
}
}
pub fn push_segment(&mut self, length_in_words: usize) {
self.segment_indices.push((self.total_words, self.total_words + length_in_words));
self.total_words += length_in_words;
}
pub fn into_owned_segments(self) -> OwnedSegments {
let owned_space = crate::Word::allocate_zeroed_vec(self.total_words);
OwnedSegments {
segment_indices: self.segment_indices,
owned_space,
}
}
pub fn into_slice_segments(self, slice: &[u8]) -> SliceSegments {
assert!(self.total_words * BYTES_PER_WORD <= slice.len());
SliceSegments {
words: slice,
segment_indices: self.segment_indices,
}
}
pub fn total_words(&self) -> usize {
self.total_words
}
pub fn to_segment_indices(self) -> Vec<(usize, usize)> {
self.segment_indices
}
}
pub fn read_message<R>(mut read: R, options: message::ReaderOptions) -> Result<message::Reader<OwnedSegments>>
where R: Read {
let owned_segments_builder = match read_segment_table(&mut read, options)? {
Some(b) => b,
None => return Err(Error::failed("Premature end of file".to_string())),
};
read_segments(&mut read, owned_segments_builder.into_owned_segments(), options)
}
pub fn try_read_message<R>(mut read: R, options: message::ReaderOptions) -> Result<Option<message::Reader<OwnedSegments>>>
where R: Read {
let owned_segments_builder = match read_segment_table(&mut read, options)? {
Some(b) => b,
None => return Ok(None),
};
Ok(Some(read_segments(&mut read, owned_segments_builder.into_owned_segments(), options)?))
}
fn read_segment_table<R>(read: &mut R,
options: message::ReaderOptions)
-> Result<Option<SegmentLengthsBuilder>>
where R: Read
{
let mut buf: [u8; 8] = [0; 8];
{
let n = read.read(&mut buf[..])?;
if n == 0 {
return Ok(None)
} else if n < 8 {
read.read_exact(&mut buf[n..])?;
}
}
let segment_count = u32::from_le_bytes(buf[0..4].try_into().unwrap()).wrapping_add(1) as usize;
if segment_count >= 512 {
return Err(Error::failed(format!("Too many segments: {}", segment_count)))
} else if segment_count == 0 {
return Err(Error::failed(format!("Too few segments: {}", segment_count)))
}
let mut segment_lengths_builder = SegmentLengthsBuilder::with_capacity(segment_count);
segment_lengths_builder.push_segment(u32::from_le_bytes(buf[4..8].try_into().unwrap()) as usize);
if segment_count > 1 {
if segment_count < 4 {
read.read_exact(&mut buf)?;
for idx in 0..(segment_count - 1) {
let segment_len =
u32::from_le_bytes(buf[(idx * 4)..(idx + 1) * 4].try_into().unwrap()) as usize;
segment_lengths_builder.push_segment(segment_len);
}
} else {
let mut segment_sizes = vec![0u8; (segment_count & !1) * 4];
read.read_exact(&mut segment_sizes[..])?;
for idx in 0..(segment_count - 1) {
let segment_len =
u32::from_le_bytes(segment_sizes[(idx * 4)..(idx + 1) * 4].try_into().unwrap()) as usize;
segment_lengths_builder.push_segment(segment_len);
}
}
}
if segment_lengths_builder.total_words() as u64 > options.traversal_limit_in_words {
return Err(Error::failed(
format!("Message has {} words, which is too large. To increase the limit on the \
receiving end, see capnp::message::ReaderOptions.", segment_lengths_builder.total_words())))
}
Ok(Some(segment_lengths_builder))
}
fn read_segments<R>(read: &mut R,
mut owned_segments: OwnedSegments,
options: message::ReaderOptions)
-> Result<message::Reader<OwnedSegments>>
where R: Read {
read.read_exact(&mut owned_segments[..])?;
Ok(crate::message::Reader::new(owned_segments, options))
}
pub fn write_message_to_words<A>(message: &message::Builder<A>) -> Vec<u8>
where A: message::Allocator
{
flatten_segments(&*message.get_segments_for_output())
}
pub fn write_message_segments_to_words<R>(message: &R) -> Vec<u8>
where R: message::ReaderSegments
{
flatten_segments(message)
}
fn flatten_segments<R: message::ReaderSegments + ?Sized>(segments: &R) -> Vec<u8> {
let word_count = compute_serialized_size(segments);
let segment_count = segments.len();
let table_size = segment_count / 2 + 1;
let mut result = Vec::with_capacity(word_count);
for _ in 0..(table_size * BYTES_PER_WORD) {
result.push(0);
}
{
let mut bytes = &mut result[..];
write_segment_table_internal(&mut bytes, segments).expect("Failed to write segment table.");
}
for i in 0..segment_count {
let segment = segments.get_segment(i as u32).unwrap();
for idx in 0..segment.len() {
result.push(segment[idx]);
}
}
result
}
pub fn write_message<W, A>(mut write: W, message: &message::Builder<A>) -> Result<()>
where W: Write, A: message::Allocator {
let segments = message.get_segments_for_output();
write_segment_table(&mut write, &segments)?;
write_segments(&mut write, &segments)
}
pub fn write_message_segments<W, R>(mut write: W, segments: &R) -> Result<()>
where W: Write, R: message::ReaderSegments {
write_segment_table_internal(&mut write, segments)?;
write_segments(&mut write, segments)
}
fn write_segment_table<W>(write: &mut W, segments: &[&[u8]]) -> Result<()>
where W: Write {
write_segment_table_internal(write, segments)
}
fn write_segment_table_internal<W, R>(write: &mut W, segments: &R) -> Result<()>
where W: Write, R: message::ReaderSegments + ?Sized {
let mut buf: [u8; 8] = [0; 8];
let segment_count = segments.len();
buf[0..4].copy_from_slice(&(segment_count as u32 - 1).to_le_bytes());
buf[4..8].copy_from_slice(&((segments.get_segment(0).unwrap().len() / BYTES_PER_WORD)as u32).to_le_bytes());
write.write_all(&buf)?;
if segment_count > 1 {
if segment_count < 4 {
for idx in 1..segment_count {
buf[(idx - 1) * 4..idx * 4].copy_from_slice(
&((segments.get_segment(idx as u32).unwrap().len() / BYTES_PER_WORD) as u32).to_le_bytes());
}
if segment_count == 2 {
for idx in 4..8 { buf[idx] = 0 }
}
write.write_all(&buf)?;
} else {
let mut buf = vec![0; (segment_count & !1) * 4];
for idx in 1..segment_count {
buf[(idx - 1) * 4..idx * 4].copy_from_slice(
&((segments.get_segment(idx as u32).unwrap().len() / BYTES_PER_WORD) as u32).to_le_bytes());
}
if segment_count % 2 == 0 {
for idx in (buf.len() - 4)..(buf.len()) { buf[idx] = 0 }
}
write.write_all(&buf)?;
}
}
Ok(())
}
fn write_segments<W, R: message::ReaderSegments + ?Sized>(write: &mut W, segments: &R) -> Result<()>
where W: Write {
for i in 0.. {
if let Some(segment) = segments.get_segment(i) {
write.write_all(segment)?;
} else {
break;
}
}
Ok(())
}
fn compute_serialized_size<R: message::ReaderSegments + ?Sized>(segments: &R) -> usize {
let len = segments.len();
let mut size = (len / 2) + 1;
for i in 0..len {
let segment = segments.get_segment(i as u32).unwrap();
size += segment.len() / BYTES_PER_WORD;
}
size
}
pub fn compute_serialized_size_in_words<A>(message: &crate::message::Builder<A>) -> usize
where A: crate::message::Allocator
{
compute_serialized_size(&message.get_segments_for_output())
}
#[cfg(test)]
pub mod test {
use alloc::vec::Vec;
use crate::io::{Write, Read};
use quickcheck::{quickcheck, TestResult};
use crate::message;
use crate::message::ReaderSegments;
use super::{read_message, try_read_message, read_message_from_flat_slice, flatten_segments,
read_segment_table, write_segment_table, write_segments};
pub fn write_message_segments<W>(write: &mut W, segments: &Vec<Vec<crate::Word>>) where W: Write {
let borrowed_segments: &[&[u8]] = &segments.iter()
.map(|segment| crate::Word::words_to_bytes(&segment[..]))
.collect::<Vec<_>>()[..];
write_segment_table(write, borrowed_segments).unwrap();
write_segments(write, borrowed_segments).unwrap();
}
#[test]
fn try_read_empty() {
let mut buf: &[u8] = &[];
assert!(try_read_message(&mut buf, message::ReaderOptions::new()).unwrap().is_none());
}
#[test]
fn test_read_segment_table() {
let mut buf = vec![];
buf.extend([0,0,0,0,
0,0,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(0, segment_lengths_builder.total_words());
assert_eq!(vec![(0,0)], segment_lengths_builder.to_segment_indices());
buf.clear();
buf.extend([0,0,0,0,
1,0,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(1, segment_lengths_builder.total_words());
assert_eq!(vec![(0,1)], segment_lengths_builder.to_segment_indices());
buf.clear();
buf.extend([1,0,0,0,
1,0,0,0,
1,0,0,0,
0,0,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(2, segment_lengths_builder.total_words());
assert_eq!(vec![(0,1), (1, 2)], segment_lengths_builder.to_segment_indices());
buf.clear();
buf.extend([2,0,0,0,
1,0,0,0,
1,0,0,0,
0,1,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(258, segment_lengths_builder.total_words());
assert_eq!(vec![(0,1), (1, 2), (2, 258)], segment_lengths_builder.to_segment_indices());
buf.clear();
buf.extend([3,0,0,0,
77,0,0,0,
23,0,0,0,
1,0,0,0,
99,0,0,0,
0,0,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(200, segment_lengths_builder.total_words());
assert_eq!(vec![(0,77), (77, 100), (100, 101), (101, 200)], segment_lengths_builder.to_segment_indices());
buf.clear();
}
struct MaxRead<R> where R: Read {
inner: R,
max: usize,
}
impl <R> Read for MaxRead<R> where R: Read {
fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> {
if buf.len() <= self.max {
self.inner.read(buf)
} else {
self.inner.read(&mut buf[0..self.max])
}
}
}
#[test]
fn test_read_segment_table_max_read() {
let mut buf: Vec<u8> = vec![];
buf.extend([0,0,0,0,
1,0,0,0]
.iter().cloned());
let segment_lengths_builder = read_segment_table(&mut MaxRead { inner: &buf[..], max: 2},
message::ReaderOptions::new()).unwrap().unwrap();
assert_eq!(1, segment_lengths_builder.total_words());
assert_eq!(vec![(0,1)], segment_lengths_builder.to_segment_indices());
}
#[test]
fn test_read_invalid_segment_table() {
let mut buf = vec![];
buf.extend([0,2,0,0].iter().cloned());
buf.extend([0; 513 * 4].iter().cloned());
assert!(read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).is_err());
buf.clear();
buf.extend([0,0,0,0].iter().cloned());
assert!(read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).is_err());
buf.clear();
buf.extend([0,0,0,0].iter().cloned());
buf.extend([0; 3].iter().cloned());
assert!(read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).is_err());
buf.clear();
buf.extend([255,255,255,255].iter().cloned());
assert!(read_segment_table(&mut &buf[..],
message::ReaderOptions::new()).is_err());
buf.clear();
}
#[test]
fn test_write_segment_table() {
let mut buf = vec![];
let segment_0 = [0u8; 0];
let segment_1 = [1u8,1,1,1,1,1,1,1];
let segment_199 = [201u8; 199 * 8];
write_segment_table(&mut buf, &[&segment_0]).unwrap();
assert_eq!(&[0,0,0,0,
0,0,0,0],
&buf[..]);
buf.clear();
write_segment_table(&mut buf, &[&segment_1]).unwrap();
assert_eq!(&[0,0,0,0,
1,0,0,0],
&buf[..]);
buf.clear();
write_segment_table(&mut buf, &[&segment_199]).unwrap();
assert_eq!(&[0,0,0,0,
199,0,0,0],
&buf[..]);
buf.clear();
write_segment_table(&mut buf, &[&segment_0, &segment_1]).unwrap();
assert_eq!(&[1,0,0,0,
0,0,0,0,
1,0,0,0,
0,0,0,0],
&buf[..]);
buf.clear();
write_segment_table(&mut buf,
&[&segment_199, &segment_1, &segment_199, &segment_0]).unwrap();
assert_eq!(&[3,0,0,0,
199,0,0,0,
1,0,0,0,
199,0,0,0,
0,0,0,0,
0,0,0,0],
&buf[..]);
buf.clear();
write_segment_table(&mut buf,
&[&segment_199, &segment_1, &segment_199, &segment_0, &segment_1]).unwrap();
assert_eq!(&[4,0,0,0,
199,0,0,0,
1,0,0,0,
199,0,0,0,
0,0,0,0,
1,0,0,0],
&buf[..]);
buf.clear();
}
#[test]
#[cfg_attr(miri, ignore)]
fn check_round_trip() {
fn round_trip(segments: Vec<Vec<crate::Word>>) -> TestResult {
if segments.len() == 0 { return TestResult::discard(); }
let mut buf: Vec<u8> = Vec::new();
write_message_segments(&mut buf, &segments);
let message = read_message(&mut &buf[..], message::ReaderOptions::new()).unwrap();
let result_segments = message.into_segments();
TestResult::from_bool(segments.iter().enumerate().all(|(i, segment)| {
crate::Word::words_to_bytes(&segment[..]) == result_segments.get_segment(i as u32).unwrap()
}))
}
quickcheck(round_trip as fn(Vec<Vec<crate::Word>>) -> TestResult);
}
#[test]
#[cfg_attr(miri, ignore)]
fn check_round_trip_slice_segments() {
fn round_trip(segments: Vec<Vec<crate::Word>>) -> TestResult {
if segments.len() == 0 { return TestResult::discard(); }
let borrowed_segments: &[&[u8]] = &segments.iter()
.map(|segment| crate::Word::words_to_bytes(&segment[..]))
.collect::<Vec<_>>()[..];
let words = flatten_segments(borrowed_segments);
let mut word_slice = &words[..];
let message = read_message_from_flat_slice(&mut word_slice, message::ReaderOptions::new()).unwrap();
assert!(word_slice.is_empty());
let result_segments = message.into_segments();
TestResult::from_bool(segments.iter().enumerate().all(|(i, segment)| {
crate::Word::words_to_bytes(&segment[..]) == result_segments.get_segment(i as u32).unwrap()
}))
}
quickcheck(round_trip as fn(Vec<Vec<crate::Word>>) -> TestResult);
}
#[test]
fn read_message_from_flat_slice_with_remainder() {
let segments = vec![vec![123,0,0,0,0,0,0,0],
vec![4,0,0,0,0,0,0,0,
5,0,0,0,0,0,0,0]];
let borrowed_segments: &[&[u8]] = &segments.iter()
.map(|segment| &segment[..])
.collect::<Vec<_>>()[..];
let mut bytes = flatten_segments(borrowed_segments);
let extra_bytes: &[u8] = &[9,9,9,9,9,9,9,9,8,7,6,5,4,3,2,1];
for &b in extra_bytes { bytes.push(b); }
let mut byte_slice = &bytes[..];
let message = read_message_from_flat_slice(&mut byte_slice, message::ReaderOptions::new()).unwrap();
assert_eq!(byte_slice, extra_bytes);
let result_segments = message.into_segments();
for idx in 0..segments.len() {
assert_eq!(
segments[idx],
result_segments.get_segment(idx as u32).expect("segment should exist"));
}
}
#[test]
fn read_message_from_flat_slice_too_short() {
let segments = vec![vec![1,0,0,0,0,0,0,0],
vec![2,0,0,0,0,0,0,0,
3,0,0,0,0,0,0,0]];
let borrowed_segments: &[&[u8]] = &segments.iter()
.map(|segment| &segment[..])
.collect::<Vec<_>>()[..];
let mut bytes = flatten_segments(borrowed_segments);
while !bytes.is_empty() {
bytes.pop();
assert!(read_message_from_flat_slice(&mut &bytes[..], message::ReaderOptions::new()).is_err());
}
}
#[test]
fn compute_serialized_size() {
const LIST_LENGTH_IN_WORDS: u32 = 5;
let mut m = message::Builder::new_default();
{
let root: crate::any_pointer::Builder = m.init_root();
let _list_builder: crate::primitive_list::Builder<u64> = root.initn_as(LIST_LENGTH_IN_WORDS);
}
assert_eq!(super::compute_serialized_size_in_words(&m) as u32, 1 + 1 + LIST_LENGTH_IN_WORDS)
}
}