Archived
1
0
Fork 0

Make EventBus coroutine-compatible

This commit is contained in:
Moritz Ruth 2020-08-13 23:43:03 +02:00
parent 321e73ca80
commit 81a39c2e96
No known key found for this signature in database
GPG key ID: AFD57E23E753841B
12 changed files with 117 additions and 51 deletions

View file

@ -14,9 +14,11 @@ val spekVersion = "2.0.12"
dependencies { dependencies {
implementation(kotlin("stdlib-jdk8")) implementation(kotlin("stdlib-jdk8"))
implementation(kotlin("reflect"))
implementation("com.google.code.gson:gson:2.8.6") implementation("com.google.code.gson:gson:2.8.6")
api("org.slf4j:slf4j-api:1.7.30") api("org.slf4j:slf4j-api:1.7.30")
api("io.netty:netty-buffer:4.1.50.Final") api("io.netty:netty-buffer:4.1.50.Final")
api("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.3.8")
testImplementation("io.strikt:strikt-core:0.26.1") testImplementation("io.strikt:strikt-core:0.26.1")
testImplementation("org.spekframework.spek2:spek-dsl-jvm:$spekVersion") testImplementation("org.spekframework.spek2:spek-dsl-jvm:$spekVersion")

View file

@ -12,4 +12,4 @@ inline fun Cancellable.ifCancelled(fn: () -> Unit) = if (isCancelled) fn() else
/** /**
* Only executes [fn] if [isCancelled][Cancellable.isCancelled] is false. * Only executes [fn] if [isCancelled][Cancellable.isCancelled] is false.
*/ */
inline fun Cancellable.ifNotCancelled(fn: () -> Unit) = if (!isCancelled) fn() else Unit inline fun <T> Cancellable.ifNotCancelled(fn: () -> T): T? = if (!isCancelled) fn() else null

View file

@ -1,23 +1,35 @@
package space.blokk.events package space.blokk.events
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.launch
import space.blokk.Blokk
import space.blokk.plugins.Plugin import space.blokk.plugins.Plugin
import java.lang.reflect.Method
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.full.*
class EventBus<EventT: Event>(eventType: KClass<EventT>) { class EventBus<EventT: Event>(private val eventClass: KClass<EventT>, private val scope: CoroutineScope) {
private val eventType: Class<EventT> = eventType.java
/** /**
* All event handlers, sorted by their priority and the order in which they were registered. * All event handlers, sorted by their priority and the order in which they were registered.
*/ */
private val handlers = mutableListOf<Handler>() private val handlers = mutableListOf<Handler<EventT>>()
/** /**
* Invokes all previously registered event handlers sorted by their priority * Invokes all previously registered event handlers sorted by their priority
* and the order in which they were registered. * and the order in which they were registered.
*/ */
suspend fun <T: EventT> emitAndAwait(event: T): T {
handlers.filter { it.eventType.isInstance(event) }.forEach {
scope.launch {
it.fn.callSuspend(it.listener, event)
}
}
return event
}
fun <T: EventT> emit(event: T): T { fun <T: EventT> emit(event: T): T {
handlers.filter { it.eventType.isInstance(event) }.forEach { it.fn.invoke(it.listener, event) } Blokk.server.scope.launch { emitAndAwait(event) }
return event return event
} }
@ -28,21 +40,29 @@ class EventBus<EventT: Event>(eventType: KClass<EventT>) {
* @throws InvalidEventHandlerException if one of the event handlers does not meet the requirements * @throws InvalidEventHandlerException if one of the event handlers does not meet the requirements
*/ */
fun <T: Listener> register(listener: T): T { fun <T: Listener> register(listener: T): T {
val handlersOfListener = listener::class.java.methods val handlersOfListener = listener::class.functions
.mapNotNull { method -> method.getAnnotation(EventHandler::class.java)?.let { method to it } } .mapNotNull { method -> method.findAnnotation<EventHandler>()?.let { method to it } }
.toMap() .toMap()
for ((method, data) in handlersOfListener) { for ((method, data) in handlersOfListener) {
if (method.parameters.count() != 1) if (method.valueParameters.count() != 1)
throw InvalidEventHandlerException("${method.name} must have exactly one parameter") throw InvalidEventHandlerException("${method.name} must have exactly one parameter")
val type = method.parameterTypes[0] @Suppress("UNCHECKED_CAST")
if (!eventType.isAssignableFrom(type)) val klass = method.parameters[1].type.classifier as KClass<EventT>
if (!eventClass.isSuperclassOf(klass))
throw InvalidEventHandlerException("${method.name}'s first parameter type is incompatible with the " + throw InvalidEventHandlerException("${method.name}'s first parameter type is incompatible with the " +
"one required by the EventBus") "one required by the EventBus")
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val handler = Handler(type as Class<out Event>, listener, method, Plugin.getCalling(), data.priority) val handler = Handler(
klass,
listener,
method as KFunction<EventT>,
Plugin.getCalling(),
data.priority
)
val insertIndex = handlers.indexOfLast { it.priority.ordinal <= handler.priority.ordinal } + 1 val insertIndex = handlers.indexOfLast { it.priority.ordinal <= handler.priority.ordinal } + 1
handlers.add(insertIndex, handler) handlers.add(insertIndex, handler)
@ -60,11 +80,11 @@ class EventBus<EventT: Event>(eventType: KClass<EventT>) {
class InvalidEventHandlerException internal constructor(message: String): Exception(message) class InvalidEventHandlerException internal constructor(message: String): Exception(message)
private data class Handler( private data class Handler<T: Event>(
val eventType: Class<out Event>, val eventType: KClass<T>,
val listener: Listener, val listener: Listener,
val fn: Method, val fn: KFunction<T>,
val plugin: Plugin?, val plugin: Plugin?,
val priority: EventPriority val priority: EventPriority
) )
} }

View file

@ -1,5 +1,6 @@
package space.blokk.net package space.blokk.net
import kotlinx.coroutines.CoroutineScope
import space.blokk.events.EventTarget import space.blokk.events.EventTarget
import space.blokk.net.events.SessionEvent import space.blokk.net.events.SessionEvent
import space.blokk.net.protocols.OutgoingPacket import space.blokk.net.protocols.OutgoingPacket
@ -12,6 +13,7 @@ interface Session: EventTarget<SessionEvent> {
*/ */
var currentProtocol: Protocol var currentProtocol: Protocol
val address: InetAddress val address: InetAddress
val scope: CoroutineScope
fun send(packet: OutgoingPacket) suspend fun send(packet: OutgoingPacket)
} }

View file

@ -1,9 +1,11 @@
package space.blokk.server package space.blokk.server
import kotlinx.coroutines.CoroutineScope
import space.blokk.events.EventBus import space.blokk.events.EventBus
import space.blokk.events.EventTarget import space.blokk.events.EventTarget
import space.blokk.server.events.ServerEvent import space.blokk.server.events.ServerEvent
interface Server: EventTarget<ServerEvent> { interface Server: EventTarget<ServerEvent> {
override val eventBus: EventBus<ServerEvent> override val eventBus: EventBus<ServerEvent>
val scope: CoroutineScope
} }

View file

@ -18,6 +18,7 @@ dependencies {
implementation("io.netty:netty-all:4.1.50.Final") implementation("io.netty:netty-all:4.1.50.Final")
implementation("org.slf4j:slf4j-api:1.7.30") implementation("org.slf4j:slf4j-api:1.7.30")
implementation("ch.qos.logback:logback-classic:1.2.3") implementation("ch.qos.logback:logback-classic:1.2.3")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:1.3.8")
testImplementation(kotlin("test-junit5")) testImplementation(kotlin("test-junit5"))
} }

View file

@ -1,5 +1,7 @@
package space.blokk package space.blokk
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import space.blokk.events.EventBus import space.blokk.events.EventBus
import space.blokk.net.BlokkSocketServer import space.blokk.net.BlokkSocketServer
import space.blokk.server.Server import space.blokk.server.Server
@ -9,7 +11,8 @@ import space.blokk.utils.Logger
class BlokkServer internal constructor(): Server { class BlokkServer internal constructor(): Server {
init { i = this } init { i = this }
override val eventBus = EventBus(ServerEvent::class) override val scope = CoroutineScope(CoroutineName("BlokkServer"))
override val eventBus = EventBus(ServerEvent::class, scope)
val logger = Logger("BlokkServer") val logger = Logger("BlokkServer")
var blokkSocketServer = BlokkSocketServer(this); private set var blokkSocketServer = BlokkSocketServer(this); private set

View file

@ -1,6 +1,8 @@
package space.blokk.net package space.blokk.net
import io.netty.channel.Channel import io.netty.channel.Channel
import kotlinx.coroutines.*
import space.blokk.BlokkServer
import space.blokk.events.* import space.blokk.events.*
import space.blokk.net.events.SessionEvent import space.blokk.net.events.SessionEvent
import space.blokk.net.events.SessionPacketReceivedEvent import space.blokk.net.events.SessionPacketReceivedEvent
@ -8,13 +10,16 @@ import space.blokk.net.events.SessionPacketSendEvent
import space.blokk.net.protocols.OutgoingPacket import space.blokk.net.protocols.OutgoingPacket
import space.blokk.net.protocols.Protocol import space.blokk.net.protocols.Protocol
import space.blokk.net.protocols.handshaking.HandshakingProtocol import space.blokk.net.protocols.handshaking.HandshakingProtocol
import space.blokk.server.events.SessionInitializedEvent
import space.blokk.utils.Logger import space.blokk.utils.Logger
import space.blokk.utils.awaitSuspending
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
class BlokkSession(private val channel: Channel) : Session { class BlokkSession(private val channel: Channel) : Session {
override val address: InetAddress = (channel.remoteAddress() as InetSocketAddress).address override val address: InetAddress = (channel.remoteAddress() as InetSocketAddress).address
val logger = Logger("BlokkSession(${address.hostAddress})") private val identifier = "BlokkSession(${address.hostAddress})"
val logger = Logger(identifier)
override var currentProtocol: Protocol = HandshakingProtocol override var currentProtocol: Protocol = HandshakingProtocol
set(value) { set(value) {
@ -22,30 +27,41 @@ class BlokkSession(private val channel: Channel) : Session {
field = value field = value
} }
override val eventBus = EventBus(SessionEvent::class) override val scope = CoroutineScope(Dispatchers.Unconfined + CoroutineName(identifier))
override val eventBus = EventBus(SessionEvent::class, scope)
var active: Boolean = true var active: Boolean = true
init { init {
eventBus.register(object : Listener { eventBus.register(object : Listener {
@EventHandler(priority = EventPriority.INTERNAL) @EventHandler(priority = EventPriority.INTERNAL)
fun onSessionPacketReceived(event: SessionPacketReceivedEvent<*>) { suspend fun onSessionPacketReceived(event: SessionPacketReceivedEvent<*>) {
event.ifNotCancelled { SessionPacketReceivedEventHandler.handle(event.session as BlokkSession, event.packet) } SessionPacketReceivedEventHandler.handle(event.session as BlokkSession, event.packet)
} }
}) })
} }
override fun send(packet: OutgoingPacket) { fun onConnect() = scope.launch {
if (BlokkServer.i.eventBus.emit(SessionInitializedEvent(this@BlokkSession)).isCancelled) channel.close()
else BlokkServer.i.blokkSocketServer.allSessionsGroup.add(this@BlokkSession)
}
fun onDisconnect() {
active = false
scope.cancel("Disconnected")
BlokkServer.i.blokkSocketServer.allSessionsGroup.remove(this)
}
override suspend fun send(packet: OutgoingPacket) {
if (!active) throw IllegalStateException("The session is not active anymore") if (!active) throw IllegalStateException("The session is not active anymore")
logger debug { "Sending packet: $packet" } logger debug { "Sending packet: $packet" }
val event = eventBus.emit(SessionPacketSendEvent(this, packet)) val event = eventBus.emit(SessionPacketSendEvent(this@BlokkSession, packet))
if (event.isCancelled) return event.ifNotCancelled {
try {
val cf = channel.writeAndFlush(PacketMessage(this, event.packet)) channel.writeAndFlush(PacketMessage(this@BlokkSession, event.packet)).awaitSuspending()
cf.addListener { future -> } catch (e: Throwable) {
if (!future.isSuccess) { logger error { "Packet send failed: $e" }
logger error { "Packet send failed: ${future.cause()}" }
} }
} }
} }

View file

@ -10,7 +10,7 @@ import space.blokk.net.protocols.status.StatusProtocolHandler
open class ProtocolPacketReceivedEventHandler(handlers: Map<out IncomingPacketCompanion<*>, PacketReceivedEventHandler<out IncomingPacket>>) { open class ProtocolPacketReceivedEventHandler(handlers: Map<out IncomingPacketCompanion<*>, PacketReceivedEventHandler<out IncomingPacket>>) {
private val handlers = handlers.mapKeys { it.key.packetType } private val handlers = handlers.mapKeys { it.key.packetType }
fun handle(session: BlokkSession, packet: IncomingPacket) { suspend fun handle(session: BlokkSession, packet: IncomingPacket) {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val handler = handlers[packet::class] as PacketReceivedEventHandler<IncomingPacket>? val handler = handlers[packet::class] as PacketReceivedEventHandler<IncomingPacket>?
handler?.handle(session, packet) handler?.handle(session, packet)
@ -18,18 +18,18 @@ open class ProtocolPacketReceivedEventHandler(handlers: Map<out IncomingPacketCo
} }
abstract class PacketReceivedEventHandler<T: IncomingPacket> { abstract class PacketReceivedEventHandler<T: IncomingPacket> {
abstract fun handle(session: BlokkSession, packet: T) abstract suspend fun handle(session: BlokkSession, packet: T)
companion object { companion object {
fun <T: IncomingPacket> of(fn: (session: BlokkSession, packet: T) -> Unit) = fun <T: IncomingPacket> of(fn: suspend (session: BlokkSession, packet: T) -> Unit) =
object : PacketReceivedEventHandler<T>() { object : PacketReceivedEventHandler<T>() {
override fun handle(session: BlokkSession, packet: T) = fn(session, packet) override suspend fun handle(session: BlokkSession, packet: T) = fn(session, packet)
} }
} }
} }
object SessionPacketReceivedEventHandler { object SessionPacketReceivedEventHandler {
fun handle(session: BlokkSession, packet: IncomingPacket) { suspend fun handle(session: BlokkSession, packet: IncomingPacket) {
val handler = when(session.currentProtocol) { val handler = when(session.currentProtocol) {
HandshakingProtocol -> HandshakingProtocolHandler HandshakingProtocol -> HandshakingProtocolHandler
StatusProtocol -> StatusProtocolHandler StatusProtocol -> StatusProtocolHandler

View file

@ -5,22 +5,14 @@ import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.MessageToMessageCodec import io.netty.handler.codec.MessageToMessageCodec
import space.blokk.BlokkServer import space.blokk.BlokkServer
import space.blokk.net.protocols.OutgoingPacket import space.blokk.net.protocols.OutgoingPacket
import space.blokk.server.events.SessionInitializedEvent
import java.io.IOException import java.io.IOException
class PacketCodec(private val session: BlokkSession): MessageToMessageCodec<ByteBuf, PacketMessage<*>>() { class PacketCodec(private val session: BlokkSession): MessageToMessageCodec<ByteBuf, PacketMessage<*>>() {
override fun channelActive(ctx: ChannelHandlerContext) { override fun channelActive(ctx: ChannelHandlerContext) { session.onConnect() }
if (BlokkServer.i.eventBus.emit(SessionInitializedEvent(session)).isCancelled) ctx.channel().close() override fun channelInactive(ctx: ChannelHandlerContext) { session.onDisconnect() }
else BlokkServer.i.blokkSocketServer.allSessionsGroup.add(session)
}
override fun channelInactive(ctx: ChannelHandlerContext) {
session.active = false
BlokkServer.i.blokkSocketServer.allSessionsGroup.remove(session)
}
override fun encode(ctx: ChannelHandlerContext, msg: PacketMessage<*>, out: MutableList<Any>) { override fun encode(ctx: ChannelHandlerContext, msg: PacketMessage<*>, out: MutableList<Any>) {
if (msg.packet !is OutgoingPacket) throw Error("Only clientbound packets can be sent") if (msg.packet !is OutgoingPacket) throw Error("Only clientbound packets are allowed. This should never happen.")
val buffer = ctx.alloc().buffer() val buffer = ctx.alloc().buffer()
with(MinecraftDataTypes) { buffer.writeVarInt(msg.packetCompanion.id) } with(MinecraftDataTypes) { buffer.writeVarInt(msg.packetCompanion.id) }
msg.packet.encode(buffer) msg.packet.encode(buffer)
@ -42,5 +34,6 @@ class PacketCodec(private val session: BlokkSession): MessageToMessageCodec<Byte
// You can usually ignore connection errors as they are caused by modified clients such as hack clients // You can usually ignore connection errors as they are caused by modified clients such as hack clients
if (BlokkServer.i.silentConnectionErrors) session.logger.debug(message) else session.logger.error(message) if (BlokkServer.i.silentConnectionErrors) session.logger.debug(message) else session.logger.error(message)
cause.printStackTrace()
} }
} }

View file

@ -2,14 +2,17 @@ package space.blokk.net
import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler import io.netty.channel.SimpleChannelInboundHandler
import kotlinx.coroutines.runBlocking
import space.blokk.net.events.SessionPacketReceivedEvent import space.blokk.net.events.SessionPacketReceivedEvent
import space.blokk.net.protocols.IncomingPacket import space.blokk.net.protocols.IncomingPacket
class PacketMessageHandler(private val session: BlokkSession): SimpleChannelInboundHandler<PacketMessage<*>>() { class PacketMessageHandler(private val session: BlokkSession): SimpleChannelInboundHandler<PacketMessage<*>>() {
override fun channelRead0(ctx: ChannelHandlerContext, msg: PacketMessage<*>) { override fun channelRead0(ctx: ChannelHandlerContext, msg: PacketMessage<*>) {
if (msg.packet !is IncomingPacket) throw Error("Only serverbound packets are allowed. This should never happen.") if (msg.packet !is IncomingPacket) throw Error("Only serverbound packets are allowed. This should never happen.")
msg.session.logger.debug { "Packet received: ${msg.packet}" } session.logger.debug { "Packet received: ${msg.packet}" }
msg.session.eventBus.emit(SessionPacketReceivedEvent(msg.session, msg.packet)) runBlocking {
session.eventBus.emitAndAwait(SessionPacketReceivedEvent(session, msg.packet))
}
// TODO: Disconnect when invalid data is received // TODO: Disconnect when invalid data is received
} }

View file

@ -0,0 +1,24 @@
package space.blokk.utils
import io.netty.channel.ChannelFuture
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlin.coroutines.Continuation
import kotlin.coroutines.suspendCoroutine
suspend fun ChannelFuture.awaitSuspending() {
fun listen(c: Continuation<Unit>) {
addListener { c.resumeWith(if (it.isSuccess) Result.success(Unit) else Result.failure(it.cause())) }
}
if (isCancellable) {
suspendCancellableCoroutine<Unit> { c ->
if (isCancelled) c.cancel()
else {
c.invokeOnCancellation { cancel(false) }
listen(c)
}
}
} else {
suspendCoroutine<Unit> { c -> listen(c) }
}
}