Predictive coding light
Toy model of predictive coding network implemented in Nengo¶
Toy model of predictive coding network implemented in Nengo. The code realizes "predictive coding light" concept. The error signal teaches the integrator of upper hierarchy and becomes zero when finished teaching. The input signal and the signal from integrator specify the current error.
integrator part is taken from Nengo examples https://www.nengo.ai/nengo/examples/dynamics/integrator.html
Integrator equation dx/dt = u(t) where u(t) is the input stimulus
In [1]:
import plotly.graph_objects as go
import nengo
from nengo.processes import Piecewise
import time
from memory_profiler import memory_usage
start_time = time.time()
Construct network¶
In [2]:
model = nengo.Network()
# Parameters
tau_synapse = 0.2 # should be reasonably large
with model:
err = nengo.Ensemble(n_neurons=100, dimensions=1, radius=1)
layer1 = nengo.Ensemble(n_neurons=100, dimensions=1, radius=1)
stim = nengo.Node(Piecewise({0: 0, 0.2: 1, 3: -1, 10: 0.5}))
layer2 = nengo.Ensemble(n_neurons=100, dimensions=1, radius=1)
nengo.Connection(stim,layer1)
nengo.Connection(layer1, err)
def forward(u):
return tau_synapse*u
# feedforward error
nengo.Connection(err, layer2, function=forward, synapse=tau_synapse)
def recurrent(x):
return x
nengo.Connection(layer2, layer2, function=recurrent, synapse=tau_synapse)
nengo.Connection(layer2, err, transform=-1) # feedback to the error population
Add probes¶
In [3]:
with model:
layer1_probe = nengo.Probe(layer1, synapse=0.1)
error_probe = nengo.Probe(err, synapse=0.1)
layer2_probe = nengo.Probe(layer2, synapse=0.1)
Run the model¶
In [4]:
# Create simulator
with nengo.Simulator(model) as sim:
# Run it for 15 seconds
sim.run(15)
# mem_usage = memory_usage(A)
end_time = time.time()
print(f"Время симуляции: {end_time - start_time} секунд")
# print(f"Использование памяти: {max(mem_usage) - min(mem_usage)} МБ")
0%
0%
Время симуляции: 2.261934280395508 секунд
Plot the results¶
In [5]:
# Create Plotly figures
time = sim.trange()
error_data = sim.data[error_probe].flatten()
layer1_data = sim.data[layer1_probe].flatten()
layer2_data = sim.data[layer2_probe].flatten()
# RGB Plot
fig_rgb = go.Figure()
fig_rgb.update_layout(
autosize=False,
width=600,
height=600,
)
fig_rgb.add_trace(go.Scatter(x=time, y=error_data, mode='lines', name='Error', line=dict(color='red', dash='dash')))
fig_rgb.add_trace(go.Scatter(x=time, y=layer1_data, mode='lines', name='Layer 1 output', line=dict(color='green')))
fig_rgb.add_trace(go.Scatter(x=time, y=layer2_data, mode='lines', name='Layer 2 output', line=dict(color='blue', dash='dot')))
fig_rgb.update_layout(title='Decoded Output of Ensembles', xaxis_title='Time (s)', yaxis_title='Output value',
legend_title='Signal')
fig_rgb.show()
In [6]:
from brian2 import *
import numpy as np
import matplotlib.pyplot as plt
start_scope()
start_time = time.time()
# Параметры модели
n_neurons = 100 # Количество нейронов
duration = 15*second # Длительность симуляции
# Определение входного сигнала (Piecewise)
def stimulus_func(t):
if t < 0.2*second:
return 0
elif t < 3*second:
return 1
elif t < 10*second:
return -1
else:
return 1
stimulus = TimedArray([stimulus_func(t*defaultclock.dt) for t in range(int(duration/defaultclock.dt))], dt=defaultclock.dt)
# Первая популяция (входной слой)
layer1 = NeuronGroup(n_neurons, 'dv/dt = (-v + stimulus(t))/(20*ms) : 1',
threshold='v > 0.99', reset='v = 0.98', method='euler')
# Вторая популяция (рекуррентный слой)
eqs = '''
dv/dt = (-v + I) / (20*ms) : 1
I = ge + I_bias : 1
dge/dt = -ge / (5*ms) : 1
I_bias : 1
'''
layer2 = NeuronGroup(n_neurons, eqs, threshold='v > 0.35', reset='v = 0.29', method='euler')
layer2.I_bias = '0'
layer2.v = '0'
layer2.ge = '0'
# Параметры STDP
A_pre = 0.01
A_post = -A_pre * 1.2
tau_pre = 20*ms
tau_post = 20*ms
w_min = 0
w_max = 0.1
# Синапсы с STDP между layer1 и layer2
syn_1_2 = Synapses(layer1, layer2,
model='''
w : 1
dapre/dt = -apre / tau_pre : 1 (event-driven)
dapost/dt = -apost / tau_post : 1 (event-driven)
''',
on_pre='''
ge += w
apre += A_pre
w = clip(w + apost, w_min, w_max)
''',
on_post='''
apost += A_post
w = clip(w + apre, w_min, w_max)
''',
method='euler')
syn_1_2.connect(p=0.1)
syn_1_2.w = '0.1'
# Мониторинг
state_monitor_layer1 = StateMonitor(layer1, 'v', record=True)
state_monitor_layer2 = StateMonitor(layer2, 'v', record=True)
# Запуск симуляции
run(duration)
# mem_usage = memory_usage(A)
end_time = time.time()
print(f"Время симуляции: {end_time - start_time} секунд")
# Построение графиков
plt.figure(figsize=(12, 16))
# Средняя активность Layer 1
plt.subplot(4, 1, 1)
plt.plot(state_monitor_layer1.t/ms, state_monitor_layer1.v.mean(axis=0), label='Средняя активность Layer 1')
plt.title('Средняя активность Layer 1')
plt.xlabel('Время (мс)')
plt.ylabel('Среднее значение v')
plt.legend()
# Средняя активность Layer 2
plt.subplot(4, 1, 2)
plt.plot(state_monitor_layer2.t/ms, state_monitor_layer2.v.mean(axis=0), label='Средняя активность Layer 2')
plt.title('Средняя активность Layer 2')
plt.xlabel('Время (мс)')
plt.ylabel('Среднее значение v')
plt.legend()
plt.tight_layout()
plt.show()
Время симуляции: 12.89769434928894 секунд