Refactor protocol derive module (#11)

This commit is contained in:
vagola 2021-09-16 03:09:10 +03:00 committed by GitHub
parent f733fabedc
commit 70bfd01848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 340 additions and 209 deletions

View File

@ -0,0 +1,45 @@
use syn::Error as SynError;
/// Possible errors while deriving.
#[derive(Debug)]
pub(crate) enum DeriveInputParserError {
/// Derive attribute must be placed on a structure or enum.
UnsupportedData,
/// Data fields must be named.
UnnamedDataFields,
FieldError {
field_error: FieldError,
},
}
/// Possible errors while parsing field.
#[derive(Debug)]
pub(crate) enum FieldError {
/// Failed to parse field meta due incorrect syntax.
BadAttributeSyntax { syn_error: SynError },
/// Unsupported field attribute type.
UnsupportedAttribute,
/// Field meta has wrong value type.
/// For example an int was expected, but a string was supplied.
AttributeWrongValueType,
}
impl From<FieldError> for DeriveInputParserError {
fn from(field_error: FieldError) -> Self {
DeriveInputParserError::FieldError { field_error }
}
}
impl From<SynError> for DeriveInputParserError {
fn from(syn_error: SynError) -> Self {
DeriveInputParserError::FieldError {
field_error: FieldError::BadAttributeSyntax { syn_error },
}
}
}
impl From<SynError> for FieldError {
fn from(syn_error: SynError) -> Self {
FieldError::BadAttributeSyntax { syn_error }
}
}

View File

@ -1,185 +1,28 @@
extern crate proc_macro;
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Ident, Span};
use quote::{quote, TokenStreamExt};
use std::iter::FromIterator;
use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, Lit, Meta, NestedMeta};
use crate::parse::parse_derive_input;
use crate::render::decoder::render_decoder;
use crate::render::encoder::render_encoder;
use proc_macro::TokenStream;
use syn::parse_macro_input;
use syn::DeriveInput;
#[proc_macro_derive(Packet, attributes(packet))]
pub fn derive_packet(input: proc_macro::TokenStream) -> TokenStream1 {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
mod error;
mod parse;
mod render;
match input.data {
Data::Struct(data) => {
let fields = &data.fields;
#[proc_macro_derive(Encoder, attributes(data_type))]
pub fn derive_encoder(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as DeriveInput);
let (name, fields) = parse_derive_input(&input).expect("Failed to parse derive input");
let encoder = impl_encoder_trait(name, fields);
let decoder = impl_decoder_trait(name, fields);
TokenStream1::from(quote! {
#encoder
#decoder
})
}
_ => panic!("Expected only structures"),
}
TokenStream::from(render_encoder(name, &fields))
}
fn impl_encoder_trait(name: &Ident, fields: &Fields) -> TokenStream2 {
let encode = quote_field(fields, |field| {
let name = &field.ident;
#[proc_macro_derive(Decoder, attributes(data_type))]
pub fn derive_decoder(tokens: TokenStream) -> TokenStream {
let input = parse_macro_input!(tokens as DeriveInput);
let (name, fields) = parse_derive_input(&input).expect("Failed to parse derive input");
let unparsed_meta = get_packet_field_meta(field);
let parsed_meta = parse_packet_field_meta(&unparsed_meta);
// This is special case because max length are used only for strings.
if let Some(max_length) = parsed_meta.max_length {
return quote! {
crate::encoder::EncoderWriteExt::write_string(writer, &self.#name, #max_length)?;
};
}
let module = parsed_meta.module.as_deref().unwrap_or("Encoder");
let module_ident = Ident::new(&module, Span::call_site());
quote! {
crate::encoder::#module_ident::encode(&self.#name, writer)?;
}
});
quote! {
#[automatically_derived]
impl crate::encoder::Encoder for #name {
fn encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), crate::error::EncodeError> {
#encode
Ok(())
}
}
}
}
fn impl_decoder_trait(name: &Ident, fields: &Fields) -> TokenStream2 {
let decode = quote_field(fields, |field| {
let name = &field.ident;
let ty = &field.ty;
let unparsed_meta = get_packet_field_meta(field);
let parsed_meta = parse_packet_field_meta(&unparsed_meta);
// This is special case because max length are used only for strings.
if let Some(max_length) = parsed_meta.max_length {
return quote! {
let #name = crate::decoder::DecoderReadExt::read_string(reader, #max_length)?;
};
}
match parsed_meta.module {
Some(module) => {
let module_ident = Ident::new(&module, Span::call_site());
quote! {
let #name = crate::decoder::#module_ident::decode(reader)?;
}
}
None => {
quote! {
let #name = <#ty as crate::decoder::Decoder>::decode(reader)?;
}
}
}
});
let create = quote_field(fields, |field| {
let name = &field.ident;
quote! {
#name,
}
});
quote! {
#[automatically_derived]
impl crate::decoder::Decoder for #name {
type Output = Self;
fn decode<R: std::io::Read>(reader: &mut R) -> Result<Self::Output, crate::error::DecodeError> {
#decode
Ok(#name {
#create
})
}
}
}
}
#[derive(Debug)]
struct PacketFieldMeta {
module: Option<String>,
max_length: Option<u16>,
}
fn parse_packet_field_meta(meta_list: &Vec<NestedMeta>) -> PacketFieldMeta {
let mut module = None;
let mut max_length = None;
for meta in meta_list {
match meta {
NestedMeta::Meta(Meta::NameValue(named_meta)) => match &named_meta.path {
path if path.is_ident("with") => match &named_meta.lit {
Lit::Str(lit_str) => module = Some(lit_str.value()),
_ => panic!("\"with\" attribute value must be string"),
},
path if path.is_ident("max_length") => match &named_meta.lit {
Lit::Int(lit_int) => {
max_length = Some(
lit_int
.base10_parse::<u16>()
.expect("Failed to parse max length attribute"),
)
}
_ => panic!("\"max_length\" attribute value must be integer"),
},
path => panic!(
"Received unrecognized attribute : \"{}\"",
path.get_ident().unwrap()
),
},
_ => panic!("Expected only named meta values"),
}
}
PacketFieldMeta { module, max_length }
}
fn get_packet_field_meta(field: &Field) -> Vec<NestedMeta> {
field
.attrs
.iter()
.filter(|a| a.path.is_ident("packet"))
.map(|a| a.parse_meta().expect("Failed to parse field attribute"))
.map(|m| match m {
Meta::List(meta_list) => Vec::from_iter(meta_list.nested),
_ => panic!("Expected only list attributes"),
})
.flatten()
.collect()
}
fn quote_field<F: Fn(&Field) -> TokenStream2>(fields: &Fields, func: F) -> TokenStream2 {
let mut output = quote!();
match fields {
Fields::Named(named_fields) => {
output.append_all(named_fields.named.iter().map(|f| func(f)))
}
_ => panic!("Expected only for named fields"),
}
output
TokenStream::from(render_decoder(name, &fields))
}

