# 深入理解JGroups的Rpc实现

# 从Building Blocks开始

Building Blocks是针对JChannel的更上层抽象。以JChannel之上扩展出的不同BuildingBlocks实现,可以开发出基于网络传输的多种功能。

public class JChannel implements Closeable {
    protected UpHandler                             up_handler; 
    protected Receiver                              receiver;

    public JChannel      setReceiver(Receiver r)             {receiver=r; return this;}
    public UpHandler     getUpHandler()                      {return up_handler;}

    public Object up(Message msg) {
        if(stats) {
            received_msgs++;
            received_bytes+=msg.getLength();
        }

        // discard local messages (sent by myself to me)
        if(discard_own_messages && local_addr != null && msg.getSrc() != null && local_addr.equals(msg.getSrc()))
            return null;

        // If UpHandler is installed, pass all events to it and return (UpHandler is e.g. a building block)
        if(up_handler != null)
            return up_handler.up(msg);

        if(receiver != null)
            receiver.receive(msg);
        return null;
    }
}

public interface UpHandler {
   
   /**
    * Invoked for all channel events except connection management and state transfer.
    */
    Object up(Event evt);

    Object up(Message msg);

    default void up(MessageBatch batch) {
        for(Message msg: batch) {
            try {
                up(msg);
            }
            catch(Throwable t) {
            }
        }
    }
}

从JChannel的up方法可以看出,如果一个JChannel注册了UpHandler,那么JChannel在获取到消息时便会跳过Receiver。所以当使用BuildingBlocks,可以认为是将UpHandler认为是真正的消息处理者,而JChannel变成了一个中间者。

# 包装JChannel

使用RpcDispatcher时发现RpcDispatcher继承自MessageDispatcher。MessageDispatcher在创建时使用JChannel作为构造函数。在源码中可以看出其包装了一个新的Protocol,名为ProtocolAdapter,作为JChannel的UpHandler。

public class MessageDispatcher implements RequestHandler, Closeable, ChannelListener {
    public MessageDispatcher(JChannel channel) {
        this.channel=channel;
        prot_adapter=new ProtocolAdapter();
        if(channel != null) {
            channel.addChannelListener(this);
            local_addr=channel.getAddress();
            installUpHandler(prot_adapter, true);
        }
        start();
    }

    protected <X extends MessageDispatcher> X installUpHandler(UpHandler handler, boolean canReplace) {
        UpHandler existing = channel.getUpHandler();
        if (existing == null)
            channel.setUpHandler(handler);
        else if(canReplace) {
            log.warn("Channel already has an up handler installed (%s) but now it is being overridden", existing);
            channel.setUpHandler(handler);
        }
        return (X)this;
    }
}

ProtocolAdapter其实没什么特别的,就是up和down方法。其中corr属性是用来处理Request和Response的封装方法类,具体做了什么事情在后文讲述。我们知道的是,一个消息从MessageDispatcher发出,和JChannel一样会先经过一个“协议栈”的链式结构到达最下端的发送者。不同的是MessageDispatcher是从ProtocolAdapter开始(也就是JChannel的UpHandler),经过JChannel,再到JChannel下层真正的协议栈。

class ProtocolAdapter extends Protocol implements UpHandler {


        /* ------------------------- Protocol Interface --------------------------- */

        @Override
        public String getName() {
            return "MessageDispatcher";
        }


        /**
         * Called by channel (we registered before) when event is received. This is the UpHandler interface.
         */
        @Override
        public Object up(Event evt) {
            if(corr != null && !corr.receive(evt)) {
                try {
                    return handleUpEvent(evt);
                }
                catch(Throwable t) {
                    throw new RuntimeException(t);
                }
            }
            return null;
        }

        public Object up(Message msg) {
            if(corr != null)
                corr.receiveMessage(msg);
            return null;
        }

        public void up(MessageBatch batch) {
            if(corr == null)
                return;
            corr.receiveMessageBatch(batch);
        }

        @Override
        public Object down(Event evt) {
            return channel != null? channel.down(evt) : null;
        }

        public Object down(Message msg) {
            if(channel != null) {
                if(!(channel.isConnected() || channel.isConnecting())) {
                    // return null;
                    throw new IllegalStateException("channel is not connected");
                }
                return channel.down(msg);
            }
            return null;
        }

        /* ----------------------- End of Protocol Interface ------------------------ */

    }

# 使用RequestCorrelator发送Requeqst和Response

