diff --git a/protocol-derive/src/parse.rs b/protocol-derive/src/parse.rs index 8e52084..6714424 100644 --- a/protocol-derive/src/parse.rs +++ b/protocol-derive/src/parse.rs @@ -3,7 +3,7 @@ use proc_macro2::Ident; use std::iter::FromIterator; use syn::punctuated::Punctuated; use syn::{ - Attribute, Data, DeriveInput, ExprLit, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type, + Attribute, Data, DeriveInput, ExprLit, Field, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type, }; use syn::{Error as SynError, Variant}; use syn::{Expr, Token}; @@ -36,6 +36,7 @@ pub(crate) struct FieldData<'a> { pub(crate) enum AttributeData { With { module: String }, MaxLength { length: usize }, + Bitfield { idx: u8, position: BitfieldPosition }, Empty, } @@ -45,6 +46,13 @@ pub(crate) enum DiscriminantType { VarInt, } +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum BitfieldPosition { + Start, + Intermediate, + End, +} + pub(crate) fn parse_derive_input( input: &DeriveInput, ) -> Result { @@ -77,7 +85,7 @@ fn parse_discriminant_type( attributes: &Vec, ) -> Result { let nested_metas = parse_attributes_nested_metas(attributes)?; - let attribute = parse_attribute(nested_metas)?; + let attribute = parse_attribute(nested_metas, None, 0)?; match attribute { AttributeData::With { module } if module == "var_int" => Ok(DiscriminantType::VarInt), @@ -127,13 +135,26 @@ fn parse_variant_discriminant(variant: &Variant) -> Option { fn parse_fields(named_fields: &FieldsNamed) -> Result, DeriveInputParserError> { let mut fields_data = Vec::new(); + let mut current_bitfield_idx = 0; - for field in named_fields.named.iter() { + let fields: Vec<&Field> = named_fields.named.iter().collect(); + + for (idx, field) in fields.iter().enumerate() { let name = field.ident.as_ref().unwrap(); let ty = &field.ty; let nested_metas = parse_attributes_nested_metas(&field.attrs)?; - let attribute = parse_attribute(nested_metas)?; + + let next_field_opt = fields.get(idx + 1); + let next_nested_metas_opt = next_field_opt + .and_then(|next_field| parse_attributes_nested_metas(&next_field.attrs).ok()); + + let attribute = parse_attribute(nested_metas, next_nested_metas_opt, current_bitfield_idx)?; + + match attribute { + AttributeData::Bitfield { .. } => current_bitfield_idx += 1, + _ => current_bitfield_idx = 0, + } fields_data.push(FieldData { name, @@ -165,12 +186,23 @@ fn parse_attributes_nested_metas( Ok(nested_metas.into_iter().flatten().collect()) } -fn parse_attribute(nested_metas: Vec) -> Result { - let attribute_parsers: Vec Result> = +fn parse_attribute( + nested_metas: Vec, + next_nested_metas_opt: Option>, + current_bitfield_idx: u8, +) -> Result { + let simple_attribute_parsers: Vec Result> = vec![get_module_attribute, get_max_length_attribute]; for nested_meta in nested_metas.iter() { - for attribute_parser in attribute_parsers.iter() { + let bitfield_attribute = + get_bitfield_attribute(current_bitfield_idx, nested_meta, &next_nested_metas_opt); + + if bitfield_attribute != AttributeData::Empty { + return Ok(bitfield_attribute); + } + + for attribute_parser in simple_attribute_parsers.iter() { let attribute = attribute_parser(nested_meta)?; if attribute != AttributeData::Empty { @@ -211,3 +243,49 @@ fn get_max_length_attribute(nested_meta: &NestedMeta) -> Result>, +) -> AttributeData { + if is_bitfield_attribute(nested_meta) { + let position = calc_bitfield_position(current_bitfield_idx, next_nested_metas_opt); + + AttributeData::Bitfield { + idx: current_bitfield_idx, + position, + } + } else { + AttributeData::Empty + } +} + +fn calc_bitfield_position( + current_bitfield_idx: u8, + next_nested_metas_opt: &Option>, +) -> BitfieldPosition { + fn next_has_bitfield_attribute(next_nested_metas: &Vec) -> bool { + next_nested_metas + .iter() + .any(|nested_meta| is_bitfield_attribute(nested_meta)) + } + + match next_nested_metas_opt { + Some(next_nested_metas) if (next_has_bitfield_attribute(&next_nested_metas)) => { + if current_bitfield_idx == 0 { + BitfieldPosition::Start + } else { + BitfieldPosition::Intermediate + } + } + _ => BitfieldPosition::End, + } +} + +fn is_bitfield_attribute(nested_meta: &NestedMeta) -> bool { + match nested_meta { + NestedMeta::Meta(Meta::Path(path)) => path.is_ident("bitfield"), + _ => false, + } +} diff --git a/protocol-derive/src/render/decoder.rs b/protocol-derive/src/render/decoder.rs index 29ec6e9..001f3cc 100644 --- a/protocol-derive/src/render/decoder.rs +++ b/protocol-derive/src/render/decoder.rs @@ -1,4 +1,4 @@ -use crate::parse::{AttributeData, DiscriminantType, FieldData, VariantData}; +use crate::parse::{AttributeData, BitfieldPosition, DiscriminantType, FieldData, VariantData}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Span}; use quote::quote; @@ -140,6 +140,7 @@ fn render_field(field: &FieldData) -> TokenStream2 { match &field.attribute { AttributeData::With { module } => render_with_field(name, module), AttributeData::MaxLength { length } => render_max_length_field(name, *length as u16), + AttributeData::Bitfield { idx, position } => render_bitfield(name, *idx, position), AttributeData::Empty => render_simple_field(name, ty), } } @@ -163,3 +164,22 @@ fn render_max_length_field(name: &Ident, max_length: u16) -> TokenStream2 { let #name = crate::decoder::DecoderReadExt::read_string(reader, #max_length)?; } } + +fn render_bitfield(name: &Ident, idx: u8, position: &BitfieldPosition) -> TokenStream2 { + let mask = 1u8 << idx; + + let render_mask = quote! { + let #name = flags & #mask > 0; + }; + + match position { + BitfieldPosition::Start => { + quote! { + let flags = reader.read_u8()?; + + #render_mask + } + } + _ => render_mask, + } +} diff --git a/protocol-derive/src/render/encoder.rs b/protocol-derive/src/render/encoder.rs index 2490a31..eae74dd 100644 --- a/protocol-derive/src/render/encoder.rs +++ b/protocol-derive/src/render/encoder.rs @@ -1,4 +1,4 @@ -use crate::parse::{AttributeData, DiscriminantType, FieldData, VariantData}; +use crate::parse::{AttributeData, BitfieldPosition, DiscriminantType, FieldData, VariantData}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Span}; use quote::quote; @@ -130,6 +130,7 @@ fn render_field(field: &FieldData, with_self: bool) -> TokenStream2 { AttributeData::MaxLength { length } => { render_max_length_field(name, *length as u16, with_self) } + AttributeData::Bitfield { idx, position } => render_bitfield(name, *idx, position), AttributeData::Empty => render_simple_field(name, with_self), } } @@ -155,6 +156,32 @@ fn render_max_length_field(name: &Ident, max_length: u16, with_self: bool) -> To } } +fn render_bitfield(name: &Ident, idx: u8, position: &BitfieldPosition) -> TokenStream2 { + let mask = 1u8 << idx; + + let render_mask = quote! { + if self.#name { + flags |= #mask; + } + }; + + match position { + BitfieldPosition::Start => quote!( + let mut flags = 0; + + #render_mask + ), + BitfieldPosition::Intermediate => render_mask, + BitfieldPosition::End => { + quote! { + #render_mask + + writer.write_u8(flags)?; + } + } + } +} + fn get_field_final_name(name: &Ident, with_self: bool) -> TokenStream2 { if with_self { quote!(&self.#name) diff --git a/protocol/src/version/v1_14_4/game.rs b/protocol/src/version/v1_14_4/game.rs index 639904a..4400831 100644 --- a/protocol/src/version/v1_14_4/game.rs +++ b/protocol/src/version/v1_14_4/game.rs @@ -12,6 +12,7 @@ use uuid::Uuid; pub enum GameServerBoundPacket { ServerBoundChatMessage(ServerBoundChatMessage), ServerBoundKeepAlive(ServerBoundKeepAlive), + ServerBoundAbilities(ServerBoundAbilities), } pub enum GameClientBoundPacket { @@ -29,6 +30,7 @@ impl GameServerBoundPacket { match self { GameServerBoundPacket::ServerBoundChatMessage(_) => 0x03, GameServerBoundPacket::ServerBoundKeepAlive(_) => 0x0F, + GameServerBoundPacket::ServerBoundAbilities(_) => 0x19, } } @@ -332,6 +334,20 @@ pub enum EntityActionId { StartFlyingWithElytra, } +#[derive(Encoder, Decoder, Debug, PartialEq)] +pub struct ServerBoundAbilities { + #[data_type(bitfield)] + pub invulnerable: bool, + #[data_type(bitfield)] + pub allow_flying: bool, + #[data_type(bitfield)] + pub flying: bool, + #[data_type(bitfield)] + pub creative_mode: bool, + pub fly_speed: f32, + pub walk_speed: f32, +} + #[cfg(test)] mod tests { use crate::data::chat::Payload; @@ -686,4 +702,33 @@ mod tests { } ); } + + #[test] + fn test_serverbound_abilities_encode() { + let abilities = ServerBoundAbilities { + invulnerable: true, + flying: true, + allow_flying: false, + creative_mode: true, + fly_speed: 0.0, + walk_speed: 0.0, + }; + + let mut vec = Vec::new(); + abilities.encode(&mut vec).unwrap(); + + assert_eq!(vec, [13, 0, 0, 0, 0, 0, 0, 0, 0]); + } + + #[test] + fn test_serverbound_abilities_decode() { + let vec = [13, 0, 0, 0, 0, 0, 0, 0, 0].to_vec(); + let mut cursor = Cursor::new(vec); + + let abilities = ServerBoundAbilities::decode(&mut cursor).unwrap(); + assert!(abilities.invulnerable); + assert!(!abilities.allow_flying); + assert!(abilities.flying); + assert!(abilities.creative_mode); + } }