package com.intel.daal.algorithms.neural_networks.training;

import com.intel.daal.algorithms.Model;
import com.intel.daal.algorithms.Precision;
import com.intel.daal.algorithms.neural_networks.BackwardLayers;
import com.intel.daal.algorithms.neural_networks.ForwardLayers;
import com.intel.daal.algorithms.neural_networks.NextLayersCollection;
import com.intel.daal.algorithms.neural_networks.layers.BackwardLayer;
import com.intel.daal.algorithms.neural_networks.layers.ForwardLayer;
import com.intel.daal.algorithms.neural_networks.prediction.PredictionModel;
import com.intel.daal.data_management.data.Factory;
import com.intel.daal.data_management.data.NumericTable;
import com.intel.daal.services.DaalContext;
import com.intel.daal.utils.LibUtils;

/* loaded from: input_file:com/intel/daal/algorithms/neural_networks/training/TrainingModel.class */
public class TrainingModel extends Model {
    public TrainingModel(DaalContext daalContext) {
        super(daalContext);
        this.cObject = cInit();
    }

    public TrainingModel(DaalContext daalContext, TrainingModel trainingModel) {
        super(daalContext);
        this.cObject = cInit(trainingModel.cObject);
    }

    public TrainingModel(DaalContext daalContext, long j) {
        super(daalContext, j);
    }

    public void initialize(Class<? extends Number> cls, long[] jArr, TrainingTopology trainingTopology) {
        if (cls == Double.class) {
            cInitialize(this.cObject, Precision.doublePrecision.getValue(), jArr, trainingTopology.cObject);
        } else {
            cInitialize(this.cObject, Precision.singlePrecision.getValue(), jArr, trainingTopology.cObject);
        }
    }

    public ForwardLayers getForwardLayers() {
        return new ForwardLayers(getContext(), cGetForwardLayers(this.cObject));
    }

    public ForwardLayer getForwardLayer(long j) {
        return new ForwardLayer(getContext(), cGetForwardLayer(this.cObject, j));
    }

    public BackwardLayers getBackwardLayers() {
        return new BackwardLayers(getContext(), cGetBackwardLayers(this.cObject));
    }

    public BackwardLayer getBackwardLayer(long j) {
        return new BackwardLayer(getContext(), cGetBackwardLayer(this.cObject, j));
    }

    public NextLayersCollection getNextLayers() {
        return new NextLayersCollection(getContext(), cGetNextLayers(this.cObject));
    }

    public PredictionModel getPredictionModel(Class<? extends Number> cls) {
        return cls == Double.class ? new PredictionModel(getContext(), cGetPredictionModel(Precision.doublePrecision.getValue(), this.cObject)) : new PredictionModel(getContext(), cGetPredictionModel(Precision.singlePrecision.getValue(), this.cObject));
    }

    public NumericTable getWeightsAndBiases() {
        return (NumericTable) Factory.instance().createObject(getContext(), cGetWeightsAndBiases(this.cObject));
    }

    public void setWeightsAndBiases(NumericTable numericTable) {
        cSetWeightsAndBiases(this.cObject, numericTable.getCObject());
    }

    private native long cInit();

    private native long cInit(long j);

    private native void cInitialize(long j, int i, long[] jArr, long j2);

    private native long cGetForwardLayers(long j);

    private native long cGetForwardLayer(long j, long j2);

    private native long cGetBackwardLayers(long j);

    private native long cGetBackwardLayer(long j, long j2);

    private native long cGetNextLayers(long j);

    private native long cGetPredictionModel(int i, long j);

    private native long cGetWeightsAndBiases(long j);

    private native void cSetWeightsAndBiases(long j, long j2);

    static {
        LibUtils.loadLibrary();
    }
}