到这里,我们可以看看RequestCorrelator真正做的是什么了。

从MessageDispatcher的start()方法中可以看出,RequestCorrelator接收三个参数,Protocol(JChannel的UpHandler)、RequestHandler(也就是MessageDispatcher自己)、当前地址。

 public <X extends MessageDispatcher> X start() {
        if(corr == null)
            corr=createRequestCorrelator(prot_adapter, this, local_addr)
              .asyncDispatching(async_dispatching).wrapExceptions(this.wrap_exceptions);
        correlatorStarted();
        corr.start();

        if(channel != null) {
            List<Address> tmp_mbrs=channel.getView() != null ? channel.getView().getMembers() : null;
            setMembers(tmp_mbrs);
            if(channel instanceof JChannel) {
                TP transport=channel.getProtocolStack().getTransport();
                corr.registerProbeHandler(transport);
            }
        }
        return (X)this;
    }

protected static RequestCorrelator createRequestCorrelator(Protocol transport, RequestHandler handler, Address local_addr) {
        return new RequestCorrelator(transport, handler, local_addr);
    }

我们看一下一个消息在RequestCorrelator中会经历什么。

public void sendRequest(Collection<Address> dest_mbrs, Buffer data, Request req, RequestOptions opts) throws Exception {
        if(transport == null) {
            log.warn("transport is not available !");
            return;
        }

        // i.   Create the request correlator header and add it to the msg
        // ii.  If a reply is expected (coll != null), add a coresponding entry in the pending requests table
        Header hdr=opts.hasExclusionList()? new MultiDestinationHeader(Header.REQ, 0, this.corr_id, opts.exclusionList())
          : new Header(Header.REQ, 0, this.corr_id);

        Message msg=new Message(null, data).putHeader(this.corr_id, hdr)
          .setFlag(opts.flags()).setTransientFlag(opts.transientFlags());

        if(req != null) { // sync
            long req_id=REQUEST_ID.getAndIncrement();
            req.requestId(req_id);
            hdr.requestId(req_id); // set the request-id only for *synchronous RPCs*
            if(log.isTraceEnabled())
                log.trace("%s: invoking multicast RPC [req-id=%d]", local_addr, req_id);
            requests.putIfAbsent(req_id, req);
            // make sure no view is received before we add ourself as a view handler (https://issues.jboss.org/browse/JGRP-1428)
            req.viewChange(view);
            if(rpc_stats.extendedStats())
                req.start_time=System.nanoTime();
        }
        else {  // async
            if(opts != null && opts.anycasting())
                rpc_stats.addAnycast(false, 0, dest_mbrs);
            else
                rpc_stats.add(RpcStats.Type.MULTICAST, null, false, 0);
        }

        if(opts.anycasting()) {
            if(opts.useAnycastAddresses()) {
                transport.down(msg.dest(new AnycastAddress(dest_mbrs)));
            }
            else {
                boolean first=true;
                for(Address mbr: dest_mbrs) {
                    Message copy=(first? msg : msg.copy(true)).dest(mbr);
                    first=false;
                    if(!mbr.equals(local_addr) && copy.isTransientFlagSet(Message.TransientFlag.DONT_LOOPBACK))
                        copy.clearTransientFlag(Message.TransientFlag.DONT_LOOPBACK);
                    transport.down(copy);
                }
            }
        }
        else
            transport.down(msg);
    }

正如官方手册介绍的那样,一个数据消息会在原基础上添加Header头。其中包括消息类型、correlator_id,和一些其他选项。消息发送前根据异步还是同步发送,会先注册当前Request的信息到一个称为rpc_stats的管理器中。接下来最为关键的,使用transport.down(msg)发送。

这里实现也比较粗糙,根据req是不是null来判断异步还是同步。实际上是因为只有Request创建后且为同步模式时,才会将Request作为参数放到方法中。下文会在Request.sendRequest中看到。

# Request的发送和接收

这里注意到方法参数Request,实际的请求封装正是此。Request实际上就是一个CompletableFuture。其sentRequest创建了发送请求,只有其receiverResponse被调用时,才会触发complete。

public class RequestCorrelator {
    protected void handleResponse(Request req, Address sender, byte[] buf, int offset, int length, boolean is_exception) {
        Object retval;
        try {
            retval=replyFromBuffer(buf, offset, length, marshaller);
        }
        catch(Exception e) {
            log.error(Util.getMessage("FailedUnmarshallingBufferIntoReturnValue"), e);
            retval=e;
            is_exception=true;
        }
        req.receiveResponse(retval, sender, is_exception);
    }
}

