Extracting and visualizing tidy draws from brms models

Matthew Kay

2022-12-15

イントロダクション

この vignette は、tidybayesggdist パッケージを使用して、 tidy データフレームを抽出し、可視化する方法について説明する。モデル変数の事後分布からのドロー、平均、および brms::brm からの予測値である。より一般的な tidybayes の紹介と、汎用のベイズモデリング言語 (Stan や JAGS など) での使用については、vignette("tidybayes") を参照。

セットアップ

このvignetteを実行するには、以下のライブラリが必要である。

library(magrittr)
library(dplyr)
library(purrr)
library(forcats)
library(tidyr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(brms)
library(ggrepel)
library(RColorBrewer)
library(gganimate)
library(posterior)

theme_set(theme_tidybayes() + panel_border())

これらのオプションは、Stanの動作を高速化するためのものである。

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

サンプルデータセット

tidybayes を実証するために、5つの条件からそれぞれ10個のオブザベーションを持つ単純なデータセットを使用することにする。

set.seed(5)
n = 10
n_condition = 5
ABC =
  tibble(
    condition = rep(c("A","B","C","D","E"), n),
    response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
  )

データのスナップショットはこのような感じである。

head(ABC, 10)
condition response
A -0.4204277
B 1.6921797
C 1.3722541
D 1.0350714
E -0.1442796
A -0.3014540
B 0.7639168
C 1.6823143
D 0.8571132
E -0.9309459

これは典型的な整頓されたフォーマットのデータフレームです: 1行に1つの観測値である。グラフィカルに

ABC %>%
  ggplot(aes(y = condition, x = response)) +
  geom_point()

#モデル {#model}

大域平均に向かって収縮する階層的モデルを当てはめよう。

m = brm(
  response ~ (1|condition), 
  data = ABC, 
  prior = c(
    prior(normal(0, 1), class = Intercept),
    prior(student_t(3, 0, 1), class = sd),
    prior(student_t(3, 0, 1), class = sigma)
  ),
  control = list(adapt_delta = .99),
  
  file = "models/tidy-brms_m.rds" # cache model (can be removed)  
)

結果はこのようになる。

m
##  Family: gaussian 
##   Links: mu = identity; sigma = identity 
## Formula: response ~ (1 | condition) 
##    Data: ABC (Number of observations: 50) 
##   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
##          total post-warmup draws = 4000
## 
## Group-Level Effects: 
## ~condition (Number of levels: 5) 
##               Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sd(Intercept)     1.17      0.47     0.61     2.31 1.00      717      870
## 
## Population-Level Effects: 
##           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## Intercept     0.49      0.46    -0.48     1.41 1.00      876     1247
## 
## Family Specific Parameters: 
##       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
## sigma     0.56      0.06     0.46     0.69 1.00     1814     2186
## 
## Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).

spread_draws を使ってフィットからドローをtidyフォーマットで抽出する

さて、結果が出たので、楽しいことが始まります:整頓されたフォーマットで描画を取り出すことである!まず、 関数を使って、生のモデル変数名のリストを取得する。まず、get_variables() 関数を使って、生のモデル変数名のリストを取得し、モデルからどの変数を抽出できるかを知ることにする。

get_variables(m)
##  [1] "b_Intercept"              "sd_condition__Intercept"  "sigma"                    "r_condition[A,Intercept]"
##  [5] "r_condition[B,Intercept]" "r_condition[C,Intercept]" "r_condition[D,Intercept]" "r_condition[E,Intercept]"
##  [9] "lprior"                   "lp__"                     "accept_stat__"            "stepsize__"              
## [13] "treedepth__"              "n_leapfrog__"             "divergent__"              "energy__"

ここで、b_Intercept は大域的な平均値、r_condition[] 変数は各条件におけるその平均値からのオフセットである。これらの変数が与えられると

各行がどちらか一方からの描画であるデータフレームが必要だろう。 r_condition [A,Intercept] , r_condition [B,Intercept] , ... [C,...] , ... [D,...] または ... [E,...] そして、その行がどのチェーン/繰り返し/描画から来たもので、どの条件( A to E )のものであるかのインデックスを持つ列があるようなデータフレームが欲しいだろう。これにより、条件ごとにグループ化された量を簡単に計算したり、ggplotを使って条件ごとにプロットを作成したり、あるいは、元のデータとドローをマージしてデータと後置を同時にプロットしたりすることができるようになる。

tidybayes の主力は spread_draws() 関数で、この関数はこの抽出を行ってくれる。この関数には、変数とそのインデックスを整頓されたデータフレームに抽出するために使用できる簡単な仕様書式が含まれている。

整頓されたフォーマットのデータフレームで、変数のインデックスを別の列に集める

このようにモデル内の変数が与えられると

r_condition [D,Intercept]

このような列の仕様で spread_draws() を提供することができる。

r_condition [condition,term]

ここで、conditionD に対応し、termIntercept に対応する。この指定に対して spread_draws() が行うことは、何も不思議なことではない。この仕様では、変数インデックスをカンマとスペースで分割する ( sep 引数を変更すれば、他の文字で分割できる) 。そして、その結果得られたインデックスに、順番に列を割り当てることができる。つまり r_condition [D,Intercept] はインデックス DIntercept を持ち、spread_draws() はこれらのインデックスを列として抽出し、r_condition からのドローの整頓されたデータフレームを生成することができる。

m %>%
  spread_draws(r_condition[condition,term]) %>%
  head(10)
## Warning: `gather_()` was deprecated in tidyr 1.2.0.
## ℹ Please use `gather()` instead.
## ℹ The deprecated feature was likely used in the tidybayes package.
##   Please report the issue at <]8;;https://github.com/mjskay/tidybayes/issues/newhttps://github.com/mjskay/tidybayes/issues/new]8;;>.
condition term r_condition .chain .iteration .draw
A Intercept -0.5316842 1 1 1
A Intercept 0.0585095 1 2 2
A Intercept 0.4812017 1 3 3
A Intercept 0.2272703 1 4 4
A Intercept 0.1997897 1 5 5
A Intercept -0.3346687 1 6 6
A Intercept -0.5019515 1 7 7
A Intercept -0.4542464 1 8 8
A Intercept -0.3041030 1 9 9
A Intercept -0.4556835 1 10 10

インデックス列には好きな名前を付けることができる。

m %>%
  spread_draws(r_condition[c,t]) %>%
  head(10)
c t r_condition .chain .iteration .draw
A Intercept -0.5316842 1 1 1
A Intercept 0.0585095 1 2 2
A Intercept 0.4812017 1 3 3
A Intercept 0.2272703 1 4 4
A Intercept 0.1997897 1 5 5
A Intercept -0.3346687 1 6 6
A Intercept -0.5019515 1 7 7
A Intercept -0.4542464 1 8 8
A Intercept -0.3041030 1 9 9
A Intercept -0.4556835 1 10 10

しかし、前の例のような、より説明的で暗号化されていない名前の方が望ましいだろう。

このモデルでは、項は1つしかない( Intercept )ので、そのインデックスを完全に省略して、各項目 condition とその条件に対する r_condition の値だけを取得することができる。

m %>%
  spread_draws(r_condition[condition,]) %>%
  head(10)
condition r_condition .chain .iteration .draw
A -0.5316842 1 1 1
A 0.0585095 1 2 2
A 0.4812017 1 3 3
A 0.2272703 1 4 4
A 0.1997897 1 5 5
A -0.3346687 1 6 6
A -0.5019515 1 7 7
A -0.4542464 1 8 8
A -0.3041030 1 9 9
A -0.4556835 1 10 10

注: spread_draws() を Stan または JAGS の生のサンプルで使用したことがある場合、spread_draws() の前に recover_types を使用してインデックス列の値を取得することに慣れているだろう (例えば、インデックスが要因であった場合) 。rstanarm モデルで spread_draws() を使用する場合、これらのモデルにはすでにその情報が変数名として含まれているため、この操作は必要ない。recover_types の詳細については、vignette("tidybayes") を参照。

ポイント要約と間隔

単純なモデル変数の場合

tidybayes は、描画から点の要約や区間を整然としたフォーマットで生成するための関数群を提供する。これらの関数は、以下の命名規則に従っている。 [median|mean|mode] _ [qi|hdi] 例えば、median_qi() , mean_qi() , mode_hdi() , などである。最初の名前 ( _ の前) は点要約のタイプを示し、2番目の名前は区間のタイプを示す。qi は分位点間隔 (別名:等値線間隔、中心間隔、パーセンタイル間隔) を生成し、hdi は最高 (事後) 密度間隔を生成する。カスタム点要約または区間関数は、point_interval() 関数を用いて適用することもできる。

例えば、オブザベーションの全体平均と標準偏差の事後分布に対応するドローを抽出することができる。

m %>%
  spread_draws(b_Intercept, sigma) %>%
  head(10)
.chain .iteration .draw b_Intercept sigma
1 1 1 0.5989321 0.5067684
1 2 2 -0.1737771 0.5259455
1 3 3 -0.0818084 0.5234344
1 4 4 -0.0438235 0.5305265
1 5 5 -0.0731284 0.5306666
1 6 6 0.6052858 0.5264126
1 7 7 0.7156502 0.5228760
1 8 8 0.5087900 0.6289509
1 9 9 0.5274868 0.5604209
1 10 10 0.6009850 0.5408009

と同様に r_condition [condition,term] と同様、整頓されたデータフレームが得られる。変数の中央値と95%分位間隔が欲しい場合は、median_qi() を適用する。

m %>%
  spread_draws(b_Intercept, sigma) %>%
  median_qi(b_Intercept, sigma)
b_Intercept b_Intercept.lower b_Intercept.upper sigma sigma.lower sigma.upper .width .point .interval
0.4992407 -0.477582 1.408877 0.5562745 0.4613838 0.6945521 0.95 median qi

上記のように、中央値と区間を取得したい列を指定することができるが、列のリストを省略すると、median_qi() は、グループ化列や特殊列ではないすべての列 ( .chain , .iteration , または .draw のような) を使用する。したがって、上記の例では、b_Interceptsigma は、モデルから収集した唯一の列でもあるため、median_qi() の冗長な引数となっている。そこで、次のように単純化することができる。

m %>%
  spread_draws(b_Intercept, sigma) %>%
  median_qi()
b_Intercept b_Intercept.lower b_Intercept.upper sigma sigma.lower sigma.upper .width .point .interval
0.4992407 -0.477582 1.408877 0.5562745 0.4613838 0.6945521 0.95 median qi

もし、長い形式のインターバルリストが欲しい場合は、代わりに gather_draws() を使ってみよう。

m %>%
  gather_draws(b_Intercept, sigma) %>%
  median_qi()
.variable .value .lower .upper .width .point .interval
b_Intercept 0.4992407 -0.4775820 1.4088766 0.95 median qi
sigma 0.5562745 0.4613838 0.6945521 0.95 median qi

gather_draws() の詳細については、vignette("tidybayes") を参照。

インデックス付きモデル変数を使用した場合

r_condition のような1つ以上のインデックスを持つモデル変数があるとき、前と同じように median_qi() (または point_interval() ファミリーの他の関数) を適用することができる。

m %>%
  spread_draws(r_condition[condition,]) %>%
  median_qi()
condition r_condition .lower .upper .width .point .interval
A -0.3050387 -1.2623359 0.7042673 0.95 median qi
B 0.4996977 -0.4445127 1.5156880 0.95 median qi
C 1.3369944 0.4167176 2.3585026 0.95 median qi
D 0.5134115 -0.4259470 1.5452761 0.95 median qi
E -1.3775040 -2.3418243 -0.3958398 0.95 median qi

median_qi() はどのようにして集計する対象を知ったのだろうか? spread_draws() によって返されたデータフレームは、あなたが渡したすべてのインデックス変数によって自動的にグループ化される。この場合、spread_draws() はその結果を condition によってグループ化することを意味する。median_qi() はこれらのグループを尊重し、すべてのグループ内のポイントの要約と間隔を計算する。そして、median_qi() には列が渡されなかったので、唯一の非特殊列 ( . -prefixed) 、非グループ列である r_condition に対して処理を行う。したがって、上記の短縮された構文は、より冗長なこの呼び出しと等価である。

m %>%
  spread_draws(r_condition[condition,]) %>%
  group_by(condition) %>%   # this line not necessary (done by spread_draws)
  median_qi(r_condition)      # b is not necessary (it is the only non-group column)
condition r_condition .lower .upper .width .point .interval
A -0.3050387 -1.2623359 0.7042673 0.95 median qi
B 0.4996977 -0.4445127 1.5156880 0.95 median qi
C 1.3369944 0.4167176 2.3585026 0.95 median qi
D 0.5134115 -0.4259470 1.5452761 0.95 median qi
E -1.3775040 -2.3418243 -0.3958398 0.95 median qi

tidybayes は、 の実装も提供している。 posterior::summarise_draws() グループ化されたデータフレーム ( tidybayes::summaries_draws.grouped_df() ) を作成することができる。 を使用すると、コンバージェンス診断が迅速に行える。

m %>%
  spread_draws(r_condition[condition,]) %>%
  summarise_draws()
condition variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
A r_condition -0.2941127 -0.3050387 0.4839759 0.4477507 -1.0717280 0.5003624 1.000754 942.7064 1336.319
B r_condition 0.5118595 0.4996977 0.4837979 0.4434061 -0.2640635 1.3221608 1.001596 910.2851 1512.549
C r_condition 1.3438245 1.3369944 0.4850489 0.4420331 0.5766851 2.1565351 1.001881 979.4256 1469.083
D r_condition 0.5245269 0.5134115 0.4859456 0.4471457 -0.2692829 1.3529230 1.002230 930.1624 1279.640
E r_condition -1.3782117 -1.3775040 0.4835818 0.4532613 -2.1710976 -0.5897791 1.001688 941.5406 1395.414

異なるインデックスを持つ変数を1つの整頓されたフォーマットのデータフレームに結合する

spread_draws() と は、異なるインデックスを持つ変数を同じデータフレームに抽出することをサポートしている。同じ名前のインデックスは自動的にマッチングされ、必要に応じて値が複製され、すべてのインデックスのすべてのレベルの組み合わせごとに1つの行が作成される。例えば、各条件における平均を計算したい場合がある (これを と呼ぶ) 。このモデルでは、その平均は切片( )と与えられた条件での効果( )である。
gather_draws() condition_mean b_Intercept r_condition b_Interceptr_condition からの描画を一つのデータフレームにまとめることができる。

m %>% 
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  head(10)
.chain .iteration .draw b_Intercept condition r_condition
1 1 1 0.5989321 A -0.5316842
1 1 1 0.5989321 B 0.3296616
1 1 1 0.5989321 C 1.0567015
1 1 1 0.5989321 D 0.4020400
1 1 1 0.5989321 E -1.2508539
1 2 2 -0.1737771 A 0.0585095
1 2 2 -0.1737771 B 0.9002136
1 2 2 -0.1737771 C 1.9663147
1 2 2 -0.1737771 D 1.4094724
1 2 2 -0.1737771 E -0.8438687

各抽選の中で、b_Interceptr_condition の各インデックスに対応するように必要に応じて繰り返される。 したがって、dplyr の mutate 関数を使用して、それらの合計 condition_mean (各条件の平均) を求めることができる。

m %>%
  spread_draws(`b_Intercept`, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  median_qi(condition_mean)
condition condition_mean .lower .upper .width .point .interval
A 0.1946425 -0.1533118 0.5569814 0.95 median qi
B 1.0009559 0.6633324 1.3356618 0.95 median qi
C 1.8338983 1.4690128 2.1777677 0.95 median qi
D 1.0168412 0.6623191 1.3605038 0.95 median qi
E -0.8935669 -1.2326801 -0.5357873 0.95 median qi

median_qi() は整頓された評価 ( vignette("tidy-evaluation ", package =" rlang") を使うので、列名だけでなく、列の式も受け取ることができる。したがって、condition_mean の計算を mutate から median_qi() に移すことで、上の例を単純化することができる。

m %>%
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  median_qi(condition_mean = b_Intercept + r_condition)
condition condition_mean .lower .upper .width .point .interval
A 0.1946425 -0.1533118 0.5569814 0.95 median qi
B 1.0009559 0.6633324 1.3356618 0.95 median qi
C 1.8338983 1.4690128 2.1777677 0.95 median qi
D 1.0168412 0.6623191 1.3605038 0.95 median qi
E -0.8935669 -1.2326801 -0.5357873 0.95 median qi

複数の確率レベルを持つ区間をプロットする

median_qi() とその姉妹関数は、 引数を設定することで、任意の数の確率区間を生成することができる。 .width =

m %>%
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  median_qi(condition_mean = b_Intercept + r_condition, .width = c(.95, .8, .5))
condition condition_mean .lower .upper .width .point .interval
A 0.1946425 -0.1533118 0.5569814 0.95 median qi
B 1.0009559 0.6633324 1.3356618 0.95 median qi
C 1.8338983 1.4690128 2.1777677 0.95 median qi
D 1.0168412 0.6623191 1.3605038 0.95 median qi
E -0.8935669 -1.2326801 -0.5357873 0.95 median qi
A 0.1946425 -0.0374505 0.4237867 0.80 median qi
B 1.0009559 0.7781791 1.2235742 0.80 median qi
C 1.8338983 1.6097208 2.0638043 0.80 median qi
D 1.0168412 0.7850910 1.2380783 0.80 median qi
E -0.8935669 -1.1149866 -0.6554495 0.80 median qi
A 0.1946425 0.0738242 0.3115411 0.50 median qi
B 1.0009559 0.8844050 1.1174733 0.50 median qi
C 1.8338983 1.7129847 1.9530778 0.50 median qi
D 1.0168412 0.8934155 1.1362549 0.50 median qi
E -0.8935669 -1.0120600 -0.7714177 0.50 median qi

結果は整然とした形式で、1グループにつき1行、不確定性区間幅( .width )である。これはプロットを容易にする。たとえば、-.widthsize の美学に割り当てると、すべての区間が表示され、太い線がより小さい区間に対応するようになる。ggdist::geom_pointinterval() geomは、複数の確率レベルを持つ点のプロットを作成するために、データ中の .width 列に基づいて、size 美学を自動的に適切に設定する。

m %>%
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  median_qi(condition_mean = b_Intercept + r_condition, .width = c(.95, .66)) %>%
  ggplot(aes(y = condition, x = condition_mean, xmin = .lower, xmax = .upper)) +
  geom_pointinterval() 
## Warning: Using the `size` aesthietic with geom_segment was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.

密度を持つ区間

区間とともに密度を見るには、ggdist::stat_eye() (区間とバイオリンプロットを組み合わせた「アイプロット」) 、または ggdist::stat_halfeye() (区間+密度プロット) を使用することができる。

m %>%
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  ggplot(aes(y = condition, x = condition_mean)) +
  stat_halfeye()

あるいは、密度の一部をカラーでアノテートしたいとする。fill の美学は、ggdist::stat_halfeye() を含む ggdist::geom_slabinterval() ファミリーのすべてのジオムおよび統計において、スラブ内で変化させることができる。例えば、ドメイン固有のROPE (region of practical equivalence) をアノテーションしたい場合、次のようなことが可能である。

m %>%
  spread_draws(b_Intercept, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  ggplot(aes(y = condition, x = condition_mean, fill = stat(abs(x) < .8))) +
  stat_halfeye() +
  geom_vline(xintercept = c(-.8, .8), linetype = "dashed") +
  scale_fill_manual(values = c("gray80", "skyblue"))
## Warning: `stat(abs(x) < 0.8)` was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(abs(x) < 0.8)` instead.

その他、分布の可視化を行う stat_slabinterval

分布を視覚化するための様々な追加統計が、ggdist::geom_slabinterval() ファミリーの統計とジオムにはある。

The slabinterval family of geoms and stats

を参照。 vignette("slabinterval ", package =" ggdist") を参照。

事後平均値と予測値

前の例のように条件付き平均を手動で計算するのではなく、brms::posterior_epred() (事後予測値の期待値から事後ドローを与える;すなわち、条件付き平均の事後分布) に類似しているが、整然としたデータ形式を使用する add_epred_draws() を使用することができる。これを modelr::data_grid() と組み合わせて、まず欲しい予測を記述したグリッドを生成し、次にそのグリッドを条件付き平均からのドローの長大なデータフレームに変換することができる。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  head(10)
condition .row .chain .iteration .draw .epred
A 1 NA NA 1 0.0672479
A 1 NA NA 2 -0.1152676
A 1 NA NA 3 0.3993933
A 1 NA NA 4 0.1834468
A 1 NA NA 5 0.1266612
A 1 NA NA 6 0.2706171
A 1 NA NA 7 0.2136986
A 1 NA NA 8 0.0545436
A 1 NA NA 9 0.2233838
A 1 NA NA 10 0.1453015

この例をプロットするために、ggplot 内で描画を点と区間にまとめる ggdist::geom_pointinterval() の代わりに ggdist::stat_pointinterval() を使用することも示す。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  ggplot(aes(x = .epred, y = condition)) +
  stat_pointinterval(.width = c(.66, .95))

分位点プロット

アルファレベルがたまたまあなたが行おうとしている決定と一致する場合、区間は良いが、事後的な形状を得ることはより良いことである (したがって、上記の目のプロット) 。一方、密度プロットから推論するのは不正確である (ある形状の面積を別の形状の割合で推定するのは難しい知覚的作業である) 。頻度形式での確率の推論はより簡単で、 quantile dotplots ( Kay et al. 2016, Fernandes et al. 2018)の動機となった。これは任意の間隔 (プロットのドット解像度まで、下の例では100) を正確に推定することも可能である。

tidybayes のジオムの slabinterval ファミリーには dotsdotsinterval ファミリーがあり、ドットプロットの適切なビンサイズを自動的に決定し、サンプルから分位を計算して分位ドットプロットを構築できる。ggdist::stat_dotsinterval() はサンプルに使用するために設計されたバリアントである。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m) %>%
  ggplot(aes(x = .epred, y = condition)) +
  stat_dotsinterval(quantiles = 100)

つまり、事後情報を1つの標準的な点または区間として考えるのではなく、 (例えば) 100のほぼ等しい可能性のある点として表現することである。

事後予測

ここで、add_epred_draws()brms::posterior_epred() に類似しており、add_predicted_draws()brms::posterior_predict() に類似しており、事後予測分布からのドローを与えている。

以下は、ggdist::stat_slab() を用いてプロットした事後予測分布の例である。

ABC %>%
  data_grid(condition) %>%
  add_predicted_draws(m) %>%
  ggplot(aes(x = .prediction, y = condition)) +
  stat_slab()

また、ggdist::stat_interval() を使って、データと一緒に予測バンドをプロットすることもできる。

ABC %>%
  data_grid(condition) %>%
  add_predicted_draws(m) %>%
  ggplot(aes(y = condition, x = .prediction)) +
  stat_interval(.width = c(.50, .80, .95, .99)) +
  geom_point(aes(x = response), data = ABC) +
  scale_color_brewer()

合わせて、データ、事後予測、平均の事後分布。

grid = ABC %>%
  data_grid(condition)

means = grid %>%
  add_epred_draws(m)

preds = grid %>%
  add_predicted_draws(m)

ABC %>%
  ggplot(aes(y = condition, x = response)) +
  stat_interval(aes(x = .prediction), data = preds) +
  stat_pointinterval(aes(x = .epred), data = means, .width = c(.66, .95), position = position_nudge(y = -0.3)) +
  geom_point() +
  scale_color_brewer()

クラシュケ流事後予測

事後予測に対する上記のアプローチは、単一の事後予測分布を与えるために、パラメータの不確実性の上に統合される。もう一つのアプローチは、John Kruschkeが彼の著書 Doing Bayesian Data Analysisでよく使っているもので、事後予測によって暗示されるいくつかの可能な予測分布を示すことによって、予測の不確かさとパラメータの不確かさの両方を同時に示すことを試みるものである。

これは、ある予測値に対する分布パラメータを求めることで、非常に簡単に行うことができる。ここでは、明示的に dpar = c("mu "," sigma") add_epred_draws() を設定することで明示的に行う。明示的にパラメータを指定するのではなく、dpar = TRUE を設定して、モデル内のすべての分布パラメータからドローを取得することもできる。これはbrmsがサポートするすべての応答分布に対して機能する。そして、sample_draws() を使って少数のドローを選択し、ggdist::stat_dist_slab() を使って、musigma の値によって暗示される各予測分布を可視化することができるのである。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m, dpar = c("mu", "sigma")) %>%
  sample_draws(30) %>%
  ggplot(aes(y = condition)) +
  stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma), 
    slab_color = "gray65", alpha = 1/10, fill = NA
  ) +
  geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)

