diff --git a/protocol-derive/src/lib.rs b/protocol-derive/src/lib.rs index e2c4269..d972181 100644 --- a/protocol-derive/src/lib.rs +++ b/protocol-derive/src/lib.rs @@ -1,8 +1,8 @@ extern crate proc_macro; use crate::parse::{parse_derive_input, DeriveInputParseResult}; -use crate::render::decoder::{render_struct_decoder, render_struct_variant_decoder}; -use crate::render::encoder::{render_struct_encoder, render_struct_variant_encoder}; +use crate::render::decoder::{render_enum_decoder, render_struct_decoder}; +use crate::render::encoder::{render_enum_encoder, render_struct_encoder}; use proc_macro::TokenStream; use syn::parse_macro_input; use syn::DeriveInput; @@ -18,9 +18,7 @@ pub fn derive_encoder(tokens: TokenStream) -> TokenStream { TokenStream::from(match derive_parse_result { DeriveInputParseResult::Struct { name, fields } => render_struct_encoder(name, &fields), - DeriveInputParseResult::StructVariant { name, variants } => { - render_struct_variant_encoder(name, &variants) - } + DeriveInputParseResult::Enum { name, variants } => render_enum_encoder(name, &variants), }) } @@ -31,8 +29,6 @@ pub fn derive_decoder(tokens: TokenStream) -> TokenStream { TokenStream::from(match derive_parse_result { DeriveInputParseResult::Struct { name, fields } => render_struct_decoder(name, &fields), - DeriveInputParseResult::StructVariant { name, variants } => { - render_struct_variant_decoder(name, &variants) - } + DeriveInputParseResult::Enum { name, variants } => render_enum_decoder(name, &variants), }) } diff --git a/protocol-derive/src/parse.rs b/protocol-derive/src/parse.rs index 17672e4..b05b110 100644 --- a/protocol-derive/src/parse.rs +++ b/protocol-derive/src/parse.rs @@ -2,23 +2,23 @@ use crate::error::{DeriveInputParserError, FieldError}; use proc_macro2::Ident; use std::iter::FromIterator; use syn::punctuated::Punctuated; -use syn::Token; -use syn::{Data, DeriveInput, Field, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type}; +use syn::{Data, DeriveInput, ExprLit, Field, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type}; use syn::{Error as SynError, Variant}; +use syn::{Expr, Token}; pub(crate) enum DeriveInputParseResult<'a> { Struct { name: &'a Ident, fields: Vec>, }, - StructVariant { + Enum { name: &'a Ident, variants: Vec>, }, } pub(crate) struct VariantData<'a> { - pub(crate) idx: u8, + pub(crate) discriminant: u8, pub(crate) name: &'a Ident, pub(crate) fields: Vec>, } @@ -53,7 +53,7 @@ pub(crate) fn parse_derive_input( Data::Enum(data_enum) => { let variants = parse_variants(&data_enum.variants)?; - Ok(DeriveInputParseResult::StructVariant { name, variants }) + Ok(DeriveInputParseResult::Enum { name, variants }) } _ => Err(DeriveInputParserError::UnsupportedData), } @@ -70,6 +70,7 @@ fn parse_variants( } fn parse_variant(idx: u8, variant: &Variant) -> Result { + let discriminant = parse_variant_discriminant(variant).unwrap_or(idx); let name = &variant.ident; let fields = match &variant.fields { @@ -78,7 +79,24 @@ fn parse_variant(idx: u8, variant: &Variant) -> Result Err(DeriveInputParserError::UnnamedDataFields), }?; - Ok(VariantData { idx, name, fields }) + Ok(VariantData { + discriminant, + name, + fields, + }) +} + +fn parse_variant_discriminant(variant: &Variant) -> Option { + variant + .discriminant + .as_ref() + .and_then(|(_, expr)| match expr { + Expr::Lit(ExprLit { + lit: Lit::Int(lit_int), + .. + }) => lit_int.base10_parse().ok(), + _ => None, + }) } fn parse_fields(named_fields: &FieldsNamed) -> Result, DeriveInputParserError> { diff --git a/protocol-derive/src/render/decoder.rs b/protocol-derive/src/render/decoder.rs index 3381f1b..a726a64 100644 --- a/protocol-derive/src/render/decoder.rs +++ b/protocol-derive/src/render/decoder.rs @@ -24,10 +24,7 @@ pub(crate) fn render_struct_decoder(name: &Ident, fields: &Vec) -> To } } -pub(crate) fn render_struct_variant_decoder( - name: &Ident, - variants: &Vec, -) -> TokenStream2 { +pub(crate) fn render_enum_decoder(name: &Ident, variants: &Vec) -> TokenStream2 { let render_variants = render_variants(variants); quote! { @@ -52,26 +49,37 @@ fn render_variants(variants: &Vec) -> TokenStream2 { } fn render_variant(variant: &VariantData) -> TokenStream2 { - let idx = variant.idx; + if variant.fields.is_empty() { + render_unit_variant(variant) + } else { + render_struct_variant(variant) + } +} + +fn render_unit_variant(variant: &VariantData) -> TokenStream2 { + let discriminant = variant.discriminant; + let name = variant.name; + + quote! { + #discriminant => Ok(Self::#name), + } +} + +fn render_struct_variant(variant: &VariantData) -> TokenStream2 { + let discriminant = variant.discriminant; let name = variant.name; let fields = &variant.fields; - if fields.is_empty() { - quote! { - #idx => Ok(Self::#name), - } - } else { - let field_names_joined_comma = render_field_names_joined_comma(fields); - let render_fields = render_fields(fields); + let field_names_joined_comma = render_field_names_joined_comma(fields); + let render_fields = render_fields(fields); - quote! { - #idx => { - #render_fields + quote! { + #discriminant => { + #render_fields - Ok(Self::#name { - #field_names_joined_comma - }) - } + Ok(Self::#name { + #field_names_joined_comma + }) } } } diff --git a/protocol-derive/src/render/encoder.rs b/protocol-derive/src/render/encoder.rs index 271e5ed..9afe87c 100644 --- a/protocol-derive/src/render/encoder.rs +++ b/protocol-derive/src/render/encoder.rs @@ -18,10 +18,7 @@ pub(crate) fn render_struct_encoder(name: &Ident, fields: &Vec) -> To } } -pub(crate) fn render_struct_variant_encoder( - name: &Ident, - variants: &Vec, -) -> TokenStream2 { +pub(crate) fn render_enum_encoder(name: &Ident, variants: &Vec) -> TokenStream2 { let render_variants = render_variants(variants); quote! { @@ -43,28 +40,39 @@ fn render_variants(variants: &Vec) -> TokenStream2 { } fn render_variant(variant: &VariantData) -> TokenStream2 { - let idx = variant.idx; + if variant.fields.is_empty() { + render_unit_variant(variant) + } else { + render_struct_variant(variant) + } +} + +fn render_unit_variant(variant: &VariantData) -> TokenStream2 { + let discriminant = variant.discriminant; + let name = variant.name; + + quote! { + Self::#name => { + writer.write_u8(#discriminant)?; + } + } +} + +fn render_struct_variant(variant: &VariantData) -> TokenStream2 { + let discriminant = variant.discriminant; let name = variant.name; let fields = &variant.fields; - if fields.is_empty() { - quote! { - Self::#name => { - writer.write_u8(#idx)?; - } - } - } else { - let field_names_joined_comma = render_field_names_joined_comma(fields); - let render_fields = render_fields(fields, false); + let field_names_joined_comma = render_field_names_joined_comma(fields); + let render_fields = render_fields(fields, false); - quote! { - Self::#name { - #field_names_joined_comma - } => { - writer.write_u8(#idx)?; + quote! { + Self::#name { + #field_names_joined_comma + } => { + writer.write_u8(#discriminant)?; - #render_fields - } + #render_fields } } } diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index e0f11b4..ee36759 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -16,6 +16,4 @@ byteorder = "1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" uuid = { version = "0.7", features = ["v4", "serde"] } -num-traits = "0.2" -num-derive = "0.2" named-binary-tag = "0.2" diff --git a/protocol/src/decoder.rs b/protocol/src/decoder.rs index 3208350..0b4f102 100644 --- a/protocol/src/decoder.rs +++ b/protocol/src/decoder.rs @@ -1,7 +1,6 @@ use crate::error::DecodeError; use byteorder::{BigEndian, ReadBytesExt}; use nbt::CompoundTag; -use num_traits::FromPrimitive; use std::io::Read; use uuid::Uuid; @@ -19,8 +18,6 @@ pub trait DecoderReadExt { fn read_byte_array(&mut self) -> Result, DecodeError>; - fn read_enum(&mut self) -> Result; - fn read_compound_tag(&mut self) -> Result; fn read_var_i32(&mut self) -> Result; @@ -86,13 +83,6 @@ impl DecoderReadExt for R { Ok(buf) } - fn read_enum(&mut self) -> Result { - let type_id = self.read_u8()?; - let result = FromPrimitive::from_u8(type_id); - - result.ok_or_else(|| DecodeError::UnknownEnumType { type_id }) - } - fn read_compound_tag(&mut self) -> Result { Ok(nbt::decode::read_compound_tag(self)?) } diff --git a/protocol/src/encoder.rs b/protocol/src/encoder.rs index 36caf1b..45b7ece 100644 --- a/protocol/src/encoder.rs +++ b/protocol/src/encoder.rs @@ -1,7 +1,6 @@ use crate::error::EncodeError; use byteorder::{BigEndian, WriteBytesExt}; use nbt::CompoundTag; -use num_traits::ToPrimitive; use std::io::Write; use uuid::Uuid; @@ -17,8 +16,6 @@ pub trait EncoderWriteExt { fn write_byte_array(&mut self, value: &[u8]) -> Result<(), EncodeError>; - fn write_enum(&mut self, value: &T) -> Result<(), EncodeError>; - fn write_compound_tag(&mut self, value: &CompoundTag) -> Result<(), EncodeError>; fn write_var_i32(&mut self, value: i32) -> Result<(), EncodeError>; @@ -80,13 +77,6 @@ impl EncoderWriteExt for W { Ok(()) } - fn write_enum(&mut self, value: &T) -> Result<(), EncodeError> { - let type_value = ToPrimitive::to_u8(value).unwrap(); - self.write_u8(type_value)?; - - Ok(()) - } - fn write_compound_tag(&mut self, value: &CompoundTag) -> Result<(), EncodeError> { nbt::encode::write_compound_tag(self, value.clone())?; diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index aa92f71..2f42969 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -10,25 +10,6 @@ pub mod version; /// Protocol limits maximum string length. const STRING_MAX_LENGTH: u16 = 32_768; -#[macro_export] -macro_rules! impl_enum_encoder_decoder ( - ($ty: ident) => ( - impl crate::encoder::Encoder for $ty { - fn encode(&self, writer: &mut W) -> Result<(), crate::error::EncodeError> { - Ok(crate::encoder::EncoderWriteExt::write_enum(writer, self)?) - } - } - - impl crate::decoder::Decoder for $ty { - type Output = Self; - - fn decode(reader: &mut R) -> Result { - Ok(crate::decoder::DecoderReadExt::read_enum(reader)?) - } - } - ); -); - #[macro_export] macro_rules! impl_json_encoder_decoder ( ($ty: ident) => ( diff --git a/protocol/src/version/v1_14_4/game.rs b/protocol/src/version/v1_14_4/game.rs index 18c3556..87534f0 100644 --- a/protocol/src/version/v1_14_4/game.rs +++ b/protocol/src/version/v1_14_4/game.rs @@ -1,9 +1,6 @@ -use num_derive::{FromPrimitive, ToPrimitive}; - use crate::data::chat::Message; use crate::decoder::Decoder; use crate::error::DecodeError; -use crate::impl_enum_encoder_decoder; use byteorder::{ReadBytesExt, WriteBytesExt}; use minecraft_protocol_derive::{Decoder, Encoder}; use nbt::CompoundTag; @@ -113,15 +110,13 @@ pub struct ClientBoundChatMessage { pub position: MessagePosition, } -#[derive(Debug, Eq, PartialEq, FromPrimitive, ToPrimitive)] +#[derive(Encoder, Decoder, Debug, Eq, PartialEq)] pub enum MessagePosition { Chat, System, HotBar, } -impl_enum_encoder_decoder!(MessagePosition); - impl ClientBoundChatMessage { pub fn new(message: Message, position: MessagePosition) -> GameClientBoundPacket { let chat_message = ClientBoundChatMessage { message, position }; @@ -143,7 +138,7 @@ pub struct JoinGame { pub reduced_debug_info: bool, } -#[derive(Debug, Eq, PartialEq, FromPrimitive, ToPrimitive)] +#[derive(Encoder, Decoder, Debug, Eq, PartialEq)] pub enum GameMode { Survival = 0, Creative = 1, @@ -152,8 +147,6 @@ pub enum GameMode { Hardcore = 8, } -impl_enum_encoder_decoder!(GameMode); - impl JoinGame { pub fn new( entity_id: u32, @@ -284,7 +277,7 @@ pub enum BossBarAction { }, } -#[derive(Debug, PartialEq, FromPrimitive, ToPrimitive)] +#[derive(Encoder, Decoder, Debug, PartialEq)] pub enum BossBarColor { Pink, Blue, @@ -295,9 +288,7 @@ pub enum BossBarColor { White, } -impl_enum_encoder_decoder!(BossBarColor); - -#[derive(Debug, PartialEq, FromPrimitive, ToPrimitive)] +#[derive(Encoder, Decoder, Debug, PartialEq)] pub enum BossBarDivision { None, Notches6, @@ -306,8 +297,6 @@ pub enum BossBarDivision { Notches20, } -impl_enum_encoder_decoder!(BossBarDivision); - impl BossBar { pub fn new(id: Uuid, action: BossBarAction) -> GameClientBoundPacket { let boss_bar = BossBar { id, action }; @@ -479,7 +468,7 @@ mod tests { fn test_join_game_encode() { let join_game = JoinGame { entity_id: 27, - game_mode: GameMode::Spectator, + game_mode: GameMode::Hardcore, dimension: 23, max_players: 100, level_type: String::from("default"), @@ -503,7 +492,7 @@ mod tests { let join_game = JoinGame::decode(&mut cursor).unwrap(); assert_eq!(join_game.entity_id, 27); - assert_eq!(join_game.game_mode, GameMode::Spectator); + assert_eq!(join_game.game_mode, GameMode::Hardcore); assert_eq!(join_game.dimension, 23); assert_eq!(join_game.max_players, 100); assert_eq!(join_game.level_type, String::from("default")); diff --git a/protocol/test/packet/game/join_game.dat b/protocol/test/packet/game/join_game.dat index e2e4383..8f6e3f2 100644 Binary files a/protocol/test/packet/game/join_game.dat and b/protocol/test/packet/game/join_game.dat differ