• Fichier: DigitNeuralNetwork.cs
  • Path: /ia_chiffre/IAChiffre/Assets/Scripts/DigitNeuralNetwork.cs
  • File size: 1.37 KB
  • MIME-type: text/plain
  • Charset: utf-8
 
Retour
using System;
using System.IO;
using System.Collections.Generic;
using UnityEngine;

public class DigitNeuralNetwork : NeuralNetwork
{
    [Serializable] private class NeuralNetworkWeights
    {
        [Serializable] public struct Connections
        {
            public float[] weights;
            public float[] biases;
        }

        public Connections[] layerWeights;

        public Connections this[int index]
        {
            get
            {
                return layerWeights[index];
            }
        }
    }

    protected override List<NeuralNetworkLayer> Setup()
    {
        string json = File.ReadAllText("Assets/Data/MnistNet.json");
        NeuralNetworkWeights neuralNetworkWeights = JsonUtility.FromJson<NeuralNetworkWeights>(json);

        List<NeuralNetworkLayer> neuralNetworkLayers = new()
        {
            new NeuralNetworkLayer(784, 300,
                NeuralNetworkLayer.Connection.Linear,
                NeuralNetworkLayer.Activation.ReLU,
                neuralNetworkWeights[0].weights,
                neuralNetworkWeights[0].biases),

            new NeuralNetworkLayer(300, 10,
                NeuralNetworkLayer.Connection.Linear,
                NeuralNetworkLayer.Activation.Softmax,
                neuralNetworkWeights[1].weights,
                neuralNetworkWeights[1].biases)
        };
    
        return neuralNetworkLayers;
    }
}