CARLsim  3.0.3
CARLsim: a GPU-accelerated SNN simulator
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
simple_weight_tuner.cpp
Go to the documentation of this file.
1 #include "simple_weight_tuner.h"
2 
3 #include <carlsim.h> // CARLsim, SpikeMonitor
4 #include <math.h> // fabs
5 #include <stdio.h> // printf
6 #include <limits> // double::max
7 #include <assert.h> // assert
8 
9 // ****************************************************************************************************************** //
10 // CONSTRUCTOR / DESTRUCTOR
11 // ****************************************************************************************************************** //
12 
13 SimpleWeightTuner::SimpleWeightTuner(CARLsim *sim, double errorMargin, int maxIter, double stepSizeFraction) {
14  assert(sim!=NULL);
15  assert(errorMargin>0);
16  assert(maxIter>0);
17  assert(stepSizeFraction>0.0f && stepSizeFraction<=1.0f);
18 
19  sim_ = sim;
20  assert(sim_->getCARLsimState()!=RUN_STATE);
21 
22  errorMargin_ = errorMargin;
23  stepSizeFraction_ = stepSizeFraction;
24  maxIter_ = maxIter;
25 
26  connId_ = -1;
27  wtRange_ = NULL;
28  wtInit_ = -1.0;
29 
30  grpId_ = -1;
31  targetRate_ = -1.0;
32 
33  wtStepSize_ = -1.0;
34  cntIter_ = 0;
35 
36  wtShouldIncrease_ = true;
37  adjustRange_ = true;
38 
39  needToInitConnection_ = true;
40  needToInitTargetFiring_ = true;
41 
42  needToInitAlgo_ = true;
43 }
44 
46  if (wtRange_!=NULL)
47  delete wtRange_;
48  wtRange_=NULL;
49 }
50 
51 
52 
53 // ****************************************************************************************************************** //
54 // PUBLIC METHODS
55 // ****************************************************************************************************************** //
56 
57 // user function to reset algo
59  needToInitAlgo_ = true;
60  initAlgo();
61 }
62 
63 bool SimpleWeightTuner::done(bool printMessage) {
64  // algo not initalized: we're not done
65  if (needToInitConnection_ || needToInitTargetFiring_ || needToInitAlgo_)
66  return false;
67 
68  // success: margin reached
69  if (fabs(currentError_) < errorMargin_) {
70  if (printMessage) {
71  printf("SimpleWeightTuner successful: Error margin reached in %d iterations.\n",cntIter_);
72  }
73  return true;
74  }
75 
76  // failure: max iter reached
77  if (cntIter_ >= maxIter_) {
78  if (printMessage) {
79  printf("SimpleWeightTuner failed: Max number of iterations (%d) reached.\n",maxIter_);
80  }
81  return true;
82  }
83 
84  // else we're not done
85  return false;
86 }
87 
88 void SimpleWeightTuner::setConnectionToTune(short int connId, double initWt, bool adjustRange) {
89  assert(connId>=0 && connId<sim_->getNumConnections());
90 
91  connId_ = connId;
92  wtInit_ = initWt;
93  adjustRange_ = adjustRange;
94 
95  needToInitConnection_ = false;
96  needToInitAlgo_ = true;
97 }
98 
99 void SimpleWeightTuner::setTargetFiringRate(int grpId, double targetRate) {
100  grpId_ = grpId;
101  targetRate_ = targetRate;
102  currentError_ = targetRate;
103 
104  // check whether group has SpikeMonitor
105  SM_ = sim_->getSpikeMonitor(grpId);
106  if (SM_==NULL) {
107  // setSpikeMonitor has not been called yet
108  SM_ = sim_->setSpikeMonitor(grpId,"NULL");
109  }
110 
111  needToInitTargetFiring_ = false;
112  needToInitAlgo_ = true;
113 }
114 
115 void SimpleWeightTuner::iterate(int runDurationMs, bool printStatus) {
116  assert(runDurationMs>0);
117 
118  // if we're done, don't iterate
119  if (done(printStatus)) {
120  return;
121  }
122 
123  // make sure we have initialized algo
124  assert(!needToInitConnection_);
125  assert(!needToInitTargetFiring_);
126  if (needToInitAlgo_)
127  initAlgo();
128 
129  // in case the user has already been messing with the SpikeMonitor, we need to make sure that
130  // PersistentMode is off
131  SM_->setPersistentData(false);
132 
133  // now iterate
134  SM_->startRecording();
135  sim_->runNetwork(runDurationMs/1000, runDurationMs%1000, false);
136  SM_->stopRecording();
137 
138  double thisRate = SM_->getPopMeanFiringRate();
139  if (printStatus) {
140  printf("#%d: rate=%.4fHz, target=%.4fHz, error=%.7f, errorMargin=%.7f\n", cntIter_, thisRate, targetRate_,
141  thisRate-targetRate_, errorMargin_);
142  }
143 
144  currentError_ = thisRate - targetRate_;
145  cntIter_++;
146 
147  // check if we're done now
148  if (done(printStatus)) {
149  return;
150  }
151 
152  // else update parameters
153  if ((wtStepSize_>0 && thisRate>targetRate_) || (wtStepSize_<0 && thisRate<targetRate_)) {
154  // we stepped too far to the right or too far to the left
155  // turn around and cut step size in half
156  // note that this should work for inhibitory connections, too: they have negative weights, so adding
157  // to the weight will actually decrease it (make it less negative)
158  wtStepSize_ = -wtStepSize_/2.0;
159  }
160 
161  // find new weight
162  sim_->biasWeights(connId_, wtStepSize_, adjustRange_);
163 }
164 
165 
166 // ****************************************************************************************************************** //
167 // PRIVATE METHODS
168 // ****************************************************************************************************************** //
169 
170 // need to call this whenever connection or target firing changes
171 // or when user calls reset
172 void SimpleWeightTuner::initAlgo() {
173  if (!needToInitAlgo_)
174  return;
175 
176  // make sure we have all the data structures we need
177  assert(!needToInitConnection_);
178  assert(!needToInitTargetFiring_);
179 
180  // update weight ranges
181  RangeWeight wt = sim_->getWeightRange(connId_);
182  wtRange_ = new RangeWeight(wt.min, wt.init, wt.max);
183 
184  // reset algo
185  wtShouldIncrease_ = true;
186  wtStepSize_ = stepSizeFraction_ * (wtRange_->max - wtRange_->min);
187 #if defined(WIN32) || defined(WIN64)
188  currentError_ = DBL_MAX;
189 #else
190  currentError_ = std::numeric_limits<double>::max();
191 #endif
192 
193  // make sure we're in the right CARLsim state
194  if (sim_->getCARLsimState()!=RUN_STATE)
195  sim_->runNetwork(0,0,false);
196 
197  // initialize weights
198  if (wtInit_>=0) {
199  // start at some specified initWt
200  if (wt.init != wtInit_) {
201  // specified starting point is not what is specified in connect
202 
203  sim_->biasWeights(connId_, wtInit_ - wt.init, adjustRange_);
204  }
205  }
206 
207  needToInitAlgo_ = false;
208 }
void startRecording()
Starts a new recording period.
carlsimState_t getCARLsimState()
Returns the current CARLsim state.
Definition: carlsim.h:1114
CARLsim User Interface This class provides a user interface to the public sections of CARLsimCore sou...
Definition: carlsim.h:142
float getPopMeanFiringRate()
Returns the mean firing rate of the entire neuronal population.
void setConnectionToTune(short int connId, double initWt=-1.0, bool adjustRange=true)
Sets up the connection to tune.
~SimpleWeightTuner()
Destructor.
SimpleWeightTuner(CARLsim *sim, double errorMargin=1e-3, int maxIter=100, double stepSizeFraction=0.5)
Creates a new instance of class SimpleWeightTuner.
void stopRecording()
Ends a recording period.
void reset()
Resets the algorithm to initial conditions.
bool done(bool printMessage=false)
Determines whether a termination criterion has been met.
int runNetwork(int nSec, int nMsec=0, bool printRunSummary=true, bool copyState=false)
run the simulation for time=(nSec*seconds + nMsec*milliseconds)
void setTargetFiringRate(int grpId, double targetRate)
Sets up the target firing rate of a specific group.
void iterate(int runDurationMs=1000, bool printStatus=true)
Performs an iteration step of the tuning algorithm.
void biasWeights(short int connId, float bias, bool updateWeightRange=false)
Adds a constant bias to the weight of every synapse in the connection.
run state, where the model is stepped
a range struct for synaptic weight magnitudes
SpikeMonitor * setSpikeMonitor(int grpId, const std::string &fileName)
Sets a Spike Monitor for a groups, prints spikes to binary file.
SpikeMonitor * getSpikeMonitor(int grpId)
returns pointer to previously allocated SpikeMonitor object, NULL else
void setPersistentData(bool persistentData)
Sets PersistentMode either on (true) or off (false)
RangeWeight getWeightRange(short int connId)
returns the RangeWeight struct for a specific connection ID