package com.intel.daal.algorithms.neural_networks.training;

import com.intel.daal.algorithms.Input;
import com.intel.daal.data_management.data.Factory;
import com.intel.daal.data_management.data.KeyValueDataCollection;
import com.intel.daal.data_management.data.Tensor;
import com.intel.daal.services.DaalContext;
import com.intel.daal.utils.LibUtils;

/* loaded from: input_file:com/intel/daal/algorithms/neural_networks/training/TrainingInput.class */
public class TrainingInput extends Input {
    public TrainingInput(DaalContext daalContext, long j) {
        super(daalContext, j);
    }

    public void set(TrainingInputId trainingInputId, Tensor tensor) {
        if (trainingInputId != TrainingInputId.data && trainingInputId != TrainingInputId.groundTruth) {
            throw new IllegalArgumentException("Incorrect TrainingInputId");
        }
        cSetInput(this.cObject, trainingInputId.getValue(), tensor.getCObject());
    }

    public void set(TrainingInputCollectionId trainingInputCollectionId, KeyValueDataCollection keyValueDataCollection) {
        if (trainingInputCollectionId != TrainingInputCollectionId.groundTruthCollection) {
            throw new IllegalArgumentException("Incorrect TrainingInputId");
        }
        cSetInput(this.cObject, trainingInputCollectionId.getValue(), keyValueDataCollection.getCObject());
    }

    public void add(TrainingInputCollectionId trainingInputCollectionId, int i, Tensor tensor) {
        if (trainingInputCollectionId != TrainingInputCollectionId.groundTruthCollection) {
            throw new IllegalArgumentException("Incorrect TrainingInputId");
        }
        cAddTensor(this.cObject, trainingInputCollectionId.getValue(), i, tensor.getCObject());
    }

    public Tensor get(TrainingInputId trainingInputId) {
        if (trainingInputId == TrainingInputId.data || trainingInputId == TrainingInputId.groundTruth) {
            return (Tensor) Factory.instance().createObject(getContext(), cGetInput(this.cObject, trainingInputId.getValue()));
        }
        throw new IllegalArgumentException("id unsupported");
    }

    public KeyValueDataCollection get(TrainingInputCollectionId trainingInputCollectionId) {
        if (trainingInputCollectionId == TrainingInputCollectionId.groundTruthCollection) {
            return new KeyValueDataCollection(getContext(), cGetInput(this.cObject, trainingInputCollectionId.getValue()));
        }
        throw new IllegalArgumentException("id unsupported");
    }

    public Tensor get(TrainingInputCollectionId trainingInputCollectionId, int i) {
        if (trainingInputCollectionId == TrainingInputCollectionId.groundTruthCollection) {
            return (Tensor) Factory.instance().createObject(getContext(), cGetTensor(this.cObject, trainingInputCollectionId.getValue(), i));
        }
        throw new IllegalArgumentException("id unsupported");
    }

    private native void cSetInput(long j, int i, long j2);

    private native long cGetInput(long j, int i);

    private native void cAddTensor(long j, int i, int i2, long j2);

    private native long cGetTensor(long j, int i, int i2);

    static {
        LibUtils.loadLibrary();
    }
}
