Extracting and visualizing tidy draws from rstanarm models

Matthew Kay

2022-12-16

イントロダクション

この vignette は、tidybayesggdist パッケージを使用して、 tidy データフレームを抽出し、可視化する方法について説明する。モデル変数の事後分布からのドロー、平均、および rstanarm からの予測値である。より一般的な tidybayes の紹介と、汎用のベイズモデリング言語 (Stan や JAGS など) での使用については、vignette("tidybayes") を参照。

セットアップ

このvignetteを実行するには、以下のライブラリが必要である。

library(magrittr)
library(dplyr)
library(purrr)
library(forcats)
library(tidyr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(rstanarm)
library(RColorBrewer)

theme_set(theme_tidybayes() + panel_border())

これらのオプションは、Stanの動作を高速化するためのものである。

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

サンプルデータセット

tidybayes を実証するために、5つの条件からそれぞれ10個のオブザベーションを持つ単純なデータセットを使用することにする。

set.seed(5)
n = 10
n_condition = 5
ABC =
  tibble(
    condition = rep(c("A","B","C","D","E"), n),
    response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
  )

データのスナップショットはこのような感じである。

head(ABC, 10)
condition response
A -0.4204277
B 1.6921797
C 1.3722541
D 1.0350714
E -0.1442796
A -0.3014540
B 0.7639168
C 1.6823143
D 0.8571132
E -0.9309459

これは典型的な整頓されたフォーマットのデータフレーム: 1行に1つの観測値である。グラフィカルに

ABC %>%
  ggplot(aes(y = condition, x = response)) +
  geom_point()

#モデル {#model}

大域平均に向かって収縮する階層的モデルを当てはめよう。

m = stan_lmer(response ~ (1|condition), data = ABC, 
  prior = normal(0, 1, autoscale = FALSE),
  prior_aux = student_t(3, 0, 1, autoscale = FALSE),
  adapt_delta = .99)

結果はこのようになる。

m
## stan_lmer
##  family:       gaussian [identity]
##  formula:      response ~ (1 | condition)
##  observations: 50
## ------
##             Median MAD_SD
## (Intercept) 0.6    0.5   
## 
## Auxiliary parameter(s):
##       Median MAD_SD
## sigma 0.6    0.1   
## 
## Error terms:
##  Groups    Name        Std.Dev.
##  condition (Intercept) 1.13    
##  Residual              0.56    
## Num. levels: condition 5 
## 
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg

spread_draws を使ってフィットからドローをtidyフォーマットで抽出する

さて、結果が出たので、楽しいことが始まります:整頓されたフォーマットで描画を取り出すことである!まず、 関数を使って、生のモデル変数名のリストを取得する。まず、get_variables() 関数を使って、生のモデル変数名のリストを取得し、モデルからどの変数を抽出できるかを知ることにする。

get_variables(m)
##  [1] "(Intercept)"                              "b[(Intercept) condition:A]"              
##  [3] "b[(Intercept) condition:B]"               "b[(Intercept) condition:C]"              
##  [5] "b[(Intercept) condition:D]"               "b[(Intercept) condition:E]"              
##  [7] "sigma"                                    "Sigma[condition:(Intercept),(Intercept)]"
##  [9] "accept_stat__"                            "stepsize__"                              
## [11] "treedepth__"                              "n_leapfrog__"                            
## [13] "divergent__"                              "energy__"

ここで、(Intercept) は大域的な平均値、b のパラメータは各条件におけるその平均値からのオフセットである。これらのパラメータが与えられると

各行が b [(Intercept) condition:A] , b [(Intercept) condition:B] , ...:C] , ...:D] , or ...:E] の どちらか一方からの draw であるデータフレームが必要だろう。 , そして、行がどの chain/iteration/draw から来たか、そしてどの条件 ( A to E ) のものであるかのインデックスを持つ列があるデータフレームが欲しいと思うだろう。これにより、条件ごとにグループ化された量を簡単に計算したり、ggplotを使用して条件ごとにプロットを生成したり、あるいはドローとオリジナルデータをマージしてデータと後置をプロットしたりすることができるようになる。

tidybayes の主力は spread_draws() 関数で、この関数はこの抽出を行ってくれる。この関数には、モデル変数とそのインデックスを整頓されたデータフレームに抽出するために使用できる簡単な仕様書式が含まれている。

