using System;
using System.IO;
using System.Collections.Generic;
using UnityEngine;
public class DigitNeuralNetwork : NeuralNetwork
{
[Serializable] private class NeuralNetworkWeights
{
[Serializable] public struct Connections
{
public float[] weights;
public float[] biases;
}
public Connections[] layerWeights;
public Connections this[int index]
{
get
{
return layerWeights[index];
}
}
}
protected override List<NeuralNetworkLayer> Setup()
{
string json = File.ReadAllText("Assets/Data/MnistNet.json");
NeuralNetworkWeights neuralNetworkWeights = JsonUtility.FromJson<NeuralNetworkWeights>(json);
List<NeuralNetworkLayer> neuralNetworkLayers = new()
{
new NeuralNetworkLayer(784, 300,
NeuralNetworkLayer.Connection.Linear,
NeuralNetworkLayer.Activation.ReLU,
neuralNetworkWeights[0].weights,
neuralNetworkWeights[0].biases),
new NeuralNetworkLayer(300, 10,
NeuralNetworkLayer.Connection.Linear,
NeuralNetworkLayer.Activation.Softmax,
neuralNetworkWeights[1].weights,
neuralNetworkWeights[1].biases)
};
return neuralNetworkLayers;
}
}