これらのチャート (およびその便利なバリエーション) のより詳しい説明については、 Solomon Kurz’s excellent blog post on the topic を参照。

予測分布のクルーシュケ風プロットと、事後平均を示す半眼を組み合わせることもできる。

ABC %>%
  data_grid(condition) %>%
  add_epred_draws(m, dpar = c("mu", "sigma")) %>%
  ggplot(aes(x = condition)) +
  stat_dist_slab(aes(dist = "norm", arg1 = mu, arg2 = sigma), 
    slab_color = "gray65", alpha = 1/10, fill = NA, data = . %>% sample_draws(30), scale = .5
  ) +
  stat_halfeye(aes(y = .epred), side = "bottom", scale = .5) +
  geom_point(aes(y = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2, position = position_nudge(x = -.2))

フィット/予測曲線

不確実性を伴う適合曲線の描画を実証するために、mtcars のデータセットの一部に、少し素朴なモデルを適合させてみよう。

m_mpg = brm(
  mpg ~ hp * cyl, 
  data = mtcars, 
  
  file = "models/tidy-brms_m_mpg.rds"  # cache model (can be removed)
)

確率帯を使ったフィットカーブを描くことができる。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 51)) %>%
  add_epred_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  stat_lineribbon(aes(y = .epred)) +
  geom_point(data = mtcars) +
  scale_fill_brewer(palette = "Greys") +
  scale_color_brewer(palette = "Set2")
