Using tidybayes with the posterior package

Matthew Kay

2022-12-15

イントロダクション

この vignette は、tidybayesggdist パッケージを posterior パッケージと一緒に使用する方法を説明する。 (特に posterior::rvar() データ型) を抽出し、 tidy を可視化することができる。 モデル変数の事後分布、適合度、予測値から rvar s のデータフレームを作成する。

このワークフローは、「long-data-frame-of- rvar s ” workflow, which is bit different from the” long-data-frame-of-draws 」ワークフローである。 vignette("tidybayes") または vignette("tidy-brms") に記載されている。rvar に基づくアプローチは、特に次のような場合に有効であろう。 より大きなモデルでは、よりメモリ効率が良いためである。

セットアップ

このvignetteを実行するには、以下のライブラリが必要である。

library(dplyr)
library(purrr)
library(modelr)
library(ggdist)
library(tidybayes)
library(ggplot2)
library(cowplot)
library(rstan)
library(brms)
library(ggrepel)
library(RColorBrewer)
library(posterior)
library(distributional)

theme_set(theme_tidybayes() + panel_border())

これらのオプションは、Stanの動作を高速化するためのものである。

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

サンプルデータセット

tidybayes を実証するために、5つの条件からそれぞれ10個のオブザベーションを持つ単純なデータセットを使用することにする。

set.seed(5)
n = 10
n_condition = 5
ABC =
  tibble(
    condition = rep(c("A","B","C","D","E"), n),
    response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
  )

データのスナップショットはこのような感じである。

head(ABC, 10)
## # A tibble: 10 × 2
##    condition response
##    <chr>        <dbl>
##  1 A           -0.420
##  2 B            1.69 
##  3 C            1.37 
##  4 D            1.04 
##  5 E           -0.144
##  6 A           -0.301
##  7 B            0.764
##  8 C            1.68 
##  9 D            0.857
## 10 E           -0.931

これは典型的な整頓されたフォーマットのデータフレームです: 1行に1つの観測値である。グラフィカルに

ABC %>%
  ggplot(aes(y = condition, x = response)) +
  geom_point()

#モデル {#model}

大域平均に向かって収縮する階層的モデルを当てはめよう。

m = brm(
  response ~ (1|condition), 
  data = ABC, 
  prior = c(
    prior(normal(0, 1), class = Intercept),
    prior(student_t(3, 0, 1), class = sd),
    prior(student_t(3, 0, 1), class = sigma)
  ),
  control = list(adapt_delta = .99),
  
  file = "models/tidy-brms_m.rds" # cache model (can be removed)  
)
## Compiling Stan program...
## Start sampling
## 
## SAMPLING FOR MODEL '309af4685b7e5d19f13a8c45140d2be2' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 2.2e-05 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.22 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.2227 seconds (Warm-up)
## Chain 1:                0.185013 seconds (Sampling)
## Chain 1:                0.407713 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL '309af4685b7e5d19f13a8c45140d2be2' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 5e-06 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.05 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.217713 seconds (Warm-up)
## Chain 2:                0.231222 seconds (Sampling)
## Chain 2:                0.448935 seconds (Total)
## Chain 2: 
## 
## SAMPLING FOR MODEL '309af4685b7e5d19f13a8c45140d2be2' NOW (CHAIN 3).
## Chain 3: 
## Chain 3: Gradient evaluation took 5e-06 seconds
## Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.05 seconds.
## Chain 3: Adjust your expectations accordingly!
## Chain 3: 
## Chain 3: 
## Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 3: 
## Chain 3:  Elapsed Time: 0.155416 seconds (Warm-up)
## Chain 3:                0.148867 seconds (Sampling)
## Chain 3:                0.304283 seconds (Total)
## Chain 3: 
## 
## SAMPLING FOR MODEL '309af4685b7e5d19f13a8c45140d2be2' NOW (CHAIN 4).
## Chain 4: 
## Chain 4: Gradient evaluation took 5e-06 seconds
## Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.05 seconds.
## Chain 4: Adjust your expectations accordingly!
## Chain 4: 
## Chain 4: 
## Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 4: 
## Chain 4:  Elapsed Time: 0.21751 seconds (Warm-up)
## Chain 4:                0.169343 seconds (Sampling)
## Chain 4:                0.386853 seconds (Total)
## Chain 4:

tidybayes::tidy_draws() が返すフォーマットは、posterior::draws_df() のフォーマットと互換性があるので posterior::summarise_draws() がサポートしている。このように、 を使って、簡単に posterior::summarise_draws() は、モデルからのドローを見る。

summarise_draws(tidy_draws(m))
## # A tibble: 16 × 10
##    variable                 mean   median      sd     mad       q5      q95     rhat ess_b…¹ ess_t…²
##    <chr>                   <dbl>    <dbl>   <dbl>   <dbl>    <dbl>    <dbl>    <dbl>   <dbl>   <dbl>
##  1 b_Intercept            0.489    0.499   0.463  4.31e-1  -0.288    1.23    1.00e 0  876.     1247.
##  2 sd_condition__Inter…   1.17     1.06    0.470  3.42e-1   0.648    2.01    1.00e 0  717.      870.
##  3 sigma                  0.562    0.556   0.0603 5.83e-2   0.474    0.671   1.00e 0 1814.     2186.
##  4 r_condition[A,Inter…  -0.294   -0.305   0.484  4.48e-1  -1.07     0.500   1.00e 0  943.     1336.
##  5 r_condition[B,Inter…   0.512    0.500   0.484  4.43e-1  -0.264    1.32    1.00e 0  910.     1513.
##  6 r_condition[C,Inter…   1.34     1.34    0.485  4.42e-1   0.577    2.16    1.00e 0  979.     1469.
##  7 r_condition[D,Inter…   0.525    0.513   0.486  4.47e-1  -0.269    1.35    1.00e 0  930.     1280.
##  8 r_condition[E,Inter…  -1.38    -1.38    0.484  4.53e-1  -2.17    -0.590   1.00e 0  942.     1395.
##  9 lprior                -2.73    -2.59    0.577  4.35e-1  -3.86    -2.11    1.00e 0  897.      992.
## 10 lp__                 -51.9    -51.5     2.42   2.36e+0 -56.4    -48.7     1.00e 0  835.     1559.
## 11 accept_stat__          0.986    0.994   0.0226 7.97e-3   0.950    1.00    1.03e 0  146.     1810.
## 12 stepsize__             0.0550   0.0549  0.0112 1.38e-2   0.0395   0.0705  3.38e14    4.03     NA 
## 13 treedepth__            5.27     5       0.953  1.48e+0   4        7       1.04e 0  102.       NA 
## 14 n_leapfrog__          62.7     63      39.1    4.74e+1  15      127       1.04e 0   95.3     994.
## 15 divergent__            0        0       0      0         0        0      NA         NA        NA 
## 16 energy__              55.9     55.5     3.13   3.08e+0  51.4     61.7     1.00e 0  918.     1612.
## # … with abbreviated variable names ¹​ess_bulk, ²​ess_tail

