Accessing the contents of a stanfit object

Stan Development Team

2022-11-26

この vignette は、stanfitオブジェクトに格納されたデータの大部分にアクセスする方法を示している。stanfitオブジェクト(クラス "stanfit" のオブジェクト)は、マルコフ連鎖モンテカルロ法、またはStanの変分近似(meanfieldまたはfull-rank)の1つを用いたStanモデルのフィットから得られたアウトプットを含む。このドキュメントでは、Eight Schoolsのサンプルモデルをフィットして得られたstanfitオブジェクトを使用する。

library(rstan)
fit <- stan_demo("eight_schools", refresh = 0)
Trying to compile a simple C file
class(fit)
[1] "stanfit"
attr(,"package")
[1] "rstan"

Posterior draws

stanfitオブジェクトに格納された事後分布のドローにアクセスするために使用できる関数がいくつかある。これらは extract , as.matrix , as.data.frame , そして as.array で、それぞれ異なるフォーマットでドローを返す。


extract()

extract 関数(デフォルトの引数)は、モデルパラメータに対応する名前付き成分を含むリストを返す。

list_of_draws <- extract(fit)
print(names(list_of_draws))
[1] "mu"    "tau"   "eta"   "theta" "lp__" 

このモデルでは、パラメータ mutau はスカラーで、theta は8つの要素を持つベクタである。これは、mutau のドローはベクタ (長さはウォームアップ後の反復回数に鎖の数をかけたもの) になり、theta のドローは行列になり、各列は8つの成分の1つに対応することを意味する。

head(list_of_draws$mu)
[1] 16.000946  7.981731  7.012390  2.530267  6.803282  9.845921
head(list_of_draws$tau)
[1] 12.5966129  1.4002067  4.4626962 20.1262659  1.3163576  0.1482642
head(list_of_draws$theta)
          
iterations      [,1]      [,2]       [,3]       [,4]       [,5]      [,6]
      [1,] 39.949844 17.356663  17.978173   7.276805 -0.3718173  5.550738
      [2,]  8.829259  4.974844   5.867984   8.638270  7.7770525  7.119422
      [3,] 14.349446  4.094581   3.199722   6.499695 -3.8042597 13.240770
      [4,] 15.037156 24.176008 -16.562654 -10.244731 -2.1426799 20.414409
      [5,]  6.047071  5.404840   8.816903   5.294886  7.6056360  6.089837
      [6,]  9.832194  9.982824  10.117211   9.921085 10.0903975  9.945737
          
iterations       [,7]       [,8]
      [1,]  2.9745414   7.740255
      [2,]  8.4564151   8.407565
      [3,] -0.4732309   1.684472
      [4,] 31.1770075 -13.882266
      [5,]  6.5967283   5.353370
      [6,]  9.8468725   9.982764


