Haskell and XLA

I while back I encountered jax. In short, it’s a python numerical library that provide automatic differentiation on top of xla which allows it to run really fast. Its use of pure function in python intrigued me. So I have been trying to implement it in Haskell, a task more herculean than I thought.
Yam.zip
Yam is the working name for the project and I will change it some time later.
To build it you’ll need stablehlo and mlir.
You would want to install mlir onto your system (I installed it in my .local directory) anywhere you install it, you need the configure scripts to be able to run llvm-config to find the location of the libraries and headers.
Stablehlo is less trivial. Build stablehlo according to the instruction on their github and then in the configure script change the path to stablehlo location.
That ought to be all that you need to build Yam.

So far, Yam have only the equivalent of jax’s jit function, and in a much more limited form. I cannot get haskell type inference to work for it so you would need to annotate the function you feed jit and the function jit spit back out.
There are two main types that makes Yam works, Tracer and Tensor.
Tensor is a multidimensional array that contains data that you can print (only Float is supported currently) out. Meanwhile, a Tracer is an abstract Tensor which accumulates the operations that formed it to be feed into the xla compiler.
An example of how this looks like is in the test directory (not example). It’s very clunky.

test/Main.hs

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Main (main) where
import Foreign

import Yam.Tensor
import Yam.PjRt
import Yam.Jit

test :: (Trace t, Num (t '[12] Float)) => t '[12] Float → t '[12] Float → t '[12] Float
test a b = (s + s) + (s + s)
where s = a + b

main :: IO ()
main = do
device ← head <$> clientAddressableDevices client
a :: Tensor '[12] Float ← withArray [0…11] $ tensorFromHostBuffer device
b :: Tensor '[12] Float ← withArray [1…12] $ tensorFromHostBuffer device

let testjit :: (Trace t, Jit t (Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float) f) => f
testjit = jit (test :: Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float)

print $ testjit a b
print $ test a b
print $ recip b
traceDebug (testjit :: Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float)
clientDestroy client
return ()

This project is only a proof of concept at this point.

9 Likes

Although not updated since Jan, there is also mlir-hs which could take care of the bootstrap phases, I think

I know about mlir-hs, but I cannot get it to build successfully, so in the end I made my own binding to mlir

How do you find using MLIR? Is it good/better/different than LLVM? Is it more expressive? Is the toolchain nice?

MLIR is part of LLVM, and they share license. XLA uses mlir as input for the compilation process, so that why I use it. Yam emit mlir bytecode in the stablehlo dialect (probably similar to what Jax does).
I have never used LLVM so I cannot say if it is better or worse, but the two are different. Their toolchain is the same I think.

The main difference between MLIR and LLVM is that LLVM is a fixed IR consisting essentially of “C + vectors”, while MLIR is an extensible IR.

MLIR has been used quite extensively in the machine learning and GPU space so it has a rich ecosystem surrounding those use-cases, supporting many domain specific optimisations that LLVM is just too low-level for.

(At least, that’s the basic pitch I remember from reading the MLIR paper. I haven’t actually used it yet.)

1 Like

This is exactly right. If squinting a little MLIR is to LLVM what parser combinators is to yacc based parsing: instead of doing everything in a single framework (LLVM), we have multiple dialects and each optimization pass is implemented in the dialect where it is easier to do so (and more dialects can be defined)

Yay, this is cool. If you publish this with a free enough licence (ideally in a git repo), I’d gladly take inspiration from it (e.g, the way you trace your operations) or outright steal some parts for GitHub - Mikolaj/horde-ad: Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation", which does automatic differentiation from the grounds up in Haskell, but requires tensor operations backends (in particular, one for GPUs) and both MLIR and XLA sound like obvious targets (currently I’m using hmatrix, which is bindings to BLAS/LAPACK on CPU).

1 Like

I’ve been implementing reverse mode differentiation using dual number based on this paper, which seems to work but you would need to annotate each function with their type. I’ll read the paper you linked.
What license you have in mind? I’ll gladly make a git repo.

2 Likes

Hey, that’s me :slight_smile: It’s not the same paper as the one that @Mikolaj mentioned, but the basic story is essentially the same: from dual-numbers-style AD we derive something that’s essentially classical tracing AD. (A comparison is in section 12.2 in our paper.)

I’m curious to see what you make of it! If you have any questions don’t hesitate to reach out, either via private message here or otherwise.

I’m not sure what you mean when you say “to annotate each function with their type”; are you referring to the interleave/deinterleave definitions for mutually recursive types (i.e. section 9)? Or something else?

3 Likes

What I meant when I said that is, because in the implementation I made, the functions need to be annotated by their type or else the compiler will encounter ambiguous type variable. So far I’ve only implemented the naive version.
Sorry if that causes confusion.

1 Like

I’ll read the paper you linked.

Since the papers describe similar algorithms, you may want to watch a video about that paper instead, which focuses on giving intuitions and motivation: https://www.youtube.com/watch?v=EPGqzkEZWyw

What license you have in mind?

My personal preferences aside (I’m secretly a gnu), BSD3 would incur least friction.

I’ll gladly make a git repo.

Thank you. Again, the least friction for me would be github, but I respect personal preferences.

the functions need to be annotated by their type or else the compiler will encounter ambiguous type variable

Oh yeah, I have the same (if what you mean are just ordinary type signatures). I think that’s normal once your types get complex enough. Fortunately, type signatures are a common recommended style even if the compiler copes without them.

I’ve made the github repo. Because github file size limit, it is missing one shared object file which you can find inside the zip file in the original post.

1 Like

Thank you very much. That’s all I needed to be able to follow your development and benefit from it (and mutually, I hope).

If, at your leisure, you could also write in the README where to get these binary files from originally, or how to build them, if you created them yourself, that would be helpful as well (and is a standard practice, I think, given that binary blobs are generally unsecure and hard to patch).

There is a fork of xla on my github where I’ve added a build target “//xla/pjrt/plugin:PJRTPlugin”. Just follow xla instructions to build it.
In the future I might make something more convoluted like exla.

1 Like

you might be interested in this Idris/OpenXLA project