I while back I encountered jax. In short, it’s a python numerical library that provide automatic differentiation on top of xla which allows it to run really fast. Its use of pure function in python intrigued me. So I have been trying to implement it in Haskell, a task more herculean than I thought.
Yam.zip
Yam is the working name for the project and I will change it some time later.
To build it you’ll need stablehlo and mlir.
You would want to install mlir onto your system (I installed it in my .local directory) anywhere you install it, you need the configure scripts to be able to run llvm-config to find the location of the libraries and headers.
Stablehlo is less trivial. Build stablehlo according to the instruction on their github and then in the configure script change the path to stablehlo location.
That ought to be all that you need to build Yam.
So far, Yam have only the equivalent of jax’s jit function, and in a much more limited form. I cannot get haskell type inference to work for it so you would need to annotate the function you feed jit and the function jit spit back out.
There are two main types that makes Yam works, Tracer and Tensor.
Tensor is a multidimensional array that contains data that you can print (only Float is supported currently) out. Meanwhile, a Tracer is an abstract Tensor which accumulates the operations that formed it to be feed into the xla compiler.
An example of how this looks like is in the test
directory (not example
). It’s very clunky.
test/Main.hs
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Main (main) where
import Foreignimport Yam.Tensor
import Yam.PjRt
import Yam.Jittest :: (Trace t, Num (t '[12] Float)) => t '[12] Float → t '[12] Float → t '[12] Float
test a b = (s + s) + (s + s)
where s = a + bmain :: IO ()
main = do
device ← head <$> clientAddressableDevices client
a :: Tensor '[12] Float ← withArray [0…11] $ tensorFromHostBuffer device
b :: Tensor '[12] Float ← withArray [1…12] $ tensorFromHostBuffer devicelet testjit :: (Trace t, Jit t (Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float) f) => f
testjit = jit (test :: Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float)print $ testjit a b
print $ test a b
print $ recip b
traceDebug (testjit :: Tracer '[12] Float → Tracer '[12] Float → Tracer '[12] Float)
clientDestroy client
return ()
This project is only a proof of concept at this point.