RequestCorrelator的handlerResponse比较简洁。下面是Request的相关方法。

public abstract class Request<T> extends CompletableFuture<T> {
    public abstract void       sendRequest(Buffer data) throws Exception;

    public abstract void       receiveResponse(Object response_value, Address sender, boolean is_exception);
}

public class UnicastRequest<T> extends Request<T> {
    public void sendRequest(Buffer data) throws Exception {
        try {
            corr.sendUnicastRequest(target, data, options.mode() == ResponseMode.GET_NONE? null : this, this.options);
        }
        catch(Exception ex) {
            corrDone();
            throw ex;
        }
    }
    

    /* ---------------------- Interface RspCollector -------------------------- */
    /**
     * <b>Callback</b> (called by RequestCorrelator or Transport).
     * Adds a response to the response table. When all responses have been received, {@code execute()} returns.
     */
    public void receiveResponse(Object response_value, Address sender, boolean is_exception) {
        if(isDone())
            return;
        if(is_exception && response_value instanceof Throwable)
            completeExceptionally((Throwable)response_value);
        else
            complete((T)response_value);
        corrDone();
    }
}

# Rpc在消息之上需要做的事情

有了上面的基础,就完全明白了一个Building Blocks构造是怎么发送方法的。可以说Building Block在JChannel之上引入了异步和同步两个概念。这也为异步Rpc和同步Rpc打造了基础。

public class RpcDispatcher extends MessageDispatcher {
    protected Object        server_obj;
    protected Marshaller    marshaller;
    protected MethodLookup  method_lookup;
    protected MethodInvoker method_invoker;
}

server_obj代表了当前Rpc调用的具体方法类。marshaller是针对request和response数据的自定义序列化方式,在下文会详细说明。method_lookup和method_invoker二者只有一个发挥作用。前者通过实现接口可以允许Rpc快速地查找到server_obj中的方法,而不必每次都重新获取class对象。method_invoker则更加直接,直接定义了invoke方法调用的逻辑。二者都是针对Rpc反射的效率优化。

# 发送方法调用调用Remote

public class RpcDispatcher extends MessageDispatcher {
    public <T> RspList<T> callRemoteMethods(Collection<Address> dests, MethodCall method_call,
                                            RequestOptions opts) throws Exception {
        if(dests != null && dests.isEmpty()) { // don't send if dest list is empty
            log.trace("destination list of %s() is empty: no need to send message", method_call.methodName());
            return empty_rsplist;
        }

        Buffer buf=methodCallToBuffer(method_call, marshaller);
        RspList<T> retval=super.castMessage(dests, buf, opts);
        if(log.isTraceEnabled())
            log.trace("dests=%s, method_call=%s, options=%s, responses: %s", dests, method_call, opts, retval);
        return retval;
    }

    protected static Buffer methodCallToBuffer(final MethodCall call, Marshaller marshaller) throws Exception {
        Object[] args=call.args();

        int estimated_size=64;
        if(args != null)
            for(Object arg: args)
                estimated_size+=marshaller != null? marshaller.estimatedSize(arg) : (arg == null? 2 : 50);

        ByteArrayDataOutputStream out=new ByteArrayDataOutputStream(estimated_size, true);
        call.writeTo(out, marshaller);
        return out.getBuffer();
    }
}

可以看出,这里有一个优化点是,默认的方法参数创建时只会根据参数数量做大致估计(50),但如果我们引入一个具体的marshaller,就会根据我们自定义的估算方法将参数大小做进一步的估计。因为如果estimated_size和参数大小的差距过大,在OutputStream中进行write操作时,会有扩容操作,也就是byte数组的copy操作。这个操作次数会影响到性能和内存占用。

public class MethodCall implements Streamable, Constructable<MethodCall> {
    public void writeTo(DataOutput out, Marshaller marshaller) throws IOException {
        out.write(mode);

        switch(mode) {
            case METHOD:
                Bits.writeString(method_name,out);
                writeMethod(out);
                break;
            case TYPES:
                Bits.writeString(method_name,out);
                writeTypes(out);
                break;
            case ID:
                out.writeShort(method_id);
                break;
            default:
                throw new IllegalStateException("mode " + mode + " unknown");
        }
        writeArgs(out, marshaller);
    }

