ifelse() と関数の分離による高速化 -- Base.randn() を題材にして --

JuliaLang Advent Calendar 2014 の 18 日目の記事です。遅れてすみません。 今回も速度向上の記事です。

概要

Base.randn() に最近なされた速度向上のための更新(#9126, #9132)を通じて、 ifelse() 関数や関数分離の有用性を示します。

サンプルコード

https://gist.github.com/yomichi/62d5f121ab11831b0759

Base.randn() について

正規分布はホワイトノイズ の分布として誤差解析によく使われ、 中心極限定理によって平均値の分布にもなり、 さらには拡散方程式の解として直接現れるなど、 自然科学や工学の分野では一様分布と同じぐらい重要な分布となっています。 また、手で積分の計算ができるほどに非常に性質が良いため、 様々なランダム性を含む理論モデルで、そのランダム性の分布として使われています。 モンテカルロ法などの数値シミュレーションにもよく使われるので、 正規分布に従うガウス乱数を高速に生成することはとても大事になります。

Base.randn() は 平均 μ = 0, 分散 σ^2 = 1正規分布 N(0,1) に従うガウス乱数を生成する関数です。 任意の平均、分散のガウス乱数を使いたい場合、μ + σ * randn() で変換できます。

ガウス乱数を生成するアルゴリズムとして、Box-Muller 法が(簡単なので)有名ですが、 Julia ではより洗練されて高速な Ziggurat 法 が用いられています *1四辻「計算機シミュレーションのための確率分布乱数生成法」によると*2、 Ziggurat 法は Box-Muller のおよそ5-6 倍速いようです。 アルゴリズムの詳細は本記事の範囲ではないので、 興味のある方はWikipedia や原著論文、上記教科書を参照してください。 もちろん、Julia の実装を読み解くというのもありだと思います:D

最近、#9126#9132 によって、 Base.randn() が更新されました。 これは恐ろしく単純な変更なのですが、 環境にもよりますが、2倍近くまで加速しています。 ここで使われた考えは決して Base.randn() に特有なものではなく、 普段のパフォーマンスチューニングにも役立つものとなっているので*3、本記事ではそれを紹介します。

初期バージョン

Ziggurat 法には必要な定数や関数がいくつかあるのですが、 あまりにも多いので Base.Random モジュールの中から拝借してきます *4

## 必要なパラメータや関数を持ってくる
import Base.Random: GLOBAL_RNG, rand_ui52
import Base.Random: ki, wi, fi, ke, we, fe  # それぞれ256 個のFloat64 の配列
import Base.Random: ziggurat_nor_r, ziggurat_nor_inv_r, ziggurat_exp_r

さて、加速する前の最初のバージョンは次のとおりです。

## 初期バージョン
@inline function randn_first(rng::AbstractRNG=GLOBAL_RNG)
    @inbounds begin
        r = rand_ui52(rng)
        rabs = int64(r>>1)
        idx = rabs & 0xFF

        # Point 1
        # 分岐命令
        x = (r % Bool ? -rabs : rabs)*wi[idx+1]

        rabs < ki[idx+1] && return x # 99.3% の確率で、ここでreturn

        # Point 2
        # ここに来る確率は 0.7 %
        @inbounds if idx == 0
          while true
            xx = -ziggurat_nor_inv_r*log(rand(rng))
            yy = -log(rand(rng))
            yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
          end
        elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
          return x # return from the triangular area
        else
          return randn_first(rng)
        end
    end
end

アルゴリズムの詳細は重要ではないので述べません。 この randn_first()N = 10000*10000 回呼び出すのにかかる時間を計測すると、

julia> @time for i in 1:N; randn_first(); end
elapsed time: 1.835435398 seconds (0 bytes allocated)

となりました*5。 加速にあたって改造する場所は2箇所、Point 1 とPoint 2 です。

Point 1 : ifelse(c :: Bool, x, y)

まずPoint 1 ですが、ここでは三項条件演算子による分岐が生じます *6。 ここを次のように ifelse(c :: Bool ,x, y) で 書き換えます。

## Point 1:
## 分岐命令(if ... end / ?: ) をifelse() 関数に置き換え

@inline function randn_ifelse(rng::AbstractRNG=GLOBAL_RNG)
    @inbounds begin
        r = rand_ui52(rng)
        rabs = int64(r>>1)
        idx = rabs & 0xFF

        # Point 1
        # ifelse() 関数(分岐命令なしで値が選択される!)
        x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]

        rabs < ki[idx+1] && return x # 99.3% の確率で、ここでreturn

        @inbounds if idx == 0
          while true
            xx = -ziggurat_nor_inv_r*log(rand(rng))
            yy = -log(rand(rng))
            yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
          end
        elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
          return x # return from the triangular area
        else
          return randn_ifelse(rng)
        end
    end
