EurekaMoments

This is my studying logs about Autonomous driving, Machine learning technologies and etc.

Juliaでの基本的なユニットテストのやり方メモ

目次

目的

  • Juliaでユニットテストをするためのテストコードの書き方を学ぶ。
  • 実際にテストコードを書いて実行し、動作を確認してみる。

公式ドキュメント

  • Juliaで標準で用意されているテスト機能のドキュメント。
  • 分からないことはここで調べればOK。

docs.julialang.org

@testマクロによるユニットテスト

  • まずはテスト対象となるコードを書く。
  • モジュール化し、これをテストコードから呼び出す。
  • 例えば足し算をするコードなら下記のようにする。
  • ファイル名: add.jl
module Add
    function add(a, b)
        return a + b
    end

    function main()
        a = 5
        b = 9
        println("$a + $b = $(add(a, b))")
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using .Add
    Add.main()
end
  • 次にテストコードを書く。
  • まずはファイルを作ってモジュールとして書き始める。
  • ファイル名: test_add.jl
module TestAdd
        
end
  • 続いて、標準のユニットテストパッケージである"Test"を呼び出すようにする。
module TestAdd
    using Test
end
  • 続いて、テスト対象のモジュールを呼び出せるようにする。
  • includeを使って対象モジュールが定義されたファイルを読み込む。
module TestAdd
    using Test
    include("add.jl")
end
  • 最後にテストコードを書く。
  • "Test"が提供する@testマクロを使うのが最もベーシックなやり方。
module TestAdd
    using Test
    include("add.jl")

    function main()
        @test Add.add(1, 3) == 4
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using. TestAdd
    TestAdd.main()
end
  • これを実行すると1 + 3の足し算の結果が4となる事を期待するテストが実行され、問題が無ければ何事もなく終了する。
  • 逆に、期待する結果を1にすると、実際の結果と異なるためテスト失敗となる。
  • 失敗した際は、下記のようにその内容や失敗した箇所を教えてくれる。 f:id:sy4310:20200917152008p:plain

@test_throwsマクロによるユニットテスト

  • 特定のエクセプションが発報されることを確認するためのテスト。
  • 例えば上記のadd関数は入力引数は2つだが、これを3つとかにするとMethodErrorが起きる。
Add.add(5, 2, 3)

f:id:sy4310:20200917153159p:plain

  • こういったエクセプションが正しく出るかを確認するために、下記のようなテストコードを書く。
@test_throws MethodError Add.add(5, 2, 3)
  • ここで、期待するものとは違うエクセプションが出た時はテスト失敗となる。
  • 例えば、BoundsErrorが期待されるのに対してMethodErrorが出た場合は下記のようになる。 f:id:sy4310:20200917155051p:plain

@testsetマクロによる複数テストのグループ化

  • 数多くのテストコードを一括で実行する際は、どこまではテストをクリアしていて、どこでテストを失敗したのかが分かる方がいい。
  • 各パッケージやカテゴリごとにグループ分けして、該当するテストコードをまとめることができる。
  • 例えばこちらのGitHubリポジトリにある全Juliaコードに対するテストコードは下記のようになる。

github.com

module TestsRunner
    # packages
    using Test

    # external modules
    include("../HelloWorld/test_hello_world.jl")
    include("../Module/test_module.jl")
    include("../Operations/test_operations.jl")
    include("../ComplexNumber/test_complex_number.jl")
    include("../String/test_string.jl")
    include("../Control/test_control.jl")
    include("../Type/test_types.jl")
    include("../Collection/test_collection.jl")
    include("../Array/test_array.jl")
    include("../Plots/test_plots.jl")
    include("../DataFrames/test_df.jl")
    include("../StacksAndQueues/test_stack_queue.jl")
    include("../Trees/test_trees.jl")
    include("../LinkedLists/test_linked_lists.jl")

    # methods
    function tests()
        @testset "JuliaPractice" begin
            TestHelloWorld.test()
            TestModule.test()
            TestOperations.test()
            TestCompNum.test()
            TestString.test()
            TestControl.test()
            TestTypes.test()
            TestCollection.test()
            TestArray.test()
            TestPlots.test()
            TestDf.test()
            TestStackQueue.test()
            TestTrees.test()
            TestLinkedLists.test()
        end
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using .TestsRunner
    @time TestsRunner.tests()
end
  • 全てのテストコードを"JuliaPractice"というグループで括っている。
  • また、それぞれのサブディレクトリにおけるテストコードを小さなグループに分けている。
  • 例えば、TestOperationsモジュールには四則演算関数の各テストコードが下記のようにまとめられている。
module TestOperations
    # packages
    using Test

    # extarnal modules
    include("add.jl")
    include("divide.jl")
    include("multiply.jl")
    include("subtract.jl")

    # method
    function test()
        @testset "Operations" begin
            @testset "Add" begin
                @test_nowarn Add.main()
            end
            @testset "Division" begin
                @test_nowarn Division.main()
            end
            @testset "Multiplication" begin
                @test_nowarn Multiplication.main()
            end
            @testset "Subtraction" begin
                @test_nowarn Subtraction.main()
            end
        end
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using. TestOperations
    TestOperations.test()
