use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use nix::{
sys::{socket, uio},
Result as NixResult,
};
use crate::wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
pub const MAX_FDS_OUT: usize = 28;
pub const MAX_BYTES_OUT: usize = 4096;
pub struct Socket {
fd: RawFd,
}
impl Socket {
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
let iov = [uio::IoVec::from_slice(bytes)];
if !fds.is_empty() {
let cmsgs = [socket::ControlMessage::ScmRights(fds)];
socket::sendmsg(self.fd, &iov, &cmsgs, socket::MsgFlags::MSG_DONTWAIT, None)?;
} else {
socket::sendmsg(self.fd, &iov, &[], socket::MsgFlags::MSG_DONTWAIT, None)?;
};
Ok(())
}
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
let iov = [uio::IoVec::from_mut_slice(buffer)];
let msg =
socket::recvmsg(self.fd, &iov[..], Some(&mut cmsg), socket::MsgFlags::MSG_DONTWAIT)?;
let mut fd_count = 0;
let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
socket::ControlMessageOwned::ScmRights(s) => s,
_ => Vec::new(),
});
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd;
}
Ok((msg.bytes, fd_count))
}
}
impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Socket {
Socket { fd }
}
}
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl IntoRawFd for Socket {
fn into_raw_fd(self) -> RawFd {
self.fd
}
}
impl Drop for Socket {
fn drop(&mut self) {
let _ = ::nix::unistd::close(self.fd);
}
}
pub struct BufferedSocket {
socket: Socket,
in_data: Buffer<u32>,
in_fds: Buffer<RawFd>,
out_data: Buffer<u32>,
out_fds: Buffer<RawFd>,
}
impl BufferedSocket {
pub fn new(socket: Socket) -> BufferedSocket {
BufferedSocket {
socket,
in_data: Buffer::new(2 * MAX_BYTES_OUT / 4),
in_fds: Buffer::new(2 * MAX_FDS_OUT),
out_data: Buffer::new(MAX_BYTES_OUT / 4),
out_fds: Buffer::new(MAX_FDS_OUT),
}
}
pub fn get_socket(&mut self) -> &mut Socket {
&mut self.socket
}
pub fn into_socket(self) -> Socket {
self.socket
}
pub fn flush(&mut self) -> NixResult<()> {
{
let words = self.out_data.get_contents();
if words.is_empty() {
return Ok(());
}
let bytes = unsafe {
::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
};
let fds = self.out_fds.get_contents();
self.socket.send_msg(bytes, fds)?;
for &fd in fds {
let _ = ::nix::unistd::close(fd);
}
}
self.out_data.clear();
self.out_fds.clear();
Ok(())
}
fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
match msg.write_to_buffers(
self.out_data.get_writable_storage(),
self.out_fds.get_writable_storage(),
) {
Ok((bytes_out, fds_out)) => {
self.out_data.advance(bytes_out);
self.out_fds.advance(fds_out);
Ok(true)
}
Err(MessageWriteError::BufferTooSmall) => Ok(false),
Err(MessageWriteError::DupFdFailed(e)) => Err(e),
}
}
pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
if !self.attempt_write_message(msg)? {
self.flush()?;
if !self.attempt_write_message(msg)? {
return Err(::nix::Error::Sys(::nix::errno::Errno::E2BIG));
}
}
Ok(())
}
pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
if !self.in_data.has_content() {
self.in_data.clear();
}
if !self.in_fds.has_content() {
self.in_fds.clear();
}
let (in_bytes, in_fds) = {
let words = self.in_data.get_writable_storage();
let bytes = unsafe {
::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
};
let fds = self.in_fds.get_writable_storage();
self.socket.rcv_msg(bytes, fds)?
};
if in_bytes == 0 {
return Err(::nix::Error::Sys(::nix::errno::Errno::EPIPE));
}
self.in_data.advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
self.in_fds.advance(in_fds);
Ok(())
}
pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
where
F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
{
let (msg, read_data, read_fd) = {
let data = self.in_data.get_contents();
let fds = self.in_fds.get_contents();
if data.len() < 2 {
return Err(MessageParseError::MissingData);
}
let object_id = data[0];
let opcode = (data[1] & 0x0000_FFFF) as u16;
if let Some(sig) = signature(object_id, opcode) {
match Message::from_raw(data, sig, fds) {
Ok((msg, rest_data, rest_fds)) => {
(msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
}
Err(e) => return Err(e),
}
} else {
return Err(MessageParseError::Malformed);
}
};
self.in_data.offset(read_data);
self.in_fds.offset(read_fd);
Ok(msg)
}
pub fn read_messages<F1, F2>(
&mut self,
mut signature: F1,
mut callback: F2,
) -> NixResult<Result<usize, MessageParseError>>
where
F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
F2: FnMut(Message) -> bool,
{
let mut dispatched = 0;
loop {
let mut err = None;
loop {
match self.read_one_message(&mut signature) {
Ok(msg) => {
let keep_going = callback(msg);
dispatched += 1;
if !keep_going {
break;
}
}
Err(e) => {
err = Some(e);
break;
}
}
}
self.in_data.move_to_front();
self.in_fds.move_to_front();
if let Some(MessageParseError::Malformed) = err {
return Ok(Err(MessageParseError::Malformed));
}
if err.is_none() && self.in_data.has_content() {
return Ok(Ok(dispatched));
}
match self.fill_incoming_buffers() {
Ok(()) => (),
Err(e @ ::nix::Error::Sys(::nix::errno::Errno::EAGAIN)) => {
if dispatched == 0 {
return Err(e);
} else {
break;
}
}
Err(e) => return Err(e),
}
}
Ok(Ok(dispatched))
}
}
struct Buffer<T: Copy> {
storage: Vec<T>,
occupied: usize,
offset: usize,
}
impl<T: Copy + Default> Buffer<T> {
fn new(size: usize) -> Buffer<T> {
Buffer { storage: vec![T::default(); size], occupied: 0, offset: 0 }
}
fn has_content(&self) -> bool {
self.occupied > self.offset
}
fn advance(&mut self, bytes: usize) {
self.occupied += bytes;
}
fn offset(&mut self, bytes: usize) {
self.offset += bytes;
}
fn clear(&mut self) {
self.occupied = 0;
self.offset = 0;
}
fn get_contents(&self) -> &[T] {
&self.storage[(self.offset)..(self.occupied)]
}
fn get_writable_storage(&mut self) -> &mut [T] {
&mut self.storage[(self.occupied)..]
}
fn move_to_front(&mut self) {
unsafe {
::std::ptr::copy(
&self.storage[self.offset] as *const T,
&mut self.storage[0] as *mut T,
self.occupied - self.offset,
);
}
self.occupied -= self.offset;
self.offset = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire::{Argument, ArgumentType, Message};
use std::ffi::CString;
use smallvec::smallvec;
fn same_file(a: RawFd, b: RawFd) -> bool {
let stat1 = ::nix::sys::stat::fstat(a).unwrap();
let stat2 = ::nix::sys::stat::fstat(b).unwrap();
stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
}
fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
assert_eq!(msg1.sender_id, msg2.sender_id);
assert_eq!(msg1.opcode, msg2.opcode);
assert_eq!(msg1.args.len(), msg2.args.len());
for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
assert!(same_file(fd1, fd2));
} else {
assert_eq!(arg1, arg2);
}
}
}
#[test]
fn write_read_cycle() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Uint(3),
Argument::Fixed(-89),
Argument::Str(Box::new(CString::new(&b"I like trains!"[..]).unwrap())),
Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
Argument::Object(88),
Argument::NewId(56),
Argument::Int(-25),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] = &[
ArgumentType::Uint,
ArgumentType::Fixed,
ArgumentType::Str,
ArgumentType::Array,
ArgumentType::Object,
ArgumentType::NewId,
ArgumentType::Int,
];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
#[test]
fn write_read_cycle_fd() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Fd(1),
Argument::Fd(0),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
#[test]
fn write_read_cycle_multiple() {
let messages = [
Message {
sender_id: 42,
opcode: 0,
args: smallvec![
Argument::Int(42),
Argument::Str(Box::new(CString::new(&b"I like trains"[..]).unwrap())),
],
},
Message {
sender_id: 42,
opcode: 1,
args: smallvec![
Argument::Fd(1),
Argument::Fd(0),
],
},
Message {
sender_id: 42,
opcode: 2,
args: smallvec![
Argument::Uint(3),
Argument::Fd(2),
],
},
];
static SIGNATURES: &'static [&'static [ArgumentType]] = &[
&[ArgumentType::Int, ArgumentType::Str],
&[ArgumentType::Fd, ArgumentType::Fd],
&[ArgumentType::Uint, ArgumentType::Fd],
];
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
for msg in &messages {
client.write_message(msg).unwrap();
}
client.flush().unwrap();
let mut recv_msgs = Vec::new();
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 {
Some(SIGNATURES[opcode as usize])
} else {
None
}
},
|message| {
recv_msgs.push(message);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 3);
assert_eq!(recv_msgs.len(), 3);
for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
assert_eq_msgs(msg1, msg2);
}
}
#[test]
fn parse_with_string_len_multiple_of_4() {
let msg = Message {
sender_id: 2,
opcode: 0,
args: smallvec![
Argument::Uint(18),
Argument::Str(Box::new(CString::new(&b"wl_shell"[..]).unwrap())),
Argument::Uint(1),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] =
&[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 2 && opcode == 0 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
}