Easy Mocking in Golang Unit Tests with Interfaces

Easy Mocking in Golang Unit Tests with Interfaces

Mocking is one of the most common practices in unit testing. In Golang, we can use Interfaces to make mocking a breeze, so that we can cover our code easily and get the most coverage possible.

Background

One of those topics that always has a lively debate is what should be covered in a unit test?. As far as I'm concerned, a unit test covers the code you've written within a specific function and, ideally, nothing else. If you call a function from an imported lib in your code, then that's the type of thing you'd want to mock.

Mocking allows us to test different outputs of the thing we're mocking. Let's say the following code in our repo:

func PrintFullName(first, last string) error {
  if _, err := fmt.Fprintln(os.Stdout, first, + last); err != nil {
    return fmt.Errorf("failed to print full name: %s", err)
  }
  return nil
}

To fully cover this, we'd need to test both the happy path and the error case.

In its current form, that's not easy to do. Without explicitly mocking the functionality in fmt.Fprintln(os.Stdout, first + " " + last), we'll have a hard time predictably ensuring that our error case is handled as we expect it to be.

Refactoring with Interfaces

It's not often said, but writing code for tests is one of the most important skills in software development. Without testing in mind, developers can often write code that accomplishes the functionality they're looking for - that is, itWorksFine() - but then they don't know how to properly test it.

Let's refactor the code above so that it's easier to test:

// adding an interface as an argument 
func PrintFullName(out io.Writer, first, last string) error {
  if _, err := fmt.Fprintln(out, first, last); err != nil {
    return fmt.Errorf("failed to print full name: %s", err)
  }
  return nil
}

Notice that we're now passing in an io.Writer as a new argument. The io.Writer interface simply declares that whatever adopts it must have a specific Write function signature attached to it:

type Writer interface {
    Write(p []byte) (n int, err error)
}

That means that the underlying code that gets executed when Write is called can be anything we want it to be, including something that throws an error.

Writing Our Tests

First, we'll test our happy path. To do that, we want to ensure that the write string is being sent to the output writer.

We could pass in os.Stdout as our writer, but that comes with some overhead. We'd have to ensure that nothing else is being written to it as part of our test execution so that we could ensure that the string being fed to it matches what we expect.

Instead, we can leverage that io.Writer interface to pass in something else, that also has a Write function signature: bytes.Buffer.

func TestPrintFullName(t *testing.T) {
    buff := &bytes.Buffer{}
    err := PrintFullName(buff, "This", "That")
    if err != nil {
        t.Fatalf("unexpected error: %s", err)
    }
    bb := buff.Bytes()
    if string(bb) != "This That\n" {
        t.Fatalf("expected printed value \"This That\n\", got %q", string(bb))
    }
}

That allowed us to ensure that we've achieved the functionality we want. So now it's on to the error case.

To do that, we'll need to create a struct with a Write a function that automatically returns a new error for us, so that it gets passed back from fmt.Fprintln.

type errorWriter struct {
    ErrMsg string
}

func (ew errorWriter) Write(p []byte) (n int, err error) {
    return 0, fmt.Errorf(ew.ErrMsg)
}

This errorWriter matches the io.Writer interface, so we can pass it into our PrintFullName function and it will throw an error every time. That allows us to test the error case so that we can get full coverage.

func TestPrintFullName_Error(t *testing.T) {
    ew := &errorWriter{
        ErrMsg: "ERROR",
    }
    err := PrintFullName(ew, "This", "That")
    if err == nil {
        t.Fatalf("expected error")
    }
    if err.Error() != "failed to print full name: ERROR" {
        t.Fatalf("expected error value \"failed to print full name: ERROR\", got %q", err.Error())
    }
}