## Warning: Using the `size` aesthietic with geom_ribbon was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Using the `size` aesthietic with geom_line was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

あるいは、適当な数のフィットラインをサンプリングして (たとえば100本) 、それらをオーバープロットすることもできる。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  # NOTE: this shows the use of ndraws to subsample within add_epred_draws()
  # ONLY do this IF you are planning to make spaghetti plots, etc.
  # NEVER subsample to a small sample to plot intervals, densities, etc.
  add_epred_draws(m_mpg, ndraws = 100) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  geom_line(aes(y = .epred, group = paste(cyl, .draw)), alpha = .1) +
  geom_point(data = mtcars) +
  scale_color_brewer(palette = "Dark2")

あるいは、フィットしたラインのアニメーション hypothetical outcome plots (HOPs) を作成することもできる。訳註:エラーのため eval=FALSE

set.seed(123456)
# NOTE: using a small number of draws to keep this example
# small, but in practice you probably want 50 or 100
ndraws = 20

p = mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_epred_draws(m_mpg, ndraws = ndraws) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl))) +
  geom_line(aes(y = .epred, group = paste(cyl, .draw))) +
  geom_point(data = mtcars) +
  scale_color_brewer(palette = "Dark2") +
  transition_states(.draw, 0, 1) +
  shadow_mark(future = TRUE, color = "gray50", alpha = 1/20)

