この vignette は、tidybayes
と ggdist
パッケージを使用して、 tidy
データフレームを抽出し、可視化する方法について説明する。モデル変数の事後分布からのドロー、平均、および
rstanarm
からの予測値である。より一般的な
tidybayes
の紹介と、汎用のベイズモデリング言語 (Stan や
JAGS など) での使用については、vignette("tidybayes")
を参照。
このvignetteを実行するには、以下のライブラリが必要である。
library(magrittr)
library(dplyr)
library(purrr)
library(forcats)
library(tidyr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(rstanarm)
library(RColorBrewer)
theme_set(theme_tidybayes() + panel_border())
これらのオプションは、Stanの動作を高速化するためのものである。
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
tidybayes
を実証するために、5つの条件からそれぞれ10個のオブザベーションを持つ単純なデータセットを使用することにする。
set.seed(5)
= 10
n = 5
n_condition =
ABC tibble(
condition = rep(c("A","B","C","D","E"), n),
response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
)
データのスナップショットはこのような感じである。
head(ABC, 10)
condition | response |
---|---|
A | -0.4204277 |
B | 1.6921797 |
C | 1.3722541 |
D | 1.0350714 |
E | -0.1442796 |
A | -0.3014540 |
B | 0.7639168 |
C | 1.6823143 |
D | 0.8571132 |
E | -0.9309459 |
これは典型的な整頓されたフォーマットのデータフレーム: 1行に1つの観測値である。グラフィカルに
%>%
ABC ggplot(aes(y = condition, x = response)) +
geom_point()
#モデル {#model}
大域平均に向かって収縮する階層的モデルを当てはめよう。
= stan_lmer(response ~ (1|condition), data = ABC,
m prior = normal(0, 1, autoscale = FALSE),
prior_aux = student_t(3, 0, 1, autoscale = FALSE),
adapt_delta = .99)
結果はこのようになる。
m
## stan_lmer
## family: gaussian [identity]
## formula: response ~ (1 | condition)
## observations: 50
## ------
## Median MAD_SD
## (Intercept) 0.6 0.5
##
## Auxiliary parameter(s):
## Median MAD_SD
## sigma 0.6 0.1
##
## Error terms:
## Groups Name Std.Dev.
## condition (Intercept) 1.13
## Residual 0.56
## Num. levels: condition 5
##
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg
spread_draws
を使ってフィットからドローをtidyフォーマットで抽出するさて、結果が出たので、楽しいことが始まります:整頓されたフォーマットで描画を取り出すことである!まず、
関数を使って、生のモデル変数名のリストを取得する。まず、get_variables()
関数を使って、生のモデル変数名のリストを取得し、モデルからどの変数を抽出できるかを知ることにする。
get_variables(m)
## [1] "(Intercept)" "b[(Intercept) condition:A]"
## [3] "b[(Intercept) condition:B]" "b[(Intercept) condition:C]"
## [5] "b[(Intercept) condition:D]" "b[(Intercept) condition:E]"
## [7] "sigma" "Sigma[condition:(Intercept),(Intercept)]"
## [9] "accept_stat__" "stepsize__"
## [11] "treedepth__" "n_leapfrog__"
## [13] "divergent__" "energy__"
ここで、(Intercept)
は大域的な平均値、b
のパラメータは各条件におけるその平均値からのオフセットである。これらのパラメータが与えられると
b [(Intercept) condition:A]
b [(Intercept) condition:B]
b [(Intercept) condition:C]
b [(Intercept) condition:D]
b [(Intercept) condition:E]
各行が b [(Intercept) condition:A]
,
b [(Intercept) condition:B]
, ...:C]
,
...:D]
, or ...:E]
の どちらか一方からの draw
であるデータフレームが必要だろう。 , そして、行がどの
chain/iteration/draw から来たか、そしてどの条件 ( A
to
E
)
のものであるかのインデックスを持つ列があるデータフレームが欲しいと思うだろう。これにより、条件ごとにグループ化された量を簡単に計算したり、ggplotを使用して条件ごとにプロットを生成したり、あるいはドローとオリジナルデータをマージしてデータと後置をプロットしたりすることができるようになる。
tidybayes
の主力は spread_draws()
関数で、この関数はこの抽出を行ってくれる。この関数には、モデル変数とそのインデックスを整頓されたデータフレームに抽出するために使用できる簡単な仕様書式が含まれている。
このようなパラメータが与えられると
b [(Intercept) condition:D]
このような列の仕様で spread_draws()
を提供することができる。
b [term,group]
ここで、term
は (Intercept)
に、group
は condition:D
に対応する。この指定に対して spread_draws()
が行うことについては、特に不思議なことはない。この仕様では、パラメータのインデックスをカンマとスペースで分割している
( sep
引数を変更すれば、他の文字で分割することもできる)。そして、その結果得られたインデックスに、順番に列を割り当てることができるのである。つまり
b [(Intercept) condition:D]
はインデックス
(Intercept)
と condition:D
を持ち、spread_draws()
はこれらのインデックスを列として抽出し、b
から得られるドローのデータフレームを整頓することができる。
%>%
m spread_draws(b[term,group]) %>%
head(10)
## Warning: `gather_()` was deprecated in tidyr 1.2.0.
## ℹ Please use `gather()` instead.
## ℹ The deprecated feature was likely used in the tidybayes package.
## Please report the issue at <]8;;https://github.com/mjskay/tidybayes/issues/newhttps://github.com/mjskay/tidybayes/issues/new]8;;>.
term | group | b | .chain | .iteration | .draw |
---|---|---|---|---|---|
(Intercept) | condition:A | -1.2677915 | 1 | 1 | 1 |
(Intercept) | condition:A | -0.2132916 | 1 | 2 | 2 |
(Intercept) | condition:A | 0.0192337 | 1 | 3 | 3 |
(Intercept) | condition:A | -0.4226583 | 1 | 4 | 4 |
(Intercept) | condition:A | -0.4007204 | 1 | 5 | 5 |
(Intercept) | condition:A | 0.0275432 | 1 | 6 | 6 |
(Intercept) | condition:A | -0.3184563 | 1 | 7 | 7 |
(Intercept) | condition:A | -0.4145287 | 1 | 8 | 8 |
(Intercept) | condition:A | -0.5637111 | 1 | 9 | 9 |
(Intercept) | condition:A | -1.1462066 | 1 | 10 | 10 |
インデックス列には好きな名前を付けることができる。
%>%
m spread_draws(b[t,g]) %>%
head(10)
t | g | b | .chain | .iteration | .draw |
---|---|---|---|---|---|
(Intercept) | condition:A | -1.2677915 | 1 | 1 | 1 |
(Intercept) | condition:A | -0.2132916 | 1 | 2 | 2 |
(Intercept) | condition:A | 0.0192337 | 1 | 3 | 3 |
(Intercept) | condition:A | -0.4226583 | 1 | 4 | 4 |
(Intercept) | condition:A | -0.4007204 | 1 | 5 | 5 |
(Intercept) | condition:A | 0.0275432 | 1 | 6 | 6 |
(Intercept) | condition:A | -0.3184563 | 1 | 7 | 7 |
(Intercept) | condition:A | -0.4145287 | 1 | 8 | 8 |
(Intercept) | condition:A | -0.5637111 | 1 | 9 | 9 |
(Intercept) | condition:A | -1.1462066 | 1 | 10 | 10 |
しかし、前の例のような、より説明的で暗号化されていない名前の方が望ましいだろう。
このモデルでは、項は1つしかない( (Intercept)
)ので、そのインデックスを完全に省略して、各項目 group
と対応する条件の b
の値だけを取得することができる。
%>%
m spread_draws(b[,group]) %>%
head(10)
group | b | .chain | .iteration | .draw |
---|---|---|---|---|
condition:A | -1.2677915 | 1 | 1 | 1 |
condition:A | -0.2132916 | 1 | 2 | 2 |
condition:A | 0.0192337 | 1 | 3 | 3 |
condition:A | -0.4226583 | 1 | 4 | 4 |
condition:A | -0.4007204 | 1 | 5 | 5 |
condition:A | 0.0275432 | 1 | 6 | 6 |
condition:A | -0.3184563 | 1 | 7 | 7 |
condition:A | -0.4145287 | 1 | 8 | 8 |
condition:A | -0.5637111 | 1 | 9 | 9 |
condition:A | -1.1462066 | 1 | 10 | 10 |
この場合、すべてのグループが condition
ファクターのものなので、対応する条件 ( A
, B
,
C
, etc)
を含むだけの列を分離することもできる。これは、tidyr::separate
を使って行うことができる。
%>%
m spread_draws(b[,group]) %>%
separate(group, c("group", "condition"), ":") %>%
head(10)
group | condition | b | .chain | .iteration | .draw |
---|---|---|---|---|---|
condition | A | -1.2677915 | 1 | 1 | 1 |
condition | A | -0.2132916 | 1 | 2 | 2 |
condition | A | 0.0192337 | 1 | 3 | 3 |
condition | A | -0.4226583 | 1 | 4 | 4 |
condition | A | -0.4007204 | 1 | 5 | 5 |
condition | A | 0.0275432 | 1 | 6 | 6 |
condition | A | -0.3184563 | 1 | 7 | 7 |
condition | A | -0.4145287 | 1 | 8 | 8 |
condition | A | -0.5637111 | 1 | 9 | 9 |
condition | A | -1.1462066 | 1 | 10 | 10 |
あるいは、sep
の引数を spread_draws()
に変更して、:
でも分割することもできる ( sep
は正規表現)。Note:
この例ではうまくいくが、マルチレベル・モデルで因子間の相互作用がグループ化レベルとして使われるrstanarmモデルではうまくいかないだろう。したがって、:
はデフォルトの分離子に含まれない。
%>%
m spread_draws(b[,group,condition], sep = "[, :]") %>%
head(10)
group | condition | b | .chain | .iteration | .draw |
---|---|---|---|---|---|
condition | A | -1.2677915 | 1 | 1 | 1 |
condition | A | -0.2132916 | 1 | 2 | 2 |
condition | A | 0.0192337 | 1 | 3 | 3 |
condition | A | -0.4226583 | 1 | 4 | 4 |
condition | A | -0.4007204 | 1 | 5 | 5 |
condition | A | 0.0275432 | 1 | 6 | 6 |
condition | A | -0.3184563 | 1 | 7 | 7 |
condition | A | -0.4145287 | 1 | 8 | 8 |
condition | A | -0.5637111 | 1 | 9 | 9 |
condition | A | -1.1462066 | 1 | 10 | 10 |
注: spread_draws()
を Stan または JAGS
の生のサンプルで使用したことがある場合、spread_draws()
の前に recover_types()
を使用してインデックス列の値を取得することに慣れているだろう
(例えば、インデックスが要因であった場合) 。rstanarm
モデルで spread_draws()
を使用する場合、これらのモデルにはすでにその情報が変数名として含まれているため、この操作は必要ない。recover_types()
の詳細については、vignette("tidybayes")
を参照。
tidybayes
は、描画から点の要約や区間を整然としたフォーマットで生成するための関数群を提供する。これらの関数は、以下の命名規則に従っている。
[median|mean|mode] _ [qi|hdi]
例えば、median_qi()
, mean_qi()
,
mode_hdi()
, などである。最初の名前 ( _
の前)
は点要約の種類を示し、2番目の名前は区間の種類を示す。qi
は分位点間隔 (別名:等値線間隔、中心間隔、パーセンタイル間隔)
、hdi
は最高 (後)
密度間隔を生成する。カスタム点または区間関数は、point_interval()
関数を用いて適用することもできる。
例えば、オブザベーションの全体平均と標準偏差の事後分布に対応するドローを抽出することができる。
%>%
m spread_draws(`(Intercept)`, sigma) %>%
head(10)
.chain | .iteration | .draw | (Intercept) | sigma |
---|---|---|---|---|
1 | 1 | 1 | 1.4707274 | 0.7488978 |
1 | 2 | 2 | 0.2947973 | 0.5082004 |
1 | 3 | 3 | 0.3301366 | 0.5640631 |
1 | 4 | 4 | 0.5871531 | 0.5641068 |
1 | 5 | 5 | 0.5221652 | 0.5328644 |
1 | 6 | 6 | 0.6041039 | 0.5441214 |
1 | 7 | 7 | 0.3356784 | 0.6227631 |
1 | 8 | 8 | 0.5821142 | 0.5585193 |
1 | 9 | 9 | 1.0652540 | 0.6601671 |
1 | 10 | 10 | 1.1857010 | 0.5061982 |
と同様に b [term,group]
と同様、整頓されたデータフレームが得られる。変数の中央値と95%分位間隔が欲しい場合は、median_qi()
を適用する。
%>%
m spread_draws(`(Intercept)`, sigma) %>%
median_qi(`(Intercept)`, sigma)
(Intercept) | (Intercept).lower | (Intercept).upper | sigma | sigma.lower | sigma.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.6086139 | -0.4855256 | 1.58487 | 0.5606439 | 0.4568133 | 0.6975851 | 0.95 | median | qi |
上記のように、中央値と区間を取得したい列を指定することもできるが、列のリストを省略すると、median_qi()
は、グループ化列や特殊列 ( .chain
, .iteration
, .draw
など)
でないすべての列を使用する。したがって、上記の例では、(Intercept)
と sigma
は、モデルから収集した唯一の列でもあるため、median_qi()
の冗長な引数となっている。そこで、次のように単純化することができる。
%>%
m spread_draws(`(Intercept)`, sigma) %>%
median_qi()
(Intercept) | (Intercept).lower | (Intercept).upper | sigma | sigma.lower | sigma.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.6086139 | -0.4855256 | 1.58487 | 0.5606439 | 0.4568133 | 0.6975851 | 0.95 | median | qi |
もし、長い形式のインターバルリストが欲しい場合は、代わりに
gather_draws()
を使ってみよう。
%>%
m gather_draws(`(Intercept)`, sigma) %>%
median_qi()
.variable | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
(Intercept) | 0.6086139 | -0.4855256 | 1.5848700 | 0.95 | median | qi |
sigma | 0.5606439 | 0.4568133 | 0.6975851 | 0.95 | median | qi |
gather_draws()
の詳細については、vignette("tidybayes")
を参照。
b
のような1つ以上のインデックスを持つモデル変数があるとき、前と同じように
median_qi()
(または point_interval()
ファミリーの他の関数) を適用することができる。
%>%
m spread_draws(b[,group]) %>%
median_qi()
group | b | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
condition:A | -0.4099501 | -1.4241652 | 0.6630132 | 0.95 | median | qi |
condition:B | 0.3853866 | -0.6377728 | 1.4480079 | 0.95 | median | qi |
condition:C | 1.2029809 | 0.2439749 | 2.3109965 | 0.95 | median | qi |
condition:D | 0.4012280 | -0.5931506 | 1.4861856 | 0.95 | median | qi |
condition:E | -1.4849292 | -2.5190222 | -0.4131681 | 0.95 | median | qi |
median_qi()
はどのようにして集計する対象を知ったのだろうか?
spread_draws()
によって返されたデータフレームは、あなたが渡したすべてのインデックス変数によって自動的にグループ化される。この場合、spread_draws()
はその結果を group
によってグループ化することを意味する。median_qi()
はこれらのグループを尊重し、すべてのグループ内のポイントの要約と間隔を計算する。そして、median_qi()
には列が渡されなかったので、唯一の非特殊列 ( .
-prefixed)
、非グループ列である b
に対して処理を行う。したがって、上記の短縮された構文は、より冗長なこの呼び出しと等価である。
%>%
m spread_draws(b[,group]) %>%
group_by(group) %>% # this line not necessary (done by spread_draws)
median_qi(b) # b is not necessary (it is the only non-group column)
group | b | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
condition:A | -0.4099501 | -1.4241652 | 0.6630132 | 0.95 | median | qi |
condition:B | 0.3853866 | -0.6377728 | 1.4480079 | 0.95 | median | qi |
condition:C | 1.2029809 | 0.2439749 | 2.3109965 | 0.95 | median | qi |
condition:D | 0.4012280 | -0.5931506 | 1.4861856 | 0.95 | median | qi |
condition:E | -1.4849292 | -2.5190222 | -0.4131681 | 0.95 | median | qi |
tidybayes
は、posterior::summarise_draws()
の実装も提供している。 グループ化されたデータフレーム (
tidybayes::summaries_draws.grouped_df()
)
を作成することができる。
を使用すると、コンバージェンス診断が迅速に行える。
%>%
m spread_draws(b[,group]) %>%
summarise_draws()
group | variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|---|
condition:A | b | -0.4159312 | -0.4099501 | 0.5129151 | 0.4761362 | -1.2298734 | 0.4080505 | 1.002804 | 1013.889 | 1168.484 |
condition:B | b | 0.3872925 | 0.3853866 | 0.5161628 | 0.4713854 | -0.4266826 | 1.2161343 | 1.001365 | 1024.497 | 1349.859 |
condition:C | b | 1.2192802 | 1.2029809 | 0.5133706 | 0.4682209 | 0.4278880 | 2.0607642 | 1.001481 | 1003.779 | 1212.333 |
condition:D | b | 0.4019813 | 0.4012280 | 0.5147134 | 0.4744561 | -0.4179856 | 1.2377231 | 1.001843 | 1032.115 | 1378.726 |
condition:E | b | -1.4890782 | -1.4849292 | 0.5169411 | 0.4781767 | -2.3260892 | -0.6672363 | 1.001625 | 1034.943 | 1310.946 |
spread_draws()
と gather_draws()
は、異なるインデックスを持つ変数を同じデータフレームに抽出することをサポートしている。同じ名前のインデックスは自動的にマッチングされ、必要に応じて値が複製され、すべてのインデックスのレベルの組み合わせごとに1つの行が作成される。例えば、各条件における平均を計算したい場合がある
(これを condition_mean
と呼ぶ)
。このモデルでは、その平均は切片((Intercept)
)と与えられた条件での効果(b
)である。
(Intercept)
と b
からの draw
を一つのデータフレームにまとめることができる。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
head(10)
.chain | .iteration | .draw | (Intercept) | group | b |
---|---|---|---|---|---|
1 | 1 | 1 | 1.4707274 | condition:A | -1.2677915 |
1 | 1 | 1 | 1.4707274 | condition:B | -0.1775179 |
1 | 1 | 1 | 1.4707274 | condition:C | 0.4365900 |
1 | 1 | 1 | 1.4707274 | condition:D | -0.4126736 |
1 | 1 | 1 | 1.4707274 | condition:E | -2.1516102 |
1 | 2 | 2 | 0.2947973 | condition:A | -0.2132916 |
1 | 2 | 2 | 0.2947973 | condition:B | 0.8856697 |
1 | 2 | 2 | 0.2947973 | condition:C | 1.6586191 |
1 | 2 | 2 | 0.2947973 | condition:D | 0.7126382 |
1 | 2 | 2 | 0.2947973 | condition:E | -1.2611341 |
各 draw の中で、(Intercept)
は b
の各インデックスに対応するように必要に応じて繰り返される。
したがって、dplyr の mutate
関数を使用して、それらの合計
condition_mean
(各条件の平均) を求めることができる。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
mutate(condition_mean = `(Intercept)` + b) %>%
median_qi(condition_mean)
group | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
condition:A | 0.1987924 | -0.1490205 | 0.5322958 | 0.95 | median | qi |
condition:B | 0.9948956 | 0.6540783 | 1.3561487 | 0.95 | median | qi |
condition:C | 1.8290092 | 1.4753124 | 2.1830999 | 0.95 | median | qi |
condition:D | 1.0154953 | 0.6644074 | 1.3666874 | 0.95 | median | qi |
condition:E | -0.8807628 | -1.2302583 | -0.5128632 | 0.95 | median | qi |
median_qi()
は整頓された評価 (
vignette("tidy-evaluation ", package =" rlang")
を使うので、列名だけでなく、列の式も受け取ることができる。したがって、condition_mean
の計算を mutate
から median_qi()
に移すことで、上の例を単純化することができる。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
median_qi(condition_mean = `(Intercept)` + b)
group | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
condition:A | 0.1987924 | -0.1490205 | 0.5322958 | 0.95 | median | qi |
condition:B | 0.9948956 | 0.6540783 | 1.3561487 | 0.95 | median | qi |
condition:C | 1.8290092 | 1.4753124 | 2.1830999 | 0.95 | median | qi |
condition:D | 1.0154953 | 0.6644074 | 1.3666874 | 0.95 | median | qi |
condition:E | -0.8807628 | -1.2302583 | -0.5128632 | 0.95 | median | qi |
median_qi()
とその姉妹関数は、
引数を設定することで、任意の数の確率区間を生成することができる。
.width =
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
median_qi(condition_mean = `(Intercept)` + b, .width = c(.95, .8, .5))
group | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
condition:A | 0.1987924 | -0.1490205 | 0.5322958 | 0.95 | median | qi |
condition:B | 0.9948956 | 0.6540783 | 1.3561487 | 0.95 | median | qi |
condition:C | 1.8290092 | 1.4753124 | 2.1830999 | 0.95 | median | qi |
condition:D | 1.0154953 | 0.6644074 | 1.3666874 | 0.95 | median | qi |
condition:E | -0.8807628 | -1.2302583 | -0.5128632 | 0.95 | median | qi |
condition:A | 0.1987924 | -0.0243302 | 0.4099228 | 0.80 | median | qi |
condition:B | 0.9948956 | 0.7687736 | 1.2267421 | 0.80 | median | qi |
condition:C | 1.8290092 | 1.5985594 | 2.0610815 | 0.80 | median | qi |
condition:D | 1.0154953 | 0.7813036 | 1.2365046 | 0.80 | median | qi |
condition:E | -0.8807628 | -1.1069926 | -0.6509841 | 0.80 | median | qi |
condition:A | 0.1987924 | 0.0816339 | 0.3077621 | 0.50 | median | qi |
condition:B | 0.9948956 | 0.8745806 | 1.1189284 | 0.50 | median | qi |
condition:C | 1.8290092 | 1.7087692 | 1.9516227 | 0.50 | median | qi |
condition:D | 1.0154953 | 0.8918738 | 1.1290154 | 0.50 | median | qi |
condition:E | -0.8807628 | -1.0014546 | -0.7577522 | 0.50 | median | qi |
結果は整然とした形式で、1グループにつき1行、不確定性区間幅(
.width
)である。これはプロットを容易にする。たとえば、-.width
を
size
の美学に割り当てると、すべての区間が表示され、太い線がより小さい区間に対応するようになる。ggdist::geom_pointinterval()
geomは、複数の確率レベルを持つ点のプロットを作成するために、データ中の
.width
列に基づいて、size
美学を自動的に適切に設定する。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
median_qi(condition_mean = `(Intercept)` + b, .width = c(.95, .66)) %>%
ggplot(aes(y = group, x = condition_mean, xmin = .lower, xmax = .upper)) +
geom_pointinterval()
## Warning: Using the `size` aesthietic with geom_segment was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
区間とともに密度を見るには、ggdist::stat_eye()
(区間とバイオリンプロットを組み合わせた「アイプロット」) 、または
ggdist::stat_halfeye()
(区間+密度プロット)
を使用することができる。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
mutate(condition_mean = `(Intercept)` + b) %>%
ggplot(aes(y = group, x = condition_mean)) +
stat_halfeye()
あるいは、密度の一部をカラーでアノテートしたいとする。fill
の美学は、ggdist::stat_halfeye()
を含む
ggdist::geom_slabinterval()
ファミリーのすべてのジオムおよび統計において、スラブ内で変化させることができる。例えば、ドメイン固有のROPE
(region of practical equivalence)
をアノテートしたい場合、次のようなことが可能である。
%>%
m spread_draws(`(Intercept)`, b[,group]) %>%
mutate(condition_mean = `(Intercept)` + b) %>%
ggplot(aes(y = group, x = condition_mean, fill = stat(abs(x) < .8))) +
stat_halfeye() +
geom_vline(xintercept = c(-.8, .8), linetype = "dashed") +
scale_fill_manual(values = c("gray80", "skyblue"))
## Warning: `stat(abs(x) < 0.8)` was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(abs(x) < 0.8)` instead.
stat_slabinterval
分布を視覚化するための様々な追加統計が、ggdist::geom_slabinterval()
ファミリーの統計とジオムにはある。
を参照。 vignette("slabinterval ", package =" ggdist")
を参照。
前の例のように条件付き平均を手動で計算するのではなく、rstanarm::posterior_epred()
(事後予測値の期待値から事後ドローを与える;すなわち、条件付き平均の事後分布)
に類似しているが、整然としたデータ形式を使用する
add_epred_draws()
を使用することができる。これを
modelr::data_grid()
と組み合わせて、まず欲しい予測を記述したグリッドを生成し、次にそのグリッドを条件付き平均からのドローの長大なデータフレームに変換することができる。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
head(10)
condition | .row | .chain | .iteration | .draw | .epred |
---|---|---|---|---|---|
A | 1 | NA | NA | 1 | 0.2029359 |
A | 1 | NA | NA | 2 | 0.0815057 |
A | 1 | NA | NA | 3 | 0.3493703 |
A | 1 | NA | NA | 4 | 0.1644948 |
A | 1 | NA | NA | 5 | 0.1214448 |
A | 1 | NA | NA | 6 | 0.6316470 |
A | 1 | NA | NA | 7 | 0.0172222 |
A | 1 | NA | NA | 8 | 0.1675855 |
A | 1 | NA | NA | 9 | 0.5015429 |
A | 1 | NA | NA | 10 | 0.0394944 |
この例をプロットするために、ggplot 内で描画を点要約と区間に要約する
ggdist::geom_pointinterval()
の代わりに
ggdist::stat_pointinterval()
を使用することも紹介する。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
ggplot(aes(x = .epred, y = condition)) +
stat_pointinterval(.width = c(.66, .95))
アルファレベルがたまたまあなたが行おうとしている決定と一致する場合、区間は良いが、事後的な形状を得ることはより良いことである (したがって、上記の目のプロット) 。一方、密度プロットから推論するのは不正確である (ある形状の面積を別の形状の割合で推定するのは難しい知覚的作業である) 。頻度形式での確率の推論はより簡単で、 quantile dotplots ( Kay et al. 2016, Fernandes et al. 2018)の動機となった。これは任意の間隔 (プロットのドット解像度まで、下の例では100) を正確に推定することも可能である。
tidybayes のジオムの slabinterval ファミリーには dots
と
dotsinterval
ファミリーがあり、点描画の適切なビンサイズを自動的に決定し、サンプルから分位を計算して分位点描画を作成できる。ggdist::stat_dotsinterval()
はサンプルに使用するために設計された水平バリアントである。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
ggplot(aes(x = .epred, y = condition)) +
stat_dotsinterval(quantiles = 100)
つまり、事後情報を1つの標準的な点または区間として考えるのではなく、 (例えば) 100のほぼ等しい可能性のある点として表現することである。
ここで、add_epred_draws()
は
rstanarm::posterior_epred()
に類似しており、add_predicted_draws()
は
rstanarm::posterior_predict()
に類似しており、事後予測分布からのドローを与えている。
tidybayes::stat_interval()
を使って、データと平均の事後分布と並べて予測バンドをプロットすることができる。
= ABC %>%
grid data_grid(condition)
= grid %>%
means add_epred_draws(m)
= grid %>%
preds add_predicted_draws(m)
%>%
ABC ggplot(aes(y = condition, x = response)) +
stat_interval(aes(x = .prediction), data = preds) +
stat_pointinterval(aes(x = .epred), data = means, .width = c(.66, .95), position = position_nudge(y = -0.3)) +
geom_point() +
scale_color_brewer()
不確実性を伴う適合曲線の描画を実証するために、mtcars
のデータセットの一部に、少し素朴なモデルを適合させてみよう。
= stan_glm(mpg ~ hp * cyl, data = mtcars) m_mpg
フィットカーブを確率バンドでプロットすることができる。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 51)) %>%
add_epred_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
stat_lineribbon(aes(y = .epred)) +
geom_point(data = mtcars) +
scale_fill_brewer(palette = "Greys") +
scale_color_brewer(palette = "Set2")
## Warning: Using the `size` aesthietic with geom_ribbon was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Using the `size` aesthietic with geom_line was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
あるいは、適当な数のフィットラインをサンプリングして (たとえば100本) 、それらをオーバープロットすることもできる。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
# NOTE: this shows the use of ndraws to subsample within add_epred_draws()
# ONLY do this IF you are planning to make spaghetti plots, etc.
# NEVER subsample to a small sample to plot intervals, densities, etc.
add_epred_draws(m_mpg, ndraws = 100) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = .1) +
geom_point(data = mtcars) +
scale_color_brewer(palette = "Dark2")
あるいは、 (平均値ではなく)
事後予測値をプロットすることもできる。この例では
また、重なり合ったバンドを見やすくするために、alpha
を使用する。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
add_predicted_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl), fill = ordered(cyl))) +
stat_lineribbon(aes(y = .prediction), .width = c(.95, .80, .50), alpha = 1/4) +
geom_point(data = mtcars) +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Dark2")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
その他のフィットラインの例については、vignette("tidy-brms")
を参照。
animated hypothetical
outcome plots (HOPs)。
各条件の平均を比較したい場合、compare_levels()
は、ある因子のレベル間のある変数の値の比較を容易にする。デフォルトでは、一対の差分をすべて計算する。
ggdist::stat_halfeye()
を使って
compare_levels()
をデモしてみよう。また
差の平均値で再注文する。
%>%
m spread_draws(b[,,condition], sep = "[, :]") %>%
compare_levels(b, by = condition) %>%
ungroup() %>%
mutate(condition = reorder(condition, b)) %>%
ggplot(aes(y = condition, x = b)) +
stat_halfeye() +
geom_vline(xintercept = 0, linetype = "dashed")
ここでは、カテゴリ予測変数の順序モデルである。
data(esoph)
= stan_polr(tobgp ~ agegp, data = esoph, prior = R2(0.25), prior_counts = rstanarm::dirichlet(1)) m_esoph_rs
rstanarm の順序回帰モデル用の関数
rstanarm::posterior_linpred()
は、各描画のリンクレベルの予測値を返する
(順序モデル用にカテゴリごとに1つの予測値を返す
brms::posterior_epred()
とは対照的に、vignette("tidy-brms")
の順序回帰の例を参照)
。残念ながら、rstanarm::posterior_epred()
はこの形式を提供していない。tidybayes
の理念は、モデルによって出力されるどんな形式でも整頓することである。その理念に沿って、順序モデル
rstanarm
に適用する場合、add_linpred_draws()
の例を使用し、カテゴリごとの予測確率に変換する方法を紹介する。
例えば、ここにリンクレベルのフィットをプロットしたものがある。
%>%
esoph data_grid(agegp) %>%
add_linpred_draws(m_esoph_rs) %>%
ggplot(aes(x = as.numeric(agegp), y = .linpred)) +
stat_lineribbon() +
scale_fill_brewer(palette = "Greys")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
これを解釈するのは難しいだろう。これをカテゴリごとの予測確率に変えるには、順序ロジスティック回帰がカテゴリ \(j\) 以下 の結果の確率を次のように定義していることを利用する必要がある。
\[ \textrm{logit}\left[Pr(Y\le j)\right] = \alpha_j - \beta x \]
したがって、カテゴリ \(j\) の確率は。
\[ \begin{align} Pr(Y = j) &= Pr(Y \le j) - Pr(Y \le j - 1)\\ &= \textrm{logit}^{-1}(\alpha_j - \beta x) - \textrm{logit}^{-1}(\alpha_{j-1} - \beta x) \end{align} \]
この値を導き出すには、2つのことが必要である。
\(\alpha_j\) の値である。これはモデルによって適合された閾値パラメータである。便宜上、 \(k\) のレベルがある場合、トップレベルかそれ以下の確率が1であるため、 \(\alpha_k = +\infty\) を取ることにする。
\(\beta x\) の値である。これらは
add_linpred_draws()
によって返される .linpred
の値だけである。
rstanarm
の閾値は、|
を含む名前を持つ係数で、どのカテゴリ間の閾値であるかを示している。これらのパラメータは、モデル中の変数のリストで見ることができる。
get_variables(m_esoph_rs)
## [1] "agegp.L" "agegp.Q" "agegp.C" "agegp^4" "agegp^5" "0-9g/day|10-19"
## [7] "10-19|20-29" "20-29|30+" "accept_stat__" "stepsize__" "treedepth__" "n_leapfrog__"
## [13] "divergent__" "energy__"
gather_draws()
の regex = TRUE
引数を使用して、|
の文字を含むすべての変数を見つけることによって、これらを自動的に抽出することができる。次に、dplyr::summarise_all(list)
を使ってこれらの閾値をリスト列に変換し、 \(+\infty\) に等しい最終閾値を追加する
(最高カテゴリーを表すため) 。
= m_esoph_rs %>%
thresholds gather_draws(`.*[|].*`, regex = TRUE) %>%
group_by(.draw) %>%
select(.draw, threshold = .value) %>%
summarise_all(list) %>%
mutate(threshold = map(threshold, ~ c(., Inf)))
head(thresholds, 10)
.draw | threshold |
---|---|
1 | -0.8200002, 0.6020005, 1.4201284, Inf |
2 | -1.1563621, 0.3755944, 1.4963719, Inf |
3 | -0.7694691, 0.1466003, 1.2931992, Inf |
4 | -0.8780252, 0.6834725, 1.4461758, Inf |
5 | -1.1254045, 0.2542654, 1.0647593, Inf |
6 | -1.2646026, 0.3302284, 0.9631856, Inf |
7 | -0.7925928, 0.1191108, 1.5750053, Inf |
8 | -1.162030, 0.320501, 1.026631, Inf |
9 | -0.8444674, 0.1494691, 1.6880150, Inf |
10 | -1.110080, 0.321730, 1.102269, Inf |
例えば、このデータフレームの1行から (つまり、事後的な1ドローから) の閾値ベクタは次のようになる。
1,]$threshold thresholds[
## [[1]]
## [1] -0.8200002 0.6020005 1.4201284 Inf
これらの閾値 (上式の \(\alpha_j\) )
と add_linpred_draws()
(上式の \(\beta x\) ) の .linpred
列を組み合わせて、カテゴリごとの確率を計算することができる。
%>%
esoph data_grid(agegp) %>%
add_linpred_draws(m_esoph_rs) %>%
inner_join(thresholds, by = ".draw") %>%
mutate(`P(Y = category)` = map2(threshold, .linpred, function(alpha, beta_x)
# this part is logit^-1(alpha_j - beta*x) - logit^-1(alpha_j-1 - beta*x)
plogis(alpha - beta_x) -
plogis(lag(alpha, default = -Inf) - beta_x)
%>%
)) mutate(.category = list(levels(esoph$tobgp))) %>%
unnest(c(threshold, `P(Y = category)`, .category)) %>%
ggplot(aes(x = agegp, y = `P(Y = category)`, color = .category)) +
stat_pointinterval(position = position_dodge(width = .4)) +
scale_size_continuous(guide = "none") +
scale_color_manual(values = brewer.pal(6, "Blues")[-c(1,2)])
上のプロットでは、カテゴリの変化がわかりにくいので、各年度の分布がよくわかるようなものを考えてみよう。
= esoph %>%
esoph_plot data_grid(agegp) %>%
add_linpred_draws(m_esoph_rs) %>%
inner_join(thresholds, by = ".draw") %>%
mutate(`P(Y = category)` = map2(threshold, .linpred, function(alpha, beta_x)
# this part is logit^-1(alpha_j - beta*x) - logit^-1(alpha_j-1 - beta*x)
plogis(alpha - beta_x) -
plogis(lag(alpha, default = -Inf) - beta_x)
%>%
)) mutate(.category = list(levels(esoph$tobgp))) %>%
unnest(c(threshold, `P(Y = category)`, .category)) %>%
ggplot(aes(x = `P(Y = category)`, y = .category)) +
coord_cartesian(expand = FALSE) +
facet_grid(. ~ agegp, switch = "x") +
theme_classic() +
theme(strip.background = element_blank(), strip.placement = "outside") +
ggtitle("P(tobacco consumption category | age group)") +
xlab("age group")
+
esoph_plot stat_summary(fun = median, geom = "bar", fill = "gray65", width = 1, color = "white") +
stat_pointinterval()
この場合、棒グラフは誤った精度感を与える可能性があるので、代わりにCCDF棒グラフを試してみることもできる。
+
esoph_plot stat_ccdfinterval() +
expand_limits(x = 0) #ensure bars go to 0
この出力は、vignette("tidy-brms")
の対応する
m_esoph_brm
モデルからの出力と非常によく似ているはずである
(異なるプライヤーを適用している) 。ただし、rstanarm
では
brms
と比較して、出力に少し手間がかかるようになっている。