Universal MLPRegressor for preset interpolation

<3 It’s a good point. If the buffer is being updated manually (meaning, from the language) we already know when it has changed. If there’s an LFO on the server that’s modulating the input space, then, pretty much, the buffer is always changing anyways!

1 Like

Just for context, the current SynthDef im using with the MLP has 74 params minus the ones i ignore for interpolation and i get 352 Ugens for running the MLP, mainly for getting the param values from the buffer and detecting a change.

1 Like

Hmm. Yeah, maybe I’m missing something @dietcv. This does seem like a lot of overhead? Is the code somewhere to peek at?

Also, you could check out @Sam_Pluta’s RTNeural which doesn’t require buffers so maybe has the kind of implementation you’re looking for.

1 Like

I have shared a version which you can try out in my initial post. If you debug the MLP synth you can see that we are creating a BufRd via FluidBufToKr for every frame in the buffer and a Changed Ugen for every frame in the Buffer, which we then sum because we dont care which frame has changed. The more params we use the more frames we have in our params buf.

If you know when you’re changing the buffer yourself, you can trigger the synth yourself:

self.bufXY.set(index, value);
self.mlpSynth.set(\t_xyTrig,1);

And then in the synth:

		self.mlpSynth = {
                       arg t_xyTrig;
			var xyData, xyTrig, paramValUni, paramTrig;

			// when data in xy buffer has changed, ...
			xyData = FluidBufToKr.kr(self.bufXY);

			// ... trigger MLP ...
			self.mlp.kr(xyTrig, self.bufXY, self.bufParams);

			// ... to get associated param value from parameter buffer ...
			paramValUni = FluidBufToKr.kr(self.bufParams);
			// no need to check here, if there's a different input, there will be a different output!

			// ... and send OSC message to update NodeProxy!
			SendReply.kr(paramTrig, "/paramsChanged", paramValUni);

			Silent.ar;
		}.play;
1 Like

thanks. i guess i can add that to the .addDependant, .changed MVC system i have already implemented. Lets see :slight_smile:

2 Likes

I have implemented that in the makeMLP Protodef, much cleaner now :slight_smile:
(makeOrbit and makeTrajectory have their own ProtoDef and we additionally have a makeNodeProxy Protodef and they are all combined in one Main Protodef)

see the MLP Protodef below:

(
ProtoDef(\makeMLP) { |nodeProxy, excludeParams=#[], ignoreParams=#[]|

	~init = { |self|

		self.styles = Prototype(\makeStyles);

		self.params = IdentityDictionary.new();
		self.paramValsUni = IdentityDictionary.new();

		self.dsXY = FluidDataSet(s);
		self.dsParams = FluidDataSet(s);

		self.bufXY = Buffer.alloc(s, 2);
		self.bufParams = Buffer.alloc(s, self.getNumParamVals);

		self.mlp = FluidMLPRegressor(
			server: s,
			hiddenLayers: [7],
			activation: FluidMLPRegressor.sigmoid,
			outputActivation: FluidMLPRegressor.sigmoid,
			maxIter: 1000,
			learnRate: 0.1,
			batchSize: 1,
			validation: 0,
		);

		self.window = Window.new("MLP", Rect(left: 10, top: 760, width: 440, height: 280)).front;
		self.window.layout = VLayout.new();

		self.window.view.children.do{ |c| c.font = self.styles[\font] };

		self.setUpDependencies;
		self.makeParamSection;

	};

	~getNumParamVals = { |self|

		var paramVals;

		paramVals = self.nodeProxy
		.getKeysValues(except: self.excludeParams ++ self.ignoreParams)
		.flop[1];

		paramVals.collect{ |val, i|

			case
			{ val.isNumber } { 1 }
			{ val.isArray} { val.size };

		}.sum;

	};

	~setUpDependencies = { |self|

		var dataChangedFunc = { |obj ...args| self.dataChanged(*args) };

		self.addDependant(dataChangedFunc);

		self.window.onClose_{

			self.unmapMidi;

			self.removeDependant(dataChangedFunc);
			self.bufXY.free;
			self.bufParams.free;

			self.params.clear;
			self.paramValsUni.clear;

			self.stopMLP;
			self.mlp.clear;

			self.dsXY.clear;
			self.dsParams.clear;

		};

	};

	~dataChanged = { |self, what ...args|

		case
		{ what == \slider } {

			self.bufXY.setn(0, args);
			self.mlpSynth !? { self.mlpSynth.set(\t_trig, 1) };

		}
		{ what == \buffer } {

			self.xySlider.setXY(*args);
			self.mlpSynth !? { self.mlpSynth.set(\t_trig, 1) };

		}
		{ what == \controller } {

			var updateSliderFunc, index, value;

			index = args[0];
			value = args[1];

			self.bufXY.set(index, value);
			self.mlpSynth !? { self.mlpSynth.set(\t_trig, 1) };

			updateSliderFunc = {
				switch(index,
					0, { self.xySlider.x_(value) },
					1, { self.xySlider.y_(value) }
				);
			};

			{ updateSliderFunc.() }.defer;

		};

	};

	~trainMLP = { |self, train|

		self.trainTask = self.trainTask ?? {

			Task {

				loop {

					var condition = CondVar.new;
					var done = false;

					self.mlp.fit(self.dsXY, self.dsParams) { |loss|
						done = true;
						condition.signalOne;
						"> Training done. Loss: %".format(loss).postln;
					};

					condition.wait { done };

					0.1.wait;
				};

			};

		};

		if (train) {
			self.trainTask.play;
		} {
			self.trainTask.pause;
		};

	};

	~startMLP = { |self|

		self.stopMLP;

		self.mlpSynth = { |t_trig|

			var paramValsUni;

			// trigger MLP when t_trig is received...
			self.mlp.kr(t_trig, self.bufXY, self.bufParams);

			// ... to get param values from params buffer
			paramValsUni = FluidBufToKr.kr(self.bufParams);

			// ... and send OSC message to update NodeProxy
			SendReply.kr(t_trig, "/paramsChanged", paramValsUni);

			Silent.ar;
		}.play;

		self.responder = OSCFunc({ |msg|
			var paramValsUniOSC = msg.drop(3);
			self.setParamVals(paramValsUniOSC);
		}, "/paramsChanged", argTemplate:[self.mlpSynth.nodeID]);

	};

	~stopMLP = { |self|

		self.mlpSynth !? { self.mlpSynth.free; self.mlpSynth = nil };
		self.responder !? { self.responder.free; self.responder = nil };

	};

	~makeParamSection = { |self|

		self.params.clear;

		self.nodeProxy
		.controlKeysValues(except:
			self.nodeProxy.internalKeys ++ self.excludeParams ++ self.ignoreParams
		)
		.pairsDo{ |key, val|

			var spec;

			spec = case
			{ val.isNumber } {
				(self.nodeProxy.specs.at(key) ?? { Spec.specs.at(key) }).asSpec;
			}
			{ val.isArray } {
				(self.nodeProxy.specs.at(key) ?? { Spec.specs.at(key) }).asSpec.dup(val.size);
			};

			self.params.put(key, spec);
		};

		if(self.paramLayout.notNil) { self.paramLayout.remove };
		self.paramLayout = self.makeParamLayout;
		self.window.layout.add(self.paramLayout);

	};

	~getParamValsUni = { |self|

		self.paramValsUni.clear;
		self.params.sortedKeysValuesDo{ |key, spec|

			var paramVal;

			paramVal = self.nodeProxy.get(key);

			case
			{ paramVal.isNumber } {
				var paramValUni;
				paramValUni = spec.unmap(paramVal);
				self.paramValsUni.put(key, paramValUni);
			}
			{ paramVal.isArray } {
				var paramValsUni;
				paramValsUni = paramVal.collect{ |val, n|
					spec.wrapAt(n).unmap(val);
				};
				self.paramValsUni.put(key, paramValsUni);
			};

		};

		self.paramValsUni;

	};

	~addPoint = { |self, id|

		var paramValsUniSerialised;

		self.getParamValsUni;

		paramValsUniSerialised = self.paramValsUni.asSortedArray.flop[1].flat;

		self.bufParams.loadCollection(paramValsUniSerialised, 0) {
			self.dsParams.addPoint(id, self.bufParams);
		};

		self.dsXY.addPoint(id, self.bufXY);

	};

	~setParamVals = { |self, paramValsUniOSC|

		var paramValsUni;

		paramValsUni = paramValsUniOSC.reshapeLike(self.paramValsUni.asSortedArray.flop[1]);

		self.params.sortedKeysValuesDo{ |key, spec, i|

			var paramVal;

			paramVal = self.nodeProxy.get(key);

			case
			{ paramVal.isNumber } {
				var value = spec.map(paramValsUni[i]);
				self.nodeProxy.set(key, value);
			}
			{ paramVal.isArray } {
				paramVal.collect{ |val, n|
					var value = spec.wrapAt(n).map(paramValsUni[i].wrapAt(n));
					self.nodeProxy.seti(key, n, value);
				};
			};

		};

	};

	~makeParamLayout = { |self|

		var view, layout, buttons;
		var graphView, pointView, plotDataSetXY;
		var counter = 0;

		view = View.new().layout_(VLayout.new());

		plotDataSetXY = {
			self.dsXY.dump { |v|
				if (v["cols"] == 0) {
					v = Dictionary["cols" -> 2, "data" -> Dictionary[]]
				};
				pointView.dict = v;

				// Assign colors to each point
				v["data"].keys.do { |pointID, i|
					var color = self.styles[\colors].values.wrapAt(i);
					pointView.pointColor_(pointID, color);
				};

			};
		};

		pointView = FluidPlotter(standalone: false).pointSizeScale_(20/12);
		plotDataSetXY.();

		self.xySlider = Slider2D()
		.background_(Color.white.alpha_(0))
		.action_{ |obj|
			self.changed(\slider, obj.x, obj.y);
		};

		graphView = StackLayout(self.xySlider, View.new().layout_(
			VLayout(pointView).margins_(10)
		)).mode_(\stackAll);

		buttons = VLayout(

			Button(bounds: 100@20)
			.states_([["Add Point"]])
			.action_{
				var id = "point-%".format(counter);
				self.addPoint(id);
				counter = counter + 1;
				plotDataSetXY.();
			},

			Button(bounds: 100@20)
			.states_([["Save Data"]])
			.action_{
				FileDialog({ |folder|
					self.dsXY.write(folder +/+ "xydata.json");
					self.dsParams.write(folder +/+ "paramsdata.json");
				}, {}, 2, 0, true);
			},

			Button(bounds: 100@20)
			.states_([["Load Data"]])
			.action_{
				FileDialog({ |folder|
					self.dsXY.read(folder +/+ "xydata.json");
					self.dsParams.read(folder +/+ "paramsdata.json");
				}, fileMode: 2, acceptMode: 0, stripResult: true);
				plotDataSetXY.();
			},

			Button(bounds: 100@20)
			.states_([
				["Train"],
				["Train", Color.white, self.styles[\colors][\orange]]
			])
			.action_{ |obj|
				self.trainMLP(obj.value > 0)
			},

			Button(bounds: 100@20)
			.states_([["Save MLP"]])
			.action_{
				Dialog.savePanel({ |path|
					if(PathName(path).extension != "json"){
						path = "%.json".format(path);
					};
					self.mlp.write(path);
				});
			},

			Button(bounds: 100@20)
			.states_([["Load MLP"]])
			.action_{
				Dialog.openPanel({ |path|
					self.mlp.read(path, action: {
						{ plotDataSetXY.() }.defer
					});
				});
			},

			Button(bounds: 100@20)
			.states_([
				["Start MLP"],
				["Stop MLP", Color.white, self.styles[\colors][\red]]
			])
			.action_{ |obj|
				if (obj.value > 0) { self.startMLP } { self.stopMLP };
			}.value_(self.mlpSynth.notNil),

			Button(bounds: 100@20)
			.states_([["Clear"]])
			.action_{|state|
				self.mlp.clear;
				self.dsXY.clear;
				self.dsParams.clear;
				plotDataSetXY.();
			}

		);

		layout = HLayout([graphView, s: 1], buttons);

		view.layout.add(layout);

		view.children.do{ |c| c.font = self.styles[\font] };

		view;
	};

	~mapMidi = { |self, ccs|

		self.midiResponders = self.midiResponders ?? { Order[] };
		self.unmapMidi;

		ccs.do { |cc, n|
			self.midiResponders[cc] = MIDIFunc.cc({ |v, c|
				var val = v / 127;
				self.changed(\controller, n, val);
			}, cc).fix
		};

	};

	~unmapMidi = { |self|
		self.midiResponders.do(_.free);
	};

};
)
2 Likes