diff --git a/protocol-generator/Cargo.toml b/protocol-generator/Cargo.toml index 47169da..a4fd963 100644 --- a/protocol-generator/Cargo.toml +++ b/protocol-generator/Cargo.toml @@ -14,3 +14,4 @@ clap = "2.33.3" serde = "1.0.120" serde_json = "1.0" handlebars = "3.5.2" +heck = "0.3.2" diff --git a/protocol-generator/src/data.rs b/protocol-generator/src/data.rs new file mode 100644 index 0000000..8249979 --- /dev/null +++ b/protocol-generator/src/data.rs @@ -0,0 +1,149 @@ +use serde::Serialize; +use std::fmt; +use std::fmt::Display; + +pub enum State { + Handshake, + Status, + Login, + Game, +} + +impl Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + State::Handshake => "Handshake", + State::Status => "Status", + State::Login => "Login", + State::Game => "Game", + }; + + write!(f, "{}", name) + } +} + +pub enum Bound { + Server, + Client, +} + +impl Display for Bound { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + Bound::Server => "Server", + Bound::Client => "Client", + }; + + write!(f, "{}", name) + } +} + +#[derive(Serialize)] +pub struct Packet { + pub name: String, + pub fields: Vec, +} + +impl Packet { + pub fn new(name: impl ToString, fields: Vec) -> Packet { + Packet { + name: name.to_string(), + fields, + } + } +} + +#[derive(Serialize)] +pub struct Field { + pub name: String, + #[serde(flatten)] + pub data_type: DataType, +} + +impl Field { + pub fn new(name: impl ToString, data_type: DataType) -> Field { + Field { + name: name.to_string(), + data_type, + } + } +} + +#[derive(Serialize, Eq, PartialEq)] +#[serde(tag = "type")] +pub enum DataType { + #[serde(rename(serialize = "bool"))] + Boolean, + #[serde(rename(serialize = "i8"))] + Byte, + #[serde(rename(serialize = "u8"))] + UnsignedByte, + #[serde(rename(serialize = "i16"))] + Short, + #[serde(rename(serialize = "u16"))] + UnsignedShort, + #[serde(rename(serialize = "i32"))] + Int { + var_int: bool, + }, + #[serde(rename(serialize = "i64"))] + Long { + var_long: bool, + }, + #[serde(rename(serialize = "f32"))] + Float, + #[serde(rename(serialize = "f64"))] + Double, + String { + max_length: u16, + }, + Uuid { + hyphenated: bool, + }, + #[serde(rename(serialize = "Vec"))] + ByteArray { + rest: bool, + }, + CompoundTag, + RefType { + ref_name: String, + }, +} + +pub struct Protocol { + pub state: State, + pub server_bound_packets: Vec, + pub client_bound_packets: Vec, +} + +impl Protocol { + pub fn new( + state: State, + server_bound_packets: Vec, + client_bound_packets: Vec, + ) -> Protocol { + Protocol { + state, + server_bound_packets, + client_bound_packets, + } + } + + pub fn contains_field_with_type(&self, data_type: DataType) -> bool { + self.server_bound_packets + .iter() + .chain(self.client_bound_packets.iter()) + .flat_map(|p| p.fields.iter()) + .find(|f| f.data_type == data_type) + .is_some() + } + + pub fn contains_field_with_predicate bool>(&self, fun: F) -> bool { + self.server_bound_packets + .iter() + .chain(self.client_bound_packets.iter()) + .flat_map(|p| p.fields.iter()) + .find(|f| fun(*f)) + .is_some() + } +} diff --git a/protocol-generator/src/main.rs b/protocol-generator/src/main.rs index b6b0d5a..1634df7 100644 --- a/protocol-generator/src/main.rs +++ b/protocol-generator/src/main.rs @@ -1,197 +1,190 @@ -use handlebars::{Handlebars, TemplateRenderError}; -use serde::{Deserialize, Serialize}; +mod data; + +use crate::data::*; +use handlebars::*; +use heck::SnakeCase; +use serde::Serialize; use serde_json::json; -use std::fmt; -use std::fmt::Display; use std::fs::File; -use std::io::{BufWriter, Read}; - -enum State { - Handshake, - Status, - Login, - Game, -} - -impl Display for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let name = match self { - State::Handshake => "Handshake", - State::Status => "Status", - State::Login => "Login", - State::Game => "Game", - }; - - write!(f, "{}", name) - } -} - -enum Bound { - Server, - Client, -} - -impl Display for Bound { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let name = match self { - Bound::Server => "Server", - Bound::Client => "Client", - }; - - write!(f, "{}", name) - } -} - -#[derive(Serialize)] -struct Packet { - name: String, - fields: Vec, -} - -impl Packet { - pub fn new(name: impl ToString, fields: Vec) -> Packet { - Packet { - name: name.to_string(), - fields, - } - } -} - -#[derive(Serialize)] -struct Field { - name: String, - #[serde(rename(serialize = "type"))] - data_type: DataType, -} - -impl Field { - pub fn new(name: impl ToString, data_type: DataType) -> Field { - Field { - name: name.to_string(), - data_type, - } - } -} - -#[derive(Serialize)] -enum DataType { - Boolean, - Byte, - UnsignedByte, - Short, - UnsignedShort, - Int, - Long, - Float, - Double, - String, - Chat, - VarInt, - VarLong, - ByteArray, -} - -struct Protocol { - state: State, - server_bound_packets: Vec, - client_bound_packets: Vec, -} - -impl Protocol { - pub fn new( - state: State, - server_bound_packets: Vec, - client_bound_packets: Vec, - ) -> Protocol { - Protocol { - state, - server_bound_packets, - client_bound_packets, - } - } - - pub fn generate_rust_file( - &self, - template_engine: &Handlebars, - writer: &mut BufWriter<&File>, - ) -> Result<(), TemplateRenderError> { - write_protocol_enum( - writer, - &template_engine, - &self.server_bound_packets, - &Bound::Server, - &self.state, - )?; - - write_protocol_enum( - writer, - &template_engine, - &self.client_bound_packets, - &Bound::Client, - &self.state, - )?; - - Ok(()) - } -} - -fn write_protocol_enum( - writer: &mut BufWriter<&File>, - template_engine: &Handlebars, - packets: &Vec, - bound: &Bound, - state: &State, -) -> Result<(), TemplateRenderError> { - if !packets.is_empty() { - let enum_name = format!("{}{}BoundPacket", state, bound); - - let data = json!({ - "protocol_state_name": enum_name, - "packets": &packets - }); - - template_engine.render_to_write("protocol_state_enum", &data, writer)?; - } - - Ok(()) -} +use std::io::Write; pub fn main() { let mut template_engine = Handlebars::new(); + template_engine.register_helper("snake_case", Box::new(format_snake_case)); + template_engine.register_helper("packet_id", Box::new(format_packet_id)); + template_engine.register_escape_fn(|s| s.to_owned()); + template_engine .register_template_file( - "protocol_state_enum", - "protocol-generator/templates/protocol_state_enum.hbs", + "protocol_imports", + "protocol-generator/templates/protocol_imports.hbs", + ) + .expect("Failed to register template"); + + template_engine + .register_template_file( + "protocol_enum", + "protocol-generator/templates/protocol_enum.hbs", + ) + .expect("Failed to register template"); + + template_engine + .register_template_file( + "protocol_structs", + "protocol-generator/templates/protocol_structs.hbs", ) .expect("Failed to register template"); let protocol = Protocol::new( State::Login, vec![ - Packet::new("LoginStart", vec![Field::new("name", DataType::String)]), + Packet::new( + "LoginStart", + vec![Field::new("name", DataType::String { max_length: 256 })], + ), Packet::new( "EncryptionResponse", vec![ - Field::new("shared_secret", DataType::ByteArray), - Field::new("verify_token", DataType::ByteArray), + Field::new("shared_secret", DataType::ByteArray { rest: true }), + Field::new("verify_token", DataType::ByteArray { rest: true }), + ], + ), + Packet::new( + "LoginPluginResponse", + vec![ + Field::new("message_id", DataType::Int { var_int: true }), + Field::new("successful", DataType::Boolean), + Field::new("data", DataType::ByteArray { rest: true }), ], ), - Packet::new("LoginPluginResponse", vec![]), ], vec![ - Packet::new("LoginDisconnect", vec![]), - Packet::new("EncryptionRequest", vec![]), - Packet::new("LoginSuccess", vec![]), + Packet::new( + "LoginDisconnect", + vec![ + Field::new("hyphenated", DataType::Uuid { hyphenated: true }), + Field::new("default", DataType::Uuid { hyphenated: false }), + ], + ), + Packet::new( + "EncryptionRequest", + vec![Field::new( + "game_mode", + DataType::RefType { + ref_name: "GameMode".to_string(), + }, + )], + ), + Packet::new( + "LoginSuccess", + vec![Field::new( + "server_status", + DataType::RefType { + ref_name: "ServerStatus".to_string(), + }, + )], + ), Packet::new("SetCompression", vec![]), Packet::new("LoginPluginRequest", vec![]), ], ); let file = File::create("login.rs").expect("Failed to create file"); - let mut writer = BufWriter::new(&file); - protocol - .generate_rust_file(&template_engine, &mut writer) - .expect("Failed to generate rust file"); + generate_rust_file(&protocol, &template_engine, &file).expect("Failed to generate rust file"); +} + +#[derive(Serialize)] +struct GenerateContext<'a> { + protocol_enum_name: String, + packets: &'a Vec, +} + +pub fn generate_rust_file( + protocol: &Protocol, + template_engine: &Handlebars, + mut writer: W, +) -> Result<(), TemplateRenderError> { + let server_bound_ctx = GenerateContext { + protocol_enum_name: format!("{}{}BoundPacket", &protocol.state, Bound::Server), + packets: &protocol.server_bound_packets, + }; + + let client_bound_ctx = GenerateContext { + protocol_enum_name: format!("{}{}BoundPacket", &protocol.state, Bound::Client), + packets: &protocol.client_bound_packets, + }; + + let mut imports = vec![ + "crate::DecodeError", + "crate::Decoder", + "std::io::Read", + "minecraft_protocol_derive::Packet", + ]; + + if protocol.contains_field_with_predicate(|f| match f.data_type { + DataType::Uuid { .. } => true, + _ => false, + }) { + imports.push("uuid::Uuid") + } + + if protocol.contains_field_with_type(DataType::CompoundTag) { + imports.push("nbt::CompoundTag") + } + + template_engine.render_to_write( + "protocol_imports", + &json!({ "imports": imports }), + &mut writer, + )?; + + template_engine.render_to_write("protocol_enum", &server_bound_ctx, &mut writer)?; + template_engine.render_to_write("protocol_enum", &client_bound_ctx, &mut writer)?; + + template_engine.render_to_write("protocol_structs", &server_bound_ctx, &mut writer)?; + template_engine.render_to_write("protocol_structs", &client_bound_ctx, &mut writer)?; + + Ok(()) +} + +fn format_snake_case( + h: &Helper, + _: &Handlebars, + _: &Context, + _: &mut RenderContext, + out: &mut dyn Output, +) -> Result<(), RenderError> { + let str = h + .param(0) + .and_then(|v| v.value().as_str()) + .ok_or(RenderError::new( + "Param 0 with str type is required for snake case helper.", + ))? as &str; + + let snake_case_str = str.to_snake_case(); + + out.write(snake_case_str.as_ref())?; + Ok(()) +} + +fn format_packet_id( + h: &Helper, + _: &Handlebars, + _: &Context, + _: &mut RenderContext, + out: &mut dyn Output, +) -> Result<(), RenderError> { + let id = h + .param(0) + .and_then(|v| v.value().as_u64()) + .ok_or(RenderError::new( + "Param 0 with u64 type is required for packet id helper.", + ))? as u64; + + let packet_id_str = format!("{:#04X}", id); + + out.write(packet_id_str.as_ref())?; + Ok(()) } diff --git a/protocol-generator/templates/protocol_enum.hbs b/protocol-generator/templates/protocol_enum.hbs new file mode 100644 index 0000000..7915ecc --- /dev/null +++ b/protocol-generator/templates/protocol_enum.hbs @@ -0,0 +1,50 @@ + +pub enum {{protocol_enum_name}} { +{{~#each packets as |p|}} + {{p.name}}{{#if p.fields}}({{p.name}}){{/if}}{{#unless @last}},{{/unless}} +{{~/each}} +} + +impl {{protocol_enum_name}} { + pub fn get_type_id(&self) -> u8 { + match self { + {{~#each packets as |p|}} + Self::{{p.name}}{{#if p.fields}}(_){{/if}} => {{packet_id @index}}{{#unless @last}},{{/unless}} + {{~/each}} + } + } + + pub fn decode(type_id: u8, reader: &mut R) -> Result { + match type_id { + {{~#each packets as |p|}} + {{@index}} => { + {{~#if p.fields}} + let {{snake_case p.name}} = {{p.name}}::decode(reader)?; + + Ok(Self::{{p.name}}({{snake_case p.name}})) + {{~/if}} + {{~#unless p.fields}} + Ok(Self::{{p.name}}) + {{~/unless}} + } + {{~/each}} + _ => Err(DecodeError::UnknownPacketType { type_id }) + } + } +{{#each packets as |p|}} + pub fn {{snake_case p.name}}({{~#each p.fields as |f|}}{{f.name}}: {{f.type}}{{#unless @last}}, {{/unless}}{{~/each}}) -> Self { + {{~#if p.fields}} + let {{snake_case p.name}} = {{p.name}} { + {{~#each p.fields as |f|}} + {{f.name}}{{#unless @last}},{{/unless}} + {{~/each}} + }; + + Self::{{p.name}}({{snake_case p.name}}) + {{~/if}} + {{~#unless p.fields}} + Self::{{p.name}} + {{~/unless}} + } +{{/each~}} +} diff --git a/protocol-generator/templates/protocol_imports.hbs b/protocol-generator/templates/protocol_imports.hbs new file mode 100644 index 0000000..03e74a9 --- /dev/null +++ b/protocol-generator/templates/protocol_imports.hbs @@ -0,0 +1,3 @@ +{{#each imports as |i|~}} +use {{i}}; +{{/each}} diff --git a/protocol-generator/templates/protocol_state_enum.hbs b/protocol-generator/templates/protocol_state_enum.hbs deleted file mode 100644 index ed31718..0000000 --- a/protocol-generator/templates/protocol_state_enum.hbs +++ /dev/null @@ -1,27 +0,0 @@ - -pub enum {{protocol_state_name}} { -{{~#each packets as |p|}} - {{p.name}}{{#unless @last}},{{/unless}} -{{~/each}} -} - -impl {{protocol_state_name}} { - pub fn get_type_id(&self) -> u8 { - match self { - {{~#each packets as |p|}} - Self::{{p.name}} => {{@index}}{{#unless @last}},{{/unless}} - {{~/each}} - } - } - - pub fn decode(type_id: u8, reader: &mut R) -> Result { - match type_id { - {{~#each packets as |p|}} - {{@index}} => { - Ok(Self::{{p.name}}) - } - {{~/each}} - _ => Err(DecodeError::UnknownPacketType { type_id }) - } - } -} diff --git a/protocol-generator/templates/protocol_structs.hbs b/protocol-generator/templates/protocol_structs.hbs new file mode 100644 index 0000000..c6d4822 --- /dev/null +++ b/protocol-generator/templates/protocol_structs.hbs @@ -0,0 +1,24 @@ +{{~#each packets as |p|}} +{{~#if p.fields}} +#[derive(Packet, Debug)] +pub struct {{p.name}} { +{{~#each p.fields as |f|}} + {{~#if f.var_int}} + #[packet(with = "var_int")]{{/if}} + {{~#if f.var_long}} + #[packet(with = "var_long")]{{/if}} + {{~#if f.rest}} + #[packet(with = "rest")]{{/if}} + {{~#if f.hyphenated}} + #[packet(with = "uuid_hyp_str")]{{/if}} + {{~#if f.max_length}} + #[packet(max_length = {{f.max_length}})]{{/if}} + {{~#if (ne f.type "RefType")}} + pub {{f.name}}: {{f.type}}{{#unless @last}},{{/unless}} + {{~else}} + pub {{f.name}}: {{f.ref_name}}{{#unless @last}},{{/unless}} + {{~/if}} +{{~/each}} +} +{{/if}} +{{~/each}}