Implementation of quadratic integrate-and-fire model with Runge-Kutta method

[:large]

Neurdon

#-*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt

#quadratic integrate and fire model
class IzhNeuron:
	def __init__(self, label, a, b, c, d, v0, u0=None):
		self.label = label
		self.a = a; self.b = b
		self.c = c; self.d = d
		self.v = v0; self.u = u0 if u0 is not None else b*v0

class IzhSim:
	def __init__(self, n, T, dt=0.005):
		self.neuron = n; self.dt = dt
		self.t  = t = np.arange(0, T+dt, dt); self.stim   = np.zeros(len(t))
		self.x  = 5.; self.y  = 140.
		self.du = lambda a, b, v, u: a*(b*v - u)
	
	def izhikevich(self, x, t, s, a, b):
		return np.array([0.04 * x[0]**2 + 5.0*x[0] + 140 - x[1] + s, a*( b*x[0] -x[1] ) ])
	
	def integrate(self, n=None, ng = None):
		if n is None: n = self.neuron
		trace = np.zeros((2,len(self.t)))
		
		#4 order Runge_Kutta method
		for i, j in enumerate(self.stim):
			X = np.array([n.v, n.u])
			w1 = self.dt *ng(X, self.t[i], self.stim[i], n.a, n.b)
			w2 = self.dt * ng(X + w1 * 0.5, self.t[i] * self.dt, self.stim[i], n.a, n.b)
			w3 = self.dt * ng(X + w2 * 0.5, self.t[i] * self.dt, self.stim[i], n.a, n.b)
			w4 = self.dt * ng(X + w3, self.t[i] + self.dt, self.stim[i], n.a, n.b)
			n.v += 1./6. * (w1[0] + 2*w2[0] + 2*w3[0] + w4[0])
			n.u += 1./6. * (w1[1] + 2*w2[1] + 2*w3[1] + w4[1])
			if n.v > 30:
				trace[0,i] = 30
				n.v= n.c
				n.u   += n.d
			else:
				trace[0,i] = n.v
				trace[1,i] = n.u
		return trace

def main():
	sims = []
	## (A) phasic spiking
	n = IzhNeuron("(A) 5-HT phasic neuron", a=0.005, b=0.28, c=-57., d=2., v0=-60)
	s = IzhSim(n, T=200)
	for i, t in enumerate(s.t):
		s.stim[i] = 1.0 if t > 10 else 0
	sims.append(s)

	##(B) phasic spiking
	n = IzhNeuron("(B) GABA phasic neuron", a=0.02,b=0.25,c=-67.,d=2.,v0=-60)
	s = IzhSim(n, T=200)
	for i, t in enumerate(s.t):
		s.stim[i] = 1.0 if t > 10 else 0

	sims.append(s)
	for i,s in enumerate(sims):
		res = s.integrate(ng=s.izhikevich)
		plt.plot(s.t, res[0], s.t, -95 + ((s.stim - min(s.stim))/(max(s.stim) - min(s.stim)))*10)
	
	plt.show()

if __name__ == "__main__":
	main()