diff --git a/Cargo.toml b/Cargo.toml index 2ffac5d..e971bbd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,4 @@ bytes = "1.1" futures = { version = "0.3", default-features = false } minecraft-protocol = { git = "https://github.com/timvisee/minecraft-protocol", rev = "c578492" } rand = "0.8" -tokio = { version = "1", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "macros", "sync", "time"] } +tokio = { version = "1", default-features = false, features = ["rt", "rt-multi-thread", "io-util", "net", "macros", "time"] } diff --git a/src/main.rs b/src/main.rs index 935541f..38ccc65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ pub mod monitor; pub mod protocol; pub mod types; +use std::sync::Arc; + use bytes::BytesMut; use futures::FutureExt; use minecraft_protocol::data::chat::{Message, Payload}; @@ -17,9 +19,9 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::tcp::ReadHalf; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc::unbounded_channel; use config::*; +use monitor::ServerState; use protocol::{Client, ClientState, RawPacket}; #[tokio::main] @@ -37,23 +39,38 @@ async fn main() -> Result<(), ()> { // Spawn server monitor let addr = ADDRESS_PROXY.parse().expect("invalid server IP"); - tokio::spawn(monitor::monitor_server(addr, server_state)); + tokio::spawn(monitor::monitor_server(addr, server_state.clone())); // Proxy all incomming connections while let Ok((inbound, _)) = listener.accept().await { let client = Client::default(); eprintln!("Client connected"); - let transfer = proxy(client, inbound, ADDRESS_PROXY.to_string()).map(|r| { - if let Err(e) = r { - println!("Failed to proxy: {:?}", e); - } + if !server_state.online() { + // When server is not online, spawn a status server + let transfer = status_server(client, inbound, server_state.clone()).map(|r| { + if let Err(e) = r { + println!("Failed to proxy: {:?}", e); + } - // TODO: proxy isn't closed for disconnected clients! - eprintln!("Client disconnected"); - }); + // TODO: proxy isn't closed for disconnected clients! + eprintln!("Client disconnected"); + }); - tokio::spawn(transfer); + tokio::spawn(transfer); + } else { + // When server is online, proxy all + let transfer = proxy(inbound, ADDRESS_PROXY.to_string()).map(|r| { + if let Err(e) = r { + println!("Failed to proxy: {:?}", e); + } + + // TODO: proxy isn't closed for disconnected clients! + eprintln!("Client disconnected"); + }); + + tokio::spawn(transfer); + } } Ok(()) @@ -64,12 +81,6 @@ pub async fn read_packet<'a>( buf: &mut BytesMut, stream: &mut ReadHalf<'a>, ) -> Result<Option<(RawPacket, Vec<u8>)>, ()> { - // // Wait until socket is readable - // if stream.readable().await.is_err() { - // eprintln!("Socket not readable!"); - // return Ok(None); - // } - // Keep reading until we have at least 2 bytes while buf.len() < 2 { // Read packet from socket @@ -112,217 +123,129 @@ pub async fn read_packet<'a>( /// Proxy the given inbound stream to a target address. // TODO: do not drop error here, return Box<dyn Error> -async fn proxy(client: Client, mut inbound: TcpStream, addr_target: String) -> Result<(), ()> { - let mut outbound = TcpStream::connect(addr_target).await.map_err(|_| ())?; +async fn status_server( + client: Client, + mut inbound: TcpStream, + server: Arc<ServerState>, +) -> Result<(), ()> { + let (mut reader, mut writer) = inbound.split(); - let (mut ri, mut wi) = inbound.split(); - let (mut ro, mut wo) = outbound.split(); + // Incoming buffer + let mut buf = BytesMut::new(); - let (client_send_queue, mut client_to_send) = unbounded_channel::<Vec<u8>>(); - - let server_available = true; - - let client_to_server = async { - // Incoming buffer - let mut buf = BytesMut::new(); - - loop { - // In login state, proxy raw data - if server_available && client.state() == ClientState::Login { - eprintln!("STARTED FULL PROXY"); - - wo.writable().await.map_err(|_| ())?; - - // Forward remaining buffer - wo.write_all(&buf).await.map_err(|_| ())?; - buf.clear(); - - // Forward rest of data - io::copy(&mut ri, &mut wo).await.map_err(|_| ())?; + loop { + // Read packet from stream + let (packet, raw) = match read_packet(&mut buf, &mut reader).await { + Ok(Some(packet)) => packet, + Ok(None) => { + eprintln!("Closing connection, could not read more"); break; } + Err(_) => { + eprintln!("Closing connection, error occurred"); + break; + } + }; - // Read packet from stream - let (packet, raw) = match read_packet(&mut buf, &mut ri).await { - Ok(Some(packet)) => packet, - Ok(None) => { - eprintln!("Closing connection, could not read more"); - break; - } - Err(_) => { - // Forward raw packet to server - wo.write_all(&buf).await.expect("failed to write to server"); - buf.clear(); - continue; - } + // Hijack login start + if client.state() == ClientState::Login + && packet.id == protocol::LOGIN_PACKET_ID_LOGIN_START + { + let packet = LoginDisconnect { + reason: Message::new(Payload::text(LABEL_SERVER_STARTING_MESSAGE)), }; - // Show packet details - eprintln!("PACKET {:?}", raw.as_slice()); - eprintln!("PACKET ID: {}", packet.id); - eprintln!("PACKET DATA: {:?}", packet.data); + let mut data = Vec::new(); + packet.encode(&mut data).map_err(|_| ())?; - // Hijack login start - if client.state() == ClientState::Login - && packet.id == protocol::LOGIN_PACKET_ID_LOGIN_START - { - let packet = LoginDisconnect { - reason: Message::new(Payload::text(LABEL_SERVER_STARTING_MESSAGE)), - }; + let response = RawPacket::new(0, data).encode()?; - let mut data = Vec::new(); - packet.encode(&mut data).map_err(|_| ())?; - - let response = RawPacket::new(0, data).encode()?; - client_send_queue - .send(response) - .expect("failed to queue logout response"); - - break; - } - - // Hijack handshake - if client.state() == ClientState::Handshake - && packet.id == protocol::STATUS_PACKET_ID_STATUS - { - if let Ok(handshake) = Handshake::decode(&mut packet.data.as_slice()) { - eprintln!("# PACKET HANDSHAKE"); - eprintln!("SWITCHING CLIENT STATE: {}", handshake.next_state); + writer.write_all(&response).await.map_err(|_| ())?; + break; + } + // Hijack handshake + if client.state() == ClientState::Handshake + && packet.id == protocol::STATUS_PACKET_ID_STATUS + { + match Handshake::decode(&mut packet.data.as_slice()) { + Ok(handshake) => { // TODO: do not panic here client.set_state( ClientState::from_id(handshake.next_state) .expect("unknown next client state"), ); - } else { - eprintln!("HANDSHAKE ERROR"); } + Err(_) => break, } - - // Hijack server status packet - if client.state() == ClientState::Status - && packet.id == protocol::STATUS_PACKET_ID_STATUS - { - eprintln!("# PACKET STATUS"); - - // Build status response - let server_status = ServerStatus { - version: ServerVersion { - name: String::from("1.16.5"), - protocol: 754, - }, - description: Message::new(Payload::text(LABEL_SERVER_SLEEPING)), - players: OnlinePlayers { - online: 0, - max: 0, - sample: vec![], - }, - }; - let packet = StatusResponse { server_status }; - - let mut data = Vec::new(); - packet.encode(&mut data).map_err(|_| ())?; - - let response = RawPacket::new(0, data).encode()?; - client_send_queue - .send(response) - .expect("failed to queue status response"); - continue; - } - - // Hijack ping packet - if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_PING - { - eprintln!("# PACKET PING"); - client_send_queue - .send(raw) - .expect("failed to queue ping response"); - continue; - } - - // Forward raw packet to server - wo.write_all(&raw).await.expect("failed to write to server"); } + // Hijack server status packet + if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_STATUS { + // Build status response + // TODO: grab latest protocol version from online server! + let description = if server.starting() { + LABEL_SERVER_STARTING + } else { + LABEL_SERVER_SLEEPING + }; + let server_status = ServerStatus { + version: ServerVersion { + name: String::from("1.16.5"), + protocol: 754, + }, + description: Message::new(Payload::text(description)), + players: OnlinePlayers { + online: 0, + max: 0, + sample: vec![], + }, + }; + let packet = StatusResponse { server_status }; + + let mut data = Vec::new(); + packet.encode(&mut data).map_err(|_| ())?; + + let response = RawPacket::new(0, data).encode()?; + writer.write_all(&response).await.map_err(|_| ())?; + continue; + } + + // Hijack ping packet + if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_PING { + writer.write_all(&raw).await.map_err(|_| ())?; + continue; + } + + // Show unhandled packet warning + eprintln!("Received unhandled packet:"); + eprintln!("- State: {:?}", client.state()); + eprintln!("- Packet ID: {}", packet.id); + } + + // Gracefully close connection + writer.shutdown().await.map_err(|_| ())?; + + Ok(()) +} + +/// Proxy the inbound stream to a target address. +// TODO: do not drop error here, return Box<dyn Error> +async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> { + let mut outbound = TcpStream::connect(addr_target).await.map_err(|_| ())?; + + // TODO: on connect fail, ping server and redirect to status_server if offline + + let (mut ri, mut wi) = inbound.split(); + let (mut ro, mut wo) = outbound.split(); + + let client_to_server = async { + io::copy(&mut ri, &mut wo).await.map_err(|_| ())?; wo.shutdown().await.map_err(|_| ()) }; let server_to_client = async { - // Server packts to send to client, add to client sending queue - let proxy = async { - // Incoming buffer - let mut buf = BytesMut::new(); - - loop { - // In login state, simply proxy all - if client.state() == ClientState::Login { - // if true { - // if true { - eprintln!("STARTED FULL PROXY"); - - // // Wait until socket is readable - // if ro.readable().await.is_err() { - // eprintln!("Socket not readable!"); - // break; - // } - - // Forward remaining data - client_send_queue.send(buf.to_vec()).map_err(|_| ())?; - buf.clear(); - - // Keep reading until we have at least 2 bytes - loop { - // Read packet from socket - let mut tmp = Vec::new(); - ro.read_buf(&mut tmp).await.map_err(|_| ())?; - if tmp.is_empty() { - break; - } - client_send_queue.send(tmp).map_err(|_| ())?; - } - - // Forward raw packet to server - // wi.writable().await; - // io::copy(&mut ro, &mut wi).await.map_err(|_| ())?; - break; - } - - // Read packet from stream - let (_packet, raw) = match read_packet(&mut buf, &mut ro).await { - Ok(Some(packet)) => packet, - Ok(None) => { - eprintln!("Closing connection, could not read more"); - break; - } - Err(_) => { - // Forward raw packet to server - client_send_queue.send(buf.to_vec()).map_err(|_| ())?; - continue; - } - }; - - client_send_queue.send(raw).map_err(|_| ())?; - } - - Ok(()) - }; - - // Push client sending queue to client - let send_queue = async { - wi.writable().await.map_err(|_| ())?; - - while let Some(msg) = client_to_send.recv().await { - // eprintln!("TO CLIENT: {:?}", &msg); - wi.write_all(msg.as_ref()).await.map_err(|_| ())?; - } - - Ok(()) - }; - - tokio::try_join!(proxy, send_queue)?; - io::copy(&mut ro, &mut wi).await.map_err(|_| ())?; - wi.shutdown().await.map_err(|_| ()) }; diff --git a/src/monitor.rs b/src/monitor.rs index f21508f..51c618d 100644 --- a/src/monitor.rs +++ b/src/monitor.rs @@ -72,7 +72,7 @@ pub async fn poll_server(addr: SocketAddr) -> bool { /// Monitor server. pub async fn monitor_server(addr: SocketAddr, state: Arc<ServerState>) { loop { - eprint!("Polling {}: ", addr); + eprintln!("Polling {} ... ", addr); let online = poll_server(addr).await; state.set_online(online);