    protected void writeArgs(DataOutput out, Marshaller marshaller) throws IOException {
        int args_len=args != null? args.length : 0;
        out.write(args_len);
        if(args_len == 0)
            return;
        for(Object obj: args) {
            if(marshaller != null)
                marshaller.objectToStream(obj, out);
            else
                Util.objectToStream(obj, out);
        }
    }
}

可以看出当有Marshaller存在时,会使用自己定义的序列化方式。

public interface Marshaller {
    default int estimatedSize(Object arg) {
        return arg == null? 2: 50;
    }
    void objectToStream(Object obj, DataOutput out) throws IOException;
    Object objectFromStream(DataInput in) throws IOException, ClassNotFoundException;
}

# Remote远端接收方法调用请求

public class RpcDispatcher extends MessageDispatcher {
    public Object handle(Message req) throws Exception {
        if(server_obj == null) {
            log.error(Util.getMessage("NoMethodHandlerIsRegisteredDiscardingRequest"));
            return null;
        }

        if(req == null || req.getLength() == 0) {
            log.error(Util.getMessage("MessageOrMessageBufferIsNull"));
            return null;
        }

        MethodCall method_call=methodCallFromBuffer(req.getRawBuffer(), req.getOffset(), req.getLength(), marshaller);
        if(log.isTraceEnabled())
            log.trace("[sender=%s], method_call: %s", req.getSrc(), method_call);

        if(method_call.mode() == MethodCall.ID) {
            if(method_invoker != null) // this trumps a method lookup
                return method_invoker.invoke(server_obj, method_call.methodId(), method_call.args());
            if(method_lookup == null)
                throw new Exception(String.format("MethodCall uses ID=%d, but method_lookup has not been set", method_call.methodId()));
            Method m=method_lookup.findMethod(method_call.methodId());
            if(m == null)
                throw new Exception("no method found for " + method_call.methodId());
            method_call.method(m);
        }
        return method_call.invoke(server_obj);
    }

    protected static MethodCall methodCallFromBuffer(final byte[] buf, int offset, int length, Marshaller   marshaller) throws Exception {
        ByteArrayDataInputStream in=new ByteArrayDataInputStream(buf, offset, length);
        MethodCall call=new MethodCall();
        call.readFrom(in, marshaller);
        return call;
    }
}
public class MethodCall implements Streamable, Constructable<MethodCall> {

    public void readFrom(DataInput in, Marshaller marshaller) throws IOException, ClassNotFoundException {
        switch(mode=in.readByte()) {
            case METHOD:
                method_name=Bits.readString(in);
                readMethod(in);
                break;
            case TYPES:
                method_name=Bits.readString(in);
                readTypes(in);
                break;
            case ID:
                method_id=in.readShort();
                break;
            default:
                throw new IllegalStateException("mode " + mode + " unknown");
        }
        readArgs(in, marshaller);
    }

    protected void readArgs(DataInput in, Marshaller marshaller) throws IOException, ClassNotFoundException {
        int args_len=in.readByte();
        if(args_len == 0)
            return;
        args=new Object[args_len];
        for(int i=0; i < args_len; i++)
            args[i]=marshaller != null? marshaller.objectFromStream(in) : Util.objectFromStream(in);
    }

    public Object invoke(Object target) throws Exception {
        if(target == null)
            throw new IllegalArgumentException("target is null");

        Class cl=target.getClass();
        Method meth=null;

        switch(mode) {
            case METHOD:
                if(this.method != null)
                    meth=this.method;
                break;
            case TYPES:
                meth=getMethod(cl, method_name, types);
                break;
            case ID:
                break;
            default:
                throw new IllegalStateException("mode " + mode + " is invalid");
        }

        if(meth != null) {
            try {
                // allow method invocation on protected or (package-) private methods, too
                if(!Modifier.isPublic(meth.getModifiers()))
                    meth.setAccessible(true);
                return meth.invoke(target, args);
            }
            catch(InvocationTargetException target_ex) {
                Throwable exception=target_ex.getTargetException();
                if(exception instanceof Error) throw (Error)exception;
                else if(exception instanceof RuntimeException) throw (RuntimeException)exception;
                else if(exception instanceof Exception) throw (Exception)exception;
                else throw new RuntimeException(exception);
            }
        }
        else
            throw new NoSuchMethodException(method_name);
    }
}

其实也就是一个简单的反射。