整頓されたフォーマットのデータフレームで、変数のインデックスを別の列に集める

このようなパラメータが与えられると

b [(Intercept) condition:D]

このような列の仕様で spread_draws() を提供することができる。

b [term,group]

ここで、term(Intercept) に、groupcondition:D に対応する。この指定に対して spread_draws() が行うことについては、特に不思議なことはない。この仕様では、パラメータのインデックスをカンマとスペースで分割している ( sep 引数を変更すれば、他の文字で分割することもできる)。そして、その結果得られたインデックスに、順番に列を割り当てることができるのである。つまり b [(Intercept) condition:D] はインデックス (Intercept)condition:D を持ち、spread_draws() はこれらのインデックスを列として抽出し、b から得られるドローのデータフレームを整頓することができる。

m %>%
  spread_draws(b[term,group]) %>%
  head(10)
## Warning: `gather_()` was deprecated in tidyr 1.2.0.
## ℹ Please use `gather()` instead.
## ℹ The deprecated feature was likely used in the tidybayes package.
##   Please report the issue at <]8;;https://github.com/mjskay/tidybayes/issues/newhttps://github.com/mjskay/tidybayes/issues/new]8;;>.
term group b .chain .iteration .draw
(Intercept) condition:A -1.2677915 1 1 1
(Intercept) condition:A -0.2132916 1 2 2
(Intercept) condition:A 0.0192337 1 3 3
(Intercept) condition:A -0.4226583 1 4 4
(Intercept) condition:A -0.4007204 1 5 5
(Intercept) condition:A 0.0275432 1 6 6
(Intercept) condition:A -0.3184563 1 7 7
(Intercept) condition:A -0.4145287 1 8 8
(Intercept) condition:A -0.5637111 1 9 9
(Intercept) condition:A -1.1462066 1 10 10

インデックス列には好きな名前を付けることができる。

m %>%
  spread_draws(b[t,g]) %>%
  head(10)
t g b .chain .iteration .draw
(Intercept) condition:A -1.2677915 1 1 1
(Intercept) condition:A -0.2132916 1 2 2
(Intercept) condition:A 0.0192337 1 3 3
(Intercept) condition:A -0.4226583 1 4 4
(Intercept) condition:A -0.4007204 1 5 5
(Intercept) condition:A 0.0275432 1 6 6
(Intercept) condition:A -0.3184563 1 7 7
(Intercept) condition:A -0.4145287 1 8 8
(Intercept) condition:A -0.5637111 1 9 9
(Intercept) condition:A -1.1462066 1 10 10

しかし、前の例のような、より説明的で暗号化されていない名前の方が望ましいだろう。

このモデルでは、項は1つしかない( (Intercept) )ので、そのインデックスを完全に省略して、各項目 group と対応する条件の b の値だけを取得することができる。

m %>%
  spread_draws(b[,group]) %>%
  head(10)
group b .chain .iteration .draw
condition:A -1.2677915 1 1 1
condition:A -0.2132916 1 2 2
condition:A 0.0192337 1 3 3
condition:A -0.4226583 1 4 4
condition:A -0.4007204 1 5 5
condition:A 0.0275432 1 6 6
condition:A -0.3184563 1 7 7
condition:A -0.4145287 1 8 8
condition:A -0.5637111 1 9 9
condition:A -1.1462066 1 10 10

この場合、すべてのグループが condition ファクターのものなので、対応する条件 ( A , B , C , etc) を含むだけの列を分離することもできる。これは、tidyr::separate を使って行うことができる。

m %>% 
  spread_draws(b[,group]) %>%
  separate(group, c("group", "condition"), ":") %>%
  head(10)
group condition b .chain .iteration .draw
condition A -1.2677915 1 1 1
condition A -0.2132916 1 2 2
condition A 0.0192337 1 3 3
condition A -0.4226583 1 4 4
condition A -0.4007204 1 5 5
condition A 0.0275432 1 6 6
condition A -0.3184563 1 7 7
condition A -0.4145287 1 8 8
condition A -0.5637111 1 9 9
condition A -1.1462066 1 10 10

あるいは、sep の引数を spread_draws() に変更して、: でも分割することもできる ( sep は正規表現)。Note: この例ではうまくいくが、マルチレベル・モデルで因子間の相互作用がグループ化レベルとして使われるrstanarmモデルではうまくいかないだろう。したがって、: はデフォルトの分離子に含まれない。