View File

@ -0,0 +1,118 @@
use crate::error::{DeriveInputParserError, FieldError};
use proc_macro2::Ident;
use std::iter::FromIterator;
use syn::Error as SynError;
use syn::{Data, DeriveInput, Field, Fields, FieldsNamed, Lit, Meta, NestedMeta, Type};
pub(crate) struct FieldData<'a> {
pub(crate) name: &'a Ident,
pub(crate) ty: &'a Type,
pub(crate) attribute: Attribute,
}
#[derive(Debug, Eq, PartialEq)]
pub(crate) enum Attribute {
With { module: String },
MaxLength { length: usize },
Empty,
}
pub(crate) fn parse_derive_input(
input: &DeriveInput,
) -> Result<(&Ident, Vec<FieldData>), DeriveInputParserError> {
let name = &input.ident;
match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named_fields) => Ok((name, parse_fields(named_fields)?)),
_ => Err(DeriveInputParserError::UnnamedDataFields),
},
_ => Err(DeriveInputParserError::UnsupportedData),
}
}
fn parse_fields(named_fields: &FieldsNamed) -> Result<Vec<FieldData>, DeriveInputParserError> {
let mut fields_data = Vec::new();
for field in named_fields.named.iter() {
let name = field.ident.as_ref().unwrap();
let ty = &field.ty;
let nested_metas = parse_field_nested_metas(field)?;
let attribute = parse_attribute(nested_metas)?;
fields_data.push(FieldData {
name,
ty,
attribute,
})
}
Ok(fields_data)
}
fn parse_field_nested_metas(field: &Field) -> Result<Vec<NestedMeta>, DeriveInputParserError> {
let parsed_metas = field
.attrs
.iter()
.filter(|a| a.path.is_ident("data_type"))
.map(|a| a.parse_meta())
.collect::<Result<Vec<Meta>, SynError>>()?;
let nested_metas = parsed_metas
.into_iter()
.map(|m| match m {
Meta::List(meta_list) => Ok(Vec::from_iter(meta_list.nested)),
_ => Err(FieldError::UnsupportedAttribute),
})
.collect::<Result<Vec<Vec<NestedMeta>>, FieldError>>()?;
Ok(nested_metas.into_iter().flatten().collect())
}
fn parse_attribute(nested_metas: Vec<NestedMeta>) -> Result<Attribute, DeriveInputParserError> {
let attribute_parsers: Vec<fn(&NestedMeta) -> Result<Attribute, FieldError>> =
vec![get_module_attribute, get_max_length_attribute];
for nested_meta in nested_metas.iter() {
for attribute_parser in attribute_parsers.iter() {
let attribute = attribute_parser(nested_meta)?;
if attribute != Attribute::Empty {
return Ok(attribute);
}
}
}
Ok(Attribute::Empty)
}
fn get_module_attribute(nested_meta: &NestedMeta) -> Result<Attribute, FieldError> {
if let NestedMeta::Meta(Meta::NameValue(named_meta)) = nested_meta {
if matches!(&named_meta.path, path if path.is_ident("with")) {
return match &named_meta.lit {
Lit::Str(lit_str) => Ok(Attribute::With {
module: lit_str.value(),
}),
_ => Err(FieldError::AttributeWrongValueType),
};
}
}
Ok(Attribute::Empty)
}
fn get_max_length_attribute(nested_meta: &NestedMeta) -> Result<Attribute, FieldError> {
if let NestedMeta::Meta(Meta::NameValue(named_meta)) = nested_meta {
if matches!(&named_meta.path, path if path.is_ident("max_length")) {
return match &named_meta.lit {
Lit::Int(lit_int) => Ok(Attribute::MaxLength {
length: lit_int.base10_parse::<usize>()?,
}),
_ => Err(FieldError::AttributeWrongValueType),
};
}
}
Ok(Attribute::Empty)
}

