1 SETUP

1.1 Load libraries

1.2 Set options

2 DATA

2.1 Load and structure data

2.2 Add model predictions

3 STATS

3.1 Trial selection

3.2 Fit separate models to each participant

3.2.1 Children

# adult data 
# model_selection = c("adult_vision", "adult_vision_sound", "baseline", "matching")

# computational models 
model_selection = c("vision", "both", "baseline", "matching")

df.fits = df.choices %>% 
  mutate(answer = factor(answer, levels = 1:3)) %>% 
  left_join(df.predictions %>% 
              select(world, answer, all_of(model_selection)),
            by = c("world", "answer")) %>% 
  mutate(across(all_of(model_selection), ~ ifelse(. == 0, 0.01, .))) %>% 
  group_by(age, participant) %>% 
  summarize(across(all_of(model_selection), ~ exp(sum(log(.))))) %>% 
  ungroup() %>% 
  rowwise() %>% 
  mutate(sum = sum(c_across(all_of(model_selection))),
         across(all_of(model_selection), ~ . / sum)) %>% 
  select(-sum)

3.2.2 Adults

# model_selection = c("adult_vision", "adult_vision_sound", "baseline", "matching")

# computational models 
model_selection = c("vision", "both", "baseline", "matching")

df.fits.adults = df.adults %>% 
    mutate(answer = as.factor(answer)) %>% 
    filter(world %in% unique(df.choices$world)) %>% 
    left_join(df.predictions %>% 
                  select(world, answer, all_of(model_selection)),
              by = c("world", "answer")) %>% 
    mutate(across(all_of(model_selection), ~ ifelse(. == 0, 0.01, .))) %>% 
    group_by(experiment, participant) %>% 
    summarize(across(all_of(model_selection), ~ exp(sum(log(.))))) %>% 
    ungroup() %>% 
    rowwise() %>% 
    mutate(sum = sum(c_across(all_of(model_selection))),
           across(all_of(model_selection), ~ . / sum)) %>% 
    select(-sum)

3.3 Accuracy as a function of age

df.choices %>% 
    mutate(age = as.character(age)) %>% 
    group_by(age) %>% 
    summarize(pct_correct = sum(correct)/n()) %>% 
    bind_rows(df.adults %>% 
                  filter(experiment == "vision_sound") %>% 
                  mutate(correct = answer == ground_truth) %>% 
                  summarize(pct_correct = sum(correct)/n()) %>% 
                  mutate(age = "adult",
                         .before = pct_correct)) %>% 
    print_table()
age pct_correct
3 0.32
4 0.30
5 0.36
6 0.37
7 0.43
8 0.58
adult 0.67

3.4 Logistic regressions

3.4.1 Effect of age

  • tests whether the probability of being correct changes with age
df.stat = df.choices %>% 
    select(participant, age, trial, correct)

fit.age_correct = brm(formula = correct ~ 1 + age + (1 | participant),
    data = df.stat,
    family = "bernoulli",
    seed = 1,
    file = "cache/fit.age_correct")

fit.age_correct
 Family: bernoulli 
  Links: mu = logit 
Formula: correct ~ 1 + age + (1 | participant) 
   Data: df.stat (Number of observations: 576) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Group-Level Effects: 
~participant (Number of levels: 64) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.35      0.17     0.04     0.67 1.00      927      921

Population-Level Effects: 
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -1.65      0.34    -2.36    -1.00 1.00     4212     2815
age           0.22      0.06     0.11     0.33 1.00     4324     2657

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

3.4.2 7s and 8s above chance (33%)?

# 7 year olds
df.stat = df.choices %>% 
    select(participant, age, trial, correct) %>% 
    filter(age == 7)

fit.seven_correct = brm(formula = correct ~ 1 + (1 | participant),
    data = df.stat,
    family = "bernoulli",
    seed = 1,
    file = "cache/fit.seven_correct")

# 8 year olds
df.stat = df.choices %>% 
    select(participant, age, trial, correct) %>% 
    filter(age == 8)

fit.eight_correct = brm(formula = correct ~ 1 + (1 | participant),
    data = df.stat,
    family = "bernoulli",
    seed = 1,
    file = "cache/fit.eight_correct")

# results (in probability scale)
fit.seven_correct %>% 
    tidy() %>% 
    filter(effect == "fixed") %>% 
    select(estimate, contains("conf")) %>% 
    mutate(across(.cols = everything(),
                  .fns = ~ inv.logit(.)))
# A tibble: 1 × 3
  estimate conf.low conf.high
     <dbl>    <dbl>     <dbl>
