Add errorbars

imperator 1 year ago
parent 0f2a06a74b
commit 660342aab0

@ -1,6 +1,7 @@
from itertools import product from itertools import product
import numpy as np import numpy as np
from scipy.integrate import trapezoid from scipy.integrate import trapezoid
from scipy.stats import sem
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.lines import Line2D from matplotlib.lines import Line2D
from IPython.display import display from IPython.display import display
@ -36,6 +37,22 @@ def draw_figure(fig):
plt.show() plt.show()
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
if len(err_container.lines[1]) == 2:
lower_caps, upper_caps = err_container.lines[1]
lower_caps.set_data(x, y - yerr)
upper_caps.set_data(x, y + yerr)
def maybe_setup(setup_fun, state): def maybe_setup(setup_fun, state):
if not is_colab: if not is_colab:
return return
@ -202,17 +219,6 @@ def plot_sims(C_size=11, num_sims=30 if not is_colab else 5):
interact(update_plot, **sliders) interact(update_plot, **sliders)
def update_errorbar(err_container, x, y, yerr):
err_container.lines[0].set_data(x, y)
linecol = err_container.lines[2][0]
segments = []
for xi, yi, yerri in zip(x, y, yerr):
segments.append([[xi, yi - yerri], [xi, yi + yerri]])
linecol.set_segments(segments)
def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000): def plot_model_free_analysis_conditions(C, ks, num_sims_per_condition=2_000):
setup_matplotlib_magic() setup_matplotlib_magic()
@ -437,7 +443,7 @@ def plot_single_neuron(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(figsize=(6.5, 4.5)) fig, axes = plt.subplots(figsize=(6.5, 4.5))
neuron_line = axes.plot([], [])[0] neuron_line = axes.errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label='mean firing rate with 95% CI')
axes.set( axes.set(
ylabel=r'$\sqrt{N_\mathrm{spikes}}$', ylabel=r'$\sqrt{N_\mathrm{spikes}}$',
@ -445,6 +451,8 @@ def plot_single_neuron(mat_data):
xlim=(0, 800) xlim=(0, 800)
) )
axes.legend(loc='upper right')
return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line} return {'fig': fig, 'axes': axes, 'neuron_line': neuron_line}
state = setup() state = setup()
@ -452,7 +460,12 @@ def plot_single_neuron(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['neuron_line'].set_data(time, binned_spike_matrix.mean(axis=0)[:, neuron_idx]) update_errorbar(
state['neuron_line'],
time,
binned_spike_matrix[:, :, neuron_idx].mean(axis=0),
yerr=sem(binned_spike_matrix[:, :, neuron_idx], axis=0) * 1.96
)
state['axes'].relim() state['axes'].relim()
state['axes'].autoscale(axis='y') state['axes'].autoscale(axis='y')
@ -478,15 +491,15 @@ def plot_neuron_by_choice(mat_data):
def setup(): def setup():
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True) fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True)
choices = ['right choice', 'left choice'] choices = ['right choice (95% CI)', 'left choice (95% CI)']
correct_lines = [] correct_lines = []
for choice in choices: for choice in choices:
correct_line = axes[0].plot([], [], label=choice)[0] correct_line = axes[0].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for choice in choices: for choice in choices:
incorrect_line = axes[1].plot([], [], label=choice)[0] incorrect_line = axes[1].errorbar([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=choice)
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -509,10 +522,30 @@ def plot_neuron_by_choice(mat_data):
def update_plot(neuron_idx): def update_plot(neuron_idx):
maybe_setup(setup, state) maybe_setup(setup, state)
state['correct_lines'][0].set_data(time, binned_spike_matrix[correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['correct_lines'][1].set_data(time, binned_spike_matrix[correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) state['correct_lines'][0],
state['incorrect_lines'][0].set_data(time, binned_spike_matrix[~correct_trials_mask & right_choice].mean(axis=0)[:, neuron_idx]) time,
state['incorrect_lines'][1].set_data(time, binned_spike_matrix[~correct_trials_mask & ~right_choice].mean(axis=0)[:, neuron_idx]) binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['correct_lines'][1],
time,
binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][0],
time,
binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & right_choice][:, :, neuron_idx], axis=0) * 1.96
)
update_errorbar(
state['incorrect_lines'][1],
time,
binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx].mean(axis=0),
sem(binned_spike_matrix[~correct_trials_mask & ~right_choice][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()
@ -546,12 +579,12 @@ def plot_neuron_by_coherence(mat_data):
choices = ['right choice', 'left choice'] choices = ['right choice', 'left choice']
correct_lines = [] correct_lines = []
for coherence in coherences: for coherence in coherences:
correct_line = axes[0].plot([], [], label=f'{coherence = :.1%}')[0] correct_line = axes[0].plot([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
correct_lines += [correct_line] correct_lines += [correct_line]
incorrect_lines = [] incorrect_lines = []
for coherence in coherences: for coherence in coherences:
incorrect_line = axes[1].plot([], [], label=f'{coherence = :.1%}')[0] incorrect_line = axes[1].plot([np.nan], [np.nan], yerr=[np.nan], capsize=4., label=f'{coherence = :.1%} (95% CI)')
incorrect_lines += [incorrect_line] incorrect_lines += [incorrect_line]
axes[0].set( axes[0].set(
@ -577,8 +610,16 @@ def plot_neuron_by_coherence(mat_data):
for i, coherence in enumerate(coherences): for i, coherence in enumerate(coherences):
coherence_mask = (mat_data['dot_coh'].flatten() == coherence) coherence_mask = (mat_data['dot_coh'].flatten() == coherence)
state['correct_lines'][i].set_data(time, binned_spike_matrix[correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) update_errorbar(
state['incorrect_lines'][i].set_data(time, binned_spike_matrix[~correct_trials_mask & coherence_mask].mean(axis=0)[:, neuron_idx]) state['correct_lines'][i],
time,
binned_spike_matrix[correct_trials_mask & coherence_mask][:, :, neuron_idx].mean(axis=0)
)
update_errorbar(
state['incorrect_lines'][i],
time,
sem(binned_spike_matrix[~correct_trials_mask & coherence_mask][:, :, neuron_idx], axis=0) * 1.96
)
state['axes'][0].relim() state['axes'][0].relim()
state['axes'][1].relim() state['axes'][1].relim()

Loading…
Cancel
Save