spread_rvars`を使ってtidy形式のフィットからドローを抽出する

さて、結果が出たので、楽しいことが始まります:変数を整頓されたフォーマットで取り出すことである!まず、 関数を使って、生のモデル変数名のリストを取得する。まず、get_variables() 関数を使って、生のモデル変数名のリストを取得し、モデルからどの変数を抽出できるかを知ることにする。

get_variables(m)
##  [1] "b_Intercept"              "sd_condition__Intercept"  "sigma"                   
##  [4] "r_condition[A,Intercept]" "r_condition[B,Intercept]" "r_condition[C,Intercept]"
##  [7] "r_condition[D,Intercept]" "r_condition[E,Intercept]" "lprior"                  
## [10] "lp__"                     "accept_stat__"            "stepsize__"              
## [13] "treedepth__"              "n_leapfrog__"             "divergent__"             
## [16] "energy__"

ここで、b_Intercept は大域的な平均値、r_condition[] 変数は各条件におけるその平均値からのオフセットである。これらの変数が与えられると

各行が、以下のすべてのドローを表す確率変数 ( rvar ) であるデータフレームが必要だろう。 r_condition [A,Intercept] , r_condition [B,Intercept] , ... [C,...] , ... [D,...] または ... [E,...] .これにより、条件ごとにグループ化した量を簡単に計算したり、ggplotを使って条件ごとにプロットを生成したり、あるいはドローと元データをマージしてデータと後置を同時にプロットしたりすることができるようになる。

これは、spread_rvars() 関数を用いて行うことができる。これは、変数とそのインデックスをtidy形式のデータフレームに抽出するために使用できる簡単な指定形式を含んでいる。この関数は vignette("tidy-brms") で説明した tidy-data-frames-of-draws ワークフローの spread_draws() と類似している。

整頓されたフォーマットのデータフレームで、変数のインデックスを別の列に集める

このようにモデル内の変数が与えられると

r_condition [D,Intercept]

このような列の仕様で spread_rvars() を提供することができる。

r_condition [condition,term]

ここで、conditionD に対応し、termIntercept に対応する。この指定に対して spread_rvars() が行うことは、何も不思議なことではない。この仕様では、変数インデックスをカンマとスペースで分割する ( sep 引数を変更すれば、他の文字で分割できる) 。そして、その結果得られたインデックスに、順番に列を割り当てることができる。つまり r_condition [D,Intercept] はインデックス DIntercept を持ち、spread_rvars() はこれらのインデックスを列として抽出し、r_condition からのドローの整頓されたデータフレームを生成することができる。

m %>%
  spread_rvars(r_condition[condition,term])
## # A tibble: 5 × 3
##   condition term        r_condition
##   <chr>     <chr>        <rvar[1d]>
## 1 A         Intercept  -0.29 ± 0.48
## 2 B         Intercept   0.51 ± 0.48
## 3 C         Intercept   1.34 ± 0.49
## 4 D         Intercept   0.52 ± 0.49
## 5 E         Intercept  -1.38 ± 0.48

上の r_condition 列は posterior::rvar() データ型で、確率変数からの抽選を表す配列のようなデータ型である。

m %>%
  spread_rvars(r_condition[condition,term]) %>%
  pull(r_condition)
## rvar<1000,4>[5] mean ± sd:
## [1] -0.29 ± 0.48   0.51 ± 0.48   1.34 ± 0.49   0.52 ± 0.49  -1.38 ± 0.48

この場合、この rvar ベクタの5つの要素それぞれについて、モデル内の4つのチェーンから1000ドローしていることになる。rvar データ型の詳細については vignette("rvar ", package =" posterior") .

インデックス列には好きな名前を付けることができる。

m %>%
  spread_rvars(r_condition[c,t])
## # A tibble: 5 × 3
##   c     t           r_condition
##   <chr> <chr>        <rvar[1d]>
## 1 A     Intercept  -0.29 ± 0.48
## 2 B     Intercept   0.51 ± 0.48
## 3 C     Intercept   1.34 ± 0.49
## 4 D     Intercept   0.52 ± 0.49
## 5 E     Intercept  -1.38 ± 0.48

しかし、前の例のような、より説明的で暗号化されていない名前の方が望ましいだろう。

インデックスの名前を省略すると、そのインデックスは列の中に「入れ子」の状態で残る。例えば は、どうせ "Intercept" という一つの値しか持たないので、term をネストすることができる。

m %>%
  spread_rvars(r_condition[condition,])
## # A tibble: 5 × 2
##   condition r_condition[,1]
##   <chr>          <rvar[,1]>
## 1 A            -0.29 ± 0.48
## 2 B             0.51 ± 0.48
## 3 C             1.34 ± 0.49
## 4 D             0.52 ± 0.49
## 5 E            -1.38 ± 0.48

あるいは、condition をネストすることもできるが、これはおそらく現実的にはそれほど有用ではないだろう。

m %>%
  spread_rvars(r_condition[,term])
## # A tibble: 1 × 2
##   term      r_condition[,1]         [,2]        [,3]         [,4]         [,5]
##   <chr>          <rvar[,1]>   <rvar[,1]>  <rvar[,1]>   <rvar[,1]>   <rvar[,1]>
## 1 Intercept    -0.29 ± 0.48  0.51 ± 0.48  1.3 ± 0.49  0.52 ± 0.49  -1.4 ± 0.48

ポイント要約と間隔

単純なモデル変数の場合

tidybayes は、描画から点の要約や区間を整然としたフォーマットで生成するための関数群を提供する。これらの関数は、以下の命名規則に従っている。 [median|mean|mode] _ [qi|hdi] 例えば、median_qi() , mean_qi() , mode_hdi() , などである。最初の名前 ( _ の前) は点要約のタイプを示し、2番目の名前は区間のタイプを示す。qi は分位点間隔 (別名:等値線間隔、中心間隔、パーセンタイル間隔) を生成し、hdi は最高 (事後) 密度間隔を生成する。カスタム点要約または区間関数は、point_interval() 関数を用いて適用することもできる。

例えば、オブザベーションの全体平均と標準偏差の事後分布に対応するドローを抽出することができる。

m %>%
  spread_rvars(b_Intercept, sigma)
## # A tibble: 1 × 2
##    b_Intercept        sigma
##     <rvar[1d]>   <rvar[1d]>
## 1  0.49 ± 0.46  0.56 ± 0.06

と同様に r_condition [condition,term] と同様、整頓されたデータフレームが得られる。もし、変数の中央値と95%分位間隔が欲しい場合は、median_qi() を適用することができる。

m %>%
  spread_rvars(b_Intercept, sigma) %>%
  median_qi(b_Intercept, sigma)
## # A tibble: 1 × 9
##   b_Intercept b_Intercept.lower b_Intercept.upper sigma sigma.lower sigma.up…¹ .width .point .inte…²
##         <dbl>             <dbl>             <dbl> <dbl>       <dbl>      <dbl>  <dbl> <chr>  <chr>  
## 1       0.499            -0.478              1.41 0.556       0.461      0.695   0.95 median qi     
## # … with abbreviated variable names ¹​sigma.upper, ²​.interval

上記のように中央値と区間を取得したい列を指定することができるが、列のリストを省略すると、median_qi() はグループ化列でないすべての列を使用する。したがって、上記の例では、b_Interceptsigma は、モデルから集めた唯一の列でもあるため、median_qi() の冗長な引数となる。そこで、次のように単純化することができる。

m %>%
  spread_rvars(b_Intercept, sigma) %>%
  median_qi()
## # A tibble: 1 × 9
##   b_Intercept b_Intercept.lower b_Intercept.upper sigma sigma.lower sigma.up…¹ .width .point .inte…²
##         <dbl>             <dbl>             <dbl> <dbl>       <dbl>      <dbl>  <dbl> <chr>  <chr>  
## 1       0.499            -0.478              1.41 0.556       0.461      0.695   0.95 median qi     
## # … with abbreviated variable names ¹​sigma.upper, ²​.interval

長文形式のリストが必要な場合は、代わりに gather_rvars() を使用する。

m %>%
  gather_rvars(b_Intercept, sigma)
## # A tibble: 2 × 2
##   .variable         .value
##   <chr>         <rvar[1d]>
## 1 b_Intercept  0.49 ± 0.46
## 2 sigma        0.56 ± 0.06

ここでは、median_qi() も使用でいた。

m %>%
  gather_rvars(b_Intercept, sigma) %>%
  median_qi(.value)
## # A tibble: 2 × 7
##   .variable   .value .lower .upper .width .point .interval
##   <chr>        <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 b_Intercept  0.499 -0.478  1.41    0.95 median qi       
## 2 sigma        0.556  0.461  0.695   0.95 median qi

インデックス付きモデル変数を使用した場合

r_condition のように、1つ以上のインデックスを持つモデル変数がある場合、前と同じように median_qi() (または point_interval() ファミリーの他の関数) を適用することができる。

m %>%
  spread_rvars(r_condition[condition,]) %>%
  median_qi(r_condition)
## # A tibble: 5 × 7
##   condition r_condition .lower .upper .width .point .interval
##   <chr>           <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 A              -0.305 -1.26   0.704   0.95 median qi       
## 2 B               0.500 -0.445  1.52    0.95 median qi       
## 3 C               1.34   0.417  2.36    0.95 median qi       
## 4 D               0.513 -0.426  1.55    0.95 median qi       
## 5 E              -1.38  -2.34  -0.396   0.95 median qi

Note for existing users of spread_draws() : you may notice that spread_rvars() requires us to be a bit more median_qi() これは、spread_rvars() がグループ化された列を返さないためである。 データフレームは、spread_draws() とは異なり、spread_rvars() からの出力では、すべての行が常にそれ自身のグループとなるためである。 グループ化前のデータフレームを返すのは冗長である。

また、rvar の列で posterior::summarise_draws() を使用すると、要約を生成することができる。 をコンバージェンス診断付きで提供する。その関数はデータフレームを返し、それを を dplyr::mutate() 関数に直接入力する。

m %>%
  spread_rvars(r_condition[condition,]) %>%
  mutate(summarise_draws(r_condition))
## # A tibble: 5 × 12
##   condition r_condition[,1] variable     mean median    sd   mad     q5    q95  rhat ess_b…¹ ess_t…²
##   <chr>          <rvar[,1]> <chr>       <dbl>  <dbl> <dbl> <dbl>  <dbl>  <dbl> <dbl>   <dbl>   <dbl>
## 1 A            -0.29 ± 0.48 r_conditi… -0.294 -0.305 0.484 0.448 -1.07   0.500  1.00    943.   1336.
## 2 B             0.51 ± 0.48 r_conditi…  0.512  0.500 0.484 0.443 -0.264  1.32   1.00    910.   1513.
## 3 C             1.34 ± 0.49 r_conditi…  1.34   1.34  0.485 0.442  0.577  2.16   1.00    979.   1469.
## 4 D             0.52 ± 0.49 r_conditi…  0.525  0.513 0.486 0.447 -0.269  1.35   1.00    930.   1280.
## 5 E            -1.38 ± 0.48 r_conditi… -1.38  -1.38  0.484 0.453 -2.17  -0.590  1.00    942.   1395.
## # … with abbreviated variable names ¹​ess_bulk, ²​ess_tail

異なるインデックスを持つ変数を1つの整頓されたフォーマットのデータフレームに結合する

spread_rvars() と は、異なるインデックスを持つ変数を同じデータフレームに抽出することをサポートしている。同じ名前のインデックスは自動的にマッチングされ、必要に応じて値が複製され、すべてのインデックスのレベルの組み合わせごとに1つの行が生成される。例えば、各条件における平均を計算したい場合がある (これを と呼ぶ) 。このモデルでは、その平均は切片( )と与えられた条件での効果( )である。
gather_rvars() condition_mean b_Intercept r_condition b_Interceptr_condition からの描画を一つのデータフレームにまとめることができる。

m %>% 
  spread_rvars(b_Intercept, r_condition[condition,])
## # A tibble: 5 × 3
##    b_Intercept condition r_condition[,1]
##     <rvar[1d]> <chr>          <rvar[,1]>
## 1  0.49 ± 0.46 A            -0.29 ± 0.48
## 2  0.49 ± 0.46 B             0.51 ± 0.48
## 3  0.49 ± 0.46 C             1.34 ± 0.49
## 4  0.49 ± 0.46 D             0.52 ± 0.49
## 5  0.49 ± 0.46 E            -1.38 ± 0.48

各抽選の中で、b_Interceptr_condition の各インデックスに対応するように必要に応じて繰り返される。 したがって、dplyr の mutate 関数を使用して、それらの合計 condition_mean (各条件の平均) を求めることができる。

m %>%
  spread_rvars(`b_Intercept`, r_condition[condition,Intercept]) %>%
  mutate(condition_mean = b_Intercept + r_condition)
## # A tibble: 5 × 5
##    b_Intercept condition Intercept   r_condition condition_mean
##     <rvar[1d]> <chr>     <chr>        <rvar[1d]>     <rvar[1d]>
## 1  0.49 ± 0.46 A         Intercept  -0.29 ± 0.48    0.19 ± 0.18
## 2  0.49 ± 0.46 B         Intercept   0.51 ± 0.48    1.00 ± 0.17
## 3  0.49 ± 0.46 C         Intercept   1.34 ± 0.49    1.83 ± 0.18
## 4  0.49 ± 0.46 D         Intercept   0.52 ± 0.49    1.01 ± 0.18
## 5  0.49 ± 0.46 E         Intercept  -1.38 ± 0.48   -0.89 ± 0.18

点の要約と間隔をプロットする

点要約と区間のプロットは、ggdist::stat_dist_pointinterval() を使えば簡単である。デフォルトでは、66%と95%の区間で可視化される (これは .width パラメータで変更可能、デフォルトは .width = c(.66, .95) ) 。

m %>%
  spread_rvars(b_Intercept, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  ggplot(aes(y = condition, dist = condition_mean)) +
  stat_dist_pointinterval()
## Warning: Using the `size` aesthietic with geom_segment was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.

median_qi() とその姉妹関数は、 引数を設定することで、任意の数の確率区間を生成することもできる。 .width =

m %>%
  spread_rvars(b_Intercept, r_condition[condition,]) %>%
  median_qi(condition_mean = b_Intercept + r_condition, .width = c(.95, .8, .5))
## # A tibble: 15 × 9
##     b_Intercept condition r_condition[,1] condition_mean  .lower .upper .width .point .interval
##      <rvar[1d]> <chr>          <rvar[,1]>          <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
##  1  0.49 ± 0.46 A            -0.29 ± 0.48          0.195 -0.153   0.557   0.95 median qi       
##  2  0.49 ± 0.46 B             0.51 ± 0.48          1.00   0.663   1.34    0.95 median qi       
##  3  0.49 ± 0.46 C             1.34 ± 0.49          1.83   1.47    2.18    0.95 median qi       
##  4  0.49 ± 0.46 D             0.52 ± 0.49          1.02   0.662   1.36    0.95 median qi       
##  5  0.49 ± 0.46 E            -1.38 ± 0.48         -0.894 -1.23   -0.536   0.95 median qi       
##  6  0.49 ± 0.46 A            -0.29 ± 0.48          0.195 -0.0375  0.424   0.8  median qi       
##  7  0.49 ± 0.46 B             0.51 ± 0.48          1.00   0.778   1.22    0.8  median qi       
##  8  0.49 ± 0.46 C             1.34 ± 0.49          1.83   1.61    2.06    0.8  median qi       
##  9  0.49 ± 0.46 D             0.52 ± 0.49          1.02   0.785   1.24    0.8  median qi       
## 10  0.49 ± 0.46 E            -1.38 ± 0.48         -0.894 -1.11   -0.655   0.8  median qi       
## 11  0.49 ± 0.46 A            -0.29 ± 0.48          0.195  0.0738  0.312   0.5  median qi       
## 12  0.49 ± 0.46 B             0.51 ± 0.48          1.00   0.884   1.12    0.5  median qi       
## 13  0.49 ± 0.46 C             1.34 ± 0.49          1.83   1.71    1.95    0.5  median qi       
## 14  0.49 ± 0.46 D             0.52 ± 0.49          1.02   0.893   1.14    0.5  median qi       
## 15  0.49 ± 0.46 E            -1.38 ± 0.48         -0.894 -1.01   -0.771   0.5  median qi

結果は整然とした形式です:グループごとに1行、不確定性区間幅( .width )である。これはプロットを容易にし、本質的に ggdist::stat_dist_pointinterval() が上記のフード下であなたのためにやっていることである。例えば、-.widthsize の美学に割り当てると、すべての区間を表示し、太い線がより小さい区間に対応するようになる。

密度を持つ区間

区間とともに密度を見るには、ggdist::stat_dist_eye() (区間とバイオリンプロットを組み合わせた「アイプロット」) 、または ggdist::stat_dist_halfeye() (区間+密度プロット) を使用することができる。

m %>%
  spread_rvars(b_Intercept, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  ggplot(aes(y = condition, dist = condition_mean)) +
  stat_dist_halfeye()

あるいは、密度の一部を色でアノテートしたいとする。fill の美学は、ggdist::stat_dist_halfeye() を含む ggdist::geom_slabinterval() ファミリーのすべてのジオムおよび統計において、スラブ内で変化させることができる。例えば、ドメイン固有のROPE (region of practical equivalence) をアノテートしたい場合、次のようなことが可能である。

m %>%
  spread_rvars(b_Intercept, r_condition[condition,]) %>%
  mutate(condition_mean = b_Intercept + r_condition) %>%
  ggplot(aes(y = condition, dist = condition_mean, fill = stat(abs(x) < .8))) +
  stat_dist_halfeye() +
  geom_vline(xintercept = c(-.8, .8), linetype = "dashed") +
  scale_fill_manual(values = c("gray80", "skyblue"))
## Warning: `stat(abs(x) < 0.8)` was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(abs(x) < 0.8)` instead.

