Stop passing the wrong shape into model with a unit test

When coding up a model. It can be easy to make a few trivial mistakes. Leading to serious errors when the training model later on. Leading to more time debugging your model. Only to find that your data was in the wrong shape. Or the layers were not configured properly.

Catching such mistakes earlier can make life so much easier.

I decided to do some googling around. And found out that you could use some testing libraries. To automatically catch those mistakes for you.

Now entering the wrong shape size through your layers. Should be a thing of the past.

Using unittest for your model

 I’m going to use the standard unittest library. I used from this article: How to Trust Your Deep Learning Code.

All credit goes to him. Have a look at his blog post. For a great tutorial on unit testing deep learning code.

 

This test simply checks if your data is the same shape that you intend to fit into your model.

Trust me.

You don’t know how many times. An error pops up that is connected to this. Especially when you're half paying attention.

This test should take minutes to set up. And can save you hours in the future.

dataiter = iter(trainloader)
images, labels = dataiter.next()
 
class MyFirstTest(unittest.TestCase):
  def test_shape(self):
      self.assertEqual(torch.Size((4, 3, 32, 32)), images.shape)#


This to run:

unittest.main(argv=[''], verbosity=2, exit=False)
 
test_shape (__main__.MyFirstTest) ... ok
 
----------------------------------------------------------------------
Ran 1 test in 0.056s
 
OK
<unittest.main.TestProgram at 0x7fb137fe3a20>

 

The batch number is hard-coded in. But this can be changed if we save our batch size into a separate variable.

 

The test with the wrong shape

Now let’s check out the test. When it has a different shape.

I’m just going to drop the batch dimension. This can be a mistake that could happen if you manipulated some of your tensors.

images = images[0,:,:,:]
images.shape
 
torch.Size([3, 32, 32])
 
unittest.main(argv=[''], verbosity=5, exit=False)


unit_test_blog_post.png

As we see, the unit test catches the error. This can save you time. As you won’t hit this issue later on when you start training.

I wanted to keep this one short. This is an area I’m still learning about. So I decided to share what I just learnt. And I wanted to have something you can try out straight away.

 

Visit these links.

These are far more detailed resources about unit testing for machine learning:

https://krokotsch.eu/cleancode/2020/08/11/Unit-Tests-for-Deep-Learning.html

https://towardsdatascience.com/pytest-for-data-scientists-2990319e55e6

https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765

https://towardsdatascience.com/unit-testing-for-data-scientists-dc5e0cd397fb

As I start to use unit testing more for my deep learning projects. I should be creating more blog posts. Of other short tests, you can write. To save you time and effort when debugging your model and data.

I used Pytorch for this. But can be done with most other frameworks. TensorFlow has its own test module. So if that’s your thing then you should check it out.

Other people also used pytest and other testing libraries. I wanted to keep things simple. But if you’re interested you can check out for yourself. And see how it can improve your tests.

If you liked this blog post. Consider signing up to my mailing list. Where I write more stuff like this