123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- from copy import deepcopy
- from itertools import product
- import numpy as np
- import pandas as pd
- chance_level = 0
- threshold_prob = 1.-(1.-chance_level)/2.0
- def reformat_params(params):
- '''Unroll multiple lists into array of their products.'''
- if isinstance(params, list):
- n_params = len(params)
- params = np.array(list(product(*params)))
- elif isinstance(params, np.ndarray):
- assert params.ndim == 1
- params = params[:, np.newaxis]
- return params
- def logistic(x, params, corr_at_thresh=threshold_prob, chance_level=chance_level):
-
- if len(params) == 3:
- THRESHOLD, SLOPE, lapse = params
- else:
- THRESHOLD, SLOPE = params
- lapse = 0.
- b = 4 * SLOPE
- a = -b * THRESHOLD
- return chance_level + (1 - lapse - chance_level) / (1 + np.exp(-(a + b*x)))
-
- def psychometric_fun( x , params ):
- return logistic(x , params , corr_at_thresh=threshold_prob, chance_level=chance_level)
- class QuestPlus(object):
- def __init__(self, stim, params, function):
- self.function = function
- self.stim_domain = stim
- self.param_domain = reformat_params(params)
- self._orig_params = deepcopy(params)
- self._orig_param_shape = (list(map(len, params)) if
- isinstance(params, list) else len(params))
- self._orig_stim_shape = (list(map(len, params)) if
- isinstance(params, list) else len(params))
- n_stim, n_param = self.stim_domain.shape[0], self.param_domain.shape[0]
-
-
- self.likelihoods = np.zeros((n_stim, n_param, 2))
- for p in range(n_param):
- self.likelihoods[:, p, 0] = self.function(self.stim_domain,
- self.param_domain[p, :])
-
- self.likelihoods[:, :, 1] = 1. - self.likelihoods[:, :, 0]
-
- self.posterior = np.ones(n_param)
- self.posterior /= self.posterior.sum()
- self.stim_history = list()
- self.resp_history = list()
- self.entropy = np.ones(n_stim)
- def update(self, contrast, ifcorrect, approximate=False):
- '''Update posterior probability with outcome of current trial.
- contrast - contrast value for the given trial
- ifcorrect - whether response was correct or not
- 1 - correct, 0 - incorrect
- '''
-
- resp_idx = 1 - ifcorrect
- contrast_idx = self._find_contrast_index(
- contrast, approximate=approximate)[0]
-
- likelihood = self.likelihoods[contrast_idx, :, resp_idx]
- self.posterior *= likelihood
- self.posterior /= self.posterior.sum()
-
- self.stim_history.append(contrast)
- self.resp_history.append(ifcorrect)
- def _find_contrast_index(self, contrast, approximate=False):
- contrast = np.atleast_1d(contrast)
- if not approximate:
- idx = [np.nonzero(self.stim_domain == cntrst)[0][0]
- for cntrst in contrast]
- else:
- idx = np.abs(self.stim_domain[np.newaxis, :] -
- contrast[:, np.newaxis]).argmin(axis=1)
- return idx
- def next_contrast(self, axis=None):
- '''Get contrast value minimizing entropy of the posterior
- distribution.
- Expected entropy is updated in self.entropy.
- Returns
- -------
- contrast : contrast value for the next trial.'''
- full_posterior = self.likelihoods * self.posterior[
- np.newaxis, :, np.newaxis]
- if axis is not None:
- shp = full_posterior.shape
- new_shape = [shp[0]] + self._orig_param_shape + [shp[-1]]
- full_posterior = full_posterior.reshape(new_shape)
- reduce_axes = np.arange(len(self._orig_param_shape)) + 1
- reduce_axes = tuple(np.delete(reduce_axes, axis))
- full_posterior = full_posterior.sum(axis=reduce_axes)
- norm = full_posterior.sum(axis=1, keepdims=True)
- full_posterior /= norm
- H = -np.nansum(full_posterior * np.log(full_posterior), axis=1)
- self.entropy = (norm[:, 0, :] * H).sum(axis=1)
-
- return self.stim_domain[self.entropy.argmin()]
- def get_entropy(self):
- return self.entropy.min()
-
- def get_posterior(self):
- return self.posterior.reshape(self._orig_param_shape)
- def get_fit_params(self, select='mode'):
- if select in ['max', 'mode']:
-
- return self.param_domain[self.posterior.argmax(), :]
- elif select == 'mean':
-
- return (self.posterior[:, np.newaxis] *
- self.param_domain).sum(axis=0)
- def fit(self, contrasts, responses, approximate=False):
- for contrast, response in zip(contrasts, responses):
- self.update(contrast, response, approximate=approximate)
- def plot(self):
- '''Plot posterior model parameter probabilities and weibull fits.'''
- pass
-
-
|