end

これについて計算時間を測定すると、

julia> @time for i in 1:N; randn_ifelse(); end
elapsed time: 1.1361549 seconds (0 bytes allocated)

と、たったの1行を書き換えただけで、およそ6 割程度も速くなりました。

今回使った ifelse(c::Bool, x, y) 関数 は、基本的には if ... else ... end?: を再現した関数で、 c == true ならば x を、 c == false ならば y を返すのですが、

  1. 関数なので xy はあらかじめ評価される
  2. LLVM IR における select instructionコンパイルされる

という2点が他の構文と大きく異なります。 特に2つ目が非常に重要で、これにより ifelse 関数は(LLVM IR のレベルで)分岐命令を生成しなくなり、 特に最内ループで使うことで性能が向上する 可能性 があります。 これらの特徴は、生成されたLLVM IR を実際に眺めれば確認できて、 まず条件構文は

julia> condition_operator(N) = iseven(N) ? div(N,2) : 3N+1
condition_operator (generic function with 1 method)

julia> code_llvm(condition_operator, (Int,) )

define i64 @julia_condition_operator_64603(i64) {
top:
  %1 = and i64 %0, 1, !dbg !577
  %2 = icmp eq i64 %1, 0, !dbg !577
  br i1 %2, label %pass2, label %L, !dbg !577

pass2:                                            ; preds = %top
  %3 = sdiv i64 %0, 2, !dbg !577
  ret i64 %3, !dbg !577

L:                                                ; preds = %top
  %4 = mul i64 %0, 3, !dbg !577
  %5 = add i64 %4, 1, !dbg !577
  ret i64 %5, !dbg !577
}

となります。br が分岐命令で、変数 %2 の真偽で %pass2:%L: へとジャンプします。 返り値はジャンプしてから計算していることがわかりますね。 一方のifelse() 関数は

julia> condition_ifelse(N) = ifelse(iseven(N), div(N,2), 3N+1)
condition_ifelse (generic function with 1 method)

julia> code_llvm(condition_ifelse, (Int,) )

define i64 @julia_condition_ifelse_64604(i64) {
pass2:
  %1 = mul i64 %0, 3, !dbg !580
  %2 = add i64 %1, 1, !dbg !580
  %3 = and i64 %0, 1, !dbg !580
  %4 = icmp ne i64 %3, 0, !dbg !580
  %5 = sdiv i64 %0, 2, !dbg !580
  %6 = select i1 %4, i64 %2, i64 %5, !dbg !580
  ret i64 %6, !dbg !580
}

となります。こちらでは先に条件の真偽値 %4 と2通りの返り値%2%5 とを計算しておいて、 select にこれらを渡して結果を返しています *7

ifelse を使うと速くなるケースが割とあるのですが、 xy の両方の式をかならず評価しないといけないため、 稀にしか選ばれない方の式が重い場合、 例えばPoint 2 で示すような rand() < 0.993 ? x : heavy_function(x) などでは、逆に遅くなります。 また、再帰関数で使うとスタックオーバーフローで落ちますし、 さらには前段の条件によってはどちらかでエラーを吐くような場合、 例えば isdefined(:x) ? x : 0 なんかはifelse の呼び出しにすらたどりつかずにエラーで落ちます。 何事も適材適所があるのです。。。

Point 2 : 関数の分離

次にPoint 2 ですが、この直前にある

rabs < ki[idx+1] && return x # 99.3% の確率で、ここでreturn

のおかげで、この部分にはおよそ0.7 % という小さな確率でしか到達しません *8。 しかしながら、この稀にしか到達しない分岐には、まだまだ長い処理が待っています *9。 こういう時には、長い部分を別の関数に分けて、関数呼び出しに置き換えることで、 性能が向上する 場合が あります。

## Point 2:
## およそ0.7% でしか到達しないのにやたら長い部分を
## 別関数に分離

@inline function randn_separate(rng::AbstractRNG=GLOBAL_RNG)
    @inbounds begin
        r = rand_ui52(rng)
        rabs = int64(r>>1)
        idx = rabs & 0xFF
        x = (r % Bool ? -rabs : rabs)*wi[idx+1]
        rabs < ki[idx+1] && return x # 99.3% の確率で、ここでreturn

        # Point 2
        # ここに来る確率は0.7 %
        # 関数呼び出しに置き換えて関数の長さを減らす
        return separate_unlikely(rng, idx, rabs, x)
    end
end

# Point 2
# たまにしか使われないくせに長い部分
function separate_unlikely(rng, idx, rabs, x)
    @inbounds if idx == 0
        while true
            xx = -ziggurat_nor_inv_r*log(rand(rng))
            yy = -log(rand(rng))
            yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
        end
    elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
        return x # return from the triangular area
    else
        return randn_separate(rng)
    end
end

こいつの実行時間を計測すると

julia> @time for i in 1:N; randn_separate(); end
elapsed time: 1.629525055 seconds (0 bytes allocated)

となり、最初のバージョンから比べて1割程度速くなります *10

Point 1+2 : 現在のrandn()

上記2つの加速を合わせたのが次のコードで、実際にBase.randn() として使われているものになります。

## 20141218 現在のrandn()
## Point 1+2
@inline function randn_final(rng::AbstractRNG=GLOBAL_RNG)
    @inbounds begin
        r = rand_ui52(rng)
        rabs = int64(r>>1)
        idx = rabs & 0xFF

        # Point 1
        x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]

        rabs < ki[idx+1] && return x # 99.3% の確率で、ここでreturn

        # Point 2
        return randn_unlikely(rng, idx, rabs, x)
    end
end

function randn_unlikely(rng, idx, rabs, x)
    @inbounds if idx == 0
        while true
            xx = -ziggurat_nor_inv_r*log(rand(rng))
            yy = -log(rand(rng))
            yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
        end
    elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
        return x # return from the triangular area
    else
        return randn_final(rng)
    end
end

実行時間は

julia> @time for i in 1:N; randn_final(); end
elapsed time: 1.014320666 seconds (0 bytes allocated)

となり、最初と比べるとおよそ8割、2倍近くまで速度が向上します。

まとめ

ifelse() 関数による分岐の書き換えはうまくハマるとかなり速くなります。 もちろん大した効果の得られない時もありますし、場合によっては 遅くなることもありますが、 パフォーマンス測定をしながら程々に使っていくとよいでしょう。

関数の分解をすることで可読性や保守性、再利用性が高まりますが、 今回のように速度面でも有利になることがあるので、 がんがん切り分けていきましょう。 もちろん、あまりやり過ぎると可読性や保守性が落ちるし 関数呼び出しのオーバーヘッドも現れうる*11ので、 やはりパフォーマンス測定をしつつ、何事も程々に*12

*1:ちなみに他の言語やライブラリでは、GSL (C) やnumpy (Python) がBox-Muller, Boost (C++) がZiggurat でした。

*2:どこもかしこも在庫切れのようですが。。。

*3:それと単純に感動したので

*4:これらはexport されていないので、いつの日か名前が変わったり別のものになったりなくなったりするかもしれません。あまりマネしないようにしましょう:P

*5:計算時間は同時に走っている関係ないプロセスから かなり影響を受けるため、本当に正確に測定するならば、 できるだけ他のプロセスを殺してから測定したり、 その上で何度も測定して最小値を取ったりする必要があります。 ただし、今回はそこまでやっていないので、 速度向上の倍率などは参考程度に捉えてください。

*6: 大多数の人は見慣れないかと思いますが、 r % Bool isodd(r) と同じです。 生成されるLLVM IR も全く同じです。

*7:アセンブラは私がよくわかっていないので見ないことにします(ぇー

*8: ここでは && 演算子によるショートカットを利用しています。 生成されるLLVM IR は if 構文と同じで、 標準ライブラリにおいては一行でif 構文を書くのに好んで用いられるようです。

*9:つまり、Point 1 のようにifelse 関数を使うと確実に遅くなります。 再帰っぽいですが、再帰するかしないかは乱数で決まるので、一応帰ってきます。

*10:結構揺らぎますが。。。

*11:一応LLVM は自動インライン化もあるようですが

*12:大事なことなので2回言いました