Ray: A Distributed Framework for Emerging AI Applications 8 minutes read | 1548 words by Ruben Berenguel
A new entry on the data papers series. Ray is a distributed framework for next generation AI applications. What does this mean? A scam? Blockchain on AI? Nah, it’s actually pretty cool, it has actors.
Ray was created as an answer to a problem that didn’t seem that well-known: properly parallelizing reinforcement learning. Most work in the past years has either been applied to parallelizing or improving classical machine learning algorithms or in adding more and more “stuff” (layers, memory, layers with memory, layers with memory and forgetting, layers with memory and selective forgetting…) to deep learning systems. But there was a decently sized opportunity in the reinforcement learning area, because a lot of problems are better tackled with reinforcement learning approaches. In this article I have a look at the paper Ray: A Distributed Framework for Emerging AI Applications.
Explained succinctly, in reinforcement learning (RL from now on) you have agents modifying a state space (or environment) while trying to maximise (or minimise) some reward function. This is done via some form of action policy that given an environment chooses the next action of the agent(s). The better known example of a reinforcement learning system would be AlphaGo (and all the other Alpha systems publicised by DeepMind), where the deep nets within AlphaGo were trained using RL, with a reward function based on positional value and estimated (simulated) endgame scores via Monte Carlo simulations.
The problems Ray tries to address are:
Being able to effectively simulate thousands of millions of state transitions, to be able to evaluate policies;
Distribute training across nodes and machines to improve the policy through all those simulations;
Finally, serve the policy in an interactive query loop.
It may not be immediately obvious, but these requirements imply fine-grained computation. We can’t distribute at the dataset level (batch Spark) or even at micro-bach level (Spark streaming). We can’t even easily distribute at the row level (Flink), although that may not be clear yet in this piece. The example from AlphaGo might already hint at the other “issue”: the system is going to be very heterogeneous. Training may take ages and require enormous GPU usage, whereas serving is probably going to be instantaneous and not require so much compute power. Clearly, no system I have covered before here fits these requirements. Which is kind of good, because those are the challenges Ray tries to fix.
A general-purpose cluster-computing framework…
that enables simulation,
for RL applications.
The last bullet there is kind of false: you can actually use Ray for more general things. More on this later.
What is the magic sauce?
The trick, and the thing I like from Ray is that it offers all this by providing a unified interface between task-parallel and actor-based computations. This is a fancy way of saying that it offers cancellable tasks that return futures as promises of a computation as well as an stateful actor model. I won’t introduce the actor model here (I have talked about it before somewhere), but this should remind you a lot of Akka. But Akka does not offer a nice interface for tasks, it essentially defers to Scala futures (or whatever other future-like you build on).
Tasks are an easy way of distributing simulations and process large inputs, as well as recover from failure. The actor system allows a very easy way of handling stateful, distributed computations. These stateful computations are fundamental for training.
Ray’s computation graph has:
node: data objects (immutable);
node: remote function invocations (possibly over immutable data objects, or returning them);
edge: data edge, capturing the dependency between a data object and a task;
edge: control edge, capturing the computation dependencies between nested remote function calls;
edge: stateful edge, methods on the same actor generate this additional edge to keep state internal to the actor.
I didn’t even bother drawing all this, it makes my head spin. But if you look at it blurrily enough, in the end this is saying that even if actors fail, we have a lineage graph for all the computations that took place on it, ending in some immutable data object.
Ray introduces a paradigm shift compared to similar systems:
The task scheduler is distributed, not centralised in a driver instance;
The metadata store maintaining lineage and directory of data objects is also distributed.
These key changes are what can allow Ray to (supposedly) support millions of tasks per second with millisecond latencies. Important for data people, Ray supports lineage-based fault tolerance for both tasks and actors.
Seeing this mention of lineage-based fault tolerance I can hear your neurons screaming: can this replace RDDs or the whole of Apache Spark? The answer is no. Although the building blocks could in theory allow for it, Ray has no higher level functionality (query planning, straggler recomputation, dataframe API…). It would be a really good project to build on top of it, though.
Ray has a dynamic task graph computation model, i.e. a Ray application is a graph of dependent tasks that evolve during the execution, with nodes and edges as described before. If this sounds a bit too mumbo-jumbo, think of a spreadsheet with many linked dependent cells, like when you are trying to fill your taxes. This is what a Ray application looks like for Ray, modifying a cell changes a lot of other cells. Within this model, Ray offers a task-based API and an actor based API, and which to choose basically depends on how you will handle the internal state of your process.
Tasks: they are running a remote function on a stateless worker. Think of a lambda function in AWS, or an Edge function in CloudFlare. It can run anywhere. Ray uses futures represending the computation, and these can be passed around until you wait on them. Passing around these futures is part of what builds the computation graph. These functions should act on immutable objects and should be both stateless and side-effect free: Ray’s failure recovery system is going to try to re-run them aggressively if it needs so.
Actors: actors represent stateful computations. Each actor offers some methods that are available remotely, and which are executed in the order they are called (wink wink actor message queues). Each of these methods is similar to a task, it’s executed somewhere and returns a future. The difference though is that it is going to be executed on a remote, stateful worker. Actors have handles, and you can pass actor handles around. This passing “by reference” combined with the futures from calling their methods is the other part built on top of Ray’s task graph and that allows for lineage of computations.
Ray’s distributed scheduler triggers computations like a chain of dominoes: as soon as inputs to a remote function or actor method are available, the function or method are triggered and evaluated. And then dependencies are triggered, and so on. This is obviously extremely heavy on the scheduler if you have millions of actors, and this is why the scheduler is distributed.
Ray uses a standard driver-worker architecture… with an additional actor piece. The driver runs the user program on a set of workers… if they are stateless. Any worker will do. Actors are explicitly instantiated for actor computations.
The metadata store is called Global Control Store (GCS) and maintains the full control state. It is essentially a key-value store offering a pubsub (publisher-subscriber) API, and it is sharded across nodes for fault-tolerance. All object metadata is stored in the GCS, to remove some unnecessary load from the scheduler.
The scheduler is bottom-up and distributed. What does this mean? Ray hosts two schedulers: one per node and a global one. Tasks created in a node go first to the scheduler in that node, which checks if it’s possible to allocate locally. If it’s not, it is passed to the global scheduler. Hence the bottom-up, and also the distributed part. The global scheduler uses an estimated waiting time metric to choose which node should take it, aside from taking into account any hardware requirements the task or actor may have. Note that the global scheduler is not a contention point: it can be replicated freely.
Inputs and outputs of tasks (all stateless computations) are stored in an object store using shared memory in each node, allowing for zero-copy (via Apache Arrow). Since inputs and outputs are immutable, this is ideal: lowest possible latency. When inputs are non-local, they are replicated to the local object store to avoid or minimise hot objects. This object store does not offer distributed objects (like dataframes in Spark) for simplicity. Everything needs to fit in one node (or you need to write your own code to handle that).
The system layer, which is most of what I have described here in the architecture section is implemented in C++, whereas the APIs for the application layer are written in Python. The GCS uses Redis.
Well, it’s an actor system, and you can basically do whatever you want with actors. In my specific scenario, I am looking into leveraging Ray Serve (the serving part of Ray, which offers something vaguely reminiscent of akka-http) as a high-throughput distributed tracking server.