m %>% 
  spread_draws(b[,group,condition], sep = "[, :]") %>%
  head(10)
group condition b .chain .iteration .draw
condition A -1.2677915 1 1 1
condition A -0.2132916 1 2 2
condition A 0.0192337 1 3 3
condition A -0.4226583 1 4 4
condition A -0.4007204 1 5 5
condition A 0.0275432 1 6 6
condition A -0.3184563 1 7 7
condition A -0.4145287 1 8 8
condition A -0.5637111 1 9 9
condition A -1.1462066 1 10 10

注: spread_draws() を Stan または JAGS の生のサンプルで使用したことがある場合、spread_draws() の前に recover_types() を使用してインデックス列の値を取得することに慣れているだろう (例えば、インデックスが要因であった場合) 。rstanarm モデルで spread_draws() を使用する場合、これらのモデルにはすでにその情報が変数名として含まれているため、この操作は必要ない。recover_types() の詳細については、vignette("tidybayes") を参照。

ポイント要約と間隔

単純なモデル変数の場合

tidybayes は、描画から点の要約や区間を整然としたフォーマットで生成するための関数群を提供する。これらの関数は、以下の命名規則に従っている。 [median|mean|mode] _ [qi|hdi] 例えば、median_qi() , mean_qi() , mode_hdi() , などである。最初の名前 ( _ の前) は点要約の種類を示し、2番目の名前は区間の種類を示す。qi は分位点間隔 (別名:等値線間隔、中心間隔、パーセンタイル間隔) 、hdi は最高 (後) 密度間隔を生成する。カスタム点または区間関数は、point_interval() 関数を用いて適用することもできる。

例えば、オブザベーションの全体平均と標準偏差の事後分布に対応するドローを抽出することができる。

m %>%
  spread_draws(`(Intercept)`, sigma) %>%
  head(10)
.chain .iteration .draw (Intercept) sigma
1 1 1 1.4707274 0.7488978
1 2 2 0.2947973 0.5082004
1 3 3 0.3301366 0.5640631
1 4 4 0.5871531 0.5641068
1 5 5 0.5221652 0.5328644
1 6 6 0.6041039 0.5441214
1 7 7 0.3356784 0.6227631
1 8 8 0.5821142 0.5585193
1 9 9 1.0652540 0.6601671
1 10 10 1.1857010 0.5061982

と同様に b [term,group] と同様、整頓されたデータフレームが得られる。変数の中央値と95%分位間隔が欲しい場合は、median_qi() を適用する。

m %>%
  spread_draws(`(Intercept)`, sigma) %>%
  median_qi(`(Intercept)`, sigma)
(Intercept) (Intercept).lower (Intercept).upper sigma sigma.lower sigma.upper .width .point .interval
0.6086139 -0.4855256 1.58487 0.5606439 0.4568133 0.6975851 0.95 median qi

上記のように、中央値と区間を取得したい列を指定することもできるが、列のリストを省略すると、median_qi() は、グループ化列や特殊列 ( .chain , .iteration , .draw など) でないすべての列を使用する。したがって、上記の例では、(Intercept)sigma は、モデルから収集した唯一の列でもあるため、median_qi() の冗長な引数となっている。そこで、次のように単純化することができる。

m %>%
  spread_draws(`(Intercept)`, sigma) %>%
  median_qi()
(Intercept) (Intercept).lower (Intercept).upper sigma sigma.lower sigma.upper .width .point .interval
0.6086139 -0.4855256 1.58487 0.5606439 0.4568133 0.6975851 0.95 median qi

もし、長い形式のインターバルリストが欲しい場合は、代わりに gather_draws() を使ってみよう。

m %>%
  gather_draws(`(Intercept)`, sigma) %>%
  median_qi()
.variable .value .lower .upper .width .point .interval
(Intercept) 0.6086139 -0.4855256 1.5848700 0.95 median qi
sigma 0.5606439 0.4568133 0.6975851 0.95 median qi

gather_draws() の詳細については、vignette("tidybayes") を参照。

インデックス付き変数の場合

b のような1つ以上のインデックスを持つモデル変数があるとき、前と同じように median_qi() (または point_interval() ファミリーの他の関数) を適用することができる。

m %>%
  spread_draws(b[,group]) %>%
  median_qi()