animate(p, nframes = ndraws, fps = 2.5, width = 432, height = 288, res = 96, dev = "png", type = "cairo")

あるいは、 (平均値ではなく) 事後予測値をプロットすることもできる。この例では また、alpha、重なり合ったバンドを見やすくするために使用する。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_predicted_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg, color = ordered(cyl), fill = ordered(cyl))) +
  stat_lineribbon(aes(y = .prediction), .width = c(.95, .80, .50), alpha = 1/4) +
  geom_point(data = mtcars) +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Dark2")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

これはグループごとに判断するのは難しいので、おそらく複数のプロットにファセットするのがよいだろう。幸い、ggplotを使っているので、そのような関数が組み込まれている。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_predicted_draws(m_mpg) %>%
  ggplot(aes(x = hp, y = mpg)) +
  stat_lineribbon(aes(y = .prediction), .width = c(.99, .95, .8, .5), color = brewer.pal(5, "Blues")[[5]]) +
  geom_point(data = mtcars) +
  scale_fill_brewer() +
  facet_grid(. ~ cyl, space = "free_x", scales = "free_x")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

分布回帰パラメータの抽出

brms::brm() は、場所 (たとえば、平均) 以外の応答分布のパラメータのためのサブモデルをセットアップすることもできる。たとえば、標準偏差のような分散パラメータが、予測変数の関数であることを許すことができる。