1    0.429    0.292     0.562
fit.eight_correct %>% 
    tidy() %>% 
    filter(effect == "fixed") %>% 
    select(estimate, contains("conf")) %>% 
    mutate(across(.cols = everything(),
                  .fns = ~ inv.logit(.)))
# A tibble: 1 × 3
  estimate conf.low conf.high
     <dbl>    <dbl>     <dbl>
1    0.593    0.441     0.733

3.5 Probability of different strategies for different age ranges

df.fits %>% 
    ungroup() %>% 
    filter(age < 6) %>% 
    summarize(matching = mean(matching))
# A tibble: 1 × 1
  matching
     <dbl>
1    0.577
df.fits %>% 
    ungroup() %>% 
    filter(age < 6) %>% 
    summarize(simulation = mean(vision) + mean(both))
# A tibble: 1 × 1
  simulation
       <dbl>
1      0.252
df.fits %>% 
    ungroup() %>% 
    filter(age == 3) %>% 
    summarize(guessing = mean(baseline))
# A tibble: 1 × 1
  guessing
     <dbl>
1    0.198
df.fits %>% 
    ungroup() %>% 
    filter(age >= 6) %>% 
    summarize(simulation = mean(vision) + mean(both))
# A tibble: 1 × 1
  simulation
       <dbl>
1      0.546
df.fits %>% 
    ungroup() %>% 
    filter(age == 8) %>% 
    summarize(simulation = mean(both))
# A tibble: 1 × 1
  simulation
       <dbl>
1      0.473

4 PLOTS

4.1 Answers per world

4.1.1 Function to create plot with images

4.1.2 Model predictions (one row)

models = c("baseline", "matching", "vision", "both")
models = c("baseline", "matching", "adult_vision", "adult_vision_sound")

df.plot = df.predictions %>% 
    select(world, answer, all_of(models)) %>% 
    pivot_longer(cols = -c(world, answer),
                 names_to = "model",
                 values_to = "prediction") %>% 
    filter(model %in% models) %>% 
    mutate(model = factor(model,
                          levels = models,
                          # labels = c("guessing", "matching", "vision", "vision & sound")))
                          labels = c("guessing", "matching",
                                     "simulation (vision)","simulation (vision & sound)")))

func_load_image = function(world){
    readPNG(str_c("../../figures/ground_truth/trial_", world, ".png"))
}

# linking images and worlds
df.images = df.plot %>%
    distinct(world) %>%
    arrange(world) %>%
    mutate(grob = map(.x = world,
                      .f = ~ func_load_image(world = .x)))

df.text = df.plot %>% 
    distinct(world) %>% 
    arrange(world) %>% 
    mutate(index = 1:n(),
           x = 0.8,
           y = Inf)
  
# plotting
p = ggplot(data = df.plot,
           mapping = aes(x = answer,
                         y = prediction)) +
    geom_col(mapping = aes(fill = model,
                           group = model),
             position = position_dodge(width = 0.9),
             color = "black") +
    geom_hline(yintercept = 1/3,
               linetype = 2) + 
    geom_custom(data = df.images,
                mapping = aes(data = grob,
                              x = -Inf,
                              y = Inf),
                grob_fun = function(x) rasterGrob(x,
                                                  interpolate = T,
                                                  vjust = -0.05,
                                                  hjust = 0)) +
    geom_text(data = df.text,
              mapping = aes(x = x,
                            y = y,
                            label = index),
              size = 12,
              color = "white",
              vjust = -4) + 
    facet_wrap(~ world, 
               nrow = 1) +
    labs(y = "proportion %") + 
    scale_size_manual(values = c(0.5, 1.5)) + 
    scale_y_continuous(breaks = seq(0, 1, 0.25),
                       labels = str_c(seq(0, 100, 25), "%"),
                       limits = c(0, 1),
                       expand = expansion(add = c(0, 0))) +
    coord_cartesian(clip = "off") +
    scale_fill_brewer(palette = "Set1") + 
    theme(panel.grid.major.y = element_line(),
          axis.text.y = element_text(size = 25),
          axis.text.x = element_text(size = 25),
          axis.title = element_blank(),
          legend.position = "bottom",
          strip.background = element_blank(),
          strip.text = element_blank(),
          panel.background = element_rect(fill = NA, color = "black"),
          panel.spacing.x = unit(0.5, "cm"),
          plot.margin = margin(t = 5, l = 1, r = 0.2, b = 0, unit = "cm"))
p

ggsave(filename = "../../figures/plots/model_predictions_adults.pdf",
# ggsave(filename = "../../figures/plots/model_predictions_models.pdf",
       width = 22,
       height = 5)