group b .lower .upper .width .point .interval
condition:A -0.4099501 -1.4241652 0.6630132 0.95 median qi
condition:B 0.3853866 -0.6377728 1.4480079 0.95 median qi
condition:C 1.2029809 0.2439749 2.3109965 0.95 median qi
condition:D 0.4012280 -0.5931506 1.4861856 0.95 median qi
condition:E -1.4849292 -2.5190222 -0.4131681 0.95 median qi

median_qi() はどのようにして集計する対象を知ったのだろうか? spread_draws() によって返されたデータフレームは、あなたが渡したすべてのインデックス変数によって自動的にグループ化される。この場合、spread_draws() はその結果を group によってグループ化することを意味する。median_qi() はこれらのグループを尊重し、すべてのグループ内のポイントの要約と間隔を計算する。そして、median_qi() には列が渡されなかったので、唯一の非特殊列 ( . -prefixed) 、非グループ列である b に対して処理を行う。したがって、上記の短縮された構文は、より冗長なこの呼び出しと等価である。

m %>%
  spread_draws(b[,group]) %>%
  group_by(group) %>%       # this line not necessary (done by spread_draws)
  median_qi(b)                # b is not necessary (it is the only non-group column)
group b .lower .upper .width .point .interval
condition:A -0.4099501 -1.4241652 0.6630132 0.95 median qi
condition:B 0.3853866 -0.6377728 1.4480079 0.95 median qi
condition:C 1.2029809 0.2439749 2.3109965 0.95 median qi
condition:D 0.4012280 -0.5931506 1.4861856 0.95 median qi
condition:E -1.4849292 -2.5190222 -0.4131681 0.95 median qi

tidybayes は、posterior::summarise_draws() の実装も提供している。 グループ化されたデータフレーム ( tidybayes::summaries_draws.grouped_df() ) を作成することができる。 を使用すると、コンバージェンス診断が迅速に行える。

m %>%
  spread_draws(b[,group]) %>%
  summarise_draws()
group variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
condition:A b -0.4159312 -0.4099501 0.5129151 0.4761362 -1.2298734 0.4080505 1.002804 1013.889 1168.484
condition:B b 0.3872925 0.3853866 0.5161628 0.4713854 -0.4266826 1.2161343 1.001365 1024.497 1349.859
condition:C b 1.2192802 1.2029809 0.5133706 0.4682209 0.4278880 2.0607642 1.001481 1003.779 1212.333
condition:D b 0.4019813 0.4012280 0.5147134 0.4744561 -0.4179856 1.2377231 1.001843 1032.115 1378.726
condition:E b -1.4890782 -1.4849292 0.5169411 0.4781767 -2.3260892 -0.6672363 1.001625 1034.943 1310.946

異なるインデックスを持つ変数を1つの整頓されたフォーマットのデータフレームに結合する

spread_draws()gather_draws() は、異なるインデックスを持つ変数を同じデータフレームに抽出することをサポートしている。同じ名前のインデックスは自動的にマッチングされ、必要に応じて値が複製され、すべてのインデックスのレベルの組み合わせごとに1つの行が作成される。例えば、各条件における平均を計算したい場合がある (これを condition_mean と呼ぶ) 。このモデルでは、その平均は切片((Intercept) )と与えられた条件での効果(b)である。

(Intercept)b からの draw を一つのデータフレームにまとめることができる。

m %>% 
  spread_draws(`(Intercept)`, b[,group]) %>%
  head(10)
.chain .iteration .draw (Intercept) group b
1 1 1 1.4707274 condition:A -1.2677915
1 1 1 1.4707274 condition:B -0.1775179
1 1 1 1.4707274 condition:C 0.4365900
1 1 1 1.4707274 condition:D -0.4126736
1 1 1 1.4707274 condition:E -2.1516102
1 2 2 0.2947973 condition:A -0.2132916
1 2 2 0.2947973 condition:B 0.8856697
1 2 2 0.2947973 condition:C 1.6586191
1 2 2 0.2947973 condition:D 0.7126382
1 2 2 0.2947973 condition:E -1.2611341

各 draw の中で、(Intercept)b の各インデックスに対応するように必要に応じて繰り返される。 したがって、dplyr の mutate 関数を使用して、それらの合計 condition_mean (各条件の平均) を求めることができる。

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  mutate(condition_mean = `(Intercept)` + b) %>%
  median_qi(condition_mean)
