import MLX
import MLXNN

nonisolated class SimpleNetwork: Module, UnaryLayer {
   let hiddenLayer: Linear
   let outputLayer: Linear

   init(inputDim: Int, hiddenDim: Int, outputDim: Int) {
      self.hiddenLayer = Linear(inputDim, hiddenDim)
      self.outputLayer = Linear(hiddenDim, outputDim)
   }
   func callAsFunction(_ x: MLXArray) -> MLXArray {
      var x = hiddenLayer(x)
      x = sigmoid(x)
      return outputLayer(x)
   }
}