func lossFunction(model: SimpleNetwork, x: MLXArray, y: MLXArray) -> MLXArray {
   crossEntropy(logits: model(x), targets: y, reduction: .mean)
}