Add errorbars

imperator 1 year ago
parent 0f2a06a74b
commit b3413d955f

@ -1,6 +1,7 @@
from itertools import product from itertools import product
import numpy as np import numpy as np
from scipy.integrate import trapezoid from scipy.integrate import trapezoid
from scipy.stats import sem
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.lines import Line2D from matplotlib.lines import Line2D
from IPython.display import display from IPython.display import display
@ -36,6 +37,22 @@ def draw_figure(fig):
plt.show() plt.show()
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
if len(err_container.lines[1]) == 2:
lower_caps, upper_caps = err_container.lines[1]
lower_caps.set_data(x, y - yerr)
upper_caps.set_data(x, y + yerr)
def maybe_setup(setup_fun, state): def maybe_setup(setup_fun, state):
if not is_colab: if not is_colab:
return return
@ -202,17 +219,6 @@ def plot_sims(C_size=11, num_sims=30 if not is_colab else 5):
interact(update_plot, **sliders) interact(update_plot, **sliders)
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000): def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000):
setup_matplotlib_magic() setup_matplotlib_magic()
@ -437,13 +443,15 @@ def plot_single_neuron(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(figsize=(6.5, 4.5)) fig, axes = plt.subplots(figsize=(6.5, 4.5))
neuron_line = axes.plot([], [])[0] neuron_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='mean firing rate with 95% CI')
axes.set( axes.set(
ylabel=r'$\sqrt{N_\mathrm{spikes}}$', ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
xlabel='time [ms]', xlabel='time [ms]',
xlim=(0, 800) xlim=(0, 800)
) )
axes.legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line} return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line}
@ -451,12 +459,17 @@ def plot_single_neuron(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['neuron_line'].set_data(time, binned_spike_matrix.mean(axis=0)[:, neuron_idx]) update_errorbar(
state['neuron_line'],
time,
binned_spike_matrix[:, :, neuron_idx].mean(axis=0),
yerr=sem(binned_spike_matrix[:, :, neuron_idx], axis=0) * 1.96
)
state['axes'].relim() state['axes'].relim()
state['axes'].autoscale(axis='y') state['axes'].autoscale(axis='y')
state['axes'].set_title(f'Neuron #{neuron_idx}', fontsize='small') state['axes'].set_title(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])
@ -478,15 +491,15 @@ def plot_neuron_by_choice(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True) fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
choices = ['right choice', 'left choice'] choices = ['right choice (95% CI)', 'left choice (95% CI)']
correct_lines = [] correct_lines = []
for choice in choices: for choice in choices:
correct_line = axes[0].plot([], [], label=choice)[0] correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for choice in choices: for choice in choices:
incorrect_line = axes[1].plot([], [], label=choice)[0] incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -499,8 +512,8 @@ def plot_neuron_by_choice(mat_data):
title='incorrect trials', title='incorrect trials',
xlabel='time [ms]' xlabel='time [ms]'
) )
axes[0].legend(loc='upper right') axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right') axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
@ -509,16 +522,36 @@ def plot_neuron_by_choice(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['correct_lines'][0].set_data(time, binned_spike_matrix[correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['correct_lines'][1].set_data(time, binned_spike_matrix[correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) state['correct_lines'][0],
state['incorrect_lines'][0].set_data(time, binned_spike_matrix[~correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) time,
state['incorrect_lines'][1].set_data(time, binned_spike_matrix[~correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['correct_lines'][1],
time,
binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][0],
time,
binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][1],
time,
binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()
state['axes'][0].autoscale(axis='y') state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y') state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])
@ -546,12 +579,12 @@ def plot_neuron_by_coherence(mat_data):
choices = ['right choice', 'left choice'] choices = ['right choice', 'left choice']
correct_lines = [] correct_lines = []
for coherence in coherences: for coherence in coherences:
correct_line = axes[0].plot([], [], label=f'{coherence = :.1%}')[0] correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for coherence in coherences: for coherence in coherences:
incorrect_line = axes[1].plot([], [], label=f'{coherence = :.1%}')[0] incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -564,8 +597,8 @@ def plot_neuron_by_coherence(mat_data):
title='incorrect trials', title='incorrect trials',
xlabel='time [ms]' xlabel='time [ms]'
) )
axes[0].legend(loc='upper right') axes[0].legend(loc='upper right', fontsize='small')
axes[1].legend(loc='upper right') axes[1].legend(loc='upper right', fontsize='small')
return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines} return {'fig': fig, 'axes': axes, 'correct_lines': correct_lines, 'incorrect_lines': incorrect_lines}
@ -577,14 +610,22 @@ def plot_neuron_by_coherence(mat_data):
for i, coherence in enumerate(coherences): for i, coherence in enumerate(coherences):
coherence_mask = (mat_data['dot_coh'].flatten() == coherence) coherence_mask = (mat_data['dot_coh'].flatten() == coherence)
state['correct_lines'][i].set_data(time, binned_spike_matrix[correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['incorrect_lines'][i].set_data(time, binned_spike_matrix[~correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) state['correct_lines'][i],
time,
binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0)
)
update_errorbar(
state['incorrect_lines'][i],
time,
sem(binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()
state['axes'][0].autoscale(axis='y') state['axes'][0].autoscale(axis='y')
state['axes'][1].autoscale(axis='y') state['axes'][1].autoscale(axis='y')
state['fig'].suptitle(f'Neuron #{neuron_idx}', fontsize='small') state['fig'].suptitle(f'Neuron #{neuron_idx}')
state['fig'].tight_layout() state['fig'].tight_layout()
draw_figure(state['fig']) draw_figure(state['fig'])

Loading…
Cancel
Save