group condition_mean .lower .upper .width .point .interval
condition:A 0.1987924 -0.1490205 0.5322958 0.95 median qi
condition:B 0.9948956 0.6540783 1.3561487 0.95 median qi
condition:C 1.8290092 1.4753124 2.1830999 0.95 median qi
condition:D 1.0154953 0.6644074 1.3666874 0.95 median qi
condition:E -0.8807628 -1.2302583 -0.5128632 0.95 median qi

median_qi() は整頓された評価 ( vignette("tidy-evaluation ", package =" rlang") を使うので、列名だけでなく、列の式も受け取ることができる。したがって、condition_mean の計算を mutate から median_qi() に移すことで、上の例を単純化することができる。

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  median_qi(condition_mean = `(Intercept)` + b)
group condition_mean .lower .upper .width .point .interval
condition:A 0.1987924 -0.1490205 0.5322958 0.95 median qi
condition:B 0.9948956 0.6540783 1.3561487 0.95 median qi
condition:C 1.8290092 1.4753124 2.1830999 0.95 median qi
condition:D 1.0154953 0.6644074 1.3666874 0.95 median qi
condition:E -0.8807628 -1.2302583 -0.5128632 0.95 median qi

複数の確率レベルを持つ区間をプロットする

median_qi() とその姉妹関数は、 引数を設定することで、任意の数の確率区間を生成することができる。 .width =

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  median_qi(condition_mean = `(Intercept)` + b, .width = c(.95, .8, .5))
group condition_mean .lower .upper .width .point .interval
condition:A 0.1987924 -0.1490205 0.5322958 0.95 median qi
condition:B 0.9948956 0.6540783 1.3561487 0.95 median qi
condition:C 1.8290092 1.4753124 2.1830999 0.95 median qi
condition:D 1.0154953 0.6644074 1.3666874 0.95 median qi
condition:E -0.8807628 -1.2302583 -0.5128632 0.95 median qi
condition:A 0.1987924 -0.0243302 0.4099228 0.80 median qi
condition:B 0.9948956 0.7687736 1.2267421 0.80 median qi
condition:C 1.8290092 1.5985594 2.0610815 0.80 median qi
condition:D 1.0154953 0.7813036 1.2365046 0.80 median qi
condition:E -0.8807628 -1.1069926 -0.6509841 0.80 median qi
condition:A 0.1987924 0.0816339 0.3077621 0.50 median qi
condition:B 0.9948956 0.8745806 1.1189284 0.50 median qi
condition:C 1.8290092 1.7087692 1.9516227 0.50 median qi
condition:D 1.0154953 0.8918738 1.1290154 0.50 median qi
condition:E -0.8807628 -1.0014546 -0.7577522 0.50 median qi

結果は整然とした形式で、1グループにつき1行、不確定性区間幅( .width )である。これはプロットを容易にする。たとえば、-.widthsize の美学に割り当てると、すべての区間が表示され、太い線がより小さい区間に対応するようになる。ggdist::geom_pointinterval() geomは、複数の確率レベルを持つ点のプロットを作成するために、データ中の .width 列に基づいて、size 美学を自動的に適切に設定する。

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  median_qi(condition_mean = `(Intercept)` + b, .width = c(.95, .66)) %>%
  ggplot(aes(y = group, x = condition_mean, xmin = .lower, xmax = .upper)) +
  geom_pointinterval()
## Warning: Using the `size` aesthietic with geom_segment was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.

密度を持つ区間

区間とともに密度を見るには、ggdist::stat_eye() (区間とバイオリンプロットを組み合わせた「アイプロット」) 、または ggdist::stat_halfeye() (区間+密度プロット) を使用することができる。

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  mutate(condition_mean = `(Intercept)` + b) %>%
  ggplot(aes(y = group, x = condition_mean)) +
  stat_halfeye()

あるいは、密度の一部をカラーでアノテートしたいとする。fill の美学は、ggdist::stat_halfeye() を含む ggdist::geom_slabinterval() ファミリーのすべてのジオムおよび統計において、スラブ内で変化させることができる。例えば、ドメイン固有のROPE (region of practical equivalence) をアノテートしたい場合、次のようなことが可能である。

m %>%
  spread_draws(`(Intercept)`, b[,group]) %>%
  mutate(condition_mean = `(Intercept)` + b) %>%
  ggplot(aes(y = group, x = condition_mean, fill = stat(abs(x) < .8))) +
  stat_halfeye() +
  geom_vline(xintercept = c(-.8, .8), linetype = "dashed") +
  scale_fill_manual(values = c("gray80", "skyblue"))
## Warning: `stat(abs(x) < 0.8)` was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(abs(x) < 0.8)` instead.

その他、分布の可視化を行う stat_slabinterval

分布を視覚化するための様々な追加統計が、ggdist::geom_slabinterval() ファミリーの統計とジオムにはある。

The slabinterval family of geoms and stats

を参照。 vignette("slabinterval ", package =" ggdist") を参照。

事後平均値と予測値

前の例のように条件付き平均を手動で計算するのではなく、rstanarm::posterior_epred() (事後予測値の期待値から事後ドローを与える;すなわち、条件付き平均の事後分布) に類似しているが、整然としたデータ形式を使用する add_epred_draws() を使用することができる。これを modelr::data_grid() と組み合わせて、まず欲しい予測を記述したグリッドを生成し、次にそのグリッドを条件付き平均からのドローの長大なデータフレームに変換することができる。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  head(10)
condition .row .chain .iteration .draw .epred
A 1 NA NA 1 0.2029359
A 1 NA NA 2 0.0815057
A 1 NA NA 3 0.3493703
A 1 NA NA 4 0.1644948
A 1 NA NA 5 0.1214448
A 1 NA NA 6 0.6316470
A 1 NA NA 7 0.0172222
A 1 NA NA 8 0.1675855
A 1 NA NA 9 0.5015429
A 1 NA NA 10 0.0394944

この例をプロットするために、ggplot 内で描画を点要約と区間に要約する ggdist::geom_pointinterval() の代わりに ggdist::stat_pointinterval() を使用することも紹介する。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  ggplot(aes(x = .epred, y = condition)) +
  stat_pointinterval(.width = c(.66, .95))

分位点プロット

アルファレベルがたまたまあなたが行おうとしている決定と一致する場合、区間は良いが、事後的な形状を得ることはより良いことである (したがって、上記の目のプロット) 。一方、密度プロットから推論するのは不正確である (ある形状の面積を別の形状の割合で推定するのは難しい知覚的作業である) 。頻度形式での確率の推論はより簡単で、 quantile dotplots ( Kay et al. 2016, Fernandes et al. 2018)の動機となった。これは任意の間隔 (プロットのドット解像度まで、下の例では100) を正確に推定することも可能である。

tidybayes のジオムの slabinterval ファミリーには dotsdotsinterval ファミリーがあり、点描画の適切なビンサイズを自動的に決定し、サンプルから分位を計算して分位点描画を作成できる。ggdist::stat_dotsinterval() はサンプルに使用するために設計された水平バリアントである。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  ggplot(aes(x = .epred, y = condition)) +
  stat_dotsinterval(quantiles = 100)

つまり、事後情報を1つの標準的な点または区間として考えるのではなく、 (例えば) 100のほぼ等しい可能性のある点として表現することである。

事後予測

ここで、add_epred_draws()rstanarm::posterior_epred() に類似しており、add_predicted_draws()rstanarm::posterior_predict() に類似しており、事後予測分布からのドローを与えている。

tidybayes::stat_interval() を使って、データと平均の事後分布と並べて予測バンドをプロットすることができる。

grid = ABC %>%
  data_grid(condition)

means = grid %>%
  add_epred_draws(m)

preds = grid %>%
  add_predicted_draws(m)

ABC %>%
  ggplot(aes(y = condition, x = response)) +
  stat_interval(aes(x = .prediction), data = preds) +
  stat_pointinterval(aes(x = .epred), data = means, .width = c(.66, .95), position = position_nudge(y = -0.3)) +
  geom_point() +
  scale_color_brewer()

フィット/予測曲線

不確実性を伴う適合曲線の描画を実証するために、mtcars のデータセットの一部に、少し素朴なモデルを適合させてみよう。

m_mpg = stan_glm(mpg ~ hp * cyl, data = mtcars)

フィットカーブを確率バンドでプロットすることができる。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 51)) %>%
  add_epred_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  stat_lineribbon(aes(y = .epred)) +
  geom_point(data = mtcars) +
  scale_fill_brewer(palette = "Greys") +
  scale_color_brewer(palette = "Set2")
