Add basic client state, cleanup

This commit is contained in:
timvisee 2021-11-07 14:53:19 +01:00
parent 923e172d0d
commit efae87af7d
No known key found for this signature in database
GPG Key ID: B8DB720BC383E172
3 changed files with 99 additions and 69 deletions

View File

@ -9,9 +9,11 @@ use std::error::Error;
use bytes::BytesMut;
use futures::future::poll_fn;
use futures::FutureExt;
use minecraft_protocol::data::chat::{Message, Payload};
use minecraft_protocol::data::server_status::*;
use minecraft_protocol::decoder::Decoder;
use minecraft_protocol::encoder::Encoder;
use minecraft_protocol::version::v1_14_4::status::{PingRequest, PingResponse};
use minecraft_protocol::version::v1_14_4::status::{PingRequest, PingResponse, StatusResponse};
use tokio::io;
use tokio::io::AsyncWriteExt;
use tokio::io::ReadBuf;
@ -19,12 +21,14 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::unbounded_channel;
use config::*;
use protocol::RawPacket;
use protocol::{Client, RawPacket};
#[tokio::main]
async fn main() -> Result<(), ()> {
println!("Public address: {}", ADDRESS_PUBLIC);
println!("Proxy address: {}", ADDRESS_PROXY);
println!(
"Proxying public {} to internal {}",
ADDRESS_PUBLIC, ADDRESS_PROXY
);
// Listen for new connections
// TODO: do not drop error here
@ -32,7 +36,10 @@ async fn main() -> Result<(), ()> {
// Proxy all incomming connections
while let Ok((inbound, _)) = listener.accept().await {
let transfer = proxy(inbound, ADDRESS_PROXY.to_string()).map(|r| {
let client = Client::default();
eprintln!("New client");
let transfer = proxy(client, inbound, ADDRESS_PROXY.to_string()).map(|r| {
if let Err(e) = r {
println!("Failed to proxy: {:?}", e);
}
@ -46,8 +53,7 @@ async fn main() -> Result<(), ()> {
/// Proxy the given inbound stream to a target address.
// TODO: do not drop error here, return Box<dyn Error>
async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
// TODO: do not drop error here
async fn proxy(mut client: Client, mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
let mut outbound = TcpStream::connect(addr_target).await.map_err(|_| ())?;
let (mut ri, mut wi) = inbound.split();
@ -61,7 +67,6 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
// Poll until we have data available
let mut poll_buf = [0; 10];
let mut poll_buf = ReadBuf::new(&mut poll_buf);
// TODO: do not drop error here!
let read = poll_fn(|cx| ri.poll_peek(cx, &mut poll_buf))
.await
.map_err(|_| ())?;
@ -69,18 +74,13 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
continue;
}
// TODO: remove
// eprintln!("READ {}", read);
// Read packet from socket
let mut buf = Vec::with_capacity(64);
// TODO: do not drop error here
let read = ri.try_read_buf(&mut buf).map_err(|_| ())?;
if read == 0 {
continue;
}
// PING PACKET TEST
eprintln!("PACKET {:?}", buf.as_slice());
match RawPacket::decode(buf.as_mut_slice()) {
@ -88,14 +88,8 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
eprintln!("PACKET ID: {}", packet.id);
eprintln!("PACKET DATA: {:?}", packet.data);
if packet.id == 0 {
// Catch status packet
eprintln!("PACKET STATUS");
use minecraft_protocol::data::chat::{Message, Payload};
use minecraft_protocol::data::server_status::*;
use minecraft_protocol::version::v1_14_4::status::*;
// Hijack server status packet
if packet.id == protocol::STATUS_PACKET_ID_STATUS {
// Build status response
let server_status = ServerStatus {
version: ServerVersion {
@ -109,56 +103,45 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
sample: vec![],
},
};
let server_status = StatusResponse { server_status };
let status_response = StatusResponse { server_status };
let mut vec = Vec::new();
status_response.encode(&mut vec).unwrap();
let status_packet = RawPacket::new(0, vec);
let response = status_packet.encode()?;
let mut data = Vec::new();
server_status.encode(&mut data).map_err(|_| ())?;
let response = RawPacket::new(0, data).encode()?;
client_send_queue
.send(response)
.expect("failed to queue status response");
continue;
}
if packet.id == 1 {
// Catch ping packet
if let Ok(ping) = PingRequest::decode(&mut packet.data.as_slice()) {
eprintln!("PACKET PING: {}", ping.time);
let response = packet.encode()?;
client_send_queue
.send(response)
.expect("failed to queue ping response");
continue;
} else {
eprintln!("PACKET PING PARSE ERROR!");
}
// Hijack ping packet
if packet.id == protocol::STATUS_PACKET_ID_PING {
let ping =
PingRequest::decode(&mut packet.data.as_slice()).map_err(|_| ())?;
let response = packet.encode()?;
client_send_queue
.send(response)
.expect("failed to queue ping response");
continue;
}
}
Err(()) => eprintln!("ERROR PARSING PACKET"),
Err(err) => {
eprintln!("Failed to parse packet: {:?}", err);
return Err(err);
}
}
// Forward data to server
wo.write_all(&buf).await.expect("failed to write to server");
// io::copy(&mut ri, &mut wo).await?;
}
// io::copy(&mut ri, &mut wo).await?;
// TODO: do not drop error here
wo.shutdown().await.map_err(|_| ())
};
let server_to_client = async {
// let proxy = io::copy(&mut ro, &mut wi);
// Server packts to send to client, add to client sending queue
let proxy = async {
// Wait for readable state
@ -166,7 +149,6 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
// Poll until we have data available
let mut poll_buf = [0; 10];
let mut poll_buf = ReadBuf::new(&mut poll_buf);
// TODO: do not drop error here
let read = poll_fn(|cx| ro.poll_peek(cx, &mut poll_buf))
.await
.map_err(|_| ())?;
@ -174,28 +156,18 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
continue;
}
// TODO: remove
// eprintln!("READ {}", read);
// Read packet from socket
let mut buf = Vec::with_capacity(64);
// TODO: do not drop error here
let read = ro.try_read_buf(&mut buf).map_err(|_| ())?;
if read == 0 {
continue;
}
assert_eq!(buf.len(), read);
client_send_queue.send(buf);
// Forward data to server
// TODO: do not drop error here
// wo.write_all(&buf).await.map_err(|_| ())?;
// io::copy(&mut ri, &mut wo).await?;
}
// io::copy(&mut ri, &mut wo).await?;
Ok(())
};
@ -216,7 +188,6 @@ async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), ()> {
tokio::try_join!(proxy, other)?;
// TODO: do not drop error here
wi.shutdown().await.map_err(|_| ())
};

View File

@ -1,5 +1,60 @@
use crate::types;
pub const STATUS_PACKET_ID_STATUS: i32 = 0;
pub const STATUS_PACKET_ID_PING: i32 = 1;
/// Client state.
// TODO: add encryption/compression state
#[derive(Debug, Default)]
pub struct Client {
/// Current client state.
pub state: ClientState,
}
#[derive(Debug, Copy, Clone)]
pub enum ClientState {
/// Initial client state.
Handshake,
/// State to query server status.
Status,
/// State to login to server.
Login,
/// State for playing.
Play,
}
impl ClientState {
/// From state ID.
pub fn from_id(id: i32) -> Option<Self> {
match id {
// 0 => Self::Handshake,
1 => Some(Self::Status),
2 => Some(Self::Login),
// 2 => Self::Play,
_ => None,
}
}
/// Get state ID.
pub fn to_id(self) -> i32 {
match self {
Self::Handshake => unimplemented!(),
Self::Status => 1,
Self::Login => 2,
Self::Play => unimplemented!(),
}
}
}
impl Default for ClientState {
fn default() -> Self {
Self::Handshake
}
}
/// Raw Minecraft packet.
///
/// Having a packet ID and a raw data byte array.
@ -18,14 +73,14 @@ impl RawPacket {
}
/// Decode packet from raw buffer.
pub fn decode(mut buf: &mut [u8]) -> Result<Self, ()> {
pub fn decode(mut buf: &[u8]) -> Result<Self, ()> {
// Read length
let (read, len) = types::read_var_int(buf)?;
buf = &mut buf[read..][..len as usize];
buf = &buf[read..][..len as usize];
// Read packet ID, select buf
let (read, packet_id) = types::read_var_int(buf)?;
buf = &mut buf[read..];
buf = &buf[read..];
Ok(Self::new(packet_id, buf.to_vec()))
}

View File

@ -1,5 +1,9 @@
/// Try to read var-int from data stream.
pub fn read_var_int(buf: &mut [u8]) -> Result<(usize, i32), ()> {
use std::io::Read;
use bytes::BytesMut;
/// Try to read var-int from data buffer.
pub fn read_var_int(buf: &[u8]) -> Result<(usize, i32), ()> {
for len in 1..=5.min(buf.len()) {
// Find var-int byte size
let extra_byte = (buf[len - 1] & (1 >> 7)) > 0;
@ -8,7 +12,7 @@ pub fn read_var_int(buf: &mut [u8]) -> Result<(usize, i32), ()> {
}
// Select var-int bytes
let buf = &mut buf[..len];
let buf = &buf[..len];
// Parse var-int, return result
return match minecraft_protocol::decoder::var_int::decode(&mut buf.as_ref()) {