I have something I’ve been struggling with that I’m hoping to get some beginner-friendly answers for. As part of a developer grant project I’m building a STARK prover with Winterfell for verifiable training of a linear regression model. Unfortunately it keeps failing. My constraint definition function looks similar to this:
fn evaluate_transition<E: FieldElement<BaseField = BaseElement> + From<Self::BaseField>>(
&self,
frame: &EvaluationFrame<E>,
_periodic_values: &[E],
constraints: &mut [E],
) {
let current_state = frame.current();
let flag = current_state[0]; // 1 for dataset members, 0 for separator row
let len_sample = current_state[1]; // Number of features for each dataset member
let learning_rate = current_state[2];
let mut dataset_hash = current_state[3]; // Current dataset hash
let mut w_b_hash = current_state[4]; // Hash of weights and biases
let input = current_state.get(5..(5 + len_sample)); // Input features
let expected = current_state[5 + len_sample]; // Expected output
let mut weights = current_state.get((6 + len_sample)..(6 + len_sample * 2)); // Weights
let mut bias = current_state[6 + len_sample * 2]; // Bias
let mut output = bias;
dataset_hash = (update_hash(
dataset_hash,
&[
input, vec![expected],
].concat(),
) * flag); // progressively updates the hash at each row
for i in 0..len_sample {
output += input[i] * weights[i] / E4;
}
let error = output - expected;
let gradient = error;
for i in 0..len_sample {
weights[i] -= (learning_rate * gradient * input[i] * flag / E8);
}
bias -= learning_rate * gradient * flag / E4;
w_b_hash = row_hash(&[
weights,
vec![bias]
].concat()); // new hash for each row
constraints[0] = frame.next()[3] - dataset_hash; // should equal zero
constraints[1] = frame.next()[4] - w_b_hash; // should equal zero
}
My question at this stage: Is there something fundamentally wrong with how I’ve set up the constraints?
Note that this is pseudoRust - I’ve taken out several type conversions etc for the sake of readability. E4 and E8 are 10000 and 100 million respectively. This adds a separate problem but I’m coming back to this after I solve the first problem. I’ve framed this question more formally in a GitHub issue here and the complete function can be seen here.