この方法は、分散が一定でない場合 (ラテン語で難読化を好む人々は、_異種分散とも呼ぶ) に有用である。例えば、2つのグループがあり、それぞれが異なる平均応答_と分散を持つことを想像してみてみよう。

set.seed(1234)
AB = tibble(
  group = rep(c("a", "b"), each = 20),
  response = rnorm(40, mean = rep(c(1, 5), each = 20), sd = rep(c(1, 3), each = 20))
)

AB %>%
  ggplot(aes(x = response, y = group)) +
  geom_point()

ここでは、response の平均_と標準偏差_を group に依存させるモデルを紹介する。

m_ab = brm(
  bf(
    response ~ group,
    sigma ~ group
  ),
  data = AB,
  
  file = "models/tidy-brms_m_ab.rds"  # cache model (can be removed)
)

平均の事後分布 response を事後予測区間とデータと並べてプロットすることができる。

grid = AB %>%
  data_grid(group)

means = grid %>%
  add_epred_draws(m_ab)

preds = grid %>%
  add_predicted_draws(m_ab)

AB %>%
  ggplot(aes(x = response, y = group)) +
  stat_halfeye(aes(x = .epred), scale = 0.6, position = position_nudge(y = 0.175), data = means) +
  stat_interval(aes(x = .prediction), data = preds) +
  geom_point(data = AB) +
  scale_color_brewer()