View File

@ -0,0 +1,72 @@
use crate::parse::{Attribute, FieldData};
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::Type;
pub(crate) fn render_decoder(name: &Ident, fields: &Vec<FieldData>) -> TokenStream2 {
let struct_create = render_struct_create(name, fields);
let render_fields = render_fields(fields);
quote! {
#[automatically_derived]
impl crate::decoder::Decoder for #name {
type Output = Self;
fn decode<R: std::io::Read>(reader: &mut R) -> Result<Self::Output, crate::error::DecodeError> {
#render_fields
Ok(#struct_create)
}
}
}
}
fn render_struct_create(name: &Ident, fields: &Vec<FieldData>) -> TokenStream2 {
let struct_fields = fields
.iter()
.map(|f| f.name)
.map(|n| quote!(#n,))
.collect::<TokenStream2>();
quote! {
#name {
#struct_fields
}
}
}
fn render_fields(fields: &Vec<FieldData>) -> TokenStream2 {
fields.iter().map(|f| render_field(f)).flatten().collect()
}
fn render_field(field: &FieldData) -> TokenStream2 {
let name = field.name;
let ty = field.ty;
match &field.attribute {
Attribute::With { module } => render_with_field(name, module),
Attribute::MaxLength { length } => render_max_length_field(name, *length as u16),
Attribute::Empty => render_simple_field(name, ty),
}
}
fn render_simple_field(name: &Ident, ty: &Type) -> TokenStream2 {
quote! {
let #name = <#ty as crate::decoder::Decoder>::decode(reader)?;
}
}
fn render_with_field(name: &Ident, module: &str) -> TokenStream2 {
let module_ident = Ident::new(module, Span::call_site());
quote! {
let #name = crate::decoder::#module_ident::decode(reader)?;
}
}
fn render_max_length_field(name: &Ident, max_length: u16) -> TokenStream2 {
quote! {
let #name = crate::decoder::DecoderReadExt::read_string(reader, #max_length)?;
}
}

View File

@ -0,0 +1,51 @@
use crate::parse::{Attribute, FieldData};
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Ident, Span};
use quote::quote;
pub(crate) fn render_encoder(name: &Ident, fields: &Vec<FieldData>) -> TokenStream2 {
let render_fields = render_fields(fields);
quote! {
#[automatically_derived]
impl crate::encoder::Encoder for #name {
fn encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), crate::error::EncodeError> {
#render_fields
Ok(())
}
}
}
}
fn render_fields(fields: &Vec<FieldData>) -> TokenStream2 {
fields.iter().map(|f| render_field(f)).flatten().collect()
}
fn render_field(field: &FieldData) -> TokenStream2 {
let name = field.name;
match &field.attribute {
Attribute::With { module } => render_with_field(name, module),
Attribute::MaxLength { length } => render_max_length_field(name, *length as u16),
Attribute::Empty => render_simple_field(name),
}
}
fn render_simple_field(name: &Ident) -> TokenStream2 {
render_with_field(name, "Encoder")
}
fn render_with_field(name: &Ident, module: &str) -> TokenStream2 {
let module_ident = Ident::new(module, Span::call_site());
quote! {
crate::encoder::#module_ident::encode(&self.#name, writer)?;
}
}
fn render_max_length_field(name: &Ident, max_length: u16) -> TokenStream2 {
quote! {
crate::encoder::EncoderWriteExt::write_string(writer, &self.#name, #max_length)?;
}
}

View File

@ -0,0 +1,2 @@
pub(crate) mod decoder;
pub(crate) mod encoder;

View File

@ -4,7 +4,7 @@ use crate::data::chat::Message;
use crate::decoder::Decoder;
use crate::error::DecodeError;
use crate::impl_enum_encoder_decoder;
use minecraft_protocol_derive::Packet;
use minecraft_protocol_derive::{Decoder, Encoder};
use nbt::CompoundTag;
use std::io::Read;
@ -89,9 +89,9 @@ impl GameClientBoundPacket {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct ServerBoundChatMessage {
#[packet(max_length = 256)]
#[data_type(max_length = 256)]
pub message: String,
}
@ -103,7 +103,7 @@ impl ServerBoundChatMessage {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct ClientBoundChatMessage {
pub message: Message,
pub position: MessagePosition,
@ -126,15 +126,15 @@ impl ClientBoundChatMessage {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct JoinGame {
pub entity_id: u32,
pub game_mode: GameMode,
pub dimension: i32,
pub max_players: u8,
#[packet(max_length = 16)]
#[data_type(max_length = 16)]
pub level_type: String,
#[packet(with = "var_int")]
#[data_type(with = "var_int")]
pub view_distance: i32,
pub reduced_debug_info: bool,
}
@ -174,7 +174,7 @@ impl JoinGame {
}
}
#[derive(Packet)]
#[derive(Encoder, Decoder)]
pub struct ServerBoundKeepAlive {
pub id: u64,
}
@ -187,7 +187,7 @@ impl ServerBoundKeepAlive {
}
}
#[derive(Packet)]
#[derive(Encoder, Decoder)]
pub struct ClientBoundKeepAlive {
pub id: u64,
}
@ -200,12 +200,12 @@ impl ClientBoundKeepAlive {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct ChunkData {
pub x: i32,
pub z: i32,
pub full: bool,
#[packet(with = "var_int")]
#[data_type(with = "var_int")]
pub primary_mask: i32,
pub heights: CompoundTag,
pub data: Vec<u8>,
@ -236,7 +236,7 @@ impl ChunkData {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct GameDisconnect {
pub reason: Message,
}

View File

@ -4,7 +4,7 @@ use uuid::Uuid;
use crate::data::chat::Message;
use crate::decoder::Decoder;
use crate::error::DecodeError;
use minecraft_protocol_derive::Packet;
use minecraft_protocol_derive::{Decoder, Encoder};
pub enum LoginServerBoundPacket {
LoginStart(LoginStart),
@ -102,7 +102,7 @@ impl LoginClientBoundPacket {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct LoginStart {
pub name: String,
}
@ -115,7 +115,7 @@ impl LoginStart {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct EncryptionResponse {
pub shared_secret: Vec<u8>,
pub verify_token: Vec<u8>,
@ -132,12 +132,12 @@ impl EncryptionResponse {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct LoginPluginResponse {
#[packet(with = "var_int")]
#[data_type(with = "var_int")]
pub message_id: i32,
pub successful: bool,
#[packet(with = "rest")]
#[data_type(with = "rest")]
pub data: Vec<u8>,
}
@ -153,7 +153,7 @@ impl LoginPluginResponse {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct LoginDisconnect {
pub reason: Message,
}
@ -166,9 +166,9 @@ impl LoginDisconnect {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct EncryptionRequest {
#[packet(max_length = 20)]
#[data_type(max_length = 20)]
pub server_id: String,
pub public_key: Vec<u8>,
pub verify_token: Vec<u8>,
@ -190,11 +190,11 @@ impl EncryptionRequest {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct LoginSuccess {
#[packet(with = "uuid_hyp_str")]
#[data_type(with = "uuid_hyp_str")]
pub uuid: Uuid,
#[packet(max_length = 16)]
#[data_type(max_length = 16)]
pub username: String,
}
@ -206,9 +206,9 @@ impl LoginSuccess {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct SetCompression {
#[packet(with = "var_int")]
#[data_type(with = "var_int")]
pub threshold: i32,
}
@ -220,12 +220,12 @@ impl SetCompression {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct LoginPluginRequest {
#[packet(with = "var_int")]
#[data_type(with = "var_int")]
pub message_id: i32,
pub channel: String,
#[packet(with = "rest")]
#[data_type(with = "rest")]
pub data: Vec<u8>,
}

View File

@ -1,7 +1,7 @@
use crate::data::server_status::*;
use crate::decoder::Decoder;
use crate::error::DecodeError;
use minecraft_protocol_derive::Packet;
use minecraft_protocol_derive::{Decoder, Encoder};
use std::io::Read;
pub enum StatusServerBoundPacket {
@ -44,7 +44,7 @@ impl StatusClientBoundPacket {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct PingRequest {
pub time: u64,
}
@ -57,7 +57,7 @@ impl PingRequest {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct PingResponse {
pub time: u64,
}
@ -70,7 +70,7 @@ impl PingResponse {
}
}
#[derive(Packet, Debug)]
#[derive(Encoder, Decoder, Debug)]
pub struct StatusResponse {
pub server_status: ServerStatus,
}