その他、分布の可視化を行うstat_dist_slabinterval` を利用する

分布を視覚化するための様々な追加統計が、ggdist::stat_dist_slabinterval() ファミリーの統計とジオムにはある。

The slabinterval family of geoms and stats

を参照。 vignette("slabinterval ", package =" ggdist") を参照。stat_dist_... で始まるすべてのジオムは rvar の列を dist の美観で使用することをサポートする。

Posterior means

前の例のように条件付き平均を手動で計算するのではなく、brms::posterior_epred() に類似した、応答の平均の事後分布 (すなわち、事後予測値の期待値の分布) から事後ドローを与える add_epred_draws() を使用することができる。これを modelr::data_grid() と組み合わせて、まず欲しい適合を記述するグリッドを生成し、そのグリッドに事後分布からのドローを表す rvar を入力することができる。

ABC %>%
  data_grid(condition) %>%
  add_epred_rvars(m)
## # A tibble: 5 × 2
##   condition        .epred
##   <chr>        <rvar[1d]>
## 1 A           0.19 ± 0.18
## 2 B           1.00 ± 0.17
## 3 C           1.83 ± 0.18
## 4 D           1.01 ± 0.18
## 5 E          -0.89 ± 0.18

分位点プロット

アルファレベルがたまたまあなたが行おうとしている決定と一致する場合、区間は良いが、事後的な形状を得ることはより良いことである (したがって、上記の目のプロット) 。一方、密度プロットからの推論は不正確である (ある形状の面積を他の形状の割合として推定することは難しい知覚的作業である) 。頻度形式での確率の推論はより簡単で、 quantile dotplots ( Kay et al. 2016, Fernandes et al. 2018)の動機となった。これは任意の間隔 (プロットのドット解像度まで、下の例では100) を正確に推定することも可能である。

tidybayes のジオムの slabinterval ファミリーには dotsdotsinterval ファミリーがあり、ドットプロットの適切なビンサイズを自動的に決定し、サンプルから分位を計算して分位ドットプロットを構築できる。ggdist::stat_dist_dots()ggdist::stat_dist_dotsinterval()rvar s で使用するために設計されたバリアントである。

ABC %>%
  data_grid(condition) %>%
  add_epred_rvars(m) %>%
  ggplot(aes(dist = .epred, y = condition)) +
  stat_dist_dotsinterval(quantiles = 100)

つまり、事後情報を1つの標準的な点または区間として考えるのではなく、 (例えば) 100のほぼ等しい可能性のある点として表現することである。

事後予測

ここで、add_epred_rvars()brms::posterior_epred() に類似しており、add_predicted_rvars()brms::posterior_predict() に類似しており、事後予測分布からのドローを与えている。

ggdist::stat_dist_interval() を使って、データと一緒に予測バンドをプロットすることができる。

ABC %>%
  data_grid(condition) %>%
  add_predicted_rvars(m) %>%
  ggplot(aes(y = condition)) +
  stat_dist_interval(aes(dist = .prediction), .width = c(.50, .80, .95, .99)) +
  geom_point(aes(x = response), data = ABC) +
  scale_color_brewer()

add_XXX_rvars() 関数を連結して事後予測値を追加することができる。 ( predicted_rvars ) に、事後予測的な平均値の分布 ( epred_rvars ) を示した。
を同じデータフレームで表示する。これにより、データと共に両者を一緒にプロットすることが容易になる。

ABC %>%
  data_grid(condition) %>%
  add_epred_rvars(m) %>%
  add_predicted_rvars(m) %>%
  ggplot(aes(y = condition)) +
  stat_dist_interval(aes(dist = .prediction)) +
  stat_dist_pointinterval(aes(dist = .epred), position = position_nudge(y = -0.3)) +
  geom_point(aes(x = response), data = ABC) +
  scale_color_brewer()

クラシュケ流事後予測

事後予測に対する上記のアプローチは、単一の事後予測分布を与えるために、パラメータの不確実性の上に統合される。もう一つのアプローチは、John Kruschkeが彼の著書 Doing Bayesian Data Analysisでよく使っているもので、事後予測によって暗示されるいくつかの可能な予測分布を示すことによって、予測の不確かさとパラメータの不確かさの両方を同時に示すことを試みるものである。

これは、ある予測値に対する分布パラメータを求めることで、非常に簡単に行うことができる。ここでは、明示的に dpar = c("mu "," sigma") add_epred_draws() を設定することで明示的に行う。パラメータを明示的に指定するのではなく、dpar = TRUE を設定して、モデル内のすべての分布パラメータからドローを得ることもできる。これは、brms::brm() でサポートされるどの応答分布に対しても機能する。

ABC %>%
  data_grid(condition) %>%
  add_epred_rvars(m, dpar = c("mu", "sigma"))
## # A tibble: 5 × 4
##   condition        .epred            mu        sigma
##   <chr>        <rvar[1d]>    <rvar[1d]>   <rvar[1d]>
## 1 A           0.19 ± 0.18   0.19 ± 0.18  0.56 ± 0.06
## 2 B           1.00 ± 0.17   1.00 ± 0.17  0.56 ± 0.06
## 3 C           1.83 ± 0.18   1.83 ± 0.18  0.56 ± 0.06
## 4 D           1.01 ± 0.18   1.01 ± 0.18  0.56 ± 0.06
## 5 E          -0.89 ± 0.18  -0.89 ± 0.18  0.56 ± 0.06

このとき、より標準的な “long-data-frame-of-draws”形式を使用する必要がある。 tidybayesのワークフローである。の結合分布から少数のドローを選択することである。 mu と から予測密度をプロットする。私たちは、 を使って、 のすべての s sigma unnest_rvars() rvar を長大なデータフレームに出力し、sample_draws() を使って30個のドローをサンプリングし、次に
ggdist::stat_dist_slab() の値によって暗示される各予測分布を可視化するために、 と 。 mu sigma

ABC %>%
  data_grid(condition) %>%
  add_epred_rvars(m, dpar = c("mu", "sigma")) %>%
  unnest_rvars() %>%
  sample_draws(30) %>%
  ggplot(aes(y = condition)) +
  stat_dist_slab(
    aes(dist = dist_normal(mu, sigma)), 
    color = "gray65", alpha = 1/10, fill = NA
  ) +
  geom_point(aes(x = response), data = ABC, shape = 21, fill = "#9ECAE1", size = 2)

これらのチャート (およびその便利なバリエーション) のより詳しい説明については、 Solomon Kurz’s excellent blog post on the topic を参照。

add_epred_rvars() の後に unnest_rvars() を使用することは、本質的に同等である。 予測関数の _rvars() の代わりに _draws() を使用しただけである。
(例: add_epred_draws() ) の方が、より高速かつ便利な場合がある。 他にどのようなデータ操作が必要なのかによる。

フィット/予測曲線

不確実性を伴う適合曲線の描画を実証するために、mtcars のデータセットの一部に、少し素朴なモデルを適合させてみよう。

m_mpg = brm(
  mpg ~ hp * cyl, 
  data = mtcars, 
  
  file = "models/tidy-brms_m_mpg.rds"  # cache model (can be removed)
)

フィットカーブ (=条件付の不確実性を示す曲線) を描くことができる。 期待値、事後予測値の期待値) を確率帯で表示する。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 51)) %>%
  add_epred_rvars(m_mpg) %>%
  ggplot(aes(x = hp, color = ordered(cyl))) +
  stat_dist_lineribbon(aes(dist = .epred)) +
  geom_point(aes(y = mpg), data = mtcars) +
  scale_fill_brewer(palette = "Greys") +
  scale_color_brewer(palette = "Set2")
## Warning: Using the `size` aesthietic with geom_ribbon was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Warning: Using the `size` aesthietic with geom_line was deprecated in ggplot2 3.4.0.
## ℹ Please use the `linewidth` aesthetic instead.
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

あるいは、 (平均値ではなく) 事後予測値をプロットすることもできる。この例では また、重なり合ったバンドを見やすくするために、alpha を使用する。

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  add_predicted_rvars(m_mpg) %>%
  ggplot(aes(x = hp, color = ordered(cyl), fill = ordered(cyl))) +
  stat_dist_lineribbon(aes(dist = .prediction), .width = c(.95, .80, .50), alpha = 1/4) +
  geom_point(aes(y = mpg), data = mtcars) +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Dark2")
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

その他のフィットラインの例については、vignette("tidy-brms") を参照。
animated hypothetical outcome plots (HOPs)

分布回帰パラメータの抽出

brms::brm() は、場所 (たとえば、平均) 以外の応答分布のパラメータのためのサブモデルをセットアップすることもできる。たとえば、標準偏差のような分散パラメータが、予測変数の関数であることを許すことができる。

この方法は、分散が一定でない場合 (ラテン語で難読化を好む人々は_異種分散性とも呼ぶ) に有用である。例えば、2つのグループがあり、それぞれが異なる平均応答_と分散を持つことを想像してみてみよう。

set.seed(1234)
AB = tibble(
  group = rep(c("a", "b"), each = 20),
  response = rnorm(40, mean = rep(c(1, 5), each = 20), sd = rep(c(1, 3), each = 20))
)

AB %>%
  ggplot(aes(x = response, y = group)) +
  geom_point()

ここでは、response の平均_と標準偏差_を group に依存させるモデルを紹介する。

m_ab = brm(
  bf(
    response ~ group,
    sigma ~ group
  ),
  data = AB,
  
  file = "models/tidy-brms_m_ab.rds"  # cache model (can be removed)
)
## Compiling Stan program...
## Start sampling
## 
## SAMPLING FOR MODEL '3922432d3973a130f23968e8af99dd9e' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 2.8e-05 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.28 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.041109 seconds (Warm-up)
## Chain 1:                0.036047 seconds (Sampling)
## Chain 1:                0.077156 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL '3922432d3973a130f23968e8af99dd9e' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 9e-06 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.09 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.039677 seconds (Warm-up)
## Chain 2:                0.037123 seconds (Sampling)
## Chain 2:                0.0768 seconds (Total)
## Chain 2: 
## 
## SAMPLING FOR MODEL '3922432d3973a130f23968e8af99dd9e' NOW (CHAIN 3).
## Chain 3: 
## Chain 3: Gradient evaluation took 7e-06 seconds
## Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.07 seconds.
## Chain 3: Adjust your expectations accordingly!
## Chain 3: 
## Chain 3: 
## Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 3: 
## Chain 3:  Elapsed Time: 0.04064 seconds (Warm-up)
## Chain 3:                0.031559 seconds (Sampling)
## Chain 3:                0.072199 seconds (Total)
## Chain 3: 
## 
## SAMPLING FOR MODEL '3922432d3973a130f23968e8af99dd9e' NOW (CHAIN 4).
## Chain 4: 
## Chain 4: Gradient evaluation took 7e-06 seconds
## Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.07 seconds.
## Chain 4: Adjust your expectations accordingly!
## Chain 4: 
## Chain 4: 
## Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 4: 
## Chain 4:  Elapsed Time: 0.041404 seconds (Warm-up)
## Chain 4:                0.04474 seconds (Sampling)
## Chain 4:                0.086144 seconds (Total)
## Chain 4:

平均の事後分布 response を事後予測区間とデータと並べてプロットすることができる。

AB %>%
  data_grid(group) %>%
  add_epred_rvars(m_ab) %>%
  add_predicted_rvars(m_ab) %>%
  ggplot(aes(y = group)) +
  stat_dist_halfeye(aes(dist = .epred), scale = 0.6, position = position_nudge(y = 0.175)) +
  stat_dist_interval(aes(dist = .prediction)) +
  geom_point(aes(x = response), data = AB) +
  scale_color_brewer()

これは、各グループの平均値の後置 (黒の区間と密度プロット) および事後予測区間 (青) を示している。

グループ b の予測区間はグループ a よりも大きくなっているが、これはモデルが各グループに対して異なる標準偏差を当てはめたことがある。対応する分布パラメータである sigmaadd_epred_rvars() の引数 dpar を使って抽出することで、どのように変化するかを見ることができる。

AB %>%
  data_grid(group) %>%
  add_epred_rvars(m_ab, dpar = TRUE) %>%
  ggplot(aes(dist = sigma, y = group)) +
  stat_dist_halfeye() +
  geom_vline(xintercept = 0, linetype = "dashed")

dpar = TRUE を設定すると、すべての分布パラメータが add_epred_rvars() の結果に追加列として追加される。特定のパラメータだけが欲しい場合は、それを指定する (あるいは、欲しいパラメータだけのリストを指定する) 。上記のモデルにおいて、dpar = TRUE は以下と等価である。 dpar = list("mu "," sigma") .

因子の水準を比較する

各条件の平均を比較したい場合、compare_levels() は、ある因子のレベル間のある変数の値の比較を容易にする。デフォルトでは、一対の差分をすべて計算する。

ggdist::stat_dist_halfeye() を使って compare_levels() をデモしてみよう。また 差の平均値で再注文する。

m %>%
  spread_rvars(r_condition[condition,]) %>%
  compare_levels(r_condition, by = condition) %>%
  ungroup() %>%
  mutate(condition = reorder(condition, r_condition)) %>%
  ggplot(aes(y = condition, dist = r_condition)) +
  stat_dist_halfeye() +
  geom_vline(xintercept = 0, linetype = "dashed") 

Ordinal models

brmsの順序・多項回帰モデル用の関数 brms::posterior_epred() は、各抽選に対して多次元変数を返し、その結果の追加次元には結果カテゴリが含まれる。tidybayes の理念は、モデルによって出力されるどんなフォーマットでも整頓することである。その理念に沿って、順序および多項式 brms モデルに適用すると、add_epred_draws() は、応答変数の各レベルに対して追加の列を持つネストされた .epred 変数を出力する。

連続予測変数の順序モデル

mtcars データセットを使って、車の燃費 (マイル/ガロン) が与えられたときの車の気筒数を予測するモデルをあてはめることにする。これは少し因果関係が逆であるが (おそらくシリンダー数が走行距離を引き起こすのだろう) 、だからといってこれは立派な予測タスクではない (私はおそらく車について何か知っている人に車のMPGを伝えることができ、彼らはエンジンのシリンダー数を当てるのにそれなりにうまくやれるだろう) 。

モデルを適合させる前に、cyl 列を順序付き因子にすることによって、データセットをきれいにしよう (デフォルトでは単なる数字) 。

mtcars_clean = mtcars %>%
  mutate(cyl = ordered(cyl))

head(mtcars_clean)
##                    mpg cyl disp  hp drat    wt  qsec vs am gear carb
## Mazda RX4         21.0   6  160 110 3.90 2.620 16.46  0  1    4    4
## Mazda RX4 Wag     21.0   6  160 110 3.90 2.875 17.02  0  1    4    4
## Datsun 710        22.8   4  108  93 3.85 2.320 18.61  1  1    4    1
## Hornet 4 Drive    21.4   6  258 110 3.08 3.215 19.44  1  0    3    1
## Hornet Sportabout 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2
## Valiant           18.1   6  225 105 2.76 3.460 20.22  1  0    3    1

そして、順序回帰モデルを当てはめる。

m_cyl = brm(
  cyl ~ mpg, 
  data = mtcars_clean, 
  family = cumulative,
  seed = 58393,
  
  file = "models/tidy-brms_m_cyl.rds"  # cache model (can be removed)
)
## Compiling Stan program...
## Start sampling
## 
## SAMPLING FOR MODEL 'e8d53bbd1abddecfe3c098a6dc689c51' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 4.4e-05 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.44 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.087775 seconds (Warm-up)
## Chain 1:                0.07582 seconds (Sampling)
## Chain 1:                0.163595 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL 'e8d53bbd1abddecfe3c098a6dc689c51' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 1.3e-05 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.13 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.09316 seconds (Warm-up)
## Chain 2:                0.091307 seconds (Sampling)
## Chain 2:                0.184467 seconds (Total)
## Chain 2: 
## 
## SAMPLING FOR MODEL 'e8d53bbd1abddecfe3c098a6dc689c51' NOW (CHAIN 3).
## Chain 3: 
## Chain 3: Gradient evaluation took 1.1e-05 seconds
## Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.11 seconds.
## Chain 3: Adjust your expectations accordingly!
## Chain 3: 
## Chain 3: 
## Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 3: 
## Chain 3:  Elapsed Time: 0.093323 seconds (Warm-up)
## Chain 3:                0.093912 seconds (Sampling)
## Chain 3:                0.187235 seconds (Total)
## Chain 3: 
## 
## SAMPLING FOR MODEL 'e8d53bbd1abddecfe3c098a6dc689c51' NOW (CHAIN 4).
## Chain 4: 
## Chain 4: Gradient evaluation took 1.2e-05 seconds
## Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.12 seconds.
## Chain 4: Adjust your expectations accordingly!
## Chain 4: 
## Chain 4: 
## Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 4: 
## Chain 4:  Elapsed Time: 0.09055 seconds (Warm-up)
## Chain 4:                0.096241 seconds (Sampling)
## Chain 4:                0.186791 seconds (Total)
## Chain 4:

add_epred_rvars() は 列のベクタの代わりに行列を返すようになる。ここで、 のネストされた列は、応答がそのカテゴリにある確率を表する。たとえば、ここにデータセット中の の2つの値に対するフィットがある。 .epred .epred mpg

tibble(mpg = c(21,22)) %>%
  add_epred_rvars(m_cyl)
## # A tibble: 2 × 2
##     mpg .epred[,"4"]       [,"6"]          [,"8"]
##   <dbl>   <rvar[,1]>   <rvar[,1]>      <rvar[,1]>
## 1    21  0.37 ± 0.17  0.61 ± 0.17  0.0262 ± 0.035
## 2    22  0.72 ± 0.16  0.27 ± 0.16  0.0078 ± 0.015

この形式は場合によっては便利なのであるが、当面の目的には 各カテゴリーの予測は、別々の行になる。add_epred_rvars()columns_to パラメータを利用すればよい。 を使用して、ネストされた列ヘッダを列の値 (ここでは "cyl" ) に移動する。これはまた、.row 列を追加する。 各予測が入力データフレームのどの行から来たかを示すインデックス。

tibble(mpg = c(21,22)) %>%
  add_epred_rvars(m_cyl, columns_to = "cyl")
## # A tibble: 6 × 4
##     mpg  .row cyl            .epred
##   <dbl> <int> <chr>      <rvar[1d]>
## 1    21     1 4      0.3653 ± 0.169
## 2    22     2 4      0.7247 ± 0.163
## 3    21     1 6      0.6085 ± 0.171
## 4    22     2 6      0.2675 ± 0.159
## 5    21     1 8      0.0262 ± 0.035
## 6    22     2 8      0.0078 ± 0.015

注: cyl 変数が元の因子レベルの名前を保持するためには、次のようにする。 は、brms バージョン 2.15.9 以降を使用している必要がある。

データセットに対して、フィットした確率のフィットラインをプロットすることができた。

data_plot = mtcars_clean %>%
  ggplot(aes(x = mpg, y = cyl, color = cyl)) +
  geom_point() +
  scale_color_brewer(palette = "Dark2", name = "cyl")

fit_plot = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  add_epred_rvars(m_cyl, value = "P(cyl | mpg)", columns_to = "cyl") %>%
  ggplot(aes(x = mpg, color = cyl)) +
  stat_dist_lineribbon(aes(dist = `P(cyl | mpg)`, fill = cyl), alpha = 1/5) +
  scale_color_brewer(palette = "Dark2") +
  scale_fill_brewer(palette = "Dark2") +
  labs(y = "P(cyl | mpg)")

plot_grid(ncol = 1, align = "v",
  data_plot,
  fit_plot
)
## Warning: Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.
## Unknown or uninitialised column: `linewidth`.

順序分布の平均について話すことは、しばしば意味を持たないが、この特定のケースでは、ガロンあたりのマイル数を与えられた車のシリンダー数の期待値は、意味のある量であると主張することができる。あるガロンあたりのマイル数を与えられた車の平均シリンダー数の事後分布を次のようにプロットすることができる。

\[ \textrm{E}[\textrm{cyl}|\textrm{mpg}=m] = \sum_{c \in \{4,6,8\}} c\cdot \textrm{P}(\textrm{cyl}=c|\textrm{mpg}=m) \]

add_epred_rvars() の出力の行列形式を考えると (つまり columns_to を使わない場合) 、この 量は、P(cyl|mpg)c(4,6,8) の内積に過ぎない。rvar のフォーマットは 行列の乗算を含む数学演算 ( %**% 演算子として) を行うことで、予測値を変換することができる。 列を簡単に期待値に変換することができる。以下は2行の例である。

tibble(mpg = c(21,22)) %>%
  # note we are *not* using `columns_to` anymore
  add_epred_rvars(m_cyl, value = "P(cyl | mpg)") %>%
  mutate(cyl = `P(cyl | mpg)` %**% c(4,6,8))
## # A tibble: 2 × 3
##     mpg `P(cyl | mpg)`[,"4"]       [,"6"]          [,"8"]     cyl[,1]
##   <dbl>           <rvar[,1]>   <rvar[,1]>      <rvar[,1]>  <rvar[,1]>
## 1    21          0.37 ± 0.17  0.61 ± 0.17  0.0262 ± 0.035  5.3 ± 0.35
## 2    22          0.72 ± 0.16  0.27 ± 0.16  0.0078 ± 0.015  4.6 ± 0.34

Altogetherに続き、unnest_rvars()、スパゲッティ・プロットを作成できるようになった。 訳註:下記はエラーが出るため eval=FALSE。

label_data_function = . %>% 
  ungroup() %>%
  filter(mpg == quantile(mpg, .47)) %>%
  summarise_if(is.numeric, mean)

data_plot_with_mean = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  # NOTE: use of ndraws = 100 here subsets draws for the creation of spaghetti plots;
  # DOT NOT do this if you are making other chart types like intervals or densities
  add_epred_rvars(m_cyl, value = "P(cyl | mpg)", ndraws = 100) %>%
  # calculate expected cylinder value
  mutate(cyl = `P(cyl | mpg)` %**% c(4,6,8)) %>%
  # transform in long-data-frame-of-draws format for making spaghetti plots
  unnest_rvars() %>%
  ggplot(aes(x = mpg, y = cyl)) +
  geom_line(aes(group = .draw), alpha = 5/100) +
  geom_point(aes(y = as.numeric(as.character(cyl)), fill = cyl), data = mtcars_clean, shape = 21, size = 2) +
  geom_text(aes(x = mpg + 4), label = "E[cyl | mpg]", data = label_data_function, hjust = 0) +
  geom_segment(aes(yend = cyl, xend = mpg + 3.9), data = label_data_function) +
  scale_fill_brewer(palette = "Set2", name = "cyl")

plot_grid(ncol = 1, align = "v",
  data_plot_with_mean,
  fit_plot
)

ここで、使用したしきい値に対する潜在線形予測変数のプロットを追加してみよう。 を使って、各カテゴリーの確率を決定する。posterior::as_draws_rvars() を利用することができる。 を使って、モデルからパラメータを rvar オブジェクトとして取得する。

draws_cyl = m_cyl %>%
  tidy_draws() %>%
  as_draws_rvars()

draws_cyl
## # A draws_rvars: 1000 iterations, 4 chains, and 11 variables
## $b_Intercept: rvar<1000,4>[2] mean ± sd:
## [1] -38 ± 12  -33 ± 11 
## 
## $b_mpg: rvar<1000,4>[1] mean ± sd:
## [1] -1.8 ± 0.57 
## 
## $disc: rvar<1000,4>[1] mean ± sd:
## [1] 1 ± 0 
## 
## $lprior: rvar<1000,4>[1] mean ± sd:
## [1] -5.1 ± 0.71 
## 
## $lp__: rvar<1000,4>[1] mean ± sd:
## [1] -13 ± 1.3 
## 
## $accept_stat__: rvar<1000,4>[1] mean ± sd:
## [1] 0.92 ± 0.12 
## 
## $stepsize__: rvar<1000,4>[1] mean ± sd:
## [1] 0.33 ± 0.035 
## 
## $treedepth__: rvar<1000,4>[1] mean ± sd:
## [1] 2.6 ± 0.65 
## 
## # ... with 3 more variables

私たちは、閾値を表す b_Intercept パラメータにとても興味がある。 を潜在的な線形予測変数の上に置く。

beta = draws_cyl$b_Intercept
beta
## rvar<1000,4>[2] mean ± sd:
## [1] -38 ± 12  -33 ± 11

また、線形予測変数が切片となる位置が欲しいところである。
の閾値と傾き( b_mpg )を使って計算することができる。

x_intercept = with(draws_cyl, b_Intercept / b_mpg)
x_intercept
## rvar<1000,4>[2] mean ± sd:
## [1] 21 ± 0.51  18 ± 0.50

add_linpred_rvars()add_epred_rvars() になぞらえて使うことで
潜在線形予測器を得る。これを beta の閾値と組み合わせる。 減算 beta [1] を線形予測器から、もう一方の閾値から減算する。 beta [2] , これらの値はすべて相関が高いため、 (したがって の不確実性を、その差異を見ることなく有意義な形で表現する)。私たちは また、.width = ppoints(XXX)stat_dist_lineribbon() で使用するデモも行っている。XXX は数字である。 3050 のように、低い alpha の値と組み合わせることで、グラデーションのような表現が可能になる。 lineribbons:

beta_2_color = brewer.pal(n = 3, name = "Dark2")[[3]]
beta_1_color = brewer.pal(n = 3, name = "Dark2")[[1]]

# vertical lines we will use to show the relationship between the linear 
# predictor and P(cyl | mpg)
x_intercept_lines = geom_vline(
  # this works because `rvar`s define median() to take the median of the 
  # distribution of each element, see vignette("rvar", package = "posterior")
  xintercept = median(x_intercept),
  color = "gray50",
  alpha = 0.2,
  size = 1
)

thresholds_plot = mtcars_clean %>%
  data_grid(mpg = seq_range(mpg, n = 101)) %>%
  add_linpred_rvars(m_cyl) %>%
  ggplot(aes(x = mpg)) +
  stat_dist_lineribbon(
    aes(dist = beta[2] - beta[1]),
    color = beta_2_color, fill = beta_2_color, 
    alpha = 1/30, .width = ppoints(30),
    size = 1, linetype = "21"
  ) +
  geom_line(aes(y = 0), size = 1, color = beta_1_color, linetype = "21") +
  stat_dist_lineribbon(
    aes(dist = .linpred - beta[1]),
    fill = "black", color = "black",
    alpha = 1/30, .width = ppoints(30)
  ) +
  labs(y = expression("linear predictor" - beta[1])) + 
  annotate("label",
    label = "beta[1]", parse = TRUE,
    x = max(mtcars_clean$mpg), y = 0, hjust = 0.8,
    color = beta_1_color
  ) +
  annotate("label",
    label = "beta[2] - beta[1]", parse = TRUE,
    x = max(mtcars_clean$mpg), y = median(beta[2] - beta[1]), hjust = 0.9,
    color = beta_2_color
  ) +
  coord_cartesian(ylim = c(-10, 10))

plot_grid(ncol = 1, align = "v", axis = "lr",
  data_plot_with_mean + x_intercept_lines,
  fit_plot + x_intercept_lines,
  thresholds_plot + x_intercept_lines
)

線形の予測変数が線と交差するとき、どのようになるか注意。 beta [1] カテゴリ 1と2が同じ確率で存在し、その線と交差するときに beta [2] 2と3は同じように可能性がある。

このモデルでlong-data-frame-of-drawsワークフローを使用した他の例については、こちらを参照。 (特定のタスクではこちらの方が簡単な場合がある) の該当セクションを参照。 vignette("tidy-brms") .