summary refs log blame commit diff
path: root/msg_socket/src/lib.rs
blob: be8607018bd946632ba6e5b3540d72f12f6cb97c (plain) (tree)
1
2
3
4
5
6
7
8
9



                                                                         

                  
                               
                             
                                        

                               
 


                                    
                               



                                                                                             
 
                                




                                                                                            
                                                                           
                                                

                                          
                              



                                                      
                        





                                                      
                                                     





                            

                                                                   
                                                                    

                                

 

                                                    
                        




                                 
                                               








                                                      
                        




                                   
                                                 






                            

                                                                               



                  





                                                                  

                                                         



                  





                                            

                                                           



                  





                                              






                                                                            

                                      



                                                                            
                         
                                                 

                                                                                            
                                                        
                                                                  
                                          
         

              
 
                                                                  
                                       

                                           

                                                                                               
                 
                                                    



                                                                                    

             
 
                                                                


                                           
                                                   







                                                  
                                                      
                                                                                         
                                            





                                              















                                                    





                                                                      
 





                                                  


                                                              
                               



                                                                                   
                                                                            


















































                                                                                            
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

mod msg_on_socket;

use std::io::{IoSlice, Result};
use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::Stream;
use libc::{EWOULDBLOCK, O_NONBLOCK};

use cros_async::add_read_waker;
use sys_util::{
    add_fd_flags, clear_fd_flags, error, handle_eintr, net::UnixSeqpacket, Error as SysError,
    ScmSocket,
};

pub use crate::msg_on_socket::*;
pub use msg_on_socket_derive::*;

/// Create a pair of socket. Request is send in one direction while response is in the other
/// direction.
pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>(
) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> {
    let (sock1, sock2) = UnixSeqpacket::pair()?;
    let requester = MsgSocket::new(sock1);
    let responder = MsgSocket::new(sock2);
    Ok((requester, responder))
}

/// Bidirection sock that support both send and recv.
pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> {
    sock: UnixSeqpacket,
    _i: PhantomData<I>,
    _o: PhantomData<O>,
}

impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
    // Create a new MsgSocket.
    pub fn new(s: UnixSeqpacket) -> MsgSocket<I, O> {
        MsgSocket {
            sock: s,
            _i: PhantomData,
            _o: PhantomData,
        }
    }

    // Creates an async receiver that implements `futures::Stream`.
    pub fn async_receiver(&self) -> MsgResult<AsyncReceiver<I, O>> {
        AsyncReceiver::new(self)
    }
}

/// One direction socket that only supports sending.
pub struct Sender<M: MsgOnSocket> {
    sock: UnixSeqpacket,
    _m: PhantomData<M>,
}

impl<M: MsgOnSocket> Sender<M> {
    /// Create a new sender sock.
    pub fn new(s: UnixSeqpacket) -> Sender<M> {
        Sender {
            sock: s,
            _m: PhantomData,
        }
    }
}

/// One direction socket that only supports receiving.
pub struct Receiver<M: MsgOnSocket> {
    sock: UnixSeqpacket,
    _m: PhantomData<M>,
}

impl<M: MsgOnSocket> Receiver<M> {
    /// Create a new receiver sock.
    pub fn new(s: UnixSeqpacket) -> Receiver<M> {
        Receiver {
            sock: s,
            _m: PhantomData,
        }
    }
}

impl<I: MsgOnSocket, O: MsgOnSocket> AsRef<UnixSeqpacket> for MsgSocket<I, O> {
    fn as_ref(&self) -> &UnixSeqpacket {
        &self.sock
    }
}

impl<I: MsgOnSocket, O: MsgOnSocket> AsRawFd for MsgSocket<I, O> {
    fn as_raw_fd(&self) -> RawFd {
        self.sock.as_raw_fd()
    }
}

impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Sender<M> {
    fn as_ref(&self) -> &UnixSeqpacket {
        &self.sock
    }
}

impl<M: MsgOnSocket> AsRawFd for Sender<M> {
    fn as_raw_fd(&self) -> RawFd {
        self.sock.as_raw_fd()
    }
}

impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Receiver<M> {
    fn as_ref(&self) -> &UnixSeqpacket {
        &self.sock
    }
}

impl<M: MsgOnSocket> AsRawFd for Receiver<M> {
    fn as_raw_fd(&self) -> RawFd {
        self.sock.as_raw_fd()
    }
}

pub trait UnixSeqpacketExt {
    fn send_msg_on_socket<M: MsgOnSocket>(&self, msg: &M) -> MsgResult<()>;
    fn recv_msg_on_socket<M: MsgOnSocket>(&self) -> MsgResult<M>;
}