これは、各グループの平均値の後置 (黒の区間と密度プロット) および事後予測区間 (青) を示している。

グループ b の予測区間はグループ a よりも大きくなっているが、これはモデルが各グループに対して異なる標準偏差を当てはめたことがある。対応する分布パラメータである sigmaadd_epred_draws() の引数 dpar を使って抽出することで、どのように変化するかを見ることができる。

grid %>%
  add_epred_draws(m_ab, dpar = TRUE) %>%
  ggplot(aes(x = sigma, y = group)) +
  stat_halfeye() +
  geom_vline(xintercept = 0, linetype = "dashed")

dpar = TRUE を設定すると、すべての分布パラメータが add_epred_draws() の結果に追加列として追加される。特定のパラメータだけが欲しい場合は、それを指定する (あるいは、欲しいパラメータだけのリストを指定する) 。上記のモデルにおいて、dpar = TRUE は以下と等価である。 dpar = list("mu "," sigma") .

因子の水準を比較する

各条件の平均を比較したい場合、compare_levels() は、ある因子のレベル間のある変数の値の比較を容易にする。デフォルトでは、一対の差分をすべて計算する。

ggdist::stat_halfeye() を使って compare_levels() をデモしてみよう。また 差の平均値で再注文する。

m %>%
  spread_draws(r_condition[condition,]) %>%
  compare_levels(r_condition, by = condition) %>%
  ungroup() %>%
  mutate(condition = reorder(condition, r_condition)) %>%
  ggplot(aes(y = condition, x = r_condition)) +
  stat_halfeye() +
  geom_vline(xintercept = 0, linetype = "dashed") 

Ordinal models

brmsの順序回帰モデルおよび多項回帰モデル用の関数 posterior_epred() は、各抽選に対して複数の変数を返する:各結果カテゴリに対して1つである (潜在線形予測変数の抽選を返す rstanarm::stan_polr() モデルとは対照的) 。tidybayes の理念は、モデルによって出力されるどんなフォーマットでも整頓することである。その理念に従って、順序および多項式 brms モデルに適用すると、add_epred_draws().category という追加の列を追加し、各カテゴリーの変数を含む別の行が、すべてのドローと予測変数について出力される。

連続予測変数の順序モデル

mtcars データセットを使って、車の燃費 (マイル/ガロン) が与えられたときの車の気筒数を予測するモデルをあてはめることにする。これは少し因果関係が逆であるが (おそらくシリンダー数が走行距離を引き起こすのだろう) 、だからといってこれは立派な予測タスクではない (私はおそらく車について何か知っている人に車のMPGを伝えることができ、彼らはエンジンのシリンダー数を当てるのにそれなりにうまくやれるだろう) 。

モデルを適合させる前に、cyl 列を順序付き因子にすることによって、データセットをきれいにしよう (デフォルトでは単なる数字) 。

mtcars_clean = mtcars %>%
  mutate(cyl = ordered(cyl))

head(mtcars_clean)
mpg cyl disp hp drat wt qsec vs am gear carb
Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4
Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4
Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1
Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1
Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2
Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1

そして、順序回帰モデルを当てはめる。

m_cyl = brm(
  cyl ~ mpg, 
  data = mtcars_clean, 
  family = cumulative,
  seed = 58393,
  
  file = "models/tidy-brms_m_cyl.rds"  # cache model (can be removed)
)

add_epred_draws() は 列を含み、 は応答がそのカテゴリにある確率の事後分布からのドローを含む。例えば、これはデータセットの1行目のフィットである。 .category .epred

tibble(mpg = 21) %>%
  add_epred_draws(m_cyl) %>%
  median_qi(.epred)
mpg .row .category .epred .lower .upper .width .point .interval
21 1 4 0.3471635 0.0891023 0.7214931 0.95 median qi
21 1 6 0.6223695 0.2573263 0.8955651 0.95 median qi
21 1 8 0.0137229 0.0003088 0.1256028 0.95 median qi

注: .category 変数が元の因子レベルの名前を保持するためには、次のようにする。 は、brms バージョン 2.15.9 以降を使用している必要がある。

データセットに対して予測される確率のフィットラインをプロットすることができた。

data_plot = mtcars_clean %>%
  ggplot(aes(x = mpg, y = cyl, color = cyl)) +
  geom_point() +
  scale_color_brewer(palette = "Dark2", name = "cyl")

fit_plot = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl") %>%
  ggplot(aes(x = mpg, y = `P(cyl | mpg)`, color = cyl)) +
  stat_lineribbon(aes(fill = cyl), alpha = 1/5) +
  scale_color_brewer(palette = "Dark2") +
  scale_fill_brewer(palette = "Dark2")

