package com.mathworks.toolbox.nnet.gui.wizard;

import com.mathworks.mwswing.MJLabel;
import com.mathworks.mwswing.MJTextArea;
import com.mathworks.toolbox.nnet.library.gui.nnIcon;
import com.mathworks.toolbox.nnet.library.gui.nnIcons;
import com.mathworks.toolbox.nnet.library.layout.nnPanels;
import com.mathworks.toolbox.nnet.library.variables.nnChangeWatcher;
import com.mathworks.toolbox.nnet.library.variables.nnVariable;
import com.mathworks.toolbox.nnet.library.widgets.nnButton;
import com.mathworks.toolbox.nnet.library.widgets.nnStringMenu;
import com.mathworks.toolbox.nnet.library.widgets.nnWidgets;
import com.mathworks.toolbox.nnet.matlab.NNMatlab;
import com.mathworks.toolbox.nnet.matlab.nnAcceptor;
import com.mathworks.toolbox.nnet.matlab.nnMFunction;
import com.mathworks.toolbox.nnet.matlab.nnMatlabError;
import com.mathworks.toolbox.nnet.modules.nnMetric;
import com.mathworks.toolbox.nnet.modules.nnPlot;
import com.mathworks.toolbox.nnet.modules.nnSample;
import com.mathworks.toolbox.nnet.nntool.gui.NNStrings;
import java.awt.Component;
import java.awt.Container;
import java.awt.Font;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Vector;
import javax.swing.Icon;

/* loaded from: input_file:com/mathworks/toolbox/nnet/gui/wizard/nnTrainingPage.class */
public abstract class nnTrainingPage extends nnWizardPage {
    private static final String TITLE = "Train Network";
    private static final Font tableHeadingFont = new MJLabel().getFont().deriveFont(1);
    private final nnPlot[] plots;
    private final nnSample[] samples;
    private final nnMetric[] metrics;
    private final String units;
    private final TrainingChoice[] trainingChoices;
    private final Component panel;
    private final nnButton[] plotButtons;
    private final MJLabel trainExamplesLabel;
    private final MJLabel validateExamplesLabel;
    private final MJLabel testExamplesLabel;
    private final MJLabel[][] valueLabels;
    private nnButton trainButton;
    private final MJLabel samplesLabel;
    private final nnVariable<String> trainingAlgorithmVariable;
    private final String firstTrainingAlgorithm;
    private final nnStringMenu trainingAlgorithmMenu;
    private final MJLabel trainingAlgorithmLabel;
    private final MJTextArea trainingAlgorithmDescription;
    private final Container trainingAlgorithmLinks;
    private final nnChangeWatcher algorithmVariableListener;
    private boolean isTraining;
    private nnPlot isPlotting;
    private boolean hasTrained;
    private nnAcceptor trainAcceptor;
    private nnAcceptor viewFitAcceptor;

    /* loaded from: input_file:com/mathworks/toolbox/nnet/gui/wizard/nnTrainingPage$PlotListener.class */
    private final class PlotListener implements ActionListener {
        final nnPlot plot;

        PlotListener(nnPlot nnplot) {
            this.plot = nnplot;
        }

        public void actionPerformed(ActionEvent actionEvent) {
            nnTrainingPage.this.viewPlot(this.plot);
        }
    }

    /* loaded from: input_file:com/mathworks/toolbox/nnet/gui/wizard/nnTrainingPage$TrainingChoice.class */
    public static class TrainingChoice {
        public final String algorithm;
        public final String description;
        public final String[] functions;

        public TrainingChoice(String str, String str2, String... strArr) {
            this.algorithm = str;
            this.description = str2;
            this.functions = strArr;
        }
    }