4.2 Accuracy with age (continuous)

set.seed(1)

df.plot.individual = df.choices %>% 
    # add continuous age
    left_join(df.data %>% 
                  select(participant, age_cts),
              by = "participant") %>% 
    group_by(participant, age, age_cts) %>% 
    summarize(pct_correct = sum(correct)/n()) %>% 
    ungroup() %>% 
    mutate(age = factor(age, levels = c(3:8,
                                        "adult\nvision",
                                        "adult\nvision & sound")))

df.plot.means = df.choices %>% 
    mutate(age = as.character(age)) %>% 
    group_by(participant, age) %>% 
    summarize(pct_correct = sum(correct)/n()) %>% 
    bind_rows(df.adults %>% 
                  mutate(correct = answer == ground_truth,
                         age = "adult") %>% 
                  group_by(participant, experiment, age) %>% 
                  summarize(pct_correct = sum(correct)/n()) %>% 
                  mutate(age = str_c(age,"\n", experiment),
                         age = str_replace(age, "_", " & "))) %>% 
    group_by(age) %>% 
    summarize(response = Hmisc::smean.cl.boot(pct_correct),
              n = n()) %>% 
    mutate(index = c("mean", "low", "high")) %>% 
    ungroup() %>% 
    pivot_wider(names_from = index,
                values_from = response) %>% 
    left_join(df.choices %>% 
                  left_join(df.data %>% 
                                select(participant, age_cts),
                            by = "participant") %>% 
                  distinct(participant, age, age_cts) %>% 
                  group_by(age) %>% 
                  summarize(age_mean = mean(age_cts)) %>% 
                  mutate(age = as.character(age)),
              by = "age") %>% 
    mutate(age_mean = ifelse(age == "adult\nvision", 10, age_mean),
           age_mean = ifelse(age == "adult\nvision & sound", 11, age_mean),
           age = factor(age, levels = c(3:8, "adult\nvision", "adult\nvision & sound")))

df.text = df.plot.means %>% 
    select(age, age_mean, n) %>% 
    mutate(label = n,
           label = ifelse(age == "3", str_c("n = ", n), n),
           y = 0.95)

ggplot() + 
    geom_hline(yintercept = seq(0, 1, 0.1),
               linetype = 1,
               alpha = 0.1) + 
    geom_hline(yintercept = 1/3,
               linetype = 2,
               color = "gray50") + 
    geom_point(data = df.plot.individual,
               mapping = aes(x = age_cts,
                             y = pct_correct,
                             color = age),
               show.legend = F,
               size = 1) +
    geom_pointrange(data = df.plot.means,
                    mapping = aes(x = age_mean,
                                  y = mean,
                                  ymin = low,
                                  ymax = high,
                                  fill = age),
                    shape = 21,
                    size = 1,
                    show.legend = F) +
    geom_text(data = df.text,
              mapping = aes(x = age_mean,
                            y = y,
                            label = label),
              size = 6,
              color = "gray20") +
    labs(x = "age (in years)",
         y = "% correct") +
    coord_cartesian(ylim = c(0, 1)) + 
    scale_x_continuous(breaks = c(seq(3, 9, 1), 10:11),
                       labels = c(seq(3, 9, 1), "adult\nvision", "adult\nvision & sound"),
                       expand = expansion(add = c(0.2, 1))) + 
    scale_y_continuous(breaks = seq(0, 1, 0.2),
                       labels = str_c(seq(0, 1, 0.2) * 100, "%"),
                       expand = expansion(add = c(0, 0.05))) + 
    scale_fill_brewer(type = "qual", palette = 2) +
    scale_color_brewer(type = "qual", palette = 2) + 
    theme(axis.text.x = element_text(size = 14))
    
ggsave(filename = "../../figures/plots/points_accuracy_continuous.pdf",
       width = 10,
       height = 4)

4.3 Model fits

4.3.1 Children

4.3.1.1 Per participant

4.3.1.2 Per age group (with adults)

5 Session Info

R version 4.2.0 (2022-04-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Big Sur/Monterey 10.16

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] grid      stats     graphics  grDevices utils     datasets  methods  
[8] base     

other attached packages:
 [1] brms_2.17.0         Rcpp_1.0.8.3        forcats_0.5.1      
 [4] stringr_1.4.0       dplyr_1.0.9         purrr_0.3.4        
 [7] readr_2.1.2         tidyr_1.2.0         tibble_3.1.7       