plot_grid(ncol = 1, align = "v",
  data_plot,
  fit_plot
)
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

上記の表示では、ある特定の値 mpg において、cyl の異なる値に対する P(cyl|mpg) の相関を見ることはできない。 例えば、P(cyl = 6|mpg = 20) が高い事後処理の部分では、P(cyl = 4|mpg = 20)P(cyl = 8|mpg = 20) は低くなければならない (これらの合計は 1 になるはずだことがある) 。

この相関を見る一つの方法は、フィットした線だけに hypothetical outcome plots (HOPs) を使って、リボンから「切り離す」ことだろう (別の方法は、このドキュメントで以前に示したように、線のアンサンブルの上に HOP を使うことである)。アニメーションを使うことで、線がどのように並んだり反対方向に動いたりして、それらの相関のパターンを明らかにすることができる。

# NOTE: using a small number of draws to keep this example
# small, but in practice you probably want 50 or 100
ndraws = 20

p = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl") %>%
  ggplot(aes(x = mpg, y = `P(cyl | mpg)`, color = cyl)) +
  # we remove the `.draw` column from the data for stat_lineribbon so that the same ribbons
  # are drawn on every frame (since we use .draw to determine the transitions below)
  stat_lineribbon(aes(fill = cyl), alpha = 1/5, color = NA, data = . %>% select(-.draw)) +
  # we use sample_draws to subsample at the level of geom_line (rather than for the full dataset
  # as in previous HOPs examples) because we need the full set of draws for stat_lineribbon above
  geom_line(aes(group = paste(.draw, cyl)), size = 1, data = . %>% sample_draws(ndraws)) +
  scale_color_brewer(palette = "Dark2") +
  scale_fill_brewer(palette = "Dark2") +
  transition_manual(.draw)
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
animate(p, nframes = ndraws, fps = 2.5, width = 576, height = 192, res = 96, dev = "png", type = "cairo")
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

線がどのように一緒に動くか、また、どのように一緒に上下に動くか、あるいは反対に動くかに注目してみよう。上のグラフのxの位置でこれらの線をスライスして (例えば、mpg = 20 ) 、散布図行列を使ってそれらの間の相関関係を見ることができる。

tibble(mpg = 20) %>%
  add_epred_draws(m_cyl, value = "P(cyl | mpg = 20)", category = "cyl") %>%
  ungroup() %>%
  select(.draw, cyl, `P(cyl | mpg = 20)`) %>%
  gather_pairs(cyl, `P(cyl | mpg = 20)`, triangle = "both") %>%
  filter(.row != .col) %>%
  ggplot(aes(.x, .y)) +
  geom_point(alpha = 1/50) +
  facet_grid(.row ~ .col) +
  ylab("P(cyl = row | mpg = 20)") +
  xlab("P(cyl = col | mpg = 20)")

順序分布の平均について話すことは、しばしば意味を持たないが、この特定のケースでは、ガロンあたりのマイル数を与えられた車のシリンダー数の期待値は、意味のある量であると主張することができる。あるガロンあたりのマイル数を与えられた車の平均シリンダー数の事後分布を次のようにプロットすることができる。

\[ \textrm{E}[\textrm{cyl}|\textrm{mpg}=m] = \sum_{c \in \{4,6,8\}} c\cdot \textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m) \] の事後分布をモデルから導き出すことができる。 $ [|=m] $ の事後分布を導出できる。このモデルは、 \(\textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m)\) の事後分布を与える : mpg = \(m\) のとき、cyl (aka .category ) = \(c\) の応答スケール線形予測器 ( add_epred_draws() からの .epred 列) は、 \(\textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m)\) になる。したがって、私たちは .draw 内でグループ化し、summarise を使って期待値を計算することができる。

label_data_function = . %>% 
  ungroup() %>%
  filter(mpg == quantile(mpg, .47)) %>%
  summarise_if(is.numeric, mean)

data_plot_with_mean = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  # NOTE: this shows the use of ndraws to subsample within add_epred_draws()
  # ONLY do this IF you are planning to make spaghetti plots, etc.
  # NEVER subsample to a small sample to plot intervals, densities, etc.
  add_epred_draws(m_cyl, value = "P(cyl | mpg)", category = "cyl", ndraws = 100) %>%
  group_by(mpg, .draw) %>%
  # calculate expected cylinder value
  mutate(cyl = as.numeric(as.character(cyl))) %>%
  summarise(cyl = sum(cyl * `P(cyl | mpg)`), .groups = "drop") %>%
  ggplot(aes(x = mpg, y = cyl)) +
  geom_line(aes(group = .draw), alpha = 5/100) +
  geom_point(aes(y = as.numeric(as.character(cyl)), fill = cyl), data = mtcars_clean, shape = 21, size = 2) +
  geom_text(aes(x = mpg + 4), label = "E[cyl | mpg]", data = label_data_function, hjust = 0) +
  geom_segment(aes(yend = cyl, xend = mpg + 3.9), data = label_data_function) +
  scale_fill_brewer(palette = "Set2", name = "cyl")

plot_grid(ncol = 1, align = "v",
  data_plot_with_mean,
  fit_plot
)
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

では、事後予測チェックをしてみよう。事後予測はデータと同じように見えるだろうか? mpg ここでは、元のデータセットにあったのと同じ値 (灰色の円) で新たに予測を行い、観測されたデータ (色のついた円) と共にプロットしてむ。