    /* JADX WARN: Type inference failed for: r1v110, types: [com.mathworks.mwswing.MJLabel[], com.mathworks.mwswing.MJLabel[][]] */
    public nnTrainingPage(nnWizard nnwizard, nnMFunction nnmfunction, nnMetric[] nnmetricArr, nnPlot[] nnplotArr, String str, TrainingChoice... trainingChoiceArr) {
        super(nnwizard, "trainNetworkPanel", nnmfunction);
        Container newColumnPanel;
        this.trainExamplesLabel = new MJLabel("", 0);
        this.validateExamplesLabel = new MJLabel("", 0);
        this.testExamplesLabel = new MJLabel("", 0);
        this.trainButton = nnWidgets.newButton("retrain_button", "Retrain", (Icon) nnIcons.TRAIN_16.toImageIcon());
        this.algorithmVariableListener = new nnChangeWatcher() { // from class: com.mathworks.toolbox.nnet.gui.wizard.nnTrainingPage.1
            @Override // com.mathworks.toolbox.nnet.library.variables.nnChangeWatcher
            public void changed() {
                nnTrainingPage.this.updateTrainingAlgorithm();
            }
        };
        this.isTraining = false;
        this.isPlotting = null;
        this.hasTrained = false;
        this.trainAcceptor = new nnAcceptor() { // from class: com.mathworks.toolbox.nnet.gui.wizard.nnTrainingPage.3
            @Override // com.mathworks.toolbox.nnet.matlab.nnAcceptor
            public void succeed(Object obj) {
                nnTrainingPage.this.invalidateFollowingPanels();
                Vector vector = (Vector) obj;
                for (int i = 0; i < nnTrainingPage.this.samples.length; i++) {
                    for (int i2 = 0; i2 < nnTrainingPage.this.metrics.length; i2++) {
                        double doubleValue = ((Double) vector.elementAt((i2 * nnTrainingPage.this.samples.length) + i)).doubleValue();
                        nnTrainingPage.this.valueLabels[i][i2].setText(NNStrings.doubleToEString(doubleValue));
                        nnTrainingPage.this.valueLabels[i][i2].setToolTipText("" + doubleValue);
                    }
                }
                nnTrainingPage.this.isTraining = false;
                nnTrainingPage.this.hasTrained = true;
                nnTrainingPage.this.wizard.toFront();
                nnTrainingPage.this.updateStatus();
            }

            @Override // com.mathworks.toolbox.nnet.matlab.nnAcceptor
            public void fail(nnMatlabError nnmatlaberror) {
                nnTrainingPage.this.wizard.feedbackDialog.launch("Unable to Train Network.", nnmatlaberror.message, nnIcons.ERROR_48);
                nnTrainingPage.this.isTraining = false;
                nnTrainingPage.this.updateStatus();
            }
        };
        this.viewFitAcceptor = new nnAcceptor() { // from class: com.mathworks.toolbox.nnet.gui.wizard.nnTrainingPage.4
            @Override // com.mathworks.toolbox.nnet.matlab.nnAcceptor
            public void succeed(Object obj) {
                nnTrainingPage.this.isPlotting = null;
                nnTrainingPage.this.updateStatus();
            }

            @Override // com.mathworks.toolbox.nnet.matlab.nnAcceptor
            public void fail(nnMatlabError nnmatlaberror) {
                nnTrainingPage.this.wizard.feedbackDialog.launch("Unable to Plot Fit.", nnmatlaberror.message, nnIcons.ERROR_48);
                nnTrainingPage.this.isPlotting = null;
                nnTrainingPage.this.updateStatus();
            }
        };
        this.metrics = nnmetricArr;
        this.plots = nnplotArr;
        this.units = str.toLowerCase();
        this.trainingChoices = trainingChoiceArr;
        this.samplesLabel = new MJLabel(str, nnIcons.DATA_ALL_16.toImageIcon(), 0);
        this.samples = new nnSample[]{nnSample.TRAINING, nnSample.VALIDATION, nnSample.TESTING};
        this.trainExamplesLabel.setName("trainExamplesLabel");
        this.validateExamplesLabel.setName("validateExamplesLabel");
        this.testExamplesLabel.setName("testExamplesLabel");
        this.trainButton.setName("trainButton");
        this.trainExamplesLabel.setToolTipText("Examples used to optimize the network");
        this.validateExamplesLabel.setToolTipText("Examples used to stop training for best generalization");
        this.testExamplesLabel.setToolTipText("Examples used to evaluate network performance");
        this.samplesLabel.setToolTipText("Number of " + this.units + " in each category");
        this.samplesLabel.setFont(tableHeadingFont);
        this.trainButton.setToolTipText("Optimize network on inputs and targets");
        this.plotButtons = new nnButton[nnplotArr.length];
        for (int i = 0; i < nnplotArr.length; i++) {
            this.plotButtons[i] = nnWidgets.newButton(nnplotArr[i].name.replaceAll(" ", "") + "_button", "Plot " + nnplotArr[i].name, "Open plot window.");
        }
        int length = (nnplotArr.length + 1) / 2;
        Component[] componentArr = new Component[length];
        for (int i2 = 0; i2 < length; i2++) {
            if ((i2 + 1) * 2 <= nnplotArr.length) {
                componentArr[i2] = nnPanels.newCenterHPanel(nnPanels.newRowPanel(this.plotButtons[i2 * 2], nnPanels.newHSpace(10), this.plotButtons[(i2 * 2) + 1]));
            } else {
                componentArr[i2] = nnPanels.newCenterHPanel(nnPanels.newRowPanel(this.plotButtons[i2 * 2]));
            }
        }
        if (nnmetricArr.length > 0) {
            MJLabel[] mJLabelArr = new MJLabel[this.samples.length];
            for (int i3 = 0; i3 < this.samples.length; i3++) {
                mJLabelArr[i3] = new MJLabel(this.samples[i3].name + ": ", this.samples[i3].icon.toImageIcon(), 2);
                mJLabelArr[i3].setToolTipText(this.samples[i3].tip);
            }
            MJLabel[] mJLabelArr2 = new MJLabel[nnmetricArr.length];
            for (int i4 = 0; i4 < nnmetricArr.length; i4++) {
                mJLabelArr2[i4] = new MJLabel(nnmetricArr[i4].abbreviation, nnmetricArr[i4].icon.toImageIcon(), 0);
                mJLabelArr2[i4].setFont(tableHeadingFont);
                mJLabelArr2[i4].setToolTipText(nnmetricArr[i4].tip);
            }
            this.valueLabels = new MJLabel[this.samples.length];
            for (int i5 = 0; i5 < this.samples.length; i5++) {
                this.valueLabels[i5] = new MJLabel[nnmetricArr.length];
                for (int i6 = 0; i6 < nnmetricArr.length; i6++) {
                    this.valueLabels[i5][i6] = new MJLabel("", 0);
                    this.valueLabels[i5][i6].setName(this.samples[i5].name.toLowerCase() + "_" + nnmetricArr[i6].name.replaceAll(" ", "") + "_label");
                }
            }
            newColumnPanel = nnPanels.newColumnPanel(nnPanels.newGridPanel(4, 4, 0, 8, new MJLabel(" "), this.samplesLabel, mJLabelArr2[0], mJLabelArr2[1], mJLabelArr[0], this.trainExamplesLabel, this.valueLabels[0][0], this.valueLabels[0][1], mJLabelArr[1], this.validateExamplesLabel, this.valueLabels[1][0], this.valueLabels[1][1], mJLabelArr[2], this.testExamplesLabel, this.valueLabels[2][0], this.valueLabels[2][1]), nnPanels.newVSpace(15), nnPanels.newSeparator(), nnPanels.newVSpace(10), nnPanels.newColumnPanel(10, componentArr));
        } else {
            this.valueLabels = (MJLabel[][]) null;
            newColumnPanel = nnPanels.newColumnPanel(10, componentArr);
        }
        Component[] componentArr2 = new Component[nnmetricArr.length];
        for (int i7 = 0; i7 < nnmetricArr.length; i7++) {
            componentArr2[i7] = nnPanels.newIconDisplayTextArea(nnmetricArr[i7].description, nnmetricArr[i7].icon.toImageIcon());
        }
        String[] strArr = new String[trainingChoiceArr.length];
        for (int i8 = 0; i8 < trainingChoiceArr.length; i8++) {
            strArr[i8] = trainingChoiceArr[i8].algorithm;
        }
        this.firstTrainingAlgorithm = strArr[0];
        this.trainingAlgorithmVariable = new nnVariable<>(this.firstTrainingAlgorithm);
        this.trainingAlgorithmMenu = nnWidgets.newStringMenu("trainingAlgorithmMenu", this.trainingAlgorithmVariable, strArr, this.firstTrainingAlgorithm);
        this.trainingAlgorithmLabel = new MJLabel("Train using " + trainingChoiceArr[0].algorithm + ".");
        this.trainingAlgorithmDescription = nnPanels.newDisplayTextArea(trainingChoiceArr[0].description);
        this.trainingAlgorithmLinks = nnPanels.newRowPanel(new Component[0]);
        for (int i9 = 0; i9 < trainingChoiceArr[0].functions.length; i9++) {
            this.trainingAlgorithmLinks.add(nnPanels.newHSpace(10));
            this.trainingAlgorithmLinks.add(nnWidgets.newFunctionLink(trainingChoiceArr[0].functions[i9]));
        }
        this.panel = nnPanels.newTopCenterPanel(nnPanels.newGridPanel(1, 2, 10, 10, nnPanels.newTitledBorderPanel(TITLE, trainingChoiceArr.length > 1 ? nnPanels.newTopBottomPanel(nnPanels.newColumnPanel(nnPanels.newLeftPanel(nnWidgets.newStringLabel("Choose a training algorithm:")), nnPanels.newVSpace(10), nnPanels.newCenterHPanel(this.trainingAlgorithmMenu), nnPanels.newVSpace(10), this.trainingAlgorithmDescription, nnPanels.newVSpace(10)), nnPanels.newColumnPanel(nnPanels.newLeftPanel(nnPanels.newRowPanel(this.trainingAlgorithmLabel, this.trainingAlgorithmLinks)), nnPanels.newVSpace(10), nnPanels.newCenterHPanel(this.trainButton))) : nnPanels.newTopBottomPanel(nnPanels.newColumnPanel(nnPanels.newLeftPanel(nnPanels.newRowPanel(this.trainingAlgorithmLabel, this.trainingAlgorithmLinks)), nnPanels.newVSpace(15), nnPanels.newCenterHPanel(this.trainButton), nnPanels.newVSpace(10)), nnPanels.newColumnPanel(nnPanels.newVSpace(10), this.trainingAlgorithmDescription))), nnPanels.newTitledBorderPanel("Results", nnPanels.newStretchyTopPanel(newColumnPanel))), nnPanels.newTitledBorderPanel("Notes", nnPanels.newGridPanel(1, 2, nnPanels.newStretchyTopPanel(nnPanels.newLeftPanel(nnPanels.newIconDisplayTextArea("Training multiple times will generate different results due to different initial conditions and sampling.", nnIcons.TRAIN_16.toImageIcon()))), nnPanels.newStretchyTopPanel(nnPanels.newLeftColumnPanel(8, componentArr2)))));
        this.trainButton.addActionListener(new ActionListener() { // from class: com.mathworks.toolbox.nnet.gui.wizard.nnTrainingPage.2
            public void actionPerformed(ActionEvent actionEvent) {
                nnTrainingPage.this.train();
            }
        });
        this.trainingAlgorithmVariable.addWatcher(this.algorithmVariableListener);
        for (int i10 = 0; i10 < nnplotArr.length; i10++) {
            this.plotButtons[i10].addActionListener(new PlotListener(nnplotArr[i10]));
        }
        super.finishSetup();
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public String title() {
        return TITLE;
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public abstract String subtitle();

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public nnIcon icon() {
        return nnIcons.TRAIN_48;
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public Component panel() {
        return this.panel;
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public void initialize() {
        this.isTraining = false;
        this.isPlotting = null;
        this.hasTrained = false;
        this.trainingAlgorithmVariable.set(this.firstTrainingAlgorithm);
        this.trainExamplesLabel.setText("" + getSetting(nnSettingKeys.NUM_TRAINING_SAMPLES_SETTING));
        this.validateExamplesLabel.setText("" + getSetting(nnSettingKeys.NUM_VALIDATION_SAMPLES_SETTING));
        this.testExamplesLabel.setText("" + getSetting(nnSettingKeys.NUM_TEST_SAMPLES_SETTING));
        for (int i = 0; i < this.samples.length; i++) {
            for (int i2 = 0; i2 < this.metrics.length; i2++) {
                this.valueLabels[i][i2].setText("-");
                this.valueLabels[i][i2].setToolTipText((String) null);
            }
        }
        this.trainButton.requestFocus();
    }

    public void updateTrainingAlgorithm() {
        int selectedIndex = this.trainingAlgorithmMenu.getSelectedIndex();
        this.trainingAlgorithmLabel.setText("Train using " + this.trainingChoices[selectedIndex].algorithm + ".");
        this.trainingAlgorithmDescription.setText(this.trainingChoices[selectedIndex].description);
        this.trainingAlgorithmLinks.removeAll();
        for (int i = 0; i < this.trainingChoices[selectedIndex].functions.length; i++) {
            this.trainingAlgorithmLinks.add(nnPanels.newHSpace(10));
            this.trainingAlgorithmLinks.add(nnWidgets.newFunctionLink(this.trainingChoices[selectedIndex].functions[i]));
        }
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public void updateStatus() {
        this.trainButton.setText(this.hasTrained ? "Retrain" : "Train");
        for (int i = 0; i < this.samples.length; i++) {
            for (int i2 = 0; i2 < this.metrics.length; i2++) {
                this.valueLabels[i][i2].setEnabled(this.hasTrained);
            }
        }
        if (this.isTraining) {
            this.trainButton.setEnabled(false);
            for (int i3 = 0; i3 < this.plots.length; i3++) {
                this.plotButtons[i3].setEnabled(false);
            }
            setStatus(nnIcons.WIZARD_STATUS_BUSY, "Training network.", false, true);
            return;
        }
        if (this.isPlotting != null) {
            this.trainButton.setEnabled(false);
            for (int i4 = 0; i4 < this.plots.length; i4++) {
                this.plotButtons[i4].setEnabled(false);
            }
            setStatus(nnIcons.WIZARD_STATUS_BUSY, "Plotting " + this.isPlotting.name.toLowerCase() + ".", false, true);
            return;
        }
        this.trainButton.setEnabled(true);
        for (int i5 = 0; i5 < this.plots.length; i5++) {
            this.plotButtons[i5].setEnabled(this.hasTrained);
        }
        if (this.hasTrained) {
            setStatus(nnIcons.WIZARD_STATUS_NEXT, "Open a plot, retrain, or click [Next] to continue.", true, false);
        } else {
            setStatus(nnIcons.WIZARD_STATUS_INFO, "Train network, then click [Next].", false, false);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void train() {
        this.isTraining = true;
        updateStatus();
        NNMatlab.call(this.trainAcceptor, this.mfunction, "trainNetwork", (Integer) getSetting(nnSettingKeys.VALIDATION_PERCENT_SETTING), (Integer) getSetting(nnSettingKeys.TEST_PERCENT_SETTING), Integer.valueOf(this.trainingAlgorithmMenu.getSelectedIndex()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void viewPlot(nnPlot nnplot) {
        this.isPlotting = nnplot;
        updateStatus();
        NNMatlab.call(this.viewFitAcceptor, this.mfunction, "viewTrainPlot", nnplot.mfunction);
    }

    @Override // com.mathworks.toolbox.nnet.gui.wizard.nnWizardPage
    public Object getPageSetting(Object obj) {
        int selectedIndex = this.trainingAlgorithmMenu.getSelectedIndex();
        if (obj == nnSettingKeys.TRAINING_FUNCTION) {
            return this.trainingChoices[selectedIndex].functions[0];
        }
        if (obj == nnSettingKeys.TRAINING_ALGORITHM) {
            return this.trainingChoices[selectedIndex].algorithm;
        }
        return null;
    }
}
