function Walker(initialPosition, logLikelihood) {
this.length = initialPosition.length;
this.position = initialPosition;
this.logLikelihood = logLikelihood;
this.logProb = logLikelihood(initialPosition);
}
Walker.prototype.update = function(z, pos) {
const proposal = [];
for (let i = 0; i < this.length; i++) {
proposal[i] = pos[i] - z * (pos[i] - this.position[i]);
}
const logProb = this.logLikelihood(proposal);
const dLogProb = (this.length - 1) * Math.log(z) + logProb - this.logProb;
if (dLogProb > Math.log(Math.random())) {
this.position = proposal;
this.logProb = logProb;
return 1;
}
return 0;
};
function EnsembleSampler(logLikelihood, initialPosition) {
this.logLikelihood = logLikelihood;
this.a = 2;
this.walkers = [];
this.numWalkers = initialPosition.length;
for (let k = 0; k < this.numWalkers; k++) {
this.walkers[k] = new Walker(initialPosition[k], this.logLikelihood);
}
this.chain = [];
this.numAccepted = 0;
};
EnsembleSampler.prototype.advance = function() {
const position = [];
for (let k = 0; k < this.numWalkers; k++) {
const z = Math.pow((this.a - 1) * Math.random() + 1, 2) / this.a;
let kp = Math.round((this.numWalkers - 1) * Math.random() - 0.5);
if (kp >= k) {
kp++;
}
this.numAccepted += this.walkers[k].update(z, this.walkers[kp].position);
position[k] = this.walkers[k].position;
}
this.chain[this.chain.length] = position;
};
EnsembleSampler.prototype.getChain = function() {
const result = [];
for (let i = 0; i < this.chain.length; i++)
{
for (let k = 0; k < this.numWalkers; k++) {
result[i * this.numWalkers + k] = this.chain[i][k];
}
}
return result;
};
EnsembleSampler.prototype.getAcceptanceFraction = function() {
return this.numAccepted / this.chain.length / this.numWalkers;
}
export function MCMC(logLikelihood, initialPosition, numIterations) {
let delay = 0,
callback,
signal;
let sampler;
function sample(interval) {
sampler = new EnsembleSampler(logLikelihood, initialPosition);
let timer;
const iteration = new Promise((resolve, reject) => {
let i = 0;
function iterate() {
if (signal?.aborted) {
reject(sampler);
}
sampler.advance();
callback?.({
"iteration": sampler.chain[i],
"chain": sampler.getChain(),
});
if (i++ > numIterations) {
resolve(sampler);
};
}
timer = interval(iterate, delay);
}).catch((sampler) => {
return sampler;
}).finally(() => {
timer.stop();
});
return iteration;
}
sample.delay = function(_) {
return arguments.length ? (delay = Math.max(_, 0), sample) : delay;
};
sample.callback = function(_) {
return arguments.length ? (callback = _, sample) : callback;
};
sample.signal = function(_) {
return arguments.length ? (signal = _, sample) : signal;
};
return sample;
}