diff --git a/src/main.rs b/src/main.rs index f8df2ef..0df2a4f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ extern crate log; pub(crate) mod config; pub(crate) mod monitor; -pub(crate) mod protocol; +pub(crate) mod proto; pub(crate) mod server; pub(crate) mod types; @@ -20,13 +20,11 @@ use minecraft_protocol::version::v1_14_4::handshake::Handshake; use minecraft_protocol::version::v1_14_4::login::LoginDisconnect; use minecraft_protocol::version::v1_14_4::status::StatusResponse; use tokio::io; -use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; -use tokio::net::tcp::ReadHalf; use tokio::net::{TcpListener, TcpStream}; use config::*; -use protocol::{Client, ClientState, RawPacket}; +use proto::{Client, ClientState, RawPacket}; use server::ServerState; #[tokio::main] @@ -35,31 +33,23 @@ async fn main() -> Result<(), ()> { let _ = dotenv::dotenv(); pretty_env_logger::init(); - info!( - "Proxying public {} to internal {}", - ADDRESS_PUBLIC, ADDRESS_PROXY, - ); - let server_state = Arc::new(ServerState::default()); // Listen for new connections // TODO: do not drop error here - let listener = TcpListener::bind(ADDRESS_PUBLIC).await.map_err(|_| ())?; + let listener = TcpListener::bind(ADDRESS_PUBLIC).await.map_err(|err| { + error!("Failed to start: {}", err); + () + })?; - // Spawn server monitor - let addr = ADDRESS_PROXY.parse().expect("invalid server IP"); - tokio::spawn(monitor::monitor_server(addr, server_state.clone())); + info!( + "Proxying egress {} to ingress {}", + ADDRESS_PUBLIC, ADDRESS_PROXY, + ); - let sub = server_state.clone(); - tokio::spawn(async move { - loop { - tokio::signal::ctrl_c().await.unwrap(); - if !sub.kill_server() { - // TODO: gracefully kill itself instead - std::process::exit(1) - } - } - }); + // Spawn server monitor and signal handler + tokio::spawn(server_monitor(server_state.clone())); + tokio::spawn(signal_handler(server_state.clone())); // Proxy all incomming connections while let Ok((inbound, _)) = listener.accept().await { @@ -67,7 +57,7 @@ async fn main() -> Result<(), ()> { if !server_state.online() { // When server is not online, spawn a status server - let transfer = status_server(client, inbound, server_state.clone()).map(|r| { + let transfer = serve_status(client, inbound, server_state.clone()).map(|r| { if let Err(err) = r { error!("Failed to serve status: {:?}", err); } @@ -78,7 +68,7 @@ async fn main() -> Result<(), ()> { // When server is online, proxy all let transfer = proxy(inbound, ADDRESS_PROXY.to_string()).map(|r| { if let Err(err) = r { - error!("Failed to proxy: {:?}", err); + error!("Failed to proxy: {}", err); } }); @@ -89,70 +79,26 @@ async fn main() -> Result<(), ()> { Ok(()) } -/// Read raw packet from stream. -pub async fn read_packet<'a>( - buf: &mut BytesMut, - stream: &mut ReadHalf<'a>, -) -> Result)>, ()> { - // Keep reading until we have at least 2 bytes - while buf.len() < 2 { - // Read packet from socket - let mut tmp = Vec::with_capacity(64); - match stream.read_buf(&mut tmp).await { - Ok(_) => {} - Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None), - Err(err) => { - dbg!(err); - return Err(()); - } +/// Signal handler task. +pub async fn signal_handler(server_state: Arc) { + loop { + tokio::signal::ctrl_c().await.unwrap(); + if !server_state.kill_server() { + // TODO: gracefully kill itself instead + std::process::exit(1) } - - if tmp.is_empty() { - return Ok(None); - } - buf.extend(tmp); } +} - // Attempt to read packet length - let (consumed, len) = match types::read_var_int(&buf) { - Ok(result) => result, - Err(err) => { - error!("Failed to read packet length, should retry!"); - error!("{:?}", (&buf).as_ref()); - return Err(err); - } - }; - - // Keep reading until we have all packet bytes - while buf.len() < consumed + len as usize { - // Read packet from socket - let mut tmp = Vec::with_capacity(64); - match stream.read_buf(&mut tmp).await { - Ok(_) => {} - Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None), - Err(err) => { - dbg!(err); - return Err(()); - } - } - - if tmp.is_empty() { - return Ok(None); - } - - buf.extend(tmp); - } - - // Parse packet - let raw = buf.split_to(consumed + len as usize); - let packet = RawPacket::decode(&raw)?; - - Ok(Some((packet, raw.to_vec()))) +/// Server monitor task. +pub async fn server_monitor(state: Arc) { + let addr = ADDRESS_PROXY.parse().expect("invalid server IP"); + monitor::monitor_server(addr, state).await } /// Proxy the given inbound stream to a target address. // TODO: do not drop error here, return Box -async fn status_server( +async fn serve_status( client: Client, mut inbound: TcpStream, server: Arc, @@ -164,7 +110,7 @@ async fn status_server( loop { // Read packet from stream - let (packet, raw) = match read_packet(&mut buf, &mut reader).await { + let (packet, raw) = match proto::read_packet(&mut buf, &mut reader).await { Ok(Some(packet)) => packet, Ok(None) => break, Err(_) => { @@ -174,9 +120,7 @@ async fn status_server( }; // Hijack login start - if client.state() == ClientState::Login - && packet.id == protocol::LOGIN_PACKET_ID_LOGIN_START - { + if client.state() == ClientState::Login && packet.id == proto::LOGIN_PACKET_ID_LOGIN_START { let packet = LoginDisconnect { reason: Message::new(Payload::text(LABEL_SERVER_STARTING_MESSAGE)), }; @@ -198,9 +142,7 @@ async fn status_server( } // Hijack handshake - if client.state() == ClientState::Handshake - && packet.id == protocol::STATUS_PACKET_ID_STATUS - { + if client.state() == ClientState::Handshake && packet.id == proto::STATUS_PACKET_ID_STATUS { match Handshake::decode(&mut packet.data.as_slice()) { Ok(handshake) => { // TODO: do not panic here @@ -214,7 +156,7 @@ async fn status_server( } // Hijack server status packet - if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_STATUS { + if client.state() == ClientState::Status && packet.id == proto::STATUS_PACKET_ID_STATUS { // Build status response // TODO: grab latest protocol version from online server! let description = if server.starting() { @@ -245,7 +187,7 @@ async fn status_server( } // Hijack ping packet - if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_PING { + if client.state() == ClientState::Status && packet.id == proto::STATUS_PACKET_ID_PING { writer.write_all(&raw).await.map_err(|_| ())?; continue; } @@ -267,10 +209,9 @@ async fn status_server( } /// Proxy the inbound stream to a target address. -// TODO: do not drop error here, return Box async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), Box> { // Set up connection to server - // TODO: on connect fail, ping server and redirect to status_server if offline + // TODO: on connect fail, ping server and redirect to serve_status if offline let mut outbound = TcpStream::connect(addr_target).await?; let (mut ri, mut wi) = inbound.split(); diff --git a/src/monitor.rs b/src/monitor.rs index 8cde794..024f510 100644 --- a/src/monitor.rs +++ b/src/monitor.rs @@ -13,7 +13,7 @@ use rand::Rng; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; -use crate::protocol::{self, ClientState, RawPacket}; +use crate::proto::{self, ClientState, RawPacket}; use crate::server::ServerState; /// Minecraft protocol version used when polling server status. @@ -72,7 +72,7 @@ async fn send_handshake(stream: &mut TcpStream, addr: SocketAddr) -> Result<(), let mut packet = Vec::new(); handshake.encode(&mut packet).map_err(|_| ())?; - let raw = RawPacket::new(protocol::HANDSHAKE_PACKET_ID_HANDSHAKE, packet) + let raw = RawPacket::new(proto::HANDSHAKE_PACKET_ID_HANDSHAKE, packet) .encode() .map_err(|_| ())?; @@ -93,7 +93,7 @@ async fn send_ping(stream: &mut TcpStream) -> Result { let mut packet = Vec::new(); ping.encode(&mut packet).map_err(|_| ())?; - let raw = RawPacket::new(protocol::STATUS_PACKET_ID_PING, packet) + let raw = RawPacket::new(proto::STATUS_PACKET_ID_PING, packet) .encode() .map_err(|_| ())?; @@ -110,14 +110,14 @@ async fn wait_for_ping(stream: &mut TcpStream, token: u64) -> Result<(), ()> { loop { // Read packet from stream - let (packet, _raw) = match crate::read_packet(&mut buf, &mut reader).await { + let (packet, _raw) = match proto::read_packet(&mut buf, &mut reader).await { Ok(Some(packet)) => packet, Ok(None) => break, Err(_) => continue, }; // Catch ping response - if packet.id == protocol::STATUS_PACKET_ID_PING { + if packet.id == proto::STATUS_PACKET_ID_PING { let ping = PingResponse::decode(&mut packet.data.as_slice()).map_err(|_| ())?; // Ensure ping token is correct diff --git a/src/protocol.rs b/src/proto.rs similarity index 59% rename from src/protocol.rs rename to src/proto.rs index f633aec..35f1609 100644 --- a/src/protocol.rs +++ b/src/proto.rs @@ -1,5 +1,10 @@ use std::sync::Mutex; +use bytes::BytesMut; +use tokio::io; +use tokio::io::AsyncReadExt; +use tokio::net::tcp::ReadHalf; + use crate::types; pub const HANDSHAKE_PACKET_ID_HANDSHAKE: i32 = 0; @@ -113,3 +118,63 @@ impl RawPacket { return Ok(packet); } } + +/// Read raw packet from stream. +pub async fn read_packet<'a>( + buf: &mut BytesMut, + stream: &mut ReadHalf<'a>, +) -> Result)>, ()> { + // Keep reading until we have at least 2 bytes + while buf.len() < 2 { + // Read packet from socket + let mut tmp = Vec::with_capacity(64); + match stream.read_buf(&mut tmp).await { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None), + Err(err) => { + dbg!(err); + return Err(()); + } + } + + if tmp.is_empty() { + return Ok(None); + } + buf.extend(tmp); + } + + // Attempt to read packet length + let (consumed, len) = match types::read_var_int(&buf) { + Ok(result) => result, + Err(err) => { + error!("Malformed packet, could not read packet length"); + return Err(err); + } + }; + + // Keep reading until we have all packet bytes + while buf.len() < consumed + len as usize { + // Read packet from socket + let mut tmp = Vec::with_capacity(64); + match stream.read_buf(&mut tmp).await { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None), + Err(err) => { + dbg!(err); + return Err(()); + } + } + + if tmp.is_empty() { + return Ok(None); + } + + buf.extend(tmp); + } + + // Parse packet + let raw = buf.split_to(consumed + len as usize); + let packet = RawPacket::decode(&raw)?; + + Ok(Some((packet, raw.to_vec()))) +}