quest_plus.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from copy import deepcopy
  2. from itertools import product
  3. import numpy as np
  4. import pandas as pd
  5. # TODO : Currently `weibull` is not used as default function
  6. # from psychometric import weibull
  7. # PARAMETERS of the psychometric function
  8. chance_level = 0 #e.g. chance_level should be 0.5 for 2AFC (Two-alternative forced choice) procedure
  9. threshold_prob = 1.-(1.-chance_level)/2.0 #the probability level at the threshold
  10. def reformat_params(params):
  11. '''Unroll multiple lists into array of their products.'''
  12. if isinstance(params, list):
  13. n_params = len(params)
  14. params = np.array(list(product(*params)))
  15. elif isinstance(params, np.ndarray):
  16. assert params.ndim == 1
  17. params = params[:, np.newaxis]
  18. return params
  19. # quest_plus.py comes also with psychometric.py wich includes the definition of the weibull and weibull_db function
  20. # here I define the logistic function using the same template that works with the quest_plus implementation
  21. def logistic(x, params, corr_at_thresh=threshold_prob, chance_level=chance_level):
  22. # unpack params
  23. if len(params) == 3:
  24. THRESHOLD, SLOPE, lapse = params
  25. else:
  26. THRESHOLD, SLOPE = params
  27. lapse = 0.
  28. b = 4 * SLOPE
  29. a = -b * THRESHOLD
  30. return chance_level + (1 - lapse - chance_level) / (1 + np.exp(-(a + b*x)))
  31. # that's a wrapper function to specify wich psychometric function one we want to use for the QUEST procedure
  32. def psychometric_fun( x , params ):
  33. return logistic(x , params , corr_at_thresh=threshold_prob, chance_level=chance_level)
  34. # TODO:
  35. # - [ ] highlight lowest point in entropy in plot
  36. class QuestPlus(object):
  37. def __init__(self, stim, params, function):
  38. self.function = function
  39. self.stim_domain = stim
  40. self.param_domain = reformat_params(params)
  41. self._orig_params = deepcopy(params)
  42. self._orig_param_shape = (list(map(len, params)) if
  43. isinstance(params, list) else len(params))
  44. self._orig_stim_shape = (list(map(len, params)) if
  45. isinstance(params, list) else len(params))
  46. n_stim, n_param = self.stim_domain.shape[0], self.param_domain.shape[0]
  47. # setup likelihoods for all combinations
  48. # of stimulus and model parameter domains
  49. self.likelihoods = np.zeros((n_stim, n_param, 2))
  50. for p in range(n_param):
  51. self.likelihoods[:, p, 0] = self.function(self.stim_domain,
  52. self.param_domain[p, :])
  53. # assumes (correct, incorrect) responses
  54. self.likelihoods[:, :, 1] = 1. - self.likelihoods[:, :, 0]
  55. # we also assume a flat prior (so we init posterior to flat too)
  56. self.posterior = np.ones(n_param)
  57. self.posterior /= self.posterior.sum()
  58. self.stim_history = list()
  59. self.resp_history = list()
  60. self.entropy = np.ones(n_stim)
  61. def update(self, contrast, ifcorrect, approximate=False):
  62. '''Update posterior probability with outcome of current trial.
  63. contrast - contrast value for the given trial
  64. ifcorrect - whether response was correct or not
  65. 1 - correct, 0 - incorrect
  66. '''
  67. # turn ifcorrect to response index
  68. resp_idx = 1 - ifcorrect
  69. contrast_idx = self._find_contrast_index(
  70. contrast, approximate=approximate)[0]
  71. # take likelihood of such resp for whole model parameter domain
  72. likelihood = self.likelihoods[contrast_idx, :, resp_idx]
  73. self.posterior *= likelihood
  74. self.posterior /= self.posterior.sum()
  75. # log history of contrasts and responses
  76. self.stim_history.append(contrast)
  77. self.resp_history.append(ifcorrect)
  78. def _find_contrast_index(self, contrast, approximate=False):
  79. contrast = np.atleast_1d(contrast)
  80. if not approximate:
  81. idx = [np.nonzero(self.stim_domain == cntrst)[0][0]
  82. for cntrst in contrast]
  83. else:
  84. idx = np.abs(self.stim_domain[np.newaxis, :] -
  85. contrast[:, np.newaxis]).argmin(axis=1)
  86. return idx
  87. def next_contrast(self, axis=None):
  88. '''Get contrast value minimizing entropy of the posterior
  89. distribution.
  90. Expected entropy is updated in self.entropy.
  91. Returns
  92. -------
  93. contrast : contrast value for the next trial.'''
  94. full_posterior = self.likelihoods * self.posterior[
  95. np.newaxis, :, np.newaxis]
  96. if axis is not None:
  97. shp = full_posterior.shape
  98. new_shape = [shp[0]] + self._orig_param_shape + [shp[-1]]
  99. full_posterior = full_posterior.reshape(new_shape)
  100. reduce_axes = np.arange(len(self._orig_param_shape)) + 1
  101. reduce_axes = tuple(np.delete(reduce_axes, axis))
  102. full_posterior = full_posterior.sum(axis=reduce_axes)
  103. norm = full_posterior.sum(axis=1, keepdims=True)
  104. full_posterior /= norm
  105. H = -np.nansum(full_posterior * np.log(full_posterior), axis=1)
  106. self.entropy = (norm[:, 0, :] * H).sum(axis=1)
  107. # choose contrast with minimal entropy
  108. return self.stim_domain[self.entropy.argmin()]
  109. def get_entropy(self):
  110. return self.entropy.min()
  111. def get_posterior(self):
  112. return self.posterior.reshape(self._orig_param_shape)
  113. def get_fit_params(self, select='mode'):
  114. if select in ['max', 'mode']:
  115. # parameters corresponding to maximum peak in posterior probability
  116. return self.param_domain[self.posterior.argmax(), :]
  117. elif select == 'mean':
  118. # parameters weighted by their probability
  119. return (self.posterior[:, np.newaxis] *
  120. self.param_domain).sum(axis=0)
  121. def fit(self, contrasts, responses, approximate=False):
  122. for contrast, response in zip(contrasts, responses):
  123. self.update(contrast, response, approximate=approximate)
  124. def plot(self):
  125. '''Plot posterior model parameter probabilities and weibull fits.'''
  126. pass
  127. # TODO : implement this method
  128. # return plot_quest_plus(self)