## Warning: Using the `size` aesthietic with geom_ribbon was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Using the `size` aesthietic with geom_line was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

あるいは、適当な数のフィットラインをサンプリングして (たとえば100本) 、それらをオーバープロットすることもできる。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  # NOTE: this shows the use of ndraws to subsample within add_epred_draws()
  # ONLY do this IF you are planning to make spaghetti plots, etc.
  # NEVER subsample to a small sample to plot intervals, densities, etc.
  add_epred_draws(m_mpg, ndraws = 100) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = .1) +
  geom_point(data = mtcars) +
  scale_color_brewer(palette = "Dark2")

あるいは、 (平均値ではなく) 事後予測値をプロットすることもできる。この例では また、重なり合ったバンドを見やすくするために、alpha を使用する。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_predicted_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl), fill = ordered(cyl))) +
  stat_lineribbon(aes(y = .prediction), .width = c(.95, .80, .50), alpha = 1/4) +
  geom_point(data = mtcars) +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Dark2")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

その他のフィットラインの例については、vignette("tidy-brms") を参照。
animated hypothetical outcome plots (HOPs)

因子の水準を比較する

各条件の平均を比較したい場合、compare_levels() は、ある因子のレベル間のある変数の値の比較を容易にする。デフォルトでは、一対の差分をすべて計算する。

