# adult data
# model_selection = c("adult_vision", "adult_vision_sound", "baseline", "matching")
# computational models
model_selection = c("vision", "both", "baseline", "matching")
df.fits = df.choices %>%
mutate(answer = factor(answer, levels = 1:3)) %>%
left_join(df.predictions %>%
select(world, answer, all_of(model_selection)),
by = c("world", "answer")) %>%
mutate(across(all_of(model_selection), ~ ifelse(. == 0, 0.01, .))) %>%
group_by(age, participant) %>%
summarize(across(all_of(model_selection), ~ exp(sum(log(.))))) %>%
ungroup() %>%
rowwise() %>%
mutate(sum = sum(c_across(all_of(model_selection))),
across(all_of(model_selection), ~ . / sum)) %>%
select(-sum)# model_selection = c("adult_vision", "adult_vision_sound", "baseline", "matching")
# computational models
model_selection = c("vision", "both", "baseline", "matching")
df.fits.adults = df.adults %>%
mutate(answer = as.factor(answer)) %>%
filter(world %in% unique(df.choices$world)) %>%
left_join(df.predictions %>%
select(world, answer, all_of(model_selection)),
by = c("world", "answer")) %>%
mutate(across(all_of(model_selection), ~ ifelse(. == 0, 0.01, .))) %>%
group_by(experiment, participant) %>%
summarize(across(all_of(model_selection), ~ exp(sum(log(.))))) %>%
ungroup() %>%
rowwise() %>%
mutate(sum = sum(c_across(all_of(model_selection))),
across(all_of(model_selection), ~ . / sum)) %>%
select(-sum)df.choices %>%
mutate(age = as.character(age)) %>%
group_by(age) %>%
summarize(pct_correct = sum(correct)/n()) %>%
bind_rows(df.adults %>%
filter(experiment == "vision_sound") %>%
mutate(correct = answer == ground_truth) %>%
summarize(pct_correct = sum(correct)/n()) %>%
mutate(age = "adult",
.before = pct_correct)) %>%
print_table()| age | pct_correct |
|---|---|
| 3 | 0.32 |
| 4 | 0.30 |
| 5 | 0.36 |
| 6 | 0.37 |
| 7 | 0.43 |
| 8 | 0.58 |
| adult | 0.67 |
df.stat = df.choices %>%
select(participant, age, trial, correct)
fit.age_correct = brm(formula = correct ~ 1 + age + (1 | participant),
data = df.stat,
family = "bernoulli",
seed = 1,
file = "cache/fit.age_correct")
fit.age_correct Family: bernoulli
Links: mu = logit
Formula: correct ~ 1 + age + (1 | participant)
Data: df.stat (Number of observations: 576)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Group-Level Effects:
~participant (Number of levels: 64)
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept) 0.35 0.17 0.04 0.67 1.00 927 921
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept -1.65 0.34 -2.36 -1.00 1.00 4212 2815
age 0.22 0.06 0.11 0.33 1.00 4324 2657
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
# 7 year olds
df.stat = df.choices %>%
select(participant, age, trial, correct) %>%
filter(age == 7)
fit.seven_correct = brm(formula = correct ~ 1 + (1 | participant),
data = df.stat,
family = "bernoulli",
seed = 1,
file = "cache/fit.seven_correct")
# 8 year olds
df.stat = df.choices %>%
select(participant, age, trial, correct) %>%
filter(age == 8)
fit.eight_correct = brm(formula = correct ~ 1 + (1 | participant),
data = df.stat,
family = "bernoulli",
seed = 1,
file = "cache/fit.eight_correct")
# results (in probability scale)
fit.seven_correct %>%
tidy() %>%
filter(effect == "fixed") %>%
select(estimate, contains("conf")) %>%
mutate(across(.cols = everything(),
.fns = ~ inv.logit(.)))# A tibble: 1 × 3
estimate conf.low conf.high
<dbl> <dbl> <dbl>
1 0.429 0.292 0.562
fit.eight_correct %>%
tidy() %>%
filter(effect == "fixed") %>%
select(estimate, contains("conf")) %>%
mutate(across(.cols = everything(),
.fns = ~ inv.logit(.)))# A tibble: 1 × 3
estimate conf.low conf.high
<dbl> <dbl> <dbl>
1 0.593 0.441 0.733
df.fits %>%
ungroup() %>%
filter(age < 6) %>%
summarize(matching = mean(matching))# A tibble: 1 × 1
matching
<dbl>
1 0.577
df.fits %>%
ungroup() %>%
filter(age < 6) %>%
summarize(simulation = mean(vision) + mean(both))# A tibble: 1 × 1
simulation
<dbl>
1 0.252
df.fits %>%
ungroup() %>%
filter(age == 3) %>%
summarize(guessing = mean(baseline))# A tibble: 1 × 1
guessing
<dbl>
1 0.198
df.fits %>%
ungroup() %>%
filter(age >= 6) %>%
summarize(simulation = mean(vision) + mean(both))# A tibble: 1 × 1
simulation
<dbl>
1 0.546
df.fits %>%
ungroup() %>%
filter(age == 8) %>%
summarize(simulation = mean(both))# A tibble: 1 × 1
simulation
<dbl>
1 0.473
models = c("baseline", "matching", "vision", "both")
models = c("baseline", "matching", "adult_vision", "adult_vision_sound")
df.plot = df.predictions %>%
select(world, answer, all_of(models)) %>%
pivot_longer(cols = -c(world, answer),
names_to = "model",
values_to = "prediction") %>%
filter(model %in% models) %>%
mutate(model = factor(model,
levels = models,
# labels = c("guessing", "matching", "vision", "vision & sound")))
labels = c("guessing", "matching",
"simulation (vision)","simulation (vision & sound)")))
func_load_image = function(world){
readPNG(str_c("../../figures/ground_truth/trial_", world, ".png"))
}
# linking images and worlds
df.images = df.plot %>%
distinct(world) %>%
arrange(world) %>%
mutate(grob = map(.x = world,
.f = ~ func_load_image(world = .x)))
df.text = df.plot %>%
distinct(world) %>%
arrange(world) %>%
mutate(index = 1:n(),
x = 0.8,
y = Inf)
# plotting
p = ggplot(data = df.plot,
mapping = aes(x = answer,
y = prediction)) +
geom_col(mapping = aes(fill = model,
group = model),
position = position_dodge(width = 0.9),
color = "black") +
geom_hline(yintercept = 1/3,
linetype = 2) +
geom_custom(data = df.images,
mapping = aes(data = grob,
x = -Inf,
y = Inf),
grob_fun = function(x) rasterGrob(x,
interpolate = T,
vjust = -0.05,
hjust = 0)) +
geom_text(data = df.text,
mapping = aes(x = x,
y = y,
label = index),
size = 12,
color = "white",
vjust = -4) +
facet_wrap(~ world,
nrow = 1) +
labs(y = "proportion %") +
scale_size_manual(values = c(0.5, 1.5)) +
scale_y_continuous(breaks = seq(0, 1, 0.25),
labels = str_c(seq(0, 100, 25), "%"),
limits = c(0, 1),
expand = expansion(add = c(0, 0))) +
coord_cartesian(clip = "off") +
scale_fill_brewer(palette = "Set1") +
theme(panel.grid.major.y = element_line(),
axis.text.y = element_text(size = 25),
axis.text.x = element_text(size = 25),
axis.title = element_blank(),
legend.position = "bottom",
strip.background = element_blank(),
strip.text = element_blank(),
panel.background = element_rect(fill = NA, color = "black"),
panel.spacing.x = unit(0.5, "cm"),
plot.margin = margin(t = 5, l = 1, r = 0.2, b = 0, unit = "cm"))
p
ggsave(filename = "../../figures/plots/model_predictions_adults.pdf",
# ggsave(filename = "../../figures/plots/model_predictions_models.pdf",
width = 22,
height = 5)set.seed(1)
df.plot.individual = df.choices %>%
# add continuous age
left_join(df.data %>%
select(participant, age_cts),
by = "participant") %>%
group_by(participant, age, age_cts) %>%
summarize(pct_correct = sum(correct)/n()) %>%
ungroup() %>%
mutate(age = factor(age, levels = c(3:8,
"adult\nvision",
"adult\nvision & sound")))
df.plot.means = df.choices %>%
mutate(age = as.character(age)) %>%
group_by(participant, age) %>%
summarize(pct_correct = sum(correct)/n()) %>%
bind_rows(df.adults %>%
mutate(correct = answer == ground_truth,
age = "adult") %>%
group_by(participant, experiment, age) %>%
summarize(pct_correct = sum(correct)/n()) %>%
mutate(age = str_c(age,"\n", experiment),
age = str_replace(age, "_", " & "))) %>%
group_by(age) %>%
summarize(response = Hmisc::smean.cl.boot(pct_correct),
n = n()) %>%
mutate(index = c("mean", "low", "high")) %>%
ungroup() %>%
pivot_wider(names_from = index,
values_from = response) %>%
left_join(df.choices %>%
left_join(df.data %>%
select(participant, age_cts),
by = "participant") %>%
distinct(participant, age, age_cts) %>%
group_by(age) %>%
summarize(age_mean = mean(age_cts)) %>%
mutate(age = as.character(age)),
by = "age") %>%
mutate(age_mean = ifelse(age == "adult\nvision", 10, age_mean),
age_mean = ifelse(age == "adult\nvision & sound", 11, age_mean),
age = factor(age, levels = c(3:8, "adult\nvision", "adult\nvision & sound")))
df.text = df.plot.means %>%
select(age, age_mean, n) %>%
mutate(label = n,
label = ifelse(age == "3", str_c("n = ", n), n),
y = 0.95)
ggplot() +
geom_hline(yintercept = seq(0, 1, 0.1),
linetype = 1,
alpha = 0.1) +
geom_hline(yintercept = 1/3,
linetype = 2,
color = "gray50") +
geom_point(data = df.plot.individual,
mapping = aes(x = age_cts,
y = pct_correct,
color = age),
show.legend = F,
size = 1) +
geom_pointrange(data = df.plot.means,
mapping = aes(x = age_mean,
y = mean,
ymin = low,
ymax = high,
fill = age),
shape = 21,
size = 1,
show.legend = F) +
geom_text(data = df.text,
mapping = aes(x = age_mean,
y = y,
label = label),
size = 6,
color = "gray20") +
labs(x = "age (in years)",
y = "% correct") +
coord_cartesian(ylim = c(0, 1)) +
scale_x_continuous(breaks = c(seq(3, 9, 1), 10:11),
labels = c(seq(3, 9, 1), "adult\nvision", "adult\nvision & sound"),
expand = expansion(add = c(0.2, 1))) +
scale_y_continuous(breaks = seq(0, 1, 0.2),
labels = str_c(seq(0, 1, 0.2) * 100, "%"),
expand = expansion(add = c(0, 0.05))) +
scale_fill_brewer(type = "qual", palette = 2) +
scale_color_brewer(type = "qual", palette = 2) +
theme(axis.text.x = element_text(size = 14))
ggsave(filename = "../../figures/plots/points_accuracy_continuous.pdf",
width = 10,
height = 4)R version 4.2.0 (2022-04-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur/Monterey 10.16
Matrix products: default
BLAS: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] grid stats graphics grDevices utils datasets methods
[8] base
other attached packages:
[1] brms_2.17.0 Rcpp_1.0.8.3 forcats_0.5.1
[4] stringr_1.4.0 dplyr_1.0.9 purrr_0.3.4
[7] readr_2.1.2 tidyr_1.2.0 tibble_3.1.7
[10] tidyverse_1.3.1 boot_1.3-28 emmeans_1.7.3
[13] broom.mixed_0.2.9.4 corrr_0.4.3 patchwork_1.1.1
[16] xtable_1.8-4 kableExtra_1.3.4 transport_0.12-2
[19] janitor_2.1.0 knitr_1.39 ggpubr_0.4.0
[22] egg_0.4.5 ggplot2_3.3.6 gridExtra_2.3
[25] png_0.1-7 readxl_1.4.0
loaded via a namespace (and not attached):
[1] backports_1.4.1 Hmisc_4.7-0 systemfonts_1.0.4
[4] plyr_1.8.7 igraph_1.3.1 splines_4.2.0
[7] crosstalk_1.2.0 listenv_0.8.0 inline_0.3.19
[10] rstantools_2.2.0 digest_0.6.29 htmltools_0.5.2
[13] fansi_1.0.3 magrittr_2.0.3 checkmate_2.1.0
[16] cluster_2.1.3 tzdb_0.3.0 globals_0.14.0
[19] modelr_0.1.8 RcppParallel_5.1.5 matrixStats_0.62.0
[22] vroom_1.5.7 xts_0.12.1 svglite_2.1.0
[25] prettyunits_1.1.1 jpeg_0.1-9 colorspace_2.0-3
[28] rvest_1.0.2 haven_2.5.0 xfun_0.30
[31] callr_3.7.0 crayon_1.5.1 jsonlite_1.8.0
[34] survival_3.3-1 zoo_1.8-10 glue_1.6.2
[37] gtable_0.3.0 webshot_0.5.3 distributional_0.3.0
[40] pkgbuild_1.3.1 car_3.0-13 rstan_2.21.5
[43] abind_1.4-5 scales_1.2.0 mvtnorm_1.1-3
[46] DBI_1.1.2 rstatix_0.7.0 miniUI_0.1.1.1
[49] htmlTable_2.4.0 viridisLite_0.4.0 diffobj_0.3.5
[52] foreign_0.8-82 bit_4.0.4 Formula_1.2-4
[55] StanHeaders_2.21.0-7 stats4_4.2.0 DT_0.22
[58] htmlwidgets_1.5.4 httr_1.4.2 threejs_0.3.3
[61] RColorBrewer_1.1-3 posterior_1.2.1 ellipsis_0.3.2
[64] pkgconfig_2.0.3 loo_2.5.1 farver_2.1.0
[67] nnet_7.3-17 sass_0.4.1 dbplyr_2.1.1
[70] utf8_1.2.2 tidyselect_1.1.2 rlang_1.0.2
[73] reshape2_1.4.4 later_1.3.0 munsell_0.5.0
[76] cellranger_1.1.0 tools_4.2.0 cli_3.3.0
[79] generics_0.1.2 broom_0.8.0 ggridges_0.5.3
[82] evaluate_0.15 fastmap_1.1.0 yaml_2.3.5
[85] bit64_4.0.5 processx_3.5.3 fs_1.5.2
[88] future_1.25.0 nlme_3.1-157 mime_0.12
[91] xml2_1.3.3 shinythemes_1.2.0 compiler_4.2.0
[94] bayesplot_1.9.0 rstudioapi_0.13 ggsignif_0.6.3
[97] reprex_2.0.1 bslib_0.3.1 stringi_1.7.6
[100] highr_0.9 ps_1.7.0 Brobdingnag_1.2-8
[103] lattice_0.20-45 Matrix_1.4-1 markdown_1.1
[106] shinyjs_2.1.0 tensorA_0.36.2 vctrs_0.4.1
[109] pillar_1.7.0 lifecycle_1.0.1 furrr_0.3.0
[112] jquerylib_0.1.4 bridgesampling_1.1-2 estimability_1.3
[115] data.table_1.14.2 httpuv_1.6.5 latticeExtra_0.6-29
[118] R6_2.5.1 bookdown_0.26 promises_1.2.0.1
[121] parallelly_1.31.1 codetools_0.2-18 colourpicker_1.1.1
[124] gtools_3.9.2 assertthat_0.2.1 withr_2.5.0
[127] shinystan_2.6.0 parallel_4.2.0 hms_1.1.1
[130] rpart_4.1.16 coda_0.19-4 rmarkdown_2.14
[133] snakecase_0.11.0 carData_3.0-5 shiny_1.7.1
[136] lubridate_1.8.0 base64enc_0.1-3 dygraphs_1.1.1.6