diff --git a/protocol-derive/src/error.rs b/protocol-derive/src/error.rs new file mode 100644 index 0000000..16d8787 --- /dev/null +++ b/protocol-derive/src/error.rs @@ -0,0 +1,45 @@ +use syn::Error as SynError; + +/// Possible errors while deriving. +#[derive(Debug)] +pub(crate) enum DeriveInputParserError { + /// Derive attribute must be placed on a structure or enum. + UnsupportedData, + /// Data fields must be named. + UnnamedDataFields, + FieldError { + field_error: FieldError, + }, +} + +/// Possible errors while parsing field. +#[derive(Debug)] +pub(crate) enum FieldError { + /// Failed to parse field meta due incorrect syntax. + BadAttributeSyntax { syn_error: SynError }, + /// Unsupported field attribute type. + UnsupportedAttribute, + /// Field meta has wrong value type. + /// For example an int was expected, but a string was supplied. + AttributeWrongValueType, +} + +impl From for DeriveInputParserError { + fn from(field_error: FieldError) -> Self { + DeriveInputParserError::FieldError { field_error } + } +} + +impl From for DeriveInputParserError { + fn from(syn_error: SynError) -> Self { + DeriveInputParserError::FieldError { + field_error: FieldError::BadAttributeSyntax { syn_error }, + } + } +} + +impl From for FieldError { + fn from(syn_error: SynError) -> Self { + FieldError::BadAttributeSyntax { syn_error } + } +} diff --git a/protocol-derive/src/lib.rs b/protocol-derive/src/lib.rs index ef4bafc..f0e26da 100644 --- a/protocol-derive/src/lib.rs +++ b/protocol-derive/src/lib.rs @@ -1,185 +1,28 @@ extern crate proc_macro; -use proc_macro::TokenStream as TokenStream1; -use proc_macro2::TokenStream as TokenStream2; -use proc_macro2::{Ident, Span}; -use quote::{quote, TokenStreamExt}; -use std::iter::FromIterator; -use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, Lit, Meta, NestedMeta}; +use crate::parse::parse_derive_input; +use crate::render::decoder::render_decoder; +use crate::render::encoder::render_encoder; +use proc_macro::TokenStream; +use syn::parse_macro_input; +use syn::DeriveInput; -#[proc_macro_derive(Packet, attributes(packet))] -pub fn derive_packet(input: proc_macro::TokenStream) -> TokenStream1 { - let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; +mod error; +mod parse; +mod render; - match input.data { - Data::Struct(data) => { - let fields = &data.fields; +#[proc_macro_derive(Encoder, attributes(data_type))] +pub fn derive_encoder(tokens: TokenStream) -> TokenStream { + let input = parse_macro_input!(tokens as DeriveInput); + let (name, fields) = parse_derive_input(&input).expect("Failed to parse derive input"); - let encoder = impl_encoder_trait(name, fields); - let decoder = impl_decoder_trait(name, fields); - - TokenStream1::from(quote! { - #encoder - - #decoder - }) - } - _ => panic!("Expected only structures"), - } + TokenStream::from(render_encoder(name, &fields)) } -fn impl_encoder_trait(name: &Ident, fields: &Fields) -> TokenStream2 { - let encode = quote_field(fields, |field| { - let name = &field.ident; +#[proc_macro_derive(Decoder, attributes(data_type))] +pub fn derive_decoder(tokens: TokenStream) -> TokenStream { + let input = parse_macro_input!(tokens as DeriveInput); + let (name, fields) = parse_derive_input(&input).expect("Failed to parse derive input"); - let unparsed_meta = get_packet_field_meta(field); - let parsed_meta = parse_packet_field_meta(&unparsed_meta); - - // This is special case because max length are used only for strings. - if let Some(max_length) = parsed_meta.max_length { - return quote! { - crate::encoder::EncoderWriteExt::write_string(writer, &self.#name, #max_length)?; - }; - } - - let module = parsed_meta.module.as_deref().unwrap_or("Encoder"); - let module_ident = Ident::new(&module, Span::call_site()); - - quote! { - crate::encoder::#module_ident::encode(&self.#name, writer)?; - } - }); - - quote! { - #[automatically_derived] - impl crate::encoder::Encoder for #name { - fn encode(&self, writer: &mut W) -> Result<(), crate::error::EncodeError> { - #encode - - Ok(()) - } - } - } -} - -fn impl_decoder_trait(name: &Ident, fields: &Fields) -> TokenStream2 { - let decode = quote_field(fields, |field| { - let name = &field.ident; - let ty = &field.ty; - - let unparsed_meta = get_packet_field_meta(field); - let parsed_meta = parse_packet_field_meta(&unparsed_meta); - - // This is special case because max length are used only for strings. - if let Some(max_length) = parsed_meta.max_length { - return quote! { - let #name = crate::decoder::DecoderReadExt::read_string(reader, #max_length)?; - }; - } - - match parsed_meta.module { - Some(module) => { - let module_ident = Ident::new(&module, Span::call_site()); - - quote! { - let #name = crate::decoder::#module_ident::decode(reader)?; - } - } - None => { - quote! { - let #name = <#ty as crate::decoder::Decoder>::decode(reader)?; - } - } - } - }); - - let create = quote_field(fields, |field| { - let name = &field.ident; - - quote! { - #name, - } - }); - - quote! { - #[automatically_derived] - impl crate::decoder::Decoder for #name { - type Output = Self; - - fn decode(reader: &mut R) -> Result { - #decode - - Ok(#name { - #create - }) - } - } - } -} - -#[derive(Debug)] -struct PacketFieldMeta { - module: Option, - max_length: Option, -} - -fn parse_packet_field_meta(meta_list: &Vec) -> PacketFieldMeta { - let mut module = None; - let mut max_length = None; - - for meta in meta_list { - match meta { - NestedMeta::Meta(Meta::NameValue(named_meta)) => match &named_meta.path { - path if path.is_ident("with") => match &named_meta.lit { - Lit::Str(lit_str) => module = Some(lit_str.value()), - _ => panic!("\"with\" attribute value must be string"), - }, - path if path.is_ident("max_length") => match &named_meta.lit { - Lit::Int(lit_int) => { - max_length = Some( - lit_int - .base10_parse::() - .expect("Failed to parse max length attribute"), - ) - } - _ => panic!("\"max_length\" attribute value must be integer"), - }, - path => panic!( - "Received unrecognized attribute : \"{}\"", - path.get_ident().unwrap() - ), - }, - _ => panic!("Expected only named meta values"), - } - } - - PacketFieldMeta { module, max_length } -} - -fn get_packet_field_meta(field: &Field) -> Vec { - field - .attrs - .iter() - .filter(|a| a.path.is_ident("packet")) - .map(|a| a.parse_meta().expect("Failed to parse field attribute")) - .map(|m| match m { - Meta::List(meta_list) => Vec::from_iter(meta_list.nested), - _ => panic!("Expected only list attributes"), - }) - .flatten() - .collect() -} - -fn quote_field TokenStream2>(fields: &Fields, func: F) -> TokenStream2 { - let mut output = quote!(); - - match fields { - Fields::Named(named_fields) => { - output.append_all(named_fields.named.iter().map(|f| func(f))) - } - _ => panic!("Expected only for named fields"), - } - - output + TokenStream::from(render_decoder(name, &fields)) } diff --git a/protocol-derive/src/parse.rs b/protocol-derive/src/parse.rs new file mode 100644 index 0000000..ee466ab --- /dev/null +++ b/protocol-derive/src/parse.rs @@ -0,0 +1,118 @@ +use crate::error::{DeriveInputParserError, FieldError}; +use proc_macro2::Ident; +use std::iter::FromIterator; +use syn::Error as SynError; +use syn::{Data, DeriveInput, Field, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type}; + +pub(crate) struct FieldData<'a> { + pub(crate) name: &'a Ident, + pub(crate) ty: &'a Type, + pub(crate) attribute: Attribute, +} + +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum Attribute { + With { module: String }, + MaxLength { length: usize }, + Empty, +} + +pub(crate) fn parse_derive_input( + input: &DeriveInput, +) -> Result<(&Ident, Vec), DeriveInputParserError> { + let name = &input.ident; + + match &input.data { + Data::Struct(data) => match &data.fields { + Fields::Named(named_fields) => Ok((name, parse_fields(named_fields)?)), + _ => Err(DeriveInputParserError::UnnamedDataFields), + }, + _ => Err(DeriveInputParserError::UnsupportedData), + } +} + +fn parse_fields(named_fields: &FieldsNamed) -> Result, DeriveInputParserError> { + let mut fields_data = Vec::new(); + + for field in named_fields.named.iter() { + let name = field.ident.as_ref().unwrap(); + let ty = &field.ty; + + let nested_metas = parse_field_nested_metas(field)?; + let attribute = parse_attribute(nested_metas)?; + + fields_data.push(FieldData { + name, + ty, + attribute, + }) + } + + Ok(fields_data) +} + +fn parse_field_nested_metas(field: &Field) -> Result, DeriveInputParserError> { + let parsed_metas = field + .attrs + .iter() + .filter(|a| a.path.is_ident("data_type")) + .map(|a| a.parse_meta()) + .collect::, SynError>>()?; + + let nested_metas = parsed_metas + .into_iter() + .map(|m| match m { + Meta::List(meta_list) => Ok(Vec::from_iter(meta_list.nested)), + _ => Err(FieldError::UnsupportedAttribute), + }) + .collect::>, FieldError>>()?; + + Ok(nested_metas.into_iter().flatten().collect()) +} + +fn parse_attribute(nested_metas: Vec) -> Result { + let attribute_parsers: Vec Result> = + vec![get_module_attribute, get_max_length_attribute]; + + for nested_meta in nested_metas.iter() { + for attribute_parser in attribute_parsers.iter() { + let attribute = attribute_parser(nested_meta)?; + + if attribute != Attribute::Empty { + return Ok(attribute); + } + } + } + + Ok(Attribute::Empty) +} + +fn get_module_attribute(nested_meta: &NestedMeta) -> Result { + if let NestedMeta::Meta(Meta::NameValue(named_meta)) = nested_meta { + if matches!(&named_meta.path, path if path.is_ident("with")) { + return match &named_meta.lit { + Lit::Str(lit_str) => Ok(Attribute::With { + module: lit_str.value(), + }), + _ => Err(FieldError::AttributeWrongValueType), + }; + } + } + + Ok(Attribute::Empty) +} + +fn get_max_length_attribute(nested_meta: &NestedMeta) -> Result { + if let NestedMeta::Meta(Meta::NameValue(named_meta)) = nested_meta { + if matches!(&named_meta.path, path if path.is_ident("max_length")) { + return match &named_meta.lit { + Lit::Int(lit_int) => Ok(Attribute::MaxLength { + length: lit_int.base10_parse::()?, + }), + _ => Err(FieldError::AttributeWrongValueType), + }; + } + } + + Ok(Attribute::Empty) +} diff --git a/protocol-derive/src/render/decoder.rs b/protocol-derive/src/render/decoder.rs new file mode 100644 index 0000000..c9fa425 --- /dev/null +++ b/protocol-derive/src/render/decoder.rs @@ -0,0 +1,72 @@ +use crate::parse::{Attribute, FieldData}; +use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Ident, Span}; +use quote::quote; +use syn::Type; + +pub(crate) fn render_decoder(name: &Ident, fields: &Vec) -> TokenStream2 { + let struct_create = render_struct_create(name, fields); + let render_fields = render_fields(fields); + + quote! { + #[automatically_derived] + impl crate::decoder::Decoder for #name { + type Output = Self; + + fn decode(reader: &mut R) -> Result { + #render_fields + + Ok(#struct_create) + } + } + } +} + +fn render_struct_create(name: &Ident, fields: &Vec) -> TokenStream2 { + let struct_fields = fields + .iter() + .map(|f| f.name) + .map(|n| quote!(#n,)) + .collect::(); + + quote! { + #name { + #struct_fields + } + } +} + +fn render_fields(fields: &Vec) -> TokenStream2 { + fields.iter().map(|f| render_field(f)).flatten().collect() +} + +fn render_field(field: &FieldData) -> TokenStream2 { + let name = field.name; + let ty = field.ty; + + match &field.attribute { + Attribute::With { module } => render_with_field(name, module), + Attribute::MaxLength { length } => render_max_length_field(name, *length as u16), + Attribute::Empty => render_simple_field(name, ty), + } +} + +fn render_simple_field(name: &Ident, ty: &Type) -> TokenStream2 { + quote! { + let #name = <#ty as crate::decoder::Decoder>::decode(reader)?; + } +} + +fn render_with_field(name: &Ident, module: &str) -> TokenStream2 { + let module_ident = Ident::new(module, Span::call_site()); + + quote! { + let #name = crate::decoder::#module_ident::decode(reader)?; + } +} + +fn render_max_length_field(name: &Ident, max_length: u16) -> TokenStream2 { + quote! { + let #name = crate::decoder::DecoderReadExt::read_string(reader, #max_length)?; + } +} diff --git a/protocol-derive/src/render/encoder.rs b/protocol-derive/src/render/encoder.rs new file mode 100644 index 0000000..2888f1a --- /dev/null +++ b/protocol-derive/src/render/encoder.rs @@ -0,0 +1,51 @@ +use crate::parse::{Attribute, FieldData}; +use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Ident, Span}; +use quote::quote; + +pub(crate) fn render_encoder(name: &Ident, fields: &Vec) -> TokenStream2 { + let render_fields = render_fields(fields); + + quote! { + #[automatically_derived] + impl crate::encoder::Encoder for #name { + fn encode(&self, writer: &mut W) -> Result<(), crate::error::EncodeError> { + #render_fields + + Ok(()) + } + } + } +} + +fn render_fields(fields: &Vec) -> TokenStream2 { + fields.iter().map(|f| render_field(f)).flatten().collect() +} + +fn render_field(field: &FieldData) -> TokenStream2 { + let name = field.name; + + match &field.attribute { + Attribute::With { module } => render_with_field(name, module), + Attribute::MaxLength { length } => render_max_length_field(name, *length as u16), + Attribute::Empty => render_simple_field(name), + } +} + +fn render_simple_field(name: &Ident) -> TokenStream2 { + render_with_field(name, "Encoder") +} + +fn render_with_field(name: &Ident, module: &str) -> TokenStream2 { + let module_ident = Ident::new(module, Span::call_site()); + + quote! { + crate::encoder::#module_ident::encode(&self.#name, writer)?; + } +} + +fn render_max_length_field(name: &Ident, max_length: u16) -> TokenStream2 { + quote! { + crate::encoder::EncoderWriteExt::write_string(writer, &self.#name, #max_length)?; + } +} diff --git a/protocol-derive/src/render/mod.rs b/protocol-derive/src/render/mod.rs new file mode 100644 index 0000000..325e8d6 --- /dev/null +++ b/protocol-derive/src/render/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod decoder; +pub(crate) mod encoder; diff --git a/protocol/src/version/v1_14_4/game.rs b/protocol/src/version/v1_14_4/game.rs index 860480c..8bdc433 100644 --- a/protocol/src/version/v1_14_4/game.rs +++ b/protocol/src/version/v1_14_4/game.rs @@ -4,7 +4,7 @@ use crate::data::chat::Message; use crate::decoder::Decoder; use crate::error::DecodeError; use crate::impl_enum_encoder_decoder; -use minecraft_protocol_derive::Packet; +use minecraft_protocol_derive::{Decoder, Encoder}; use nbt::CompoundTag; use std::io::Read; @@ -89,9 +89,9 @@ impl GameClientBoundPacket { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct ServerBoundChatMessage { - #[packet(max_length = 256)] + #[data_type(max_length = 256)] pub message: String, } @@ -103,7 +103,7 @@ impl ServerBoundChatMessage { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct ClientBoundChatMessage { pub message: Message, pub position: MessagePosition, @@ -126,15 +126,15 @@ impl ClientBoundChatMessage { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct JoinGame { pub entity_id: u32, pub game_mode: GameMode, pub dimension: i32, pub max_players: u8, - #[packet(max_length = 16)] + #[data_type(max_length = 16)] pub level_type: String, - #[packet(with = "var_int")] + #[data_type(with = "var_int")] pub view_distance: i32, pub reduced_debug_info: bool, } @@ -174,7 +174,7 @@ impl JoinGame { } } -#[derive(Packet)] +#[derive(Encoder, Decoder)] pub struct ServerBoundKeepAlive { pub id: u64, } @@ -187,7 +187,7 @@ impl ServerBoundKeepAlive { } } -#[derive(Packet)] +#[derive(Encoder, Decoder)] pub struct ClientBoundKeepAlive { pub id: u64, } @@ -200,12 +200,12 @@ impl ClientBoundKeepAlive { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct ChunkData { pub x: i32, pub z: i32, pub full: bool, - #[packet(with = "var_int")] + #[data_type(with = "var_int")] pub primary_mask: i32, pub heights: CompoundTag, pub data: Vec, @@ -236,7 +236,7 @@ impl ChunkData { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct GameDisconnect { pub reason: Message, } diff --git a/protocol/src/version/v1_14_4/login.rs b/protocol/src/version/v1_14_4/login.rs index 560cdd9..14248a3 100644 --- a/protocol/src/version/v1_14_4/login.rs +++ b/protocol/src/version/v1_14_4/login.rs @@ -4,7 +4,7 @@ use uuid::Uuid; use crate::data::chat::Message; use crate::decoder::Decoder; use crate::error::DecodeError; -use minecraft_protocol_derive::Packet; +use minecraft_protocol_derive::{Decoder, Encoder}; pub enum LoginServerBoundPacket { LoginStart(LoginStart), @@ -102,7 +102,7 @@ impl LoginClientBoundPacket { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct LoginStart { pub name: String, } @@ -115,7 +115,7 @@ impl LoginStart { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct EncryptionResponse { pub shared_secret: Vec, pub verify_token: Vec, @@ -132,12 +132,12 @@ impl EncryptionResponse { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct LoginPluginResponse { - #[packet(with = "var_int")] + #[data_type(with = "var_int")] pub message_id: i32, pub successful: bool, - #[packet(with = "rest")] + #[data_type(with = "rest")] pub data: Vec, } @@ -153,7 +153,7 @@ impl LoginPluginResponse { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct LoginDisconnect { pub reason: Message, } @@ -166,9 +166,9 @@ impl LoginDisconnect { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct EncryptionRequest { - #[packet(max_length = 20)] + #[data_type(max_length = 20)] pub server_id: String, pub public_key: Vec, pub verify_token: Vec, @@ -190,11 +190,11 @@ impl EncryptionRequest { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct LoginSuccess { - #[packet(with = "uuid_hyp_str")] + #[data_type(with = "uuid_hyp_str")] pub uuid: Uuid, - #[packet(max_length = 16)] + #[data_type(max_length = 16)] pub username: String, } @@ -206,9 +206,9 @@ impl LoginSuccess { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct SetCompression { - #[packet(with = "var_int")] + #[data_type(with = "var_int")] pub threshold: i32, } @@ -220,12 +220,12 @@ impl SetCompression { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct LoginPluginRequest { - #[packet(with = "var_int")] + #[data_type(with = "var_int")] pub message_id: i32, pub channel: String, - #[packet(with = "rest")] + #[data_type(with = "rest")] pub data: Vec, } diff --git a/protocol/src/version/v1_14_4/status.rs b/protocol/src/version/v1_14_4/status.rs index 82cc010..57c281a 100644 --- a/protocol/src/version/v1_14_4/status.rs +++ b/protocol/src/version/v1_14_4/status.rs @@ -1,7 +1,7 @@ use crate::data::server_status::*; use crate::decoder::Decoder; use crate::error::DecodeError; -use minecraft_protocol_derive::Packet; +use minecraft_protocol_derive::{Decoder, Encoder}; use std::io::Read; pub enum StatusServerBoundPacket { @@ -44,7 +44,7 @@ impl StatusClientBoundPacket { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct PingRequest { pub time: u64, } @@ -57,7 +57,7 @@ impl PingRequest { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct PingResponse { pub time: u64, } @@ -70,7 +70,7 @@ impl PingResponse { } } -#[derive(Packet, Debug)] +#[derive(Encoder, Decoder, Debug)] pub struct StatusResponse { pub server_status: ServerStatus, }