ggdist::stat_halfeye() を使って compare_levels() をデモしてみよう。また 差の平均値で再注文する。

m %>%
  spread_draws(b[,,condition], sep = "[, :]") %>%
  compare_levels(b, by = condition) %>%
  ungroup() %>%
  mutate(condition = reorder(condition, b)) %>%
  ggplot(aes(y = condition, x = b)) +
  stat_halfeye() +
  geom_vline(xintercept = 0, linetype = "dashed") 

Ordinal models

カテゴリカル予測変数の順序モデル

ここでは、カテゴリ予測変数の順序モデルである。

data(esoph)
m_esoph_rs = stan_polr(tobgp ~ agegp, data = esoph, prior = R2(0.25), prior_counts = rstanarm::dirichlet(1))

rstanarm の順序回帰モデル用の関数 rstanarm::posterior_linpred() は、各描画のリンクレベルの予測値を返する (順序モデル用にカテゴリごとに1つの予測値を返す brms::posterior_epred() とは対照的に、vignette("tidy-brms") の順序回帰の例を参照) 。残念ながら、rstanarm::posterior_epred() はこの形式を提供していない。tidybayes の理念は、モデルによって出力されるどんな形式でも整頓することである。その理念に沿って、順序モデル rstanarm に適用する場合、add_linpred_draws() の例を使用し、カテゴリごとの予測確率に変換する方法を紹介する。

例えば、ここにリンクレベルのフィットをプロットしたものがある。

esoph %>%
  data_grid(agegp) %>%
  add_linpred_draws(m_esoph_rs) %>%
  ggplot(aes(x = as.numeric(agegp), y = .linpred)) +
  stat_lineribbon() +
  scale_fill_brewer(palette = "Greys")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

これを解釈するのは難しいだろう。これをカテゴリごとの予測確率に変えるには、順序ロジスティック回帰がカテゴリ \(j\) 以下 の結果の確率を次のように定義していることを利用する必要がある。

\[ \textrm{logit}\left[Pr(Y\le j)\right] = \alpha_j - \beta x \]

したがって、カテゴリ \(j\) の確率は。

\[ \begin{align} Pr(Y = j) &= Pr(Y \le j) - Pr(Y \le j - 1)\\ &= \textrm{logit}^{-1}(\alpha_j - \beta x) - \textrm{logit}^{-1}(\alpha_{j-1} - \beta x) \end{align} \]

この値を導き出すには、2つのことが必要である。

rstanarm の閾値は、| を含む名前を持つ係数で、どのカテゴリ間の閾値であるかを示している。これらのパラメータは、モデル中の変数のリストで見ることができる。

get_variables(m_esoph_rs)
##  [1] "agegp.L"        "agegp.Q"        "agegp.C"        "agegp^4"        "agegp^5"        "0-9g/day|10-19"
##  [7] "10-19|20-29"    "20-29|30+"      "accept_stat__"  "stepsize__"     "treedepth__"    "n_leapfrog__"  
## [13] "divergent__"    "energy__"

