func lossFunction(model: SimpleNetwork, x: MLXArray, y: MLXArray) -> MLXArray {
   mseLoss(predictions: model(x), targets: y, reduction: .mean)
}