package com.intel.daal.algorithms.neural_networks.training;

import com.intel.daal.algorithms.Precision;
import com.intel.daal.algorithms.optimization_solver.iterative_solver.Batch;
import com.intel.daal.services.DaalContext;
import com.intel.daal.utils.LibUtils;

/* loaded from: input_file:com/intel/daal/algorithms/neural_networks/training/TrainingBatch.class */
public class TrainingBatch extends com.intel.daal.algorithms.TrainingBatch {
    public TrainingMethod method;
    public TrainingInput input;
    public TrainingParameter parameter;
    protected Precision prec;

    public TrainingBatch(DaalContext daalContext, Class<? extends Number> cls, TrainingMethod trainingMethod, Batch batch) {
        super(daalContext);
        initialize(daalContext, cls, trainingMethod, batch);
    }

    public TrainingBatch(DaalContext daalContext, Class<? extends Number> cls, Batch batch) {
        super(daalContext);
        initialize(daalContext, cls, TrainingMethod.defaultDense, batch);
    }

    public TrainingBatch(DaalContext daalContext, Batch batch) {
        super(daalContext);
        initialize(daalContext, Float.class, TrainingMethod.defaultDense, batch);
    }

    public TrainingBatch(DaalContext daalContext, TrainingBatch trainingBatch) {
        super(daalContext);
        this.method = trainingBatch.method;
        this.prec = trainingBatch.prec;
        this.cObject = cClone(trainingBatch.cObject, this.prec.getValue(), this.method.getValue());
        this.input = new TrainingInput(daalContext, cGetInput(this.cObject, this.prec.getValue(), this.method.getValue()));
        this.parameter = new TrainingParameter(daalContext, cInitParameter(this.cObject, this.prec.getValue(), this.method.getValue()));
    }

    public void initialize(long[] jArr, TrainingTopology trainingTopology) {
        cInitialize(this.cObject, this.prec.getValue(), this.method.getValue(), jArr, trainingTopology.cObject);
    }

    @Override // com.intel.daal.algorithms.TrainingBatch
    public TrainingResult compute() {
        super.compute();
        return new TrainingResult(getContext(), cGetResult(this.cObject, this.prec.getValue(), this.method.getValue()));
    }

    public void setResult(TrainingResult trainingResult) {
        cSetResult(this.cObject, this.prec.getValue(), this.method.getValue(), trainingResult.getCObject());
    }

    @Override // com.intel.daal.algorithms.TrainingBatch, com.intel.daal.algorithms.Algorithm
    public TrainingBatch clone(DaalContext daalContext) {
        return new TrainingBatch(daalContext, this);
    }

    private void initialize(DaalContext daalContext, Class<? extends Number> cls, TrainingMethod trainingMethod, Batch batch) {
        this.method = trainingMethod;
        if (trainingMethod != TrainingMethod.defaultDense && trainingMethod != TrainingMethod.feedforwardDense) {
            throw new IllegalArgumentException("method unsupported");
        }
        if (cls != Double.class && cls != Float.class) {
            throw new IllegalArgumentException("type unsupported");
        }
        if (cls == Double.class) {
            this.prec = Precision.doublePrecision;
        } else {
            this.prec = Precision.singlePrecision;
        }
        this.cObject = cInit(this.prec.getValue(), trainingMethod.getValue(), batch.cObject);
        this.input = new TrainingInput(daalContext, cGetInput(this.cObject, this.prec.getValue(), trainingMethod.getValue()));
        this.parameter = new TrainingParameter(daalContext, cInitParameter(this.cObject, this.prec.getValue(), trainingMethod.getValue()));
    }

    private native long cInit(int i, int i2, long j);

    private native long cInitParameter(long j, int i, int i2);

    private native long cGetInput(long j, int i, int i2);

    private native long cGetResult(long j, int i, int i2);

    private native void cSetResult(long j, int i, int i2, long j2);

    private native void cInitialize(long j, int i, int i2, long[] jArr, long j2);

    private native long cClone(long j, int i, int i2);

    static {
        LibUtils.loadLibrary();
    }
}