gather_draws()regex = TRUE 引数を使用して、| の文字を含むすべての変数を見つけることによって、これらを自動的に抽出することができる。次に、dplyr::summarise_all(list) を使ってこれらの閾値をリスト列に変換し、 \(+\infty\) に等しい最終閾値を追加する (最高カテゴリーを表すため) 。

thresholds = m_esoph_rs %>%
  gather_draws(`.*[|].*`, regex = TRUE) %>%
  group_by(.draw) %>%
  select(.draw, threshold = .value) %>%
  summarise_all(list) %>%
  mutate(threshold = map(threshold, ~ c(., Inf)))

head(thresholds, 10)
.draw threshold
1 -0.8200002, 0.6020005, 1.4201284, Inf
2 -1.1563621, 0.3755944, 1.4963719, Inf
3 -0.7694691, 0.1466003, 1.2931992, Inf
4 -0.8780252, 0.6834725, 1.4461758, Inf
5 -1.1254045, 0.2542654, 1.0647593, Inf
6 -1.2646026, 0.3302284, 0.9631856, Inf
7 -0.7925928, 0.1191108, 1.5750053, Inf
8 -1.162030, 0.320501, 1.026631, Inf
9 -0.8444674, 0.1494691, 1.6880150, Inf
10 -1.110080, 0.321730, 1.102269, Inf

例えば、このデータフレームの1行から (つまり、事後的な1ドローから) の閾値ベクタは次のようになる。

thresholds[1,]$threshold
## [[1]]
## [1] -0.8200002  0.6020005  1.4201284        Inf

これらの閾値 (上式の \(\alpha_j\) ) と add_linpred_draws() (上式の \(\beta x\) ) の .linpred 列を組み合わせて、カテゴリごとの確率を計算することができる。

esoph %>%
  data_grid(agegp) %>%
  add_linpred_draws(m_esoph_rs) %>%
  inner_join(thresholds, by = ".draw") %>%
  mutate(`P(Y = category)` = map2(threshold, .linpred, function(alpha, beta_x)
      # this part is logit^-1(alpha_j - beta*x) - logit^-1(alpha_j-1 - beta*x)
      plogis(alpha - beta_x) - 
      plogis(lag(alpha, default = -Inf) - beta_x)
    )) %>%
  mutate(.category = list(levels(esoph$tobgp))) %>%
  unnest(c(threshold, `P(Y = category)`, .category)) %>%
  ggplot(aes(x = agegp, y = `P(Y = category)`, color = .category)) +
  stat_pointinterval(position = position_dodge(width = .4)) +
  scale_size_continuous(guide = "none") +
  scale_color_manual(values = brewer.pal(6, "Blues")[-c(1,2)]) 

上のプロットでは、カテゴリの変化がわかりにくいので、各年度の分布がよくわかるようなものを考えてみよう。

esoph_plot = esoph %>%
  data_grid(agegp) %>%
  add_linpred_draws(m_esoph_rs) %>%
  inner_join(thresholds, by = ".draw") %>%
  mutate(`P(Y = category)` = map2(threshold, .linpred, function(alpha, beta_x)
      # this part is logit^-1(alpha_j - beta*x) - logit^-1(alpha_j-1 - beta*x)
      plogis(alpha - beta_x) - 
      plogis(lag(alpha, default = -Inf) - beta_x)
    )) %>%
  mutate(.category = list(levels(esoph$tobgp))) %>%
  unnest(c(threshold, `P(Y = category)`, .category)) %>%
  ggplot(aes(x = `P(Y = category)`, y = .category)) +
  coord_cartesian(expand = FALSE) +
  facet_grid(. ~ agegp, switch = "x") +
  theme_classic() +
  theme(strip.background = element_blank(), strip.placement = "outside") +
  ggtitle("P(tobacco consumption category | age group)") +
  xlab("age group")

esoph_plot +
  stat_summary(fun = median, geom = "bar", fill = "gray65", width = 1, color = "white") +
  stat_pointinterval()

この場合、棒グラフは誤った精度感を与える可能性があるので、代わりにCCDF棒グラフを試してみることもできる。

esoph_plot +
  stat_ccdfinterval() +
  expand_limits(x = 0) #ensure bars go to 0

この出力は、vignette("tidy-brms") の対応する m_esoph_brm モデルからの出力と非常によく似ているはずである (異なるプライヤーを適用している) 。ただし、rstanarm では brms と比較して、出力に少し手間がかかるようになっている。