quest_plus.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from copy import deepcopy
  2. from itertools import product
  3. import numpy as np
  4. import pandas as pd
  5. # TODO : Currently `weibull` is not used as default function
  6. # from psychometric import weibull
  7. def reformat_params(params):
  8. '''Unroll multiple lists into array of their products.'''
  9. if isinstance(params, list):
  10. n_params = len(params)
  11. params = np.array(list(product(*params)))
  12. elif isinstance(params, np.ndarray):
  13. assert params.ndim == 1
  14. params = params[:, np.newaxis]
  15. return params
  16. # TODO:
  17. # - [ ] highlight lowest point in entropy in plot
  18. class QuestPlus(object):
  19. def __init__(self, stim, params, function):
  20. self.function = function
  21. self.stim_domain = stim
  22. self.param_domain = reformat_params(params)
  23. self._orig_params = deepcopy(params)
  24. self._orig_param_shape = (list(map(len, params)) if
  25. isinstance(params, list) else len(params))
  26. self._orig_stim_shape = (list(map(len, params)) if
  27. isinstance(params, list) else len(params))
  28. n_stim, n_param = self.stim_domain.shape[0], self.param_domain.shape[0]
  29. # setup likelihoods for all combinations
  30. # of stimulus and model parameter domains
  31. self.likelihoods = np.zeros((n_stim, n_param, 2))
  32. for p in range(n_param):
  33. self.likelihoods[:, p, 0] = self.function(self.stim_domain,
  34. self.param_domain[p, :])
  35. # assumes (correct, incorrect) responses
  36. self.likelihoods[:, :, 1] = 1. - self.likelihoods[:, :, 0]
  37. # we also assume a flat prior (so we init posterior to flat too)
  38. self.posterior = np.ones(n_param)
  39. self.posterior /= self.posterior.sum()
  40. self.stim_history = list()
  41. self.resp_history = list()
  42. self.entropy = np.ones(n_stim)
  43. def update(self, contrast, ifcorrect, approximate=False):
  44. '''Update posterior probability with outcome of current trial.
  45. contrast - contrast value for the given trial
  46. ifcorrect - whether response was correct or not
  47. 1 - correct, 0 - incorrect
  48. '''
  49. # turn ifcorrect to response index
  50. resp_idx = 1 - ifcorrect
  51. contrast_idx = self._find_contrast_index(
  52. contrast, approximate=approximate)[0]
  53. # take likelihood of such resp for whole model parameter domain
  54. likelihood = self.likelihoods[contrast_idx, :, resp_idx]
  55. self.posterior *= likelihood
  56. self.posterior /= self.posterior.sum()
  57. # log history of contrasts and responses
  58. self.stim_history.append(contrast)
  59. self.resp_history.append(ifcorrect)
  60. def _find_contrast_index(self, contrast, approximate=False):
  61. contrast = np.atleast_1d(contrast)
  62. if not approximate:
  63. idx = [np.nonzero(self.stim_domain == cntrst)[0][0]
  64. for cntrst in contrast]
  65. else:
  66. idx = np.abs(self.stim_domain[np.newaxis, :] -
  67. contrast[:, np.newaxis]).argmin(axis=1)
  68. return idx
  69. def next_contrast(self, axis=None):
  70. '''Get contrast value minimizing entropy of the posterior
  71. distribution.
  72. Expected entropy is updated in self.entropy.
  73. Returns
  74. -------
  75. contrast : contrast value for the next trial.'''
  76. full_posterior = self.likelihoods * self.posterior[
  77. np.newaxis, :, np.newaxis]
  78. if axis is not None:
  79. shp = full_posterior.shape
  80. new_shape = [shp[0]] + self._orig_param_shape + [shp[-1]]
  81. full_posterior = full_posterior.reshape(new_shape)
  82. reduce_axes = np.arange(len(self._orig_param_shape)) + 1
  83. reduce_axes = tuple(np.delete(reduce_axes, axis))
  84. full_posterior = full_posterior.sum(axis=reduce_axes)
  85. norm = full_posterior.sum(axis=1, keepdims=True)
  86. full_posterior /= norm
  87. H = -np.nansum(full_posterior * np.log(full_posterior), axis=1)
  88. self.entropy = (norm[:, 0, :] * H).sum(axis=1)
  89. # choose contrast with minimal entropy
  90. return self.stim_domain[self.entropy.argmin()]
  91. def get_entropy(self):
  92. return self.entropy.min()
  93. def get_posterior(self):
  94. return self.posterior.reshape(self._orig_param_shape)
  95. def get_fit_params(self, select='mode'):
  96. if select in ['max', 'mode']:
  97. # parameters corresponding to maximum peak in posterior probability
  98. return self.param_domain[self.posterior.argmax(), :]
  99. elif select == 'mean':
  100. # parameters weighted by their probability
  101. return (self.posterior[:, np.newaxis] *
  102. self.param_domain).sum(axis=0)
  103. def fit(self, contrasts, responses, approximate=False):
  104. for contrast, response in zip(contrasts, responses):
  105. self.update(contrast, response, approximate=approximate)
  106. def plot(self):
  107. '''Plot posterior model parameter probabilities and weibull fits.'''
  108. pass
  109. # TODO : implement this method
  110. # return plot_quest_plus(self)