end
  • これを実行し、全テストをクリアした場合は、下記のように全4つのテストをクリアした事が示される。
    f:id:sy4310:20200917191854p:plain

  • また、4つの内どれかでテスト失敗となった場合は、下記のように失敗箇所と発報されたエクセプションが示される。 f:id:sy4310:20200917192618p:plain

  • テスト自体は最後まで一通り実行され、下記のようにクリアしたテストと失敗したテストの集計結果を示してくれる。 f:id:sy4310:20200917192648p:plain

  • これはリポジトリにある全テストを実行して、そのうちの一部を失敗させてみた場合。 f:id:sy4310:20200917193326p:plain

@test_nowarnマクロによるユニットテスト

  • とりあえず何かしらのエラーが検出された時に失敗とするテストマクロ。
  • テスト対象コードの振る舞いを一意に定めないざっくりしたテストをしたい場合に使える。
  • 例えば下記のように特に返り値を持たないような関数のテストをしたい場合。
module DfCreator
    # packages
    using DataFrames

    # data
    data_mat = [
    1 0.179324  "A"
    2 0.818923  "B"
    3 0.979487  "C"
    4 0.882494  "A"
    5 0.0530208 "B"
    ]

    # methods
    function mat_2_df()
        df_mat = DataFrame(data_mat)
        println("DataFrame from Matrix = ", df_mat)
        println("")
    end

    function add_row_name()
        df_rn = DataFrame(data_mat, [:A, :B, :C])
        println("DataFrame + Row name = ", df_rn)
        println("")
    end

    function rv_2_df()
        df_rv = DataFrame(
            A = 1:5,
            B = rand(5),
            C = ["A", "B", "C", "A", "B"])
        println("DataFrame from row vector = ", df_rv)
        println("")
    end

    function dict_2_df()
        dict = Dict(
            :A => 1:5,
            :B => rand(5),
            :C => ["A", "B", "C", "A", "B"])
        df_dict = DataFrame(dict)
        println("DataFrame from dictionary = ", df_dict)
        println("")
    end

    function main()
        mat_2_df()
        add_row_name()
        rv_2_df()
        dict_2_df()
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using .DfCreator
    DfCreator.main()
end
  • あるいは、下記のようにグラフの描画をするコードのテストにも使える。
module Plots2D
    # packages
    using PyPlot: plt

    # flag to switch show or not
    show_plot = true

    function set_show_plot(flag)
        global show_plot = flag
    end

    # method
    function plot_line()
        x = range(0, 2pi, length=100)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(x, sin.(x), "-", c="b", ms=10)
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_title("2D Line Plot Sample")
        ax.grid(true)
    end

    function plot_point()
        x = range(0, 2pi, length=100)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(x, cos.(x), ".", c="b", ms=10)
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_title("2D Point Plot Sample")
        ax.grid(true)
    end

    function plot_over()
        x = range(0, 2pi, length=100)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(x, sin.(x), "-", c="b", ms=10, label="sin")
        ax.plot(x, cos.(x), ".", c="r", ms=10, label="cos")
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_title("2D Over Plot Sample")
        ax.grid(true)
        ax.legend()
    end

    function plot_scatter()
        x = rand(100)
        y = x.^2 + randn(100).* abs.(x) * 0.5
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.scatter(x, y, marker="*", s=13)
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_title("2D Scatter Sample")
        ax.grid(true)
    end

    function plot_histogram()
        x = vcat(randn(100), 2 * randn(100).+ 4)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.hist(x, bins=30, color="g")
        ax.set_title("Histgram Sample")
        ax.grid(true)
    end

    function plot_pie_chart()
        x = 10:10:40
        labels = ["A", "B", "C", "D"]
        colors = ["r", "g", "b", "m"]
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.pie(x, labels=labels, colors=colors, autopct="%1.1f%%", counterclock=true, startangle=90)
        ax.legend(labels, fontsize=12, loc=1)
        ax.set_title("Pie Chart Sample")
    end

    function main()
        plot_line()
        plot_point()
        plot_over()
        plot_scatter()
        plot_histogram()
        plot_pie_chart()
        if show_plot == true
            plt.show()
        end
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using .Plots2D
    Plots2D.main()
end
  • グラフを表示させるとそこでテストが止まってしまうので、テストコード実行時にはグラフ表示だけは行わないようにさせる工夫をする。
  • グラフ表示するか否かを切り替えるフラグのSetterを用意し、テストコード側からそれを通じて表示フラグをFalseにしている。
module TestPlots
    # packages
    using Test

    # external modules
    include("plots_2d.jl")
    include("plots_3d.jl")
    include("subplot_2d.jl")
    include("animation_2d.jl")

    # methods
    function test()
        @testset "Plots" begin
            @testset "Plots2D" begin
                Plots2D.set_show_plot(false)
                @test_nowarn Plots2D.main()
            end
            @testset "Plots3D" begin
                Plots3D.set_show_plot(false)
                @test_nowarn Plots3D.main()
            end
            @testset "SubPlot2D" begin
                SubPlot2D.set_show_plot(false)
                @test_nowarn SubPlot2D.main()
            end
            @testset "Anime2D" begin
                Anime2D.set_show_plot(false)
                @test_nowarn Anime2D.main()
            end
        end
    end
end

if abspath(PROGRAM_FILE) == @__FILE__
    using .TestPlots
    TestPlots.test()
end

GitHub

  • 紹介したテストコードは全て下記リポジトリで公開済み。 github.com