[10] tidyverse_1.3.1     boot_1.3-28         emmeans_1.7.3      
[13] broom.mixed_0.2.9.4 corrr_0.4.3         patchwork_1.1.1    
[16] xtable_1.8-4        kableExtra_1.3.4    transport_0.12-2   
[19] janitor_2.1.0       knitr_1.39          ggpubr_0.4.0       
[22] egg_0.4.5           ggplot2_3.3.6       gridExtra_2.3      
[25] png_0.1-7           readxl_1.4.0       

loaded via a namespace (and not attached):
  [1] backports_1.4.1      Hmisc_4.7-0          systemfonts_1.0.4   
  [4] plyr_1.8.7           igraph_1.3.1         splines_4.2.0       
  [7] crosstalk_1.2.0      listenv_0.8.0        inline_0.3.19       
 [10] rstantools_2.2.0     digest_0.6.29        htmltools_0.5.2     
 [13] fansi_1.0.3          magrittr_2.0.3       checkmate_2.1.0     
 [16] cluster_2.1.3        tzdb_0.3.0           globals_0.14.0      
 [19] modelr_0.1.8         RcppParallel_5.1.5   matrixStats_0.62.0  
 [22] vroom_1.5.7          xts_0.12.1           svglite_2.1.0       
 [25] prettyunits_1.1.1    jpeg_0.1-9           colorspace_2.0-3    
 [28] rvest_1.0.2          haven_2.5.0          xfun_0.30           
 [31] callr_3.7.0          crayon_1.5.1         jsonlite_1.8.0      
 [34] survival_3.3-1       zoo_1.8-10           glue_1.6.2          
 [37] gtable_0.3.0         webshot_0.5.3        distributional_0.3.0
 [40] pkgbuild_1.3.1       car_3.0-13           rstan_2.21.5        
 [43] abind_1.4-5          scales_1.2.0         mvtnorm_1.1-3       
 [46] DBI_1.1.2            rstatix_0.7.0        miniUI_0.1.1.1      
 [49] htmlTable_2.4.0      viridisLite_0.4.0    diffobj_0.3.5       
 [52] foreign_0.8-82       bit_4.0.4            Formula_1.2-4       
 [55] StanHeaders_2.21.0-7 stats4_4.2.0         DT_0.22             
 [58] htmlwidgets_1.5.4    httr_1.4.2           threejs_0.3.3       
 [61] RColorBrewer_1.1-3   posterior_1.2.1      ellipsis_0.3.2      
 [64] pkgconfig_2.0.3      loo_2.5.1            farver_2.1.0        
 [67] nnet_7.3-17          sass_0.4.1           dbplyr_2.1.1        
 [70] utf8_1.2.2           tidyselect_1.1.2     rlang_1.0.2         
 [73] reshape2_1.4.4       later_1.3.0          munsell_0.5.0       
 [76] cellranger_1.1.0     tools_4.2.0          cli_3.3.0           
 [79] generics_0.1.2       broom_0.8.0          ggridges_0.5.3      
 [82] evaluate_0.15        fastmap_1.1.0        yaml_2.3.5          
 [85] bit64_4.0.5          processx_3.5.3       fs_1.5.2            
 [88] future_1.25.0        nlme_3.1-157         mime_0.12           
 [91] xml2_1.3.3           shinythemes_1.2.0    compiler_4.2.0      
 [94] bayesplot_1.9.0      rstudioapi_0.13      ggsignif_0.6.3      
 [97] reprex_2.0.1         bslib_0.3.1          stringi_1.7.6       
[100] highr_0.9            ps_1.7.0             Brobdingnag_1.2-8   
[103] lattice_0.20-45      Matrix_1.4-1         markdown_1.1        
[106] shinyjs_2.1.0        tensorA_0.36.2       vctrs_0.4.1         
[109] pillar_1.7.0         lifecycle_1.0.1      furrr_0.3.0         
[112] jquerylib_0.1.4      bridgesampling_1.1-2 estimability_1.3    
[115] data.table_1.14.2    httpuv_1.6.5         latticeExtra_0.6-29 
[118] R6_2.5.1             bookdown_0.26        promises_1.2.0.1    
[121] parallelly_1.31.1    codetools_0.2-18     colourpicker_1.1.1  
[124] gtools_3.9.2         assertthat_0.2.1     withr_2.5.0         
[127] shinystan_2.6.0      parallel_4.2.0       hms_1.1.1           
[130] rpart_4.1.16         coda_0.19-4          rmarkdown_2.14      
[133] snakecase_0.11.0     carData_3.0-5        shiny_1.7.1         
[136] lubridate_1.8.0      base64enc_0.1-3      dygraphs_1.1.1.6