この vignette は、tidybayes
と ggdist
パッケージを使用して、 tidy
データフレームを抽出し、可視化する方法について説明する。モデル変数の事後分布からのドロー、平均、および
brms::brm
からの予測値である。より一般的な
tidybayes
の紹介と、汎用のベイズモデリング言語 (Stan や
JAGS など) での使用については、vignette("tidybayes")
を参照。
このvignetteを実行するには、以下のライブラリが必要である。
library(magrittr)
library(dplyr)
library(purrr)
library(forcats)
library(tidyr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(brms)
library(ggrepel)
library(RColorBrewer)
library(gganimate)
library(posterior)
theme_set(theme_tidybayes() + panel_border())
これらのオプションは、Stanの動作を高速化するためのものである。
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
tidybayes
を実証するために、5つの条件からそれぞれ10個のオブザベーションを持つ単純なデータセットを使用することにする。
set.seed(5)
= 10
n = 5
n_condition =
ABC tibble(
condition = rep(c("A","B","C","D","E"), n),
response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
)
データのスナップショットはこのような感じである。
head(ABC, 10)
condition | response |
---|---|
A | -0.4204277 |
B | 1.6921797 |
C | 1.3722541 |
D | 1.0350714 |
E | -0.1442796 |
A | -0.3014540 |
B | 0.7639168 |
C | 1.6823143 |
D | 0.8571132 |
E | -0.9309459 |
これは典型的な整頓されたフォーマットのデータフレームです: 1行に1つの観測値である。グラフィカルに
%>%
ABC ggplot(aes(y = condition, x = response)) +
geom_point()
#モデル {#model}
大域平均に向かって収縮する階層的モデルを当てはめよう。
= brm(
m ~ (1|condition),
response data = ABC,
prior = c(
prior(normal(0, 1), class = Intercept),
prior(student_t(3, 0, 1), class = sd),
prior(student_t(3, 0, 1), class = sigma)
),control = list(adapt_delta = .99),
file = "models/tidy-brms_m.rds" # cache model (can be removed)
)
結果はこのようになる。
m
## Family: gaussian
## Links: mu = identity; sigma = identity
## Formula: response ~ (1 | condition)
## Data: ABC (Number of observations: 50)
## Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
## total post-warmup draws = 4000
##
## Group-Level Effects:
## ~condition (Number of levels: 5)
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept) 1.17 0.47 0.61 2.31 1.00 717 870
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept 0.49 0.46 -0.48 1.41 1.00 876 1247
##
## Family Specific Parameters:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma 0.56 0.06 0.46 0.69 1.00 1814 2186
##
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
spread_draws
を使ってフィットからドローをtidyフォーマットで抽出するさて、結果が出たので、楽しいことが始まります:整頓されたフォーマットで描画を取り出すことである!まず、
関数を使って、生のモデル変数名のリストを取得する。まず、get_variables()
関数を使って、生のモデル変数名のリストを取得し、モデルからどの変数を抽出できるかを知ることにする。
get_variables(m)
## [1] "b_Intercept" "sd_condition__Intercept" "sigma" "r_condition[A,Intercept]"
## [5] "r_condition[B,Intercept]" "r_condition[C,Intercept]" "r_condition[D,Intercept]" "r_condition[E,Intercept]"
## [9] "lprior" "lp__" "accept_stat__" "stepsize__"
## [13] "treedepth__" "n_leapfrog__" "divergent__" "energy__"
ここで、b_Intercept
は大域的な平均値、r_condition[]
変数は各条件におけるその平均値からのオフセットである。これらの変数が与えられると
r_condition [A,Intercept]
r_condition [B,Intercept]
r_condition [C,Intercept]
r_condition [D,Intercept]
r_condition [E,Intercept]
各行がどちらか一方からの描画であるデータフレームが必要だろう。
r_condition [A,Intercept]
,
r_condition [B,Intercept]
, ... [C,...]
,
... [D,...]
または ... [E,...]
そして、その行がどのチェーン/繰り返し/描画から来たもので、どの条件(
A
to E
)のものであるかのインデックスを持つ列があるようなデータフレームが欲しいだろう。これにより、条件ごとにグループ化された量を簡単に計算したり、ggplotを使って条件ごとにプロットを作成したり、あるいは、元のデータとドローをマージしてデータと後置を同時にプロットしたりすることができるようになる。
tidybayes
の主力は spread_draws()
関数で、この関数はこの抽出を行ってくれる。この関数には、変数とそのインデックスを整頓されたデータフレームに抽出するために使用できる簡単な仕様書式が含まれている。
このようにモデル内の変数が与えられると
r_condition [D,Intercept]
このような列の仕様で spread_draws()
を提供することができる。
r_condition [condition,term]
ここで、condition
は D
に対応し、term
は Intercept
に対応する。この指定に対して spread_draws()
が行うことは、何も不思議なことではない。この仕様では、変数インデックスをカンマとスペースで分割する
( sep
引数を変更すれば、他の文字で分割できる)
。そして、その結果得られたインデックスに、順番に列を割り当てることができる。つまり
r_condition [D,Intercept]
はインデックス D
と
Intercept
を持ち、spread_draws()
はこれらのインデックスを列として抽出し、r_condition
からのドローの整頓されたデータフレームを生成することができる。
%>%
m spread_draws(r_condition[condition,term]) %>%
head(10)
## Warning: `gather_()` was deprecated in tidyr 1.2.0.
## ℹ Please use `gather()` instead.
## ℹ The deprecated feature was likely used in the tidybayes package.
## Please report the issue at <]8;;https://github.com/mjskay/tidybayes/issues/newhttps://github.com/mjskay/tidybayes/issues/new]8;;>.
condition | term | r_condition | .chain | .iteration | .draw |
---|---|---|---|---|---|
A | Intercept | -0.5316842 | 1 | 1 | 1 |
A | Intercept | 0.0585095 | 1 | 2 | 2 |
A | Intercept | 0.4812017 | 1 | 3 | 3 |
A | Intercept | 0.2272703 | 1 | 4 | 4 |
A | Intercept | 0.1997897 | 1 | 5 | 5 |
A | Intercept | -0.3346687 | 1 | 6 | 6 |
A | Intercept | -0.5019515 | 1 | 7 | 7 |
A | Intercept | -0.4542464 | 1 | 8 | 8 |
A | Intercept | -0.3041030 | 1 | 9 | 9 |
A | Intercept | -0.4556835 | 1 | 10 | 10 |
インデックス列には好きな名前を付けることができる。
%>%
m spread_draws(r_condition[c,t]) %>%
head(10)
c | t | r_condition | .chain | .iteration | .draw |
---|---|---|---|---|---|
A | Intercept | -0.5316842 | 1 | 1 | 1 |
A | Intercept | 0.0585095 | 1 | 2 | 2 |
A | Intercept | 0.4812017 | 1 | 3 | 3 |
A | Intercept | 0.2272703 | 1 | 4 | 4 |
A | Intercept | 0.1997897 | 1 | 5 | 5 |
A | Intercept | -0.3346687 | 1 | 6 | 6 |
A | Intercept | -0.5019515 | 1 | 7 | 7 |
A | Intercept | -0.4542464 | 1 | 8 | 8 |
A | Intercept | -0.3041030 | 1 | 9 | 9 |
A | Intercept | -0.4556835 | 1 | 10 | 10 |
しかし、前の例のような、より説明的で暗号化されていない名前の方が望ましいだろう。
このモデルでは、項は1つしかない( Intercept
)ので、そのインデックスを完全に省略して、各項目 condition
とその条件に対する r_condition
の値だけを取得することができる。
%>%
m spread_draws(r_condition[condition,]) %>%
head(10)
condition | r_condition | .chain | .iteration | .draw |
---|---|---|---|---|
A | -0.5316842 | 1 | 1 | 1 |
A | 0.0585095 | 1 | 2 | 2 |
A | 0.4812017 | 1 | 3 | 3 |
A | 0.2272703 | 1 | 4 | 4 |
A | 0.1997897 | 1 | 5 | 5 |
A | -0.3346687 | 1 | 6 | 6 |
A | -0.5019515 | 1 | 7 | 7 |
A | -0.4542464 | 1 | 8 | 8 |
A | -0.3041030 | 1 | 9 | 9 |
A | -0.4556835 | 1 | 10 | 10 |
注: spread_draws()
を Stan または JAGS
の生のサンプルで使用したことがある場合、spread_draws()
の前に recover_types
を使用してインデックス列の値を取得することに慣れているだろう
(例えば、インデックスが要因であった場合) 。rstanarm
モデルで spread_draws()
を使用する場合、これらのモデルにはすでにその情報が変数名として含まれているため、この操作は必要ない。recover_types
の詳細については、vignette("tidybayes")
を参照。
tidybayes
は、描画から点の要約や区間を整然としたフォーマットで生成するための関数群を提供する。これらの関数は、以下の命名規則に従っている。
[median|mean|mode] _ [qi|hdi]
例えば、median_qi()
, mean_qi()
,
mode_hdi()
, などである。最初の名前 ( _
の前)
は点要約のタイプを示し、2番目の名前は区間のタイプを示す。qi
は分位点間隔 (別名:等値線間隔、中心間隔、パーセンタイル間隔)
を生成し、hdi
は最高 (事後)
密度間隔を生成する。カスタム点要約または区間関数は、point_interval()
関数を用いて適用することもできる。
例えば、オブザベーションの全体平均と標準偏差の事後分布に対応するドローを抽出することができる。
%>%
m spread_draws(b_Intercept, sigma) %>%
head(10)
.chain | .iteration | .draw | b_Intercept | sigma |
---|---|---|---|---|
1 | 1 | 1 | 0.5989321 | 0.5067684 |
1 | 2 | 2 | -0.1737771 | 0.5259455 |
1 | 3 | 3 | -0.0818084 | 0.5234344 |
1 | 4 | 4 | -0.0438235 | 0.5305265 |
1 | 5 | 5 | -0.0731284 | 0.5306666 |
1 | 6 | 6 | 0.6052858 | 0.5264126 |
1 | 7 | 7 | 0.7156502 | 0.5228760 |
1 | 8 | 8 | 0.5087900 | 0.6289509 |
1 | 9 | 9 | 0.5274868 | 0.5604209 |
1 | 10 | 10 | 0.6009850 | 0.5408009 |
と同様に r_condition [condition,term]
と同様、整頓されたデータフレームが得られる。変数の中央値と95%分位間隔が欲しい場合は、median_qi()
を適用する。
%>%
m spread_draws(b_Intercept, sigma) %>%
median_qi(b_Intercept, sigma)
b_Intercept | b_Intercept.lower | b_Intercept.upper | sigma | sigma.lower | sigma.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.4992407 | -0.477582 | 1.408877 | 0.5562745 | 0.4613838 | 0.6945521 | 0.95 | median | qi |
上記のように、中央値と区間を取得したい列を指定することができるが、列のリストを省略すると、median_qi()
は、グループ化列や特殊列ではないすべての列 ( .chain
,
.iteration
, または .draw
のような)
を使用する。したがって、上記の例では、b_Intercept
と
sigma
は、モデルから収集した唯一の列でもあるため、median_qi()
の冗長な引数となっている。そこで、次のように単純化することができる。
%>%
m spread_draws(b_Intercept, sigma) %>%
median_qi()
b_Intercept | b_Intercept.lower | b_Intercept.upper | sigma | sigma.lower | sigma.upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
0.4992407 | -0.477582 | 1.408877 | 0.5562745 | 0.4613838 | 0.6945521 | 0.95 | median | qi |
もし、長い形式のインターバルリストが欲しい場合は、代わりに
gather_draws()
を使ってみよう。
%>%
m gather_draws(b_Intercept, sigma) %>%
median_qi()
.variable | .value | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
b_Intercept | 0.4992407 | -0.4775820 | 1.4088766 | 0.95 | median | qi |
sigma | 0.5562745 | 0.4613838 | 0.6945521 | 0.95 | median | qi |
gather_draws()
の詳細については、vignette("tidybayes")
を参照。
r_condition
のような1つ以上のインデックスを持つモデル変数があるとき、前と同じように
median_qi()
(または point_interval()
ファミリーの他の関数) を適用することができる。
%>%
m spread_draws(r_condition[condition,]) %>%
median_qi()
condition | r_condition | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | -0.3050387 | -1.2623359 | 0.7042673 | 0.95 | median | qi |
B | 0.4996977 | -0.4445127 | 1.5156880 | 0.95 | median | qi |
C | 1.3369944 | 0.4167176 | 2.3585026 | 0.95 | median | qi |
D | 0.5134115 | -0.4259470 | 1.5452761 | 0.95 | median | qi |
E | -1.3775040 | -2.3418243 | -0.3958398 | 0.95 | median | qi |
median_qi()
はどのようにして集計する対象を知ったのだろうか?
spread_draws()
によって返されたデータフレームは、あなたが渡したすべてのインデックス変数によって自動的にグループ化される。この場合、spread_draws()
はその結果を condition
によってグループ化することを意味する。median_qi()
はこれらのグループを尊重し、すべてのグループ内のポイントの要約と間隔を計算する。そして、median_qi()
には列が渡されなかったので、唯一の非特殊列 ( .
-prefixed)
、非グループ列である r_condition
に対して処理を行う。したがって、上記の短縮された構文は、より冗長なこの呼び出しと等価である。
%>%
m spread_draws(r_condition[condition,]) %>%
group_by(condition) %>% # this line not necessary (done by spread_draws)
median_qi(r_condition) # b is not necessary (it is the only non-group column)
condition | r_condition | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | -0.3050387 | -1.2623359 | 0.7042673 | 0.95 | median | qi |
B | 0.4996977 | -0.4445127 | 1.5156880 | 0.95 | median | qi |
C | 1.3369944 | 0.4167176 | 2.3585026 | 0.95 | median | qi |
D | 0.5134115 | -0.4259470 | 1.5452761 | 0.95 | median | qi |
E | -1.3775040 | -2.3418243 | -0.3958398 | 0.95 | median | qi |
tidybayes
は、 の実装も提供している。
posterior::summarise_draws()
グループ化されたデータフレーム
( tidybayes::summaries_draws.grouped_df()
)
を作成することができる。
を使用すると、コンバージェンス診断が迅速に行える。
%>%
m spread_draws(r_condition[condition,]) %>%
summarise_draws()
condition | variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|---|
A | r_condition | -0.2941127 | -0.3050387 | 0.4839759 | 0.4477507 | -1.0717280 | 0.5003624 | 1.000754 | 942.7064 | 1336.319 |
B | r_condition | 0.5118595 | 0.4996977 | 0.4837979 | 0.4434061 | -0.2640635 | 1.3221608 | 1.001596 | 910.2851 | 1512.549 |
C | r_condition | 1.3438245 | 1.3369944 | 0.4850489 | 0.4420331 | 0.5766851 | 2.1565351 | 1.001881 | 979.4256 | 1469.083 |
D | r_condition | 0.5245269 | 0.5134115 | 0.4859456 | 0.4471457 | -0.2692829 | 1.3529230 | 1.002230 | 930.1624 | 1279.640 |
E | r_condition | -1.3782117 | -1.3775040 | 0.4835818 | 0.4532613 | -2.1710976 | -0.5897791 | 1.001688 | 941.5406 | 1395.414 |
spread_draws()
と
は、異なるインデックスを持つ変数を同じデータフレームに抽出することをサポートしている。同じ名前のインデックスは自動的にマッチングされ、必要に応じて値が複製され、すべてのインデックスのすべてのレベルの組み合わせごとに1つの行が作成される。例えば、各条件における平均を計算したい場合がある
(これを と呼ぶ) 。このモデルでは、その平均は切片(
)と与えられた条件での効果( )である。
gather_draws()
condition_mean
b_Intercept
r_condition
b_Intercept
と r_condition
からの描画を一つのデータフレームにまとめることができる。
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
head(10)
.chain | .iteration | .draw | b_Intercept | condition | r_condition |
---|---|---|---|---|---|
1 | 1 | 1 | 0.5989321 | A | -0.5316842 |
1 | 1 | 1 | 0.5989321 | B | 0.3296616 |
1 | 1 | 1 | 0.5989321 | C | 1.0567015 |
1 | 1 | 1 | 0.5989321 | D | 0.4020400 |
1 | 1 | 1 | 0.5989321 | E | -1.2508539 |
1 | 2 | 2 | -0.1737771 | A | 0.0585095 |
1 | 2 | 2 | -0.1737771 | B | 0.9002136 |
1 | 2 | 2 | -0.1737771 | C | 1.9663147 |
1 | 2 | 2 | -0.1737771 | D | 1.4094724 |
1 | 2 | 2 | -0.1737771 | E | -0.8438687 |
各抽選の中で、b_Intercept
は r_condition
の各インデックスに対応するように必要に応じて繰り返される。
したがって、dplyr の mutate
関数を使用して、それらの合計
condition_mean
(各条件の平均) を求めることができる。
%>%
m spread_draws(`b_Intercept`, r_condition[condition,]) %>%
mutate(condition_mean = b_Intercept + r_condition) %>%
median_qi(condition_mean)
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1946425 | -0.1533118 | 0.5569814 | 0.95 | median | qi |
B | 1.0009559 | 0.6633324 | 1.3356618 | 0.95 | median | qi |
C | 1.8338983 | 1.4690128 | 2.1777677 | 0.95 | median | qi |
D | 1.0168412 | 0.6623191 | 1.3605038 | 0.95 | median | qi |
E | -0.8935669 | -1.2326801 | -0.5357873 | 0.95 | median | qi |
median_qi()
は整頓された評価 (
vignette("tidy-evaluation ", package =" rlang")
を使うので、列名だけでなく、列の式も受け取ることができる。したがって、condition_mean
の計算を mutate
から median_qi()
に移すことで、上の例を単純化することができる。
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
median_qi(condition_mean = b_Intercept + r_condition)
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1946425 | -0.1533118 | 0.5569814 | 0.95 | median | qi |
B | 1.0009559 | 0.6633324 | 1.3356618 | 0.95 | median | qi |
C | 1.8338983 | 1.4690128 | 2.1777677 | 0.95 | median | qi |
D | 1.0168412 | 0.6623191 | 1.3605038 | 0.95 | median | qi |
E | -0.8935669 | -1.2326801 | -0.5357873 | 0.95 | median | qi |
median_qi()
とその姉妹関数は、
引数を設定することで、任意の数の確率区間を生成することができる。
.width =
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
median_qi(condition_mean = b_Intercept + r_condition, .width = c(.95, .8, .5))
condition | condition_mean | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|
A | 0.1946425 | -0.1533118 | 0.5569814 | 0.95 | median | qi |
B | 1.0009559 | 0.6633324 | 1.3356618 | 0.95 | median | qi |
C | 1.8338983 | 1.4690128 | 2.1777677 | 0.95 | median | qi |
D | 1.0168412 | 0.6623191 | 1.3605038 | 0.95 | median | qi |
E | -0.8935669 | -1.2326801 | -0.5357873 | 0.95 | median | qi |
A | 0.1946425 | -0.0374505 | 0.4237867 | 0.80 | median | qi |
B | 1.0009559 | 0.7781791 | 1.2235742 | 0.80 | median | qi |
C | 1.8338983 | 1.6097208 | 2.0638043 | 0.80 | median | qi |
D | 1.0168412 | 0.7850910 | 1.2380783 | 0.80 | median | qi |
E | -0.8935669 | -1.1149866 | -0.6554495 | 0.80 | median | qi |
A | 0.1946425 | 0.0738242 | 0.3115411 | 0.50 | median | qi |
B | 1.0009559 | 0.8844050 | 1.1174733 | 0.50 | median | qi |
C | 1.8338983 | 1.7129847 | 1.9530778 | 0.50 | median | qi |
D | 1.0168412 | 0.8934155 | 1.1362549 | 0.50 | median | qi |
E | -0.8935669 | -1.0120600 | -0.7714177 | 0.50 | median | qi |
結果は整然とした形式で、1グループにつき1行、不確定性区間幅(
.width
)である。これはプロットを容易にする。たとえば、-.width
を
size
の美学に割り当てると、すべての区間が表示され、太い線がより小さい区間に対応するようになる。ggdist::geom_pointinterval()
geomは、複数の確率レベルを持つ点のプロットを作成するために、データ中の
.width
列に基づいて、size
美学を自動的に適切に設定する。
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
median_qi(condition_mean = b_Intercept + r_condition, .width = c(.95, .66)) %>%
ggplot(aes(y = condition, x = condition_mean, xmin = .lower, xmax = .upper)) +
geom_pointinterval()
## Warning: Using the `size` aesthietic with geom_segment was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
区間とともに密度を見るには、ggdist::stat_eye()
(区間とバイオリンプロットを組み合わせた「アイプロット」) 、または
ggdist::stat_halfeye()
(区間+密度プロット)
を使用することができる。
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
mutate(condition_mean = b_Intercept + r_condition) %>%
ggplot(aes(y = condition, x = condition_mean)) +
stat_halfeye()
あるいは、密度の一部をカラーでアノテートしたいとする。fill
の美学は、ggdist::stat_halfeye()
を含む
ggdist::geom_slabinterval()
ファミリーのすべてのジオムおよび統計において、スラブ内で変化させることができる。例えば、ドメイン固有のROPE
(region of practical equivalence)
をアノテーションしたい場合、次のようなことが可能である。
%>%
m spread_draws(b_Intercept, r_condition[condition,]) %>%
mutate(condition_mean = b_Intercept + r_condition) %>%
ggplot(aes(y = condition, x = condition_mean, fill = stat(abs(x) < .8))) +
stat_halfeye() +
geom_vline(xintercept = c(-.8, .8), linetype = "dashed") +
scale_fill_manual(values = c("gray80", "skyblue"))
## Warning: `stat(abs(x) < 0.8)` was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(abs(x) < 0.8)` instead.
stat_slabinterval
分布を視覚化するための様々な追加統計が、ggdist::geom_slabinterval()
ファミリーの統計とジオムにはある。
を参照。 vignette("slabinterval ", package =" ggdist")
を参照。
前の例のように条件付き平均を手動で計算するのではなく、brms::posterior_epred()
(事後予測値の期待値から事後ドローを与える;すなわち、条件付き平均の事後分布)
に類似しているが、整然としたデータ形式を使用する
add_epred_draws()
を使用することができる。これを
modelr::data_grid()
と組み合わせて、まず欲しい予測を記述したグリッドを生成し、次にそのグリッドを条件付き平均からのドローの長大なデータフレームに変換することができる。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
head(10)
condition | .row | .chain | .iteration | .draw | .epred |
---|---|---|---|---|---|
A | 1 | NA | NA | 1 | 0.0672479 |
A | 1 | NA | NA | 2 | -0.1152676 |
A | 1 | NA | NA | 3 | 0.3993933 |
A | 1 | NA | NA | 4 | 0.1834468 |
A | 1 | NA | NA | 5 | 0.1266612 |
A | 1 | NA | NA | 6 | 0.2706171 |
A | 1 | NA | NA | 7 | 0.2136986 |
A | 1 | NA | NA | 8 | 0.0545436 |
A | 1 | NA | NA | 9 | 0.2233838 |
A | 1 | NA | NA | 10 | 0.1453015 |
この例をプロットするために、ggplot 内で描画を点と区間にまとめる
ggdist::geom_pointinterval()
の代わりに
ggdist::stat_pointinterval()
を使用することも示す。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
ggplot(aes(x = .epred, y = condition)) +
stat_pointinterval(.width = c(.66, .95))
アルファレベルがたまたまあなたが行おうとしている決定と一致する場合、区間は良いが、事後的な形状を得ることはより良いことである (したがって、上記の目のプロット) 。一方、密度プロットから推論するのは不正確である (ある形状の面積を別の形状の割合で推定するのは難しい知覚的作業である) 。頻度形式での確率の推論はより簡単で、 quantile dotplots ( Kay et al. 2016, Fernandes et al. 2018)の動機となった。これは任意の間隔 (プロットのドット解像度まで、下の例では100) を正確に推定することも可能である。
tidybayes のジオムの slabinterval ファミリーには dots
と
dotsinterval
ファミリーがあり、ドットプロットの適切なビンサイズを自動的に決定し、サンプルから分位を計算して分位ドットプロットを構築できる。ggdist::stat_dotsinterval()
はサンプルに使用するために設計されたバリアントである。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m) %>%
ggplot(aes(x = .epred, y = condition)) +
stat_dotsinterval(quantiles = 100)
つまり、事後情報を1つの標準的な点または区間として考えるのではなく、 (例えば) 100のほぼ等しい可能性のある点として表現することである。
ここで、add_epred_draws()
は
brms::posterior_epred()
に類似しており、add_predicted_draws()
は
brms::posterior_predict()
に類似しており、事後予測分布からのドローを与えている。
以下は、ggdist::stat_slab()
を用いてプロットした事後予測分布の例である。
%>%
ABC data_grid(condition) %>%
add_predicted_draws(m) %>%
ggplot(aes(x = .prediction, y = condition)) +
stat_slab()
また、ggdist::stat_interval()
を使って、データと一緒に予測バンドをプロットすることもできる。
%>%
ABC data_grid(condition) %>%
add_predicted_draws(m) %>%
ggplot(aes(y = condition, x = .prediction)) +
stat_interval(.width = c(.50, .80, .95, .99)) +
geom_point(aes(x = response), data = ABC) +
scale_color_brewer()
合わせて、データ、事後予測、平均の事後分布。
= ABC %>%
grid data_grid(condition)
= grid %>%
means add_epred_draws(m)
= grid %>%
preds add_predicted_draws(m)
%>%
ABC ggplot(aes(y = condition, x = response)) +
stat_interval(aes(x = .prediction), data = preds) +
stat_pointinterval(aes(x = .epred), data = means, .width = c(.66, .95), position = position_nudge(y = -0.3)) +
geom_point() +
scale_color_brewer()
事後予測に対する上記のアプローチは、単一の事後予測分布を与えるために、パラメータの不確実性の上に統合される。もう一つのアプローチは、John Kruschkeが彼の著書 Doing Bayesian Data Analysisでよく使っているもので、事後予測によって暗示されるいくつかの可能な予測分布を示すことによって、予測の不確かさとパラメータの不確かさの両方を同時に示すことを試みるものである。
これは、ある予測値に対する分布パラメータを求めることで、非常に簡単に行うことができる。ここでは、明示的に
dpar = c("mu "," sigma")
add_epred_draws()
を設定することで明示的に行う。明示的にパラメータを指定するのではなく、dpar = TRUE
を設定して、モデル内のすべての分布パラメータからドローを取得することもできる。これはbrmsがサポートするすべての応答分布に対して機能する。そして、sample_draws()
を使って少数のドローを選択し、ggdist::stat_dist_slab()
を使って、mu
と sigma
の値によって暗示される各予測分布を可視化することができるのである。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
sample_draws(30) %>%
ggplot(aes(y = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA
+
) geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)
これらのチャート (およびその便利なバリエーション) のより詳しい説明については、 Solomon Kurz’s excellent blog post on the topic を参照。
予測分布のクルーシュケ風プロットと、事後平均を示す半眼を組み合わせることもできる。
%>%
ABC data_grid(condition) %>%
add_epred_draws(m, dpar = c("mu", "sigma")) %>%
ggplot(aes(x = condition)) +
stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma),
slab_color = "gray65", alpha = 1/10, fill = NA, data = . %>% sample_draws(30), scale = .5
+
) stat_halfeye(aes(y = .epred), side = "bottom", scale = .5) +
geom_point(aes(y = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2, position = position_nudge(x = -.2))
不確実性を伴う適合曲線の描画を実証するために、mtcars
のデータセットの一部に、少し素朴なモデルを適合させてみよう。
= brm(
m_mpg ~ hp * cyl,
mpg data = mtcars,
file = "models/tidy-brms_m_mpg.rds" # cache model (can be removed)
)
確率帯を使ったフィットカーブを描くことができる。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 51)) %>%
add_epred_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
stat_lineribbon(aes(y = .epred)) +
geom_point(data = mtcars) +
scale_fill_brewer(palette = "Greys") +
scale_color_brewer(palette = "Set2")
## Warning: Using the `size` aesthietic with geom_ribbon was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Using the `size` aesthietic with geom_line was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
あるいは、適当な数のフィットラインをサンプリングして (たとえば100本) 、それらをオーバープロットすることもできる。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
# NOTE: this shows the use of ndraws to subsample within add_epred_draws()
# ONLY do this IF you are planning to make spaghetti plots, etc.
# NEVER subsample to a small sample to plot intervals, densities, etc.
add_epred_draws(m_mpg, ndraws = 100) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = .1) +
geom_point(data = mtcars) +
scale_color_brewer(palette = "Dark2")
あるいは、フィットしたラインのアニメーション hypothetical outcome plots (HOPs) を作成することもできる。訳註:エラーのため eval=FALSE
set.seed(123456)
# NOTE: using a small number of draws to keep this example
# small, but in practice you probably want 50 or 100
= 20
ndraws
= mtcars %>%
p group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
add_epred_draws(m_mpg, ndraws = ndraws) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
geom_line(aes(y = .epred, group = paste(cyl, .draw))) +
geom_point(data = mtcars) +
scale_color_brewer(palette = "Dark2") +
transition_states(.draw, 0, 1) +
shadow_mark(future = TRUE, color = "gray50", alpha = 1/20)
animate(p, nframes = ndraws, fps = 2.5, width = 432, height = 288, res = 96, dev = "png", type = "cairo")
あるいは、 (平均値ではなく)
事後予測値をプロットすることもできる。この例では
また、alpha
、重なり合ったバンドを見やすくするために使用する。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
add_predicted_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg, color = ordered(cyl), fill = ordered(cyl))) +
stat_lineribbon(aes(y = .prediction), .width = c(.95, .80, .50), alpha = 1/4) +
geom_point(data = mtcars) +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Dark2")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
これはグループごとに判断するのは難しいので、おそらく複数のプロットにファセットするのがよいだろう。幸い、ggplotを使っているので、そのような関数が組み込まれている。
%>%
mtcars group_by(cyl) %>%
data_grid(hp = seq_range(hp, n = 101)) %>%
add_predicted_draws(m_mpg) %>%
ggplot(aes(x = hp, y = mpg)) +
stat_lineribbon(aes(y = .prediction), .width = c(.99, .95, .8, .5), color = brewer.pal(5, "Blues")[[5]]) +
geom_point(data = mtcars) +
scale_fill_brewer() +
facet_grid(. ~ cyl, space = "free_x", scales = "free_x")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
brms::brm()
は、場所 (たとえば、平均)
以外の応答分布のパラメータのためのサブモデルをセットアップすることもできる。たとえば、標準偏差のような分散パラメータが、予測変数の関数であることを許すことができる。
この方法は、分散が一定でない場合 (ラテン語で難読化を好む人々は、_異種分散とも呼ぶ) に有用である。例えば、2つのグループがあり、それぞれが異なる平均応答_と分散を持つことを想像してみてみよう。
set.seed(1234)
= tibble(
AB group = rep(c("a", "b"), each = 20),
response = rnorm(40, mean = rep(c(1, 5), each = 20), sd = rep(c(1, 3), each = 20))
)
%>%
AB ggplot(aes(x = response, y = group)) +
geom_point()
ここでは、response
の平均_と標準偏差_を
group
に依存させるモデルを紹介する。
= brm(
m_ab bf(
~ group,
response ~ group
sigma
),data = AB,
file = "models/tidy-brms_m_ab.rds" # cache model (can be removed)
)
平均の事後分布 response
を事後予測区間とデータと並べてプロットすることができる。
= AB %>%
grid data_grid(group)
= grid %>%
means add_epred_draws(m_ab)
= grid %>%
preds add_predicted_draws(m_ab)
%>%
AB ggplot(aes(x = response, y = group)) +
stat_halfeye(aes(x = .epred), scale = 0.6, position = position_nudge(y = 0.175), data = means) +
stat_interval(aes(x = .prediction), data = preds) +
geom_point(data = AB) +
scale_color_brewer()
これは、各グループの平均値の後置 (黒の区間と密度プロット) および事後予測区間 (青) を示している。
グループ b
の予測区間はグループ a
よりも大きくなっているが、これはモデルが各グループに対して異なる標準偏差を当てはめたことがある。対応する分布パラメータである
sigma
を add_epred_draws()
の引数
dpar
を使って抽出することで、どのように変化するかを見ることができる。
%>%
grid add_epred_draws(m_ab, dpar = TRUE) %>%
ggplot(aes(x = sigma, y = group)) +
stat_halfeye() +
geom_vline(xintercept = 0, linetype = "dashed")
dpar = TRUE
を設定すると、すべての分布パラメータが
add_epred_draws()
の結果に追加列として追加される。特定のパラメータだけが欲しい場合は、それを指定する
(あるいは、欲しいパラメータだけのリストを指定する)
。上記のモデルにおいて、dpar = TRUE
は以下と等価である。
dpar = list("mu "," sigma")
.
各条件の平均を比較したい場合、compare_levels()
は、ある因子のレベル間のある変数の値の比較を容易にする。デフォルトでは、一対の差分をすべて計算する。
ggdist::stat_halfeye()
を使って
compare_levels()
をデモしてみよう。また
差の平均値で再注文する。
%>%
m spread_draws(r_condition[condition,]) %>%
compare_levels(r_condition, by = condition) %>%
ungroup() %>%
mutate(condition = reorder(condition, r_condition)) %>%
ggplot(aes(y = condition, x = r_condition)) +
stat_halfeye() +
geom_vline(xintercept = 0, linetype = "dashed")
brmsの順序回帰モデルおよび多項回帰モデル用の関数
posterior_epred()
は、各抽選に対して複数の変数を返する:各結果カテゴリに対して1つである
(潜在線形予測変数の抽選を返す rstanarm::stan_polr()
モデルとは対照的) 。tidybayes
の理念は、モデルによって出力されるどんなフォーマットでも整頓することである。その理念に従って、順序および多項式
brms
モデルに適用すると、add_epred_draws()
は
.category
という追加の列を追加し、各カテゴリーの変数を含む別の行が、すべてのドローと予測変数について出力される。
mtcars
データセットを使って、車の燃費 (マイル/ガロン)
が与えられたときの車の気筒数を予測するモデルをあてはめることにする。これは少し因果関係が逆であるが
(おそらくシリンダー数が走行距離を引き起こすのだろう)
、だからといってこれは立派な予測タスクではない
(私はおそらく車について何か知っている人に車のMPGを伝えることができ、彼らはエンジンのシリンダー数を当てるのにそれなりにうまくやれるだろう)
。
モデルを適合させる前に、cyl
列を順序付き因子にすることによって、データセットをきれいにしよう
(デフォルトでは単なる数字) 。
= mtcars %>%
mtcars_clean mutate(cyl = ordered(cyl))
head(mtcars_clean)
mpg | cyl | disp | hp | drat | wt | qsec | vs | am | gear | carb | |
---|---|---|---|---|---|---|---|---|---|---|---|
Mazda RX4 | 21.0 | 6 | 160 | 110 | 3.90 | 2.620 | 16.46 | 0 | 1 | 4 | 4 |
Mazda RX4 Wag | 21.0 | 6 | 160 | 110 | 3.90 | 2.875 | 17.02 | 0 | 1 | 4 | 4 |
Datsun 710 | 22.8 | 4 | 108 | 93 | 3.85 | 2.320 | 18.61 | 1 | 1 | 4 | 1 |
Hornet 4 Drive | 21.4 | 6 | 258 | 110 | 3.08 | 3.215 | 19.44 | 1 | 0 | 3 | 1 |
Hornet Sportabout | 18.7 | 8 | 360 | 175 | 3.15 | 3.440 | 17.02 | 0 | 0 | 3 | 2 |
Valiant | 18.1 | 6 | 225 | 105 | 2.76 | 3.460 | 20.22 | 1 | 0 | 3 | 1 |
そして、順序回帰モデルを当てはめる。
= brm(
m_cyl ~ mpg,
cyl data = mtcars_clean,
family = cumulative,
seed = 58393,
file = "models/tidy-brms_m_cyl.rds" # cache model (can be removed)
)
add_epred_draws()
は 列を含み、
は応答がそのカテゴリにある確率の事後分布からのドローを含む。例えば、これはデータセットの1行目のフィットである。
.category
.epred
tibble(mpg = 21) %>%
add_epred_draws(m_cyl) %>%
median_qi(.epred)
mpg | .row | .category | .epred | .lower | .upper | .width | .point | .interval |
---|---|---|---|---|---|---|---|---|
21 | 1 | 4 | 0.3471635 | 0.0891023 | 0.7214931 | 0.95 | median | qi |
21 | 1 | 6 | 0.6223695 | 0.2573263 | 0.8955651 | 0.95 | median | qi |
21 | 1 | 8 | 0.0137229 | 0.0003088 | 0.1256028 | 0.95 | median | qi |
注: .category
変数が元の因子レベルの名前を保持するためには、次のようにする。
は、brms
バージョン 2.15.9
以降を使用している必要がある。
データセットに対して予測される確率のフィットラインをプロットすることができた。
= mtcars_clean %>%
data_plot ggplot(aes(x = mpg, y = cyl, color = cyl)) +
geom_point() +
scale_color_brewer(palette = "Dark2", name = "cyl")
= mtcars_clean %>%
fit_plot data_grid(mpg = seq_range(mpg, n = 101)) %>%
add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl") %>%
ggplot(aes(x = mpg, y = `P(cyl | mpg)`, color = cyl)) +
stat_lineribbon(aes(fill = cyl), alpha = 1/5) +
scale_color_brewer(palette = "Dark2") +
scale_fill_brewer(palette = "Dark2")
plot_grid(ncol = 1, align = "v",
data_plot,
fit_plot )
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
上記の表示では、ある特定の値 mpg
において、cyl
の異なる値に対する P(cyl|mpg)
の相関を見ることはできない。 例えば、P(cyl = 6|mpg = 20)
が高い事後処理の部分では、P(cyl = 4|mpg = 20)
と
P(cyl = 8|mpg = 20)
は低くなければならない (これらの合計は
1 になるはずだことがある) 。
この相関を見る一つの方法は、フィットした線だけに hypothetical outcome plots (HOPs) を使って、リボンから「切り離す」ことだろう (別の方法は、このドキュメントで以前に示したように、線のアンサンブルの上に HOP を使うことである)。アニメーションを使うことで、線がどのように並んだり反対方向に動いたりして、それらの相関のパターンを明らかにすることができる。
# NOTE: using a small number of draws to keep this example
# small, but in practice you probably want 50 or 100
= 20
ndraws
= mtcars_clean %>%
p data_grid(mpg = seq_range(mpg, n = 101)) %>%
add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl") %>%
ggplot(aes(x = mpg, y = `P(cyl | mpg)`, color = cyl)) +
# we remove the `.draw` column from the data for stat_lineribbon so that the same ribbons
# are drawn on every frame (since we use .draw to determine the transitions below)
stat_lineribbon(aes(fill = cyl), alpha = 1/5, color = NA, data = . %>% select(-.draw)) +
# we use sample_draws to subsample at the level of geom_line (rather than for the full dataset
# as in previous HOPs examples) because we need the full set of draws for stat_lineribbon above
geom_line(aes(group = paste(.draw, cyl)), size = 1, data = . %>% sample_draws(ndraws)) +
scale_color_brewer(palette = "Dark2") +
scale_fill_brewer(palette = "Dark2") +
transition_manual(.draw)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
animate(p, nframes = ndraws, fps = 2.5, width = 576, height = 192, res = 96, dev = "png", type = "cairo")
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
線がどのように一緒に動くか、また、どのように一緒に上下に動くか、あるいは反対に動くかに注目してみよう。上のグラフのxの位置でこれらの線をスライスして
(例えば、mpg = 20
)
、散布図行列を使ってそれらの間の相関関係を見ることができる。
tibble(mpg = 20) %>%
add_epred_draws(m_cyl, value = "P(cyl | mpg = 20)", category = "cyl") %>%
ungroup() %>%
select(.draw, cyl, `P(cyl | mpg = 20)`) %>%
gather_pairs(cyl, `P(cyl | mpg = 20)`, triangle = "both") %>%
filter(.row != .col) %>%
ggplot(aes(.x, .y)) +
geom_point(alpha = 1/50) +
facet_grid(.row ~ .col) +
ylab("P(cyl = row | mpg = 20)") +
xlab("P(cyl = col | mpg = 20)")
順序分布の平均について話すことは、しばしば意味を持たないが、この特定のケースでは、ガロンあたりのマイル数を与えられた車のシリンダー数の期待値は、意味のある量であると主張することができる。あるガロンあたりのマイル数を与えられた車の平均シリンダー数の事後分布を次のようにプロットすることができる。
\[
\textrm{E}[\textrm{cyl}|\textrm{mpg}=m] = \sum_{c \in \{4,6,8\}} c\cdot
\textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m)
\] の事後分布をモデルから導き出すことができる。 $ [|=m] $
の事後分布を導出できる。このモデルは、 \(\textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m)\)
の事後分布を与える : mpg
= \(m\) のとき、cyl
(aka
.category
) = \(c\)
の応答スケール線形予測器 ( add_epred_draws()
からの
.epred
列) は、 \(\textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m)\)
になる。したがって、私たちは .draw
内でグループ化し、summarise
を使って期待値を計算することができる。
= . %>%
label_data_function ungroup() %>%
filter(mpg == quantile(mpg, .47)) %>%
summarise_if(is.numeric, mean)
= mtcars_clean %>%
data_plot_with_mean data_grid(mpg = seq_range(mpg, n = 101)) %>%
# NOTE: this shows the use of ndraws to subsample within add_epred_draws()
# ONLY do this IF you are planning to make spaghetti plots, etc.
# NEVER subsample to a small sample to plot intervals, densities, etc.
add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl", ndraws = 100) %>%
group_by(mpg, .draw) %>%
# calculate expected cylinder value
mutate(cyl = as.numeric(as.character(cyl))) %>%
summarise(cyl = sum(cyl * `P(cyl | mpg)`), .groups = "drop") %>%
ggplot(aes(x = mpg, y = cyl)) +
geom_line(aes(group = .draw), alpha = 5/100) +
geom_point(aes(y = as.numeric(as.character(cyl)), fill = cyl), data = mtcars_clean, shape = 21, size = 2) +
geom_text(aes(x = mpg + 4), label = "E[cyl | mpg]", data = label_data_function, hjust = 0) +
geom_segment(aes(yend = cyl, xend = mpg + 3.9), data = label_data_function) +
scale_fill_brewer(palette = "Set2", name = "cyl")
plot_grid(ncol = 1, align = "v",
data_plot_with_mean,
fit_plot )
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
では、事後予測チェックをしてみよう。事後予測はデータと同じように見えるだろうか?
mpg
ここでは、元のデータセットにあったのと同じ値 (灰色の円)
で新たに予測を行い、観測されたデータ (色のついた円)
と共にプロットしてむ。
%>%
mtcars_clean # we use `select` instead of `data_grid` here because we want to make posterior predictions
# for exactly the same set of observations we have in the original data
select(mpg) %>%
add_predicted_draws(m_cyl, seed = 1234) %>%
# recover original factor labels
mutate(cyl = levels(mtcars_clean$cyl)[.prediction]) %>%
ggplot(aes(x = mpg, y = cyl)) +
geom_count(color = "gray75") +
geom_point(aes(fill = cyl), data = mtcars_clean, shape = 21, size = 2) +
scale_fill_brewer(palette = "Dark2") +
geom_label_repel(
data = . %>% ungroup() %>% filter(cyl == "8") %>% filter(mpg == max(mpg)) %>% dplyr::slice(1),
label = "posterior predictions", xlim = c(26, NA), ylim = c(NA, 2.8), point.padding = 0.3,
label.size = NA, color = "gray50", segment.color = "gray75"
+
) geom_label_repel(
data = mtcars_clean %>% filter(cyl == "6") %>% filter(mpg == max(mpg)) %>% dplyr::slice(1),
label = "observed data", xlim = c(26, NA), ylim = c(2.2, NA), point.padding = 0.2,
label.size = NA, segment.color = "gray35"
)
これはかなり良さそうである。もう1つの典型的な事後予測チェックプロットを使ってチェックしてみよう:応答の観察された分布に対して、応答の多くのシミュレーションされた分布(
cyl
)である。連続応答変数の場合、これは通常、密度プロットで行われる。ここでは、応答変数が離散なので、各ビンでの事後予測数を折れ線グラフでプロットする。
%>%
mtcars_clean select(mpg) %>%
add_predicted_draws(m_cyl, ndraws = 100, seed = 12345) %>%
# recover original factor labels
mutate(cyl = levels(mtcars_clean$cyl)[.prediction]) %>%
ggplot(aes(x = cyl)) +
stat_count(aes(group = NA), geom = "line", data = mtcars_clean, color = "red", size = 3, alpha = .5) +
stat_count(aes(group = .draw), geom = "line", position = "identity", alpha = .05) +
geom_label(data = data.frame(cyl = "4"), y = 9.5, label = "posterior\npredictions",
hjust = 1, color = "gray50", lineheight = 1, label.size = NA) +
geom_label(data = data.frame(cyl = "8"), y = 14, label = "observed\ndata",
hjust = 0, color = "red", lineheight = 1, label.size = NA)
これもまた良さそうである。
これらの事後予測は、散布図行列として見ることもできる。gather_pairs()
は、ggplot2::facet_grid()
を用いて、ggplot
でカスタム散布図行列 (あるいは、任意の行列スタイルの小倍数プロット)
の作成に適した長大なデータフレームを簡単に生成することが可能である。
set.seed(12345)
%>%
mtcars_clean select(mpg) %>%
add_predicted_draws(m_cyl) %>%
# recover original factor labels. Must ungroup first so that the
# factor is created in the same way in all groups; this is a workaround
# because brms no longer returns labelled predictions (hopefully that
# is fixed then this will no longer be necessary)
ungroup() %>%
mutate(cyl = factor(levels(mtcars_clean$cyl)[.prediction])) %>%
# need .drop = FALSE to ensure 0 counts are not dropped
group_by(.draw, .drop = FALSE) %>%
count(cyl) %>%
gather_pairs(cyl, n) %>%
ggplot(aes(.x, .y)) +
geom_count(color = "gray75") +
geom_point(data = mtcars_clean %>% count(cyl) %>% gather_pairs(cyl, n), color = "red") +
facet_grid(vars(.row), vars(.col)) +
xlab("Number of observations with cyl = col") +
ylab("Number of observations with cyl = row")
## Warning: Combining variables of class <factor> and <ordered> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `combine_vars()`)
## Warning: Combining variables of class <ordered> and <factor> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `combine_vars()`)
## Warning: Combining variables of class <ordered> and <factor> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `join_keys()`)
ここでは、カテゴリ予測変数の順序モデルである。
data(esoph)
= brm(
m_esoph_brm ~ agegp,
tobgp data = esoph,
family = cumulative(),
file = "models/tidy-brms_m_esoph_brm.rds"
)
そして、予測変数の各レベル内の各結果カテゴリについて、予測された確率をプロットできる。
%>%
esoph data_grid(agegp) %>%
add_epred_draws(m_esoph_brm, dpar = TRUE, category = "tobgp") %>%
ggplot(aes(x = agegp, y = .epred, color = tobgp)) +
stat_pointinterval(position = position_dodge(width = .4)) +
scale_size_continuous(guide = "none") +
scale_color_manual(values = brewer.pal(6, "Blues")[-c(1,2)])
上のプロットでは、カテゴリの変化がわかりにくいので、各年度の分布がよくわかるようなものを考えてみよう。
= esoph %>%
esoph_plot data_grid(agegp) %>%
add_epred_draws(m_esoph_brm, category = "tobgp") %>%
ggplot(aes(x = .epred, y = tobgp)) +
coord_cartesian(expand = FALSE) +
facet_grid(. ~ agegp, switch = "x") +
theme_classic() +
theme(strip.background = element_blank(), strip.placement = "outside") +
ggtitle("P(tobacco consumption category | age group)") +
xlab("age group")
+
esoph_plot stat_summary(fun = median, geom = "bar", fill = "gray65", width = 1, color = "white") +
stat_pointinterval()
この場合、棒グラフは誤った精度感を与える可能性があるので、代わりにCCDF棒グラフを試してみることもできる。
+
esoph_plot stat_ccdfinterval() +
expand_limits(x = 0) #ensure bars go to 0
この出力は、vignette("tidy-rstanarm")
の対応する
m_esoph_rs
モデルからの出力と非常に似ているはずである
(異なるプライヤーを適用している)。ただし、brms は rstanarm
よりも出力するための作業をより多く行っている。