diff --git a/Cargo.lock b/Cargo.lock index 515afb0..4bd7264 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -247,7 +247,7 @@ checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" [[package]] name = "minecraft-protocol" version = "0.1.0" -source = "git+https://github.com/eihwaz/minecraft-protocol?rev=09f95625eff272794590e1ef73b2a1b673f47f50#09f95625eff272794590e1ef73b2a1b673f47f50" +source = "git+https://github.com/timvisee/minecraft-protocol?rev=c578492#c57849246166add5ad45ef36b3bdebd8b744883d" dependencies = [ "byteorder", "minecraft-protocol-derive", @@ -260,7 +260,7 @@ dependencies = [ [[package]] name = "minecraft-protocol-derive" version = "0.0.0" -source = "git+https://github.com/eihwaz/minecraft-protocol?rev=09f95625eff272794590e1ef73b2a1b673f47f50#09f95625eff272794590e1ef73b2a1b673f47f50" +source = "git+https://github.com/timvisee/minecraft-protocol?rev=c578492#c57849246166add5ad45ef36b3bdebd8b744883d" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 339b3d6..8d1c18c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,5 @@ edition = "2018" [dependencies] bytes = "1.1" futures = "0.3" -minecraft-protocol = { git = "https://github.com/eihwaz/minecraft-protocol", rev = "09f95625eff272794590e1ef73b2a1b673f47f50" } +minecraft-protocol = { git = "https://github.com/timvisee/minecraft-protocol", rev = "c578492" } tokio = { version = "1", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index cba28b8..08f79b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,19 +9,24 @@ use std::error::Error; use bytes::BytesMut; use futures::future::poll_fn; use futures::FutureExt; +use futures::TryFutureExt; use minecraft_protocol::data::chat::{Message, Payload}; use minecraft_protocol::data::server_status::*; use minecraft_protocol::decoder::Decoder; use minecraft_protocol::encoder::Encoder; +use minecraft_protocol::version::v1_14_4::handshake::Handshake; use minecraft_protocol::version::v1_14_4::status::{PingRequest, PingResponse, StatusResponse}; use tokio::io; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::io::ReadBuf; +use tokio::net::tcp::ReadHalf; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::unbounded_channel; use config::*; -use protocol::{Client, RawPacket}; +use protocol::{Client, ClientState, RawPacket}; #[tokio::main] async fn main() -> Result<(), ()> { @@ -51,6 +56,57 @@ async fn main() -> Result<(), ()> { Ok(()) } +/// Read raw packet from stream. +async fn read_packet<'a>( + buf: &mut BytesMut, + stream: &mut ReadHalf<'a>, +) -> Result)>, ()> { + // // Wait until socket is readable + // if stream.readable().await.is_err() { + // eprintln!("Socket not readable!"); + // return Ok(None); + // } + + // Keep reading until we have at least 2 bytes + while buf.len() < 2 { + // Read packet from socket + let mut tmp = Vec::with_capacity(64); + stream.read_buf(&mut tmp).await; + if tmp.is_empty() { + return Ok(None); + } + buf.extend(tmp); + } + + // Attempt to read packet length + let (consumed, len) = match types::read_var_int(&buf) { + Ok(result) => result, + Err(err) => { + eprintln!("Failed to read packet length, should retry!"); + eprintln!("{:?}", (&buf).as_ref()); + return Err(err); + } + }; + + // Keep reading until we have all packet bytes + while buf.len() < consumed + len as usize { + // Read packet from socket + let mut tmp = Vec::with_capacity(64); + stream.read_buf(&mut tmp).await; + if tmp.is_empty() { + return Ok(None); + } + + buf.extend(tmp); + } + + // Parse packet + let raw = buf.split_to(consumed + len as usize); + let packet = RawPacket::decode(&raw)?; + + Ok(Some((packet, raw.to_vec()))) +} + /// Proxy the given inbound stream to a target address. // TODO: do not drop error here, return Box async fn proxy(mut client: Client, mut inbound: TcpStream, addr_target: String) -> Result<(), ()> { @@ -62,131 +118,186 @@ async fn proxy(mut client: Client, mut inbound: TcpStream, addr_target: String) let (client_send_queue, mut client_to_send) = unbounded_channel::>(); let client_to_server = async { - // Wait for readable state - while ri.readable().await.is_ok() { - // Poll until we have data available - let mut poll_buf = [0; 10]; - let mut poll_buf = ReadBuf::new(&mut poll_buf); - let read = poll_fn(|cx| ri.poll_peek(cx, &mut poll_buf)) - .await - .map_err(|_| ())?; - if read == 0 { - continue; + // Incoming buffer + let mut buf = BytesMut::new(); + + loop { + // In login state, proxy raw data + if client.state() == ClientState::Login { + eprintln!("STARTED FULL PROXY"); + + wo.writable().await; + + // Forward remaining buffer + wo.write_all(&buf).await.map_err(|_| ())?; + buf.clear(); + + // Forward rest of data + io::copy(&mut ri, &mut wo).await.map_err(|_| ())?; + break; } - // Read packet from socket - let mut buf = Vec::with_capacity(64); - let read = ri.try_read_buf(&mut buf).map_err(|_| ())?; - if read == 0 { - continue; - } - - eprintln!("PACKET {:?}", buf.as_slice()); - - match RawPacket::decode(buf.as_mut_slice()) { - Ok(packet) => { - eprintln!("PACKET ID: {}", packet.id); - eprintln!("PACKET DATA: {:?}", packet.data); - - // Hijack server status packet - if packet.id == protocol::STATUS_PACKET_ID_STATUS { - // Build status response - let server_status = ServerStatus { - version: ServerVersion { - name: String::from("1.16.5"), - protocol: 754, - }, - description: Message::new(Payload::text(LABEL_SERVER_SLEEPING)), - players: OnlinePlayers { - online: 0, - max: 0, - sample: vec![], - }, - }; - let server_status = StatusResponse { server_status }; - - let mut data = Vec::new(); - server_status.encode(&mut data).map_err(|_| ())?; - - let response = RawPacket::new(0, data).encode()?; - client_send_queue - .send(response) - .expect("failed to queue status response"); - continue; - } - - // Hijack ping packet - if packet.id == protocol::STATUS_PACKET_ID_PING { - let ping = - PingRequest::decode(&mut packet.data.as_slice()).map_err(|_| ())?; - let response = packet.encode()?; - client_send_queue - .send(response) - .expect("failed to queue ping response"); - continue; - } + // Read packet from stream + let (packet, raw) = match read_packet(&mut buf, &mut ri).await { + Ok(Some(packet)) => packet, + Ok(None) => { + eprintln!("Closing connection, could not read more"); + break; } - Err(err) => { - eprintln!("Failed to parse packet: {:?}", err); - return Err(err); + Err(_) => { + // Forward raw packet to server + wo.write_all(&buf).await.expect("failed to write to server"); + buf.clear(); + continue; + } + }; + + // Show packet details + eprintln!("PACKET {:?}", raw.as_slice()); + eprintln!("PACKET ID: {}", packet.id); + eprintln!("PACKET DATA: {:?}", packet.data); + + // Hijack handshake + if client.state() == ClientState::Handshake + && packet.id == protocol::STATUS_PACKET_ID_STATUS + { + if let Ok(handshake) = Handshake::decode(&mut packet.data.as_slice()) { + eprintln!("# PACKET HANDSHAKE"); + eprintln!("SWITCHING CLIENT STATE: {}", handshake.next_state); + + // TODO: do not panic here + client.set_state( + ClientState::from_id(handshake.next_state) + .expect("unknown next client state"), + ); + } else { + eprintln!("HANDSHAKE ERROR"); } } - // Forward data to server - wo.write_all(&buf).await.expect("failed to write to server"); + // Hijack server status packet + if client.state() == ClientState::Status + && packet.id == protocol::STATUS_PACKET_ID_STATUS + { + eprintln!("# PACKET STATUS"); + + // Build status response + let server_status = ServerStatus { + version: ServerVersion { + name: String::from("1.16.5"), + protocol: 754, + }, + description: Message::new(Payload::text(LABEL_SERVER_SLEEPING)), + players: OnlinePlayers { + online: 0, + max: 0, + sample: vec![], + }, + }; + let server_status = StatusResponse { server_status }; + + let mut data = Vec::new(); + server_status.encode(&mut data).map_err(|_| ())?; + + let response = RawPacket::new(0, data).encode()?; + client_send_queue + .send(response) + .expect("failed to queue status response"); + continue; + } + + // Hijack ping packet + if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_PING + { + eprintln!("# PACKET PING"); + client_send_queue + .send(raw) + .expect("failed to queue ping response"); + continue; + } + + // Forward raw packet to server + wo.write_all(&raw).await.expect("failed to write to server"); } - // io::copy(&mut ri, &mut wo).await?; - wo.shutdown().await.map_err(|_| ()) }; let server_to_client = async { // Server packts to send to client, add to client sending queue let proxy = async { - // Wait for readable state - while ro.readable().await.is_ok() { - // Poll until we have data available - let mut poll_buf = [0; 10]; - let mut poll_buf = ReadBuf::new(&mut poll_buf); - let read = poll_fn(|cx| ro.poll_peek(cx, &mut poll_buf)) - .await - .map_err(|_| ())?; - if read == 0 { - continue; + // Incoming buffer + let mut buf = BytesMut::new(); + + loop { + // In login state, simply proxy all + if client.state() == ClientState::Login { + // if true { + // if true { + eprintln!("STARTED FULL PROXY"); + + // // Wait until socket is readable + // if ro.readable().await.is_err() { + // eprintln!("Socket not readable!"); + // break; + // } + + // Forward remaining data + client_send_queue.send(buf.to_vec()); + buf.clear(); + + // Keep reading until we have at least 2 bytes + loop { + // Read packet from socket + let mut tmp = Vec::new(); + ro.read_buf(&mut tmp).await; + if tmp.is_empty() { + break; + } + client_send_queue.send(tmp); + } + + // Forward raw packet to server + // wi.writable().await; + // io::copy(&mut ro, &mut wi).await.map_err(|_| ())?; + break; } - // Read packet from socket - let mut buf = Vec::with_capacity(64); - let read = ro.try_read_buf(&mut buf).map_err(|_| ())?; - if read == 0 { - continue; - } + // Read packet from stream + let (packet, raw) = match read_packet(&mut buf, &mut ro).await { + Ok(Some(packet)) => packet, + Ok(None) => { + eprintln!("Closing connection, could not read more"); + break; + } + Err(_) => { + // Forward raw packet to server + client_send_queue.send(buf.to_vec()); + continue; + } + }; - client_send_queue.send(buf); + client_send_queue.send(raw); } - // io::copy(&mut ri, &mut wo).await?; - Ok(()) }; // Push client sending queue to client - let other = async { - loop { - let msg = poll_fn(|cx| client_to_send.poll_recv(cx)) - .await - .expect("failed to poll_fn"); + let send_queue = async { + wi.writable().await.map_err(|_| ())?; - wi.write_all(msg.as_ref()) - .await - .expect("failed to write to client"); + while let Some(msg) = client_to_send.recv().await { + // eprintln!("TO CLIENT: {:?}", &msg); + wi.write_all(msg.as_ref()).await.map_err(|_| ())?; } Ok(()) }; - tokio::try_join!(proxy, other)?; + tokio::try_join!(proxy, send_queue)?; + + io::copy(&mut ro, &mut wi).await.map_err(|_| ())?; wi.shutdown().await.map_err(|_| ()) }; diff --git a/src/protocol.rs b/src/protocol.rs index 06b567a..7073e59 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,3 +1,5 @@ +use std::sync::Mutex; + use crate::types; pub const STATUS_PACKET_ID_STATUS: i32 = 0; @@ -8,10 +10,22 @@ pub const STATUS_PACKET_ID_PING: i32 = 1; #[derive(Debug, Default)] pub struct Client { /// Current client state. - pub state: ClientState, + pub state: Mutex, } -#[derive(Debug, Copy, Clone)] +impl Client { + /// Get client state. + pub fn state(&self) -> ClientState { + *self.state.lock().unwrap() + } + + /// Set client state. + pub fn set_state(&self, state: ClientState) { + *self.state.lock().unwrap() = state; + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum ClientState { /// Initial client state. Handshake, @@ -78,6 +92,11 @@ impl RawPacket { let (read, len) = types::read_var_int(buf)?; buf = &buf[read..][..len as usize]; + Self::decode_data(len, buf) + } + + /// Decode packet from raw buffer without the length header. + pub fn decode_data(len: i32, mut buf: &[u8]) -> Result { // Read packet ID, select buf let (read, packet_id) = types::read_var_int(buf)?; buf = &buf[read..];