as.matrix()、as.data.frame()、as.array() {#as.matrix(),-as.data.frame(),-as.array()}

as.matrix , as.data.frame , as.array 関数は、stanfit オブジェクトから事後ドローを取得するために使用することも可能である。

matrix_of_draws <- as.matrix(fit)
print(colnames(matrix_of_draws))
 [1] "mu"       "tau"      "eta[1]"   "eta[2]"   "eta[3]"   "eta[4]"  
 [7] "eta[5]"   "eta[6]"   "eta[7]"   "eta[8]"   "theta[1]" "theta[2]"
[13] "theta[3]" "theta[4]" "theta[5]" "theta[6]" "theta[7]" "theta[8]"
[19] "lp__"    
df_of_draws <- as.data.frame(fit)
print(colnames(df_of_draws))
 [1] "mu"       "tau"      "eta[1]"   "eta[2]"   "eta[3]"   "eta[4]"  
 [7] "eta[5]"   "eta[6]"   "eta[7]"   "eta[8]"   "theta[1]" "theta[2]"
[13] "theta[3]" "theta[4]" "theta[5]" "theta[6]" "theta[7]" "theta[8]"
[19] "lp__"    
array_of_draws <- as.array(fit)
print(dimnames(array_of_draws))
$iterations
NULL

$chains
[1] "chain:1" "chain:2" "chain:3" "chain:4"

$parameters
 [1] "mu"       "tau"      "eta[1]"   "eta[2]"   "eta[3]"   "eta[4]"  
 [7] "eta[5]"   "eta[6]"   "eta[7]"   "eta[8]"   "theta[1]" "theta[2]"
[13] "theta[3]" "theta[4]" "theta[5]" "theta[6]" "theta[7]" "theta[8]"
[19] "lp__"    

as.matrixas.data.frame メソッドは、それぞれ行列とデータフレーム形式を除いて、本質的に同じものを返す。as.array メソッドは各チェーンからのドローを別々に返すので、次元が追加されている。

print(dim(matrix_of_draws))
print(dim(df_of_draws))
print(dim(array_of_draws))
[1] 4000   19
[1] 4000   19
[1] 1000    4   19

デフォルトでは、事後ドローを取得するすべての関数は、all parameters (and generated quantities) に対するドローを返す。オプションの引数 pars (文字ベクタ) は、例えば、パラメータの部分集合のみが必要な場合に使用することができる。

mu_and_theta1 <- as.matrix(fit, pars = c("mu", "theta[1]"))
head(mu_and_theta1)
          parameters
iterations        mu  theta[1]
      [1,]  4.818409 27.308941
      [2,] 15.502630 11.626600
      [3,]  5.717083 11.553901
      [4,]  7.639136 -1.974016
      [5,]  5.834604  8.108018
      [6,]  6.492155  5.747456


事後要約統計と収束診断

要約統計は summary 関数で取得する。返されるオブジェクトは2つの要素を持つリストである。

fit_summary <- summary(fit)
print(names(fit_summary))
[1] "summary"   "c_summary"

fit_summary$summary では全てのチェインはマージされるが、 fit_summary$c_summary には、各チェーンの要約が個別に格納されている。通常、私たちはすべてのチェーンの要約をマージしたいのであるが、ここではそれに焦点を当てる。

要約は、パラメータに対応する行と、さまざまな要約量に対応する列を持つ行列である。これらは、事後平均、事後標準偏差、そして、抽選から計算された様々な分位数である。probs 引数は計算する分位数を指定するために使用され、pars は要約に含めるパラメータのサブセットを指定するために使用することができる。

MCMC を用いて適合したモデルについては、モンテカルロ標準誤差 ( se_mean )、有効サンプルサイズ ( n_eff )、R-hat統計量 ( Rhat )も要約に含まれている。

print(fit_summary$summary)
                 mean    se_mean        sd        2.5%         25%
mu         8.00697487 0.12695143 5.1818438  -2.2435023   4.6066332
tau        6.68324919 0.15188170 5.6457090   0.1685776   2.4568062
eta[1]     0.39575261 0.01546834 0.9358576  -1.4874527  -0.1906544
eta[2]     0.01066301 0.01380506 0.8740841  -1.7097379  -0.5756675
eta[3]    -0.19175138 0.01372680 0.9165516  -1.9948825  -0.8173802
eta[4]    -0.03618882 0.01399424 0.8688513  -1.7261816  -0.6061251
eta[5]    -0.35810031 0.01430566 0.8535416  -2.0251242  -0.9140415
eta[6]    -0.21237584 0.01579286 0.8707081  -1.8390263  -0.8224697
eta[7]     0.32086489 0.01390132 0.8801882  -1.4584343  -0.2501303
eta[8]     0.06311189 0.01489611 0.9260779  -1.7879700  -0.5520251
theta[1]  11.60401196 0.17198504 8.4744519  -2.1127163   6.0648662
theta[2]   7.98716870 0.09771440 6.2996284  -4.4230716   4.0106556
theta[3]   6.17555351 0.13208098 7.8231607 -11.5793793   2.1728332
theta[4]   7.61695463 0.10405501 6.5785326  -6.1174739   3.6677558
theta[5]   5.13998767 0.09368144 6.3530986  -8.5385294   1.2144369
theta[6]   6.08588123 0.10905135 6.8663894  -9.3456206   2.1304987
theta[7]  10.71466567 0.12362457 6.8737817  -1.2268722   6.0955530
theta[8]   8.62939570 0.14987971 8.0256347  -7.3080820   3.9258219
lp__     -39.47902403 0.07264253 2.6050627 -45.2539576 -41.0345173
                   50%         75%      97.5%    n_eff      Rhat
mu         7.948417136  11.3035903  18.755149 1666.071 1.0023964
tau        5.320002201   9.4660194  20.823182 1381.739 1.0016867
eta[1]     0.411158136   1.0289319   2.193529 3660.430 0.9997490
eta[2]     0.003562785   0.5790822   1.823607 4008.942 0.9993210
eta[3]    -0.197587363   0.4269975   1.618491 4458.361 0.9992458
eta[4]    -0.033070148   0.5393708   1.697758 3854.715 1.0002961
eta[5]    -0.375139037   0.1949353   1.371675 3559.865 1.0007069
eta[6]    -0.230941616   0.3700844   1.542308 3039.651 1.0004185
eta[7]     0.338738217   0.9136309   1.993621 4009.025 1.0008989
eta[8]     0.068498493   0.6845922   1.907414 3865.000 1.0002710
theta[1]  10.295902350  15.8475550  32.450989 2427.962 1.0005995
theta[2]   7.923974901  11.8355694  20.529657 4156.356 0.9997165
theta[3]   6.606830806  10.9101686  20.824607 3508.197 1.0008608
theta[4]   7.688560397  11.6879137  20.539363 3996.981 0.9996747
theta[5]   5.669796280   9.4558881  16.588829 4599.007 0.9997273
theta[6]   6.426207704  10.5361335  18.816166 3964.558 0.9996848
theta[7]  10.026985377  14.4577322  26.571155 3091.590 1.0002103
theta[8]   8.298905673  13.0591050  26.374881 2867.300 1.0001130
lp__     -39.234603636 -37.6340732 -35.097994 1286.040 1.0012221

例えば、含まれる分位数が10%と90%だけで、含まれるパラメータが mutau だけなら、次のように指定する。

mu_tau_summary <- summary(fit, pars = c("mu", "tau"), probs = c(0.1, 0.9))$summary
print(mu_tau_summary)
        mean   se_mean       sd       10%      90%    n_eff     Rhat
mu  8.006975 0.1269514 5.181844 1.6423816 14.64906 1666.071 1.002396
tau 6.683249 0.1518817 5.645709 0.9319342 14.16945 1381.739 1.001687

mu_tau_summary は行列なので、その名前を使って列を取り出すことができる。

mu_tau_80pct <- mu_tau_summary[, c("10%", "90%")]
print(mu_tau_80pct)
          10%      90%
mu  1.6423816 14.64906
tau 0.9319342 14.16945


サンプラー診断

MCMC を使って適合したモデルについては、stanfitオブジェクトはサンプラーに使われたパラメータの値も含んでいる。get_sampler_params 関数は、この情報にアクセスするために使用することができる。

get_sampler_params が返すオブジェクトは、1つの鎖につき1つのコンポーネント(行列)を持つリストである。各行列は、サンプラーパラメータの数に対応する数の列を持ち、列の名前はパラメータ名である。オプションの引数 inc_warmup (デフォルトは TRUE ) は、ウォームアップ期間を含めるかどうかを示す。

sampler_params <- get_sampler_params(fit, inc_warmup = FALSE)
sampler_params_chain1 <- sampler_params[[1]]
colnames(sampler_params_chain1)
[1] "accept_stat__" "stepsize__"    "treedepth__"   "n_leapfrog__" 
[5] "divergent__"   "energy__"     

各鎖の accept_stat__ の平均値(NUTS アルゴリズムを使用する場合は各鎖の treedepth__ の最大値など)を計算する場合、sapply 関数は sampler_params の各成分に同じ関数を適用するため便利である。

mean_accept_stat_by_chain <- sapply(sampler_params, function(x) mean(x[, "accept_stat__"]))
print(mean_accept_stat_by_chain)
[1] 0.9244861 0.8394644 0.9205059 0.8121741
max_treedepth_by_chain <- sapply(sampler_params, function(x) max(x[, "treedepth__"]))
print(max_treedepth_by_chain)
[1] 4 4 5 4


モデルコード

Stanプログラム自体もstanfitオブジェクトに格納されており、get_stancode でアクセスすることができる。

code <- get_stancode(fit)

オブジェクト code は単一の文字列であり、表示するとあまり分かりやすいものではない。

print(code)
[1] "data {\n  int<lower=0> J;          // number of schools\n  real y[J];               // estimated treatment effects\n  real<lower=0> sigma[J];  // s.e. of effect estimates\n}\nparameters {\n  real mu;\n  real<lower=0> tau;\n  vector[J] eta;\n}\ntransformed parameters {\n  vector[J] theta;\n  theta = mu + tau * eta;\n}\nmodel {\n  target += normal_lpdf(eta | 0, 1);\n  target += normal_lpdf(y | theta, sigma);\n}"
attr(,"model_name2")
[1] "schools"

読みやすいバージョンは、cat を使って表示することができる。

cat(code)
data {
  int<lower=0> J;          // number of schools
  real y[J];               // estimated treatment effects
  real<lower=0> sigma[J];  // s.e. of effect estimates
}
parameters {
  real mu;
  real<lower=0> tau;
  vector[J] eta;
}
transformed parameters {
  vector[J] theta;
  theta = mu + tau * eta;
}
model {
  target += normal_lpdf(eta | 0, 1);
  target += normal_lpdf(y | theta, sigma);
}


初期値

get_inits 関数は、チェーンごとに1つのコンポーネントを持つリストとして、初期値を返す。各コンポーネントはそれ自体、対応するチェーンの各パラメータの初期値を含む(名前付き)リストである。

inits <- get_inits(fit)
inits_chain1 <- inits[[1]]
print(inits_chain1)
$mu
[1] 0.09690272

$tau
[1] 0.1412118

$eta
[1] -0.7262940  0.5685015  1.6062456  0.5386035 -1.0973997 -1.0391758  1.7643938
[8] -1.0770862

$theta
[1] -0.005658565  0.177181843  0.323723563  0.172959890 -0.058063069
[6] -0.049841160  0.346055956 -0.055194565


(P)RNG シード

get_seed 関数は、(P)RNG の種子を整数で返す。

print(get_seed(fit))
[1] 57260474


ウォームアップとサンプリング時間

get_elapsed_time 関数は、各チェーンのウォームアップ時間とサンプリング時間を行列にして返す。

print(get_elapsed_time(fit))
        warmup sample
chain:1  0.083  0.084
chain:2  0.073  0.065
chain:3  0.080  0.089
chain:4  0.089  0.056