How to create quickcheck tests for dependently typed matrix addition without specifying each dimension statically?

Hi! I have been attempting to write a test case with dependently typed matrices, and have been struggling to find a way to write them without statically specifying the dimensions. This isn’t ideal because I want to test on many different dimensions.

Here’s what I currently have

-- Need to specify dimensions :(
sensStaticHMatrixPlus = SensStaticHMatrix.plus @3 @4 @L1 @Diff

-- This generates tests automatically
$( sensCheck
    "passing_tests"
    [ 'Correct.solo_double
    ...
    , 'sensStaticHMatrixMult
    ]
 )

This isn’t ideal I would like to test adding two matrices with different sized columns and rows (of course with the same dimensions for addition).

Here’s one of my attempts based on this post How to create Arbitrary instance for dependent types? - #7 by tomjaguarpaw
I was able to make a arbitrary generator using that.

data SomeMatrix c n s where
  SomeMatrix :: (KnownNat (x :: Nat), KnownNat (y :: Nat)) => SensStaticHMatrix (x :: TL.Nat) (y :: TL.Nat) c n s -> SomeMatrix c n s


example :: forall c n s. Gen (SomeMatrix c n s)
example = do
  x' <- arbitrary
  y' <- arbitrary
  reifyNat x' $ \(x :: Proxy x) -> reifyNat y' $ \(y :: Proxy y) -> do
    elems <- replicateM ( fromInteger x' * fromInteger y') (arbitrary @Double)
    pure $ SomeMatrix @x @y $ SensStaticHMatrixUNSAFE $ matrix elems

-- Generate two matrices of the same dimensions
-- Useful for addition
exampleTwo :: forall c n s1 s2. Gen (SomeMatrix c n s1, SomeMatrix c n s2)
exampleTwo = do
  x' <- arbitrary
  y' <- arbitrary
  reifyNat x' $ \(x :: Proxy x) -> reifyNat y' $ \(y :: Proxy y) -> do
    elems1 <- replicateM ( fromInteger x' * fromInteger y') (arbitrary @Double)
    elems2 <- replicateM ( fromInteger x' * fromInteger y') (arbitrary @Double)
    pure (SomeMatrix @x @y $ SensStaticHMatrixUNSAFE $ matrix elems1, SomeMatrix @x @y $ SensStaticHMatrixUNSAFE $ matrix elems2)

However when I attempt to use it I think the x and y aren’t the same. I think maybe my use case is slightly different.

test = do
  (SomeMatrix m1, SomeMatrix m2) <- generate $ exampleTwo @L2 @Diff
  pure $ (plus m1 m2) == (plus m1 m2) -- toy example not the actual property I check
--                ^^^
-- Couldn't match type ‘y1’ with ‘y’
--   Expected: SensStaticHMatrix x y L2 Diff s20
--     Actual: SensStaticHMatrix x1 y1 L2 Diff s20

My attempted solutions can be found here: SensCheck/src/SensStaticHMatrix.hs at f89c98c40c3d7851a04cf5f44a9abf3d256261e2 · uvm-plaid/SensCheck · GitHub

I would greatly appreciate some help with this. Thanks!

2 Likes

You’ll have to do the same existential type trick again for pairs of matrices with matching dimensions. So instead of

data SomeMatrix c n s where
  SomeMatrix ::
    (KnownNat (x :: Nat), KnownNat (y :: Nat)) =>
    SensStaticHMatrix (x :: TL.Nat) (y :: TL.Nat) c n s ->
      SomeMatrix c n s

for this case you’ll want something like

data SomeMatrices c n s1 s2 where
  SomeMatrices ::
    (KnownNat (x :: Nat), KnownNat (y :: Nat)) =>
    SensStaticHMatrix (x :: TL.Nat) (y :: TL.Nat) c n s1 ->
    SensStaticHMatrix (x :: TL.Nat) (y :: TL.Nat) c n s2 ->
      SomeMatrices c n s1 s2

and you’ll return one of those from your generator instead of the pair (SomeMatrix c n s1, SomeMatrix c n s2).

You’ll need yet another type if you want to do a pair fit for matrix multiplication, and so on.

1 Like

I suggest to avoid “Some” wrappers and instead keep the size variables in scope.

test2 = generate $ do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  SomeNat @z _ <- arbitraryKnownNat
  m1 <- exampleThree @x @y @L2 @Diff
  m2 <- exampleThree @y @z @L2 @Diff
  pure $ (mult m1 m2) == (mult m1 m2)

Fulll code at: Generate matrices that can be added and multiplied · uvm-plaid/SensCheck@8312227 · GitHub

By the way, I appreciate you providing full source code. It made it much easier for me to know what to suggest. However, I couldn’t get the project to build with cabal because grenade's dependencies on Hacakge are severely outdated. I’ve reported that at Adjust bounds by sorki · Pull Request #107 · HuwCampbell/grenade · GitHub

3 Likes

Thanks! It looks like it compiles on my end using Stack! I am building with the latest commit from their repo though.

There’s one more thing I’m trying to figure out how to use it in this context. I think I originally had it set up wrong.

testStaticPlus = quickCheck (forAll genTwo plusProp)

genTwo = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  m1 <- exampleThree @x @y @L2 @Diff
  m2 <- exampleThree @x @y @L2 @Diff
  pure (m1, m2)

The problem is that I need to use it with forAll using the Gen type and property, and not sure how to use it outside the Gen monad. Here’s my attempt at the use site: SensCheck/test/Spec.hs at main · uvm-plaid/SensCheck · GitHub

Thanks!

I’d do it like this

genTwo ::
  (forall n m.
   KnownNat n =>
   KnownNat m =>
   SensStaticHMatrix n m L2 Diff s ->
   SensStaticHMatrix n m L2 Diff s ->
   SensStaticHMatrix n m L2 Diff s ->
   SensStaticHMatrix n m L2 Diff s ->
   Gen r) ->
  Gen r
genTwo cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  m1 <- exampleThree @x @y @L2 @Diff
  m2 <- exampleThree @x @y @L2 @Diff
  m3 <- exampleThree @x @y @L2 @Diff
  m4 <- exampleThree @x @y @L2 @Diff
  cond m1 m2 m3 m4
testStaticPlus =
  quickCheck
    (forAll
      (SensStaticHMatrix.genTwo
         (\m1 m2 m3 m4 -> pure (plusProp m1 m2 m3 m4))
      )
      id
    )

See Test it · uvm-plaid/SensCheck@f33427a · GitHub

1 Like

It looks like it works on my end. Thanks so much Tom this is hugely helpful! I was stuck on this for quite a while and I really appreciate you helping unblock me!

1 Like

You’re welcome! I’ve become very familiar with this style because we’ve adopted it at work (Groq). It works very well, but requires becoming familiar with new idioms (such as Some).

1 Like

That’s awesome I’ve seen some of the press on Groq looks like you all are doing some cool things there!

1 Like

I ran into one more thing maybe slightly different in this case

scalarMult :: forall x y scalar cmetric nmetric s.
              (KnownNat x, KnownNat y, KnownNat scalar) => 
              SensStaticHMatrix x y cmetric nmetric s ->
              SensStaticHMatrix x y cmetric nmetric (ScaleSens s scalar)
scalarMult m1 = SensStaticHMatrixUNSAFE $ unSensStaticHMatrix m1 * fromInteger (TL.natVal (Proxy @scalar))

My attempt:

scalarMult :: forall x y scalar cmetric nmetric s.
              (KnownNat x, KnownNat y, KnownNat scalar) => 
              SensStaticHMatrix x y cmetric nmetric s ->
              SensStaticHMatrix x y cmetric nmetric (ScaleSens s scalar)
scalarMult m1 = SensStaticHMatrixUNSAFE $ unSensStaticHMatrix m1 * fromInteger (TL.natVal (Proxy @scalar))
----

scalarMult :: forall x y scalar cmetric nmetric s.
              (KnownNat x, KnownNat y, KnownNat scalar) => 
              SensStaticHMatrix x y cmetric nmetric s ->
              SensStaticHMatrix x y cmetric nmetric (ScaleSens s scalar)
scalarMult m1 = SensStaticHMatrixUNSAFE $ unSensStaticHMatrix m1 * fromInteger (TL.natVal (Proxy @scalar))

I’m not sure what the question is or what you’ve pasted. Is that the same code three times?

Oh sorry that was an error here’s what I meant to post for the 2nd code blob:


genScalarMult ::
  (forall x y scalar.
   KnownNat x =>
   KnownNat y =>
   KnownNat scalar =>
   SensStaticHMatrix x y L2 Diff s ->
   SensStaticHMatrix x y L2 Diff s ->
   Proxy scalar ->
   Gen (Proxy scalar -> r)) ->
  Gen (Proxy scalar -> r)
genScalarMult cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  SomeNat @scalar scalar <- arbitraryKnownNat
  m1 <- gen @x @y @L2 @Diff
  m2 <- gen @x @y @L2 @Diff
  cond scalar m1 m2
  --   Expected: Gen (Proxy @{k} scalar -> r)
  --   Actual:   Gen (Proxy @{Nat} scalar0 -> r)

testStaticScalarMult =
  quickCheck
    (forAll
      (SensStaticHMatrix.genScalar
         (\proxy m1 m2 -> pure (scalarMultProp m1 m2))
          -- ^^ Not using this right now
      )
      id
    )

And here’s the function again:

scalarMult :: forall x y scalar cmetric nmetric s.
              (KnownNat x, KnownNat y, KnownNat scalar) => 
              SensStaticHMatrix x y cmetric nmetric s ->
              SensStaticHMatrix x y cmetric nmetric (ScaleSens s scalar)
scalarMult m1 = SensStaticHMatrixUNSAFE $ unSensStaticHMatrix m1 * fromInteger (TL.natVal (Proxy @scalar))

It’s slightly different because it is doing a scalar multiplication. For my use case I need to track the amount I am scaling by in the type system as scalar. This is used in the return type. Which is different from my first use case of using arbitrary dimensions.

I still don’t understand what the question is. Is the problem that genScalarMult doesn’t type check?

There are a few reasons that the version of genScalarMult you gave can’t work. Firstly, because it’s equivalent to this:

genScalarMult ::
  forall {k} (scalar0 :: k).
  (forall x y (scalar :: Nat).
   KnownNat x =>
   KnownNat y =>
   KnownNat scalar =>
   SensStaticHMatrix x y L2 Diff s ->
   SensStaticHMatrix x y L2 Diff s ->
   Proxy scalar ->
   Gen (Proxy scalar -> r)) ->
  Gen (Proxy scalar0 -> r)
genScalarMult cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  SomeNat @scalar scalar <- arbitraryKnownNat
  m1 <- gen @x @y @L2 @Diff
  m2 <- gen @x @y @L2 @Diff
  cond scalar m1 m2

The two “scalar” type arguments have nothing to do with each other! But I don’t understand why there’s a Proxy scalar argument inside Gen. It doesn’t seem to serve any purpose, so you can just get rid of it.

genScalarMult ::
  (forall x y scalar.
   KnownNat x =>
   KnownNat y =>
   KnownNat scalar =>
   SensStaticHMatrix x y L2 Diff s ->
   SensStaticHMatrix x y L2 Diff s ->
   Proxy scalar ->
   Gen r) ->
  Gen r
genScalarMult cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  SomeNat @scalar scalar <- arbitraryKnownNat
  m1 <- gen @x @y @L2 @Diff
  m2 <- gen @x @y @L2 @Diff
  cond scalar m1 m2

That still doesn’t work because you’re simply applying the arguments to cond in the wrong order. This works:

genScalarMult ::
  (forall x y scalar.
   KnownNat x =>
   KnownNat y =>
   KnownNat scalar =>
   SensStaticHMatrix x y L2 Diff s ->
   SensStaticHMatrix x y L2 Diff s ->
   Proxy scalar ->
   Gen r) ->
  Gen r
genScalarMult cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  SomeNat @scalar scalar <- arbitraryKnownNat
  m1 <- gen @x @y @L2 @Diff
  m2 <- gen @x @y @L2 @Diff
  cond m1 m2 scalar

May I ask why do you want to randomize the dimensions? For performance benchmarking, you could set dimensions by powers of 10, eg one test for d=100, one for d=1000 and so on. Unless you have some mmult logic (outside your control) that has different optimizations according to the size, which you want to discover by sampling.

1 Like

Hey so I was able to get that to compile thanks! However I realized that this doesn’t really work for my case. The stuff I am doing is generating quickcheck properties from the type signature. In the former case the dimensions did not matter for that property but in this case the scalar does effect the property. I don’t think I can really make it work (well possibly but maybe more work then I would want to do right now).

Sorry about that! This was really helpful!

That’s a good point. I think I would probably want to show that my property holds for any arbitrary dimension. I’m not sure how quickcheck actually generates the number of elements in a standard list by default for instance but I would imagine it’s randomized. I could be wrong though.

I don’t really understand. The scalar is passed in as an argument to cond, exactly so that the property can depend on it! In any case, if you want to pass the scalar in separately, you can do this

genScalarMult ::
  KnownNat scalar =>
  Proxy scalar ->
  (forall x y.
   KnownNat x =>
   KnownNat y =>
   SensStaticHMatrix x y L2 Diff s ->
   SensStaticHMatrix x y L2 Diff s ->
   Proxy scalar ->
   Gen r) ->
  Gen r
genScalarMult scalar cond = do
  SomeNat @x _ <- arbitraryKnownNat
  SomeNat @y _ <- arbitraryKnownNat
  m1 <- gen @x @y @L2 @Diff
  m2 <- gen @x @y @L2 @Diff
  cond m1 m2 scalar