impl UnixSeqpacketExt for UnixSeqpacket {
    fn send_msg_on_socket<M: MsgOnSocket>(&self, msg: &M) -> MsgResult<()> {
        let msg_size = msg.msg_size();
        let fd_size = msg.fd_count();
        let mut msg_buffer: Vec<u8> = vec![0; msg_size];
        let mut fd_buffer: Vec<RawFd> = vec![0; fd_size];

        let fd_size = msg.write_to_buffer(&mut msg_buffer, &mut fd_buffer)?;
        if fd_size == 0 {
            handle_eintr!(self.send(&msg_buffer))
                .map_err(|e| MsgError::Send(SysError::new(e.raw_os_error().unwrap_or(0))))?;
        } else {
            let ioslice = IoSlice::new(&msg_buffer[..]);
            self.send_with_fds(&[ioslice], &fd_buffer[0..fd_size])
                .map_err(MsgError::Send)?;
        }
        Ok(())
    }

    fn recv_msg_on_socket<M: MsgOnSocket>(&self) -> MsgResult<M> {
        let (msg_buffer, fd_buffer) = {
            if M::uses_fd() {
                self.recv_as_vec_with_fds()
                    .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?
            } else {
                (
                    self.recv_as_vec().map_err(|e| {
                        MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0)))
                    })?,
                    vec![],
                )
            }
        };

        if msg_buffer.len() == 0 && M::fixed_size() != Some(0) {
            return Err(MsgError::RecvZero);
        }

        if let Some(fixed_size) = M::fixed_size() {
            if fixed_size != msg_buffer.len() {
                return Err(MsgError::BadRecvSize {
                    expected: fixed_size,
                    actual: msg_buffer.len(),
                });
            }
        }

        // Safe because fd buffer is read from socket.
        let (v, read_fd_size) = unsafe { M::read_from_buffer(&msg_buffer, &fd_buffer)? };
        if fd_buffer.len() != read_fd_size {
            return Err(MsgError::NotExpectFd);
        }
        Ok(v)
    }
}

/// Types that could send a message.
pub trait MsgSender: AsRef<UnixSeqpacket> {
    type M: MsgOnSocket;
    fn send(&self, msg: &Self::M) -> MsgResult<()> {
        self.as_ref().send_msg_on_socket(msg)
    }
}

/// Types that could receive a message.
pub trait MsgReceiver: AsRef<UnixSeqpacket> {
    type M: MsgOnSocket;
    fn recv(&self) -> MsgResult<Self::M> {
        self.as_ref().recv_msg_on_socket()
    }
}

impl<I: MsgOnSocket, O: MsgOnSocket> MsgSender for MsgSocket<I, O> {
    type M = I;
}
impl<I: MsgOnSocket, O: MsgOnSocket> MsgReceiver for MsgSocket<I, O> {
    type M = O;
}

impl<I: MsgOnSocket> MsgSender for Sender<I> {
    type M = I;
}
impl<O: MsgOnSocket> MsgReceiver for Receiver<O> {
    type M = O;
}

/// Asynchronous adaptor for `MsgSocket`.
pub struct AsyncReceiver<'a, I: MsgOnSocket, O: MsgOnSocket> {
    inner: &'a MsgSocket<I, O>,
    done: bool, // Have hit an error and the Stream should return null when polled.
}

impl<'a, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'a, I, O> {
    fn new(msg_socket: &MsgSocket<I, O>) -> MsgResult<AsyncReceiver<I, O>> {
        add_fd_flags(msg_socket.as_raw_fd(), O_NONBLOCK).map_err(MsgError::SettingFdFlags)?;
        Ok(AsyncReceiver {
            inner: msg_socket,
            done: false,
        })
    }
}

impl<'a, I: MsgOnSocket, O: MsgOnSocket> Drop for AsyncReceiver<'a, I, O> {
    fn drop(&mut self) {
        if let Err(e) = clear_fd_flags(self.inner.as_raw_fd(), O_NONBLOCK) {
            error!(
                "Failed to restore non-blocking behavior to message socket: {}",
                e
            );
        }
    }
}

impl<'a, I: MsgOnSocket, O: MsgOnSocket> Stream for AsyncReceiver<'a, I, O> {
    type Item = MsgResult<O>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        if self.done {
            return Poll::Ready(None);
        }

        let ret = match self.inner.recv() {
            Ok(msg) => Ok(Poll::Ready(Some(Ok(msg)))),
            Err(MsgError::Recv(e)) => {
                if e.errno() == EWOULDBLOCK {
                    add_read_waker(self.inner.as_raw_fd(), cx.waker().clone())
                        .map(|_| Poll::Pending)
                        .map_err(MsgError::AddingWaker)
                } else {
                    Err(MsgError::Recv(e))
                }
            }
            Err(e) => Err(e),
        };

        match ret {
            Ok(p) => p,
            Err(e) => {
                // Indicate something went wrong and no more events will be provided.
                self.done = true;
                Poll::Ready(Some(Err(e)))
            }
        }
    }
}