mtcars_clean %>%
  # we use `select` instead of `data_grid` here because we want to make posterior predictions
  # for exactly the same set of observations we have in the original data
  select(mpg) %>%
  add_predicted_draws(m_cyl, seed = 1234) %>%
  # recover original factor labels
  mutate(cyl = levels(mtcars_clean$cyl)[.prediction]) %>%
  ggplot(aes(x = mpg, y = cyl)) +
  geom_count(color = "gray75") +
  geom_point(aes(fill = cyl), data = mtcars_clean, shape = 21, size = 2) +
  scale_fill_brewer(palette = "Dark2") +
  geom_label_repel(
    data = . %>% ungroup() %>% filter(cyl == "8") %>% filter(mpg == max(mpg)) %>% dplyr::slice(1),
    label = "posterior predictions", xlim = c(26, NA), ylim = c(NA, 2.8), point.padding = 0.3,
    label.size = NA, color = "gray50", segment.color = "gray75"
  ) +
  geom_label_repel(
    data = mtcars_clean %>% filter(cyl == "6") %>% filter(mpg == max(mpg)) %>% dplyr::slice(1),
    label = "observed data", xlim = c(26, NA), ylim = c(2.2, NA), point.padding = 0.2,
    label.size = NA, segment.color = "gray35"
  )

これはかなり良さそうである。もう1つの典型的な事後予測チェックプロットを使ってチェックしてみよう:応答の観察された分布に対して、応答の多くのシミュレーションされた分布( cyl )である。連続応答変数の場合、これは通常、密度プロットで行われる。ここでは、応答変数が離散なので、各ビンでの事後予測数を折れ線グラフでプロットする。

mtcars_clean %>%
  select(mpg) %>%
  add_predicted_draws(m_cyl, ndraws = 100, seed = 12345) %>%
  # recover original factor labels
  mutate(cyl = levels(mtcars_clean$cyl)[.prediction]) %>%
  ggplot(aes(x = cyl)) +
  stat_count(aes(group = NA), geom = "line", data = mtcars_clean, color = "red", size = 3, alpha = .5) +
  stat_count(aes(group = .draw), geom = "line", position = "identity", alpha = .05) +
  geom_label(data = data.frame(cyl = "4"), y = 9.5, label = "posterior\npredictions",
    hjust = 1, color = "gray50", lineheight = 1, label.size = NA) +
  geom_label(data = data.frame(cyl = "8"), y = 14, label = "observed\ndata",
    hjust = 0, color = "red", lineheight = 1, label.size = NA)

これもまた良さそうである。

これらの事後予測は、散布図行列として見ることもできる。gather_pairs() は、ggplot2::facet_grid() を用いて、ggplot でカスタム散布図行列 (あるいは、任意の行列スタイルの小倍数プロット) の作成に適した長大なデータフレームを簡単に生成することが可能である。

set.seed(12345)

mtcars_clean %>%
  select(mpg) %>%
  add_predicted_draws(m_cyl) %>%
  # recover original factor labels. Must ungroup first so that the
  # factor is created in the same way in all groups; this is a workaround
  # because brms no longer returns labelled predictions (hopefully that
  # is fixed then this will no longer be necessary)
  ungroup() %>%
  mutate(cyl = factor(levels(mtcars_clean$cyl)[.prediction])) %>%
  # need .drop = FALSE to ensure 0 counts are not dropped
  group_by(.draw, .drop = FALSE) %>%
  count(cyl) %>%
  gather_pairs(cyl, n) %>%
  ggplot(aes(.x, .y)) +
  geom_count(color = "gray75") +
  geom_point(data = mtcars_clean %>% count(cyl) %>% gather_pairs(cyl, n), color = "red") +
  facet_grid(vars(.row), vars(.col)) +
  xlab("Number of observations with cyl = col") +
  ylab("Number of observations with cyl = row") 
## Warning: Combining variables of class <factor> and <ordered> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `combine_vars()`)
## Warning: Combining variables of class <ordered> and <factor> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `combine_vars()`)
## Warning: Combining variables of class <ordered> and <factor> was deprecated in ggplot2 3.4.0.
## ℹ Please ensure your variables are compatible before plotting (location: `join_keys()`)

カテゴリカル予測変数の順序モデル

ここでは、カテゴリ予測変数の順序モデルである。

data(esoph)
m_esoph_brm = brm(
  tobgp ~ agegp, 
  data = esoph, 
  family = cumulative(),

  file = "models/tidy-brms_m_esoph_brm.rds"  
)

そして、予測変数の各レベル内の各結果カテゴリについて、予測された確率をプロットできる。

esoph %>%
  data_grid(agegp) %>%
  add_epred_draws(m_esoph_brm, dpar = TRUE, category = "tobgp") %>%
  ggplot(aes(x = agegp, y = .epred, color = tobgp)) +
  stat_pointinterval(position = position_dodge(width = .4)) +
  scale_size_continuous(guide = "none") +
  scale_color_manual(values = brewer.pal(6, "Blues")[-c(1,2)])

上のプロットでは、カテゴリの変化がわかりにくいので、各年度の分布がよくわかるようなものを考えてみよう。

esoph_plot = esoph %>%
  data_grid(agegp) %>%
  add_epred_draws(m_esoph_brm, category = "tobgp") %>%
  ggplot(aes(x = .epred, y = tobgp)) +
  coord_cartesian(expand = FALSE) +
  facet_grid(. ~ agegp, switch = "x") +
  theme_classic() +
  theme(strip.background = element_blank(), strip.placement = "outside") +
  ggtitle("P(tobacco consumption category | age group)") +
  xlab("age group")

esoph_plot +
  stat_summary(fun = median, geom = "bar", fill = "gray65", width = 1, color = "white") +
  stat_pointinterval()

この場合、棒グラフは誤った精度感を与える可能性があるので、代わりにCCDF棒グラフを試してみることもできる。

esoph_plot +
  stat_ccdfinterval() +
  expand_limits(x = 0) #ensure bars go to 0

この出力は、vignette("tidy-rstanarm") の対応する m_esoph_rs モデルからの出力と非常に似ているはずである (異なるプライヤーを適用している)。ただし、brms は rstanarm よりも出力するための作業をより多く行っている。