pystanでdivergenceだったりmax_tree_depthなサンプルを観察

お急ぎな方向け説明

pystanでdivergenceなパラメータだとかmax_tree_depthに陥ってるパラメータってどんなのか気になる事ありますよね。そんなときは以下のようにして、抽出できます。

def get_divert_params(params):
    #get_sample_params["divertgent__"]にdivertしたかどうかのフラグが入っている
    #ただしこれはwarmupも含めた結果なので、warmup分をさっ引いて拾ってくる
    #chain毎に存在するものは後が楽になるように結合しておく
    divert=np.array([chain["divergent__"][fit.sim["warmup"]:] for chain in  fit.get_sampler_params()  ]).transpose().reshape(-1)
    didx=np.nonzero(divert)[0]
    #サンプルを取得するがpermutedをFalseにして、シャッフルされる前のデータを取ってくる
    #そうしないと、上のdivert情報と一致しない。
    p=fit.extract(params,permuted=False)
    ret={}
    #divertな要素のみを切りだす(didx)
    for pi in params:
        yy=p[pi]
        yy=yy.reshape( (-1, *yy.shape[2::] ))[didx]
        ret[pi]=yy
    return ret
def get_params_at_treedepth(params,depth):
    #get_sample_params["treedepth__"]にNUTS探索時のdepthが入ってる
    d=np.array([chain["treedepth__"][fit.sim["warmup"]:] for chain in  fit.get_sampler_params()  ]).transpose().reshape(-1)
    didx=np.nonzero(d==depth)[0]
    p=fit.extract(params,permuted=False)
    ret={}
    for pi in params:
        yy=p[pi]
        yy=yy.reshape( (-1, *yy.shape[2::] ))[didx]
        ret[pi]=yy
    return ret

何がしたいのか

以前HMCの挙動ってこんな風になってるよ。っていう説明をしました。尤度の曲率が急激に急になって、MCMCサンプラーのleapFrog部分がぶっ飛ぶのがdivergenceでした。そしてdivergenceが起きるほど急激な変化はないけれども、曲率が急な部分と緩やかな部分の差が大きい時に緩やかな部分で多量のステップ数を消費するのがmax_tree_depthの警告でした。

そういった警告の様子を直接可視化してみたいです。

Divergenceを見る普通の方法

普通にやるには、単純にarvizを使ってtrace_plotやpaire_plotをしてあげると各パラメータがどんな値を取るときにdivergenceを起こす値が観察できます。

例えば上のtrace_plotの黒いちょびひげみたいなマークがdivergenceを意味しています。ただ、関数のパラメータを可視化しても、その時にどんな関数だったか良く分からないですよね。そんなときは手動で可視化してあげる必要があります。

Divergentなサンプルだけを抽出

samplingしたときに帰ってくるfitにはサンプル中の様々な記録が残されています。この中で以下の値の中に、チェイン毎のdivergent(発散)したサンプルであったかが0/1で記録されています。

fit.get_sampler_params()[“divergent__”]

この値が1のサンプルを拾えばいいのですが、注意が2つあります。一つはこの変数がwarmup時の記録も持っているので、普通はそれを捨ててwarmup後の挙動のみを見る必要があります。2つめは、fit.extract()で得られるサンプルはランダムシャッフルされてるという事です。 シャッフルされる前の値をとるにはpermutate=Falseを指定する必要があります。そこにさえ気をつければ、簡単に値が拾って来れます。

以下は以前紹介した、赤色のノイズ混じりのガウス関数に事前分布なしでフィットさせた例です。

発散したサンプルのみをぷろっとしてみると、ちゃんとフィット出来てないサンプルのみで有る事が解ります。

Tree depth毎に可視化する

tree depth毎に可視化する方法はarvizには用意されていないようです。(みつけれてないだけ?)ということでこれも上の方法と同様に手動で切りだします。使ったtree_depthは

fit.get_sampler_params()[“treedepth__”]

に記録されています。今回の例ではあまり面白く無くて、tree-depth 0~9までは特に全体をやったものとの差はありません。

一方でmax_tree_depthである10のときはフィット出来ていないものが多いです。

単純に正解から非常に遠いところでは尤度の曲率も緩やかになる場合もあってそれが全部max_tree_depthのところに集まった感じです

非常に重い裾での例

また以前の説明でもでてきたような非常に裾の重い分布の場合をためしてみました。

中心に集まりがあるものの、その他の部分はほぼ一様分布じゃないか?ってぐらいに平らな分布担っています。こういう分布をtree_depthごとにみていくと、tree_depthが小さいところでは中心部分を主にサンプルしていて、depthが大きいときに周辺も見ている事がよくわかります。

途中省略

尤度の曲率変化と軌道の長さ(=tree_depth)の関係が直感的にも分かりやすくて面白いですよね。

ソース

ソースは以下にあります。Githubはipynbファイルをスグには表示してくれない事が多いので、気長にリロードしてみてください。

https://github.com/akirayou/stan_example/blob/master/divergence_and_tree.ipynb