This commit is contained in:
timvisee 2021-11-07 23:40:32 +01:00
parent 1de68e7335
commit 6ed72b7adb
No known key found for this signature in database
GPG Key ID: B8DB720BC383E172
3 changed files with 104 additions and 98 deletions

View File

@ -3,7 +3,7 @@ extern crate log;
pub(crate) mod config;
pub(crate) mod monitor;
pub(crate) mod protocol;
pub(crate) mod proto;
pub(crate) mod server;
pub(crate) mod types;
@ -20,13 +20,11 @@ use minecraft_protocol::version::v1_14_4::handshake::Handshake;
use minecraft_protocol::version::v1_14_4::login::LoginDisconnect;
use minecraft_protocol::version::v1_14_4::status::StatusResponse;
use tokio::io;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::ReadHalf;
use tokio::net::{TcpListener, TcpStream};
use config::*;
use protocol::{Client, ClientState, RawPacket};
use proto::{Client, ClientState, RawPacket};
use server::ServerState;
#[tokio::main]
@ -35,31 +33,23 @@ async fn main() -> Result<(), ()> {
let _ = dotenv::dotenv();
pretty_env_logger::init();
info!(
"Proxying public {} to internal {}",
ADDRESS_PUBLIC, ADDRESS_PROXY,
);
let server_state = Arc::new(ServerState::default());
// Listen for new connections
// TODO: do not drop error here
let listener = TcpListener::bind(ADDRESS_PUBLIC).await.map_err(|_| ())?;
let listener = TcpListener::bind(ADDRESS_PUBLIC).await.map_err(|err| {
error!("Failed to start: {}", err);
()
})?;
// Spawn server monitor
let addr = ADDRESS_PROXY.parse().expect("invalid server IP");
tokio::spawn(monitor::monitor_server(addr, server_state.clone()));
info!(
"Proxying egress {} to ingress {}",
ADDRESS_PUBLIC, ADDRESS_PROXY,
);
let sub = server_state.clone();
tokio::spawn(async move {
loop {
tokio::signal::ctrl_c().await.unwrap();
if !sub.kill_server() {
// TODO: gracefully kill itself instead
std::process::exit(1)
}
}
});
// Spawn server monitor and signal handler
tokio::spawn(server_monitor(server_state.clone()));
tokio::spawn(signal_handler(server_state.clone()));
// Proxy all incomming connections
while let Ok((inbound, _)) = listener.accept().await {
@ -67,7 +57,7 @@ async fn main() -> Result<(), ()> {
if !server_state.online() {
// When server is not online, spawn a status server
let transfer = status_server(client, inbound, server_state.clone()).map(|r| {
let transfer = serve_status(client, inbound, server_state.clone()).map(|r| {
if let Err(err) = r {
error!("Failed to serve status: {:?}", err);
}
@ -78,7 +68,7 @@ async fn main() -> Result<(), ()> {
// When server is online, proxy all
let transfer = proxy(inbound, ADDRESS_PROXY.to_string()).map(|r| {
if let Err(err) = r {
error!("Failed to proxy: {:?}", err);
error!("Failed to proxy: {}", err);
}
});
@ -89,70 +79,26 @@ async fn main() -> Result<(), ()> {
Ok(())
}
/// Read raw packet from stream.
pub async fn read_packet<'a>(
buf: &mut BytesMut,
stream: &mut ReadHalf<'a>,
) -> Result<Option<(RawPacket, Vec<u8>)>, ()> {
// Keep reading until we have at least 2 bytes
while buf.len() < 2 {
// Read packet from socket
let mut tmp = Vec::with_capacity(64);
match stream.read_buf(&mut tmp).await {
Ok(_) => {}
Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None),
Err(err) => {
dbg!(err);
return Err(());
}
/// Signal handler task.
pub async fn signal_handler(server_state: Arc<ServerState>) {
loop {
tokio::signal::ctrl_c().await.unwrap();
if !server_state.kill_server() {
// TODO: gracefully kill itself instead
std::process::exit(1)
}
if tmp.is_empty() {
return Ok(None);
}
buf.extend(tmp);
}
}
// Attempt to read packet length
let (consumed, len) = match types::read_var_int(&buf) {
Ok(result) => result,
Err(err) => {
error!("Failed to read packet length, should retry!");
error!("{:?}", (&buf).as_ref());
return Err(err);
}
};
// Keep reading until we have all packet bytes
while buf.len() < consumed + len as usize {
// Read packet from socket
let mut tmp = Vec::with_capacity(64);
match stream.read_buf(&mut tmp).await {
Ok(_) => {}
Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None),
Err(err) => {
dbg!(err);
return Err(());
}
}
if tmp.is_empty() {
return Ok(None);
}
buf.extend(tmp);
}
// Parse packet
let raw = buf.split_to(consumed + len as usize);
let packet = RawPacket::decode(&raw)?;
Ok(Some((packet, raw.to_vec())))
/// Server monitor task.
pub async fn server_monitor(state: Arc<ServerState>) {
let addr = ADDRESS_PROXY.parse().expect("invalid server IP");
monitor::monitor_server(addr, state).await
}
/// Proxy the given inbound stream to a target address.
// TODO: do not drop error here, return Box<dyn Error>
async fn status_server(
async fn serve_status(
client: Client,
mut inbound: TcpStream,
server: Arc<ServerState>,
@ -164,7 +110,7 @@ async fn status_server(
loop {
// Read packet from stream
let (packet, raw) = match read_packet(&mut buf, &mut reader).await {
let (packet, raw) = match proto::read_packet(&mut buf, &mut reader).await {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(_) => {
@ -174,9 +120,7 @@ async fn status_server(
};
// Hijack login start
if client.state() == ClientState::Login
&& packet.id == protocol::LOGIN_PACKET_ID_LOGIN_START
{
if client.state() == ClientState::Login && packet.id == proto::LOGIN_PACKET_ID_LOGIN_START {
let packet = LoginDisconnect {
reason: Message::new(Payload::text(LABEL_SERVER_STARTING_MESSAGE)),
};
@ -198,9 +142,7 @@ async fn status_server(
}
// Hijack handshake
if client.state() == ClientState::Handshake
&& packet.id == protocol::STATUS_PACKET_ID_STATUS
{
if client.state() == ClientState::Handshake && packet.id == proto::STATUS_PACKET_ID_STATUS {
match Handshake::decode(&mut packet.data.as_slice()) {
Ok(handshake) => {
// TODO: do not panic here
@ -214,7 +156,7 @@ async fn status_server(
}
// Hijack server status packet
if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_STATUS {
if client.state() == ClientState::Status && packet.id == proto::STATUS_PACKET_ID_STATUS {
// Build status response
// TODO: grab latest protocol version from online server!
let description = if server.starting() {
@ -245,7 +187,7 @@ async fn status_server(
}
// Hijack ping packet
if client.state() == ClientState::Status && packet.id == protocol::STATUS_PACKET_ID_PING {
if client.state() == ClientState::Status && packet.id == proto::STATUS_PACKET_ID_PING {
writer.write_all(&raw).await.map_err(|_| ())?;
continue;
}
@ -267,10 +209,9 @@ async fn status_server(
}
/// Proxy the inbound stream to a target address.
// TODO: do not drop error here, return Box<dyn Error>
async fn proxy(mut inbound: TcpStream, addr_target: String) -> Result<(), Box<dyn Error>> {
// Set up connection to server
// TODO: on connect fail, ping server and redirect to status_server if offline
// TODO: on connect fail, ping server and redirect to serve_status if offline
let mut outbound = TcpStream::connect(addr_target).await?;
let (mut ri, mut wi) = inbound.split();

View File

@ -13,7 +13,7 @@ use rand::Rng;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use crate::protocol::{self, ClientState, RawPacket};
use crate::proto::{self, ClientState, RawPacket};
use crate::server::ServerState;
/// Minecraft protocol version used when polling server status.
@ -72,7 +72,7 @@ async fn send_handshake(stream: &mut TcpStream, addr: SocketAddr) -> Result<(),
let mut packet = Vec::new();
handshake.encode(&mut packet).map_err(|_| ())?;
let raw = RawPacket::new(protocol::HANDSHAKE_PACKET_ID_HANDSHAKE, packet)
let raw = RawPacket::new(proto::HANDSHAKE_PACKET_ID_HANDSHAKE, packet)
.encode()
.map_err(|_| ())?;
@ -93,7 +93,7 @@ async fn send_ping(stream: &mut TcpStream) -> Result<u64, ()> {
let mut packet = Vec::new();
ping.encode(&mut packet).map_err(|_| ())?;
let raw = RawPacket::new(protocol::STATUS_PACKET_ID_PING, packet)
let raw = RawPacket::new(proto::STATUS_PACKET_ID_PING, packet)
.encode()
.map_err(|_| ())?;
@ -110,14 +110,14 @@ async fn wait_for_ping(stream: &mut TcpStream, token: u64) -> Result<(), ()> {
loop {
// Read packet from stream
let (packet, _raw) = match crate::read_packet(&mut buf, &mut reader).await {
let (packet, _raw) = match proto::read_packet(&mut buf, &mut reader).await {
Ok(Some(packet)) => packet,
Ok(None) => break,
Err(_) => continue,
};
// Catch ping response
if packet.id == protocol::STATUS_PACKET_ID_PING {
if packet.id == proto::STATUS_PACKET_ID_PING {
let ping = PingResponse::decode(&mut packet.data.as_slice()).map_err(|_| ())?;
// Ensure ping token is correct

View File

@ -1,5 +1,10 @@
use std::sync::Mutex;
use bytes::BytesMut;
use tokio::io;
use tokio::io::AsyncReadExt;
use tokio::net::tcp::ReadHalf;
use crate::types;
pub const HANDSHAKE_PACKET_ID_HANDSHAKE: i32 = 0;
@ -113,3 +118,63 @@ impl RawPacket {
return Ok(packet);
}
}
/// Read raw packet from stream.
pub async fn read_packet<'a>(
buf: &mut BytesMut,
stream: &mut ReadHalf<'a>,
) -> Result<Option<(RawPacket, Vec<u8>)>, ()> {
// Keep reading until we have at least 2 bytes
while buf.len() < 2 {
// Read packet from socket
let mut tmp = Vec::with_capacity(64);
match stream.read_buf(&mut tmp).await {
Ok(_) => {}
Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None),
Err(err) => {
dbg!(err);
return Err(());
}
}
if tmp.is_empty() {
return Ok(None);
}
buf.extend(tmp);
}
// Attempt to read packet length
let (consumed, len) = match types::read_var_int(&buf) {
Ok(result) => result,
Err(err) => {
error!("Malformed packet, could not read packet length");
return Err(err);
}
};
// Keep reading until we have all packet bytes
while buf.len() < consumed + len as usize {
// Read packet from socket
let mut tmp = Vec::with_capacity(64);
match stream.read_buf(&mut tmp).await {
Ok(_) => {}
Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(None),
Err(err) => {
dbg!(err);
return Err(());
}
}
if tmp.is_empty() {
return Ok(None);
}
buf.extend(tmp);
}
// Parse packet
let raw = buf.split_to(consumed + len as usize);
let packet = RawPacket::decode(&raw)?;
Ok(Some((packet, raw.to_vec())))
}