package com.maplesoft.pen.recognition.character.training;

import com.maplesoft.mathdoc.controller.WmiMenu;
import com.maplesoft.mathdoc.exception.WmiErrorLog;
import com.maplesoft.mathdoc.exception.WmiNoReadAccessException;
import com.maplesoft.mathdoc.exception.WmiNoWriteAccessException;
import com.maplesoft.mathdoc.model.WmiAttributeSet;
import com.maplesoft.mathdoc.model.WmiMathDocumentModel;
import com.maplesoft.mathdoc.model.WmiModel;
import com.maplesoft.mathdoc.model.WmiModelLock;
import com.maplesoft.mathdoc.model.math.WmiMathAttributeSet;
import com.maplesoft.mathdoc.model.math.WmiPowerBuilder;
import com.maplesoft.pen.controller.file.PenFileOpen;
import com.maplesoft.pen.controller.training.PenTrainingCommand;
import com.maplesoft.pen.controller.training.PenTrainingControllerManager;
import com.maplesoft.pen.controller.training.PenTrainingOpenConfusionMatrix;
import com.maplesoft.pen.exception.PenRecognizerInitializationException;
import com.maplesoft.pen.io.xml.PenXMLImportParser;
import com.maplesoft.pen.model.PenAttributeConstants;
import com.maplesoft.pen.model.PenCanvasModel;
import com.maplesoft.pen.model.PenDocumentModel;
import com.maplesoft.pen.model.PenStrokeCollectionModel;
import com.maplesoft.pen.recognition.database.PenConfusionMatrix;
import com.maplesoft.pen.recognition.database.PenRecognitionData;
import com.maplesoft.pen.recognition.database.PenRecognitionDataStore;
import com.maplesoft.util.PlatformInfo;
import com.maplesoft.util.ResourceLoader;
import com.maplesoft.worksheet.io.classic.attributes.WmiClassicPlotAttributeSet;
import java.io.File;
import java.io.InputStreamReader;
import java.util.HashSet;
import java.util.ResourceBundle;
import java.util.Set;

/* loaded from: input_file:com/maplesoft/pen/recognition/character/training/PenCharacterBatchTrainer.class */
public class PenCharacterBatchTrainer {
    public static final int DEFAULT_TEST_SPARSITY = 5;
    private static final String DEFAULT_TRAINING_PATH = "com/maplesoft/pen/trainingdata/character/penmath/";
    private static final String TRAINING_FILES_RESOURCE = "files";
    private static final String TRAINING_FILES_KEY = "character.training.files";
    private static final int UPDATE_DELAY = 15000;
    private static final String[] EXCLUDED_NAMES_FOR_SYMBOL_PALETTE = {"0", "1", "2", WmiClassicPlotAttributeSet.MODE_3D, "4", "5", "6", "7", "8", "9", "Ϝ", "ϝ", WmiPowerBuilder.SQUARE_ROOT_FUNCTION_NAME, WmiMathAttributeSet.SEMANTICS_SUM, "prod", "int", PlatformInfo.DEC_ALPHA, "beta", "gamma", "pi"};
    private String[] fileNameArray;
    private String testInputDatabase;
    private String testOutputDatabase;
    private Set<String> excludedNames;
    private WmiMathDocumentModel[] docArray = null;
    private int sparsity = 0;
    private int startIndex = 0;
    private boolean databaseLoaded = false;
    private PenCharacterTrainingController trainer = new PenCharacterTrainingController();

    public PenCharacterBatchTrainer() {
        this.excludedNames = null;
        this.trainer.setVerboseRecognition(false);
        PenTrainingControllerManager.setActiveController(this.trainer);
        this.excludedNames = new HashSet();
        for (int i = 0; i < EXCLUDED_NAMES_FOR_SYMBOL_PALETTE.length; i++) {
            this.excludedNames.add(EXCLUDED_NAMES_FOR_SYMBOL_PALETTE[i]);
        }
    }

    public void setTrainingInputFilePath(String str) {
        this.fileNameArray = ResourceBundle.getBundle(str + TRAINING_FILES_RESOURCE).getString(TRAINING_FILES_KEY).split(WmiMenu.LIST_DELIMITER);
        this.docArray = new WmiMathDocumentModel[this.fileNameArray.length];
        for (int i = 0; i < this.fileNameArray.length; i++) {
            this.docArray[i] = readFile(str, this.fileNameArray[i]);
        }
    }

    public void trainCharacterRecognizer() {
        for (int i = 0; i < this.fileNameArray.length; i++) {
            System.out.println("Training from file: " + this.fileNameArray[i] + " (" + (i + 1) + "/" + this.fileNameArray.length + ")");
            trainDocument(this.docArray[i]);
        }
        this.databaseLoaded = true;
    }

    public void testCharacterRecognizer() {
        testCharacterRecognizer(5);
    }

    public void testCharacterRecognizer(int i) {
        this.sparsity = i;
        this.startIndex = 0;
        if (this.testInputDatabase == null) {
            trainCharacterRecognizer();
            if (this.testOutputDatabase != null) {
                PenRecognitionDataStore.writeToSerializedFile(this.testOutputDatabase, (PenConfusionMatrix) getDatabase());
                System.out.println("Writing training database to file: " + this.testOutputDatabase);
            }
        } else if (!this.databaseLoaded) {
            System.out.println("Reading training database from file: " + this.testInputDatabase);
            try {
                this.trainer.setRecognitionData(PenTrainingOpenConfusionMatrix.loadFromNativeFile(new File(this.testInputDatabase)));
                System.out.println("   ... done.");
                this.databaseLoaded = true;
            } catch (PenRecognizerInitializationException e) {
                e.printStackTrace();
                throw new RuntimeException("recognizer initialization exception");
            }
        }
        for (int i2 = 0; i2 < this.docArray.length; i2++) {
            System.out.println("Testing file: " + this.fileNameArray[i2] + " (" + (i2 + 1) + "/" + this.fileNameArray.length + ")");
            testDocument(this.docArray[i2]);
        }
        this.trainer.outputTally();
        System.out.println(getStatisticsDescription());
        this.sparsity = 0;
        this.startIndex = 0;
    }

    public void setTestInputFilename(String str) {
        this.testInputDatabase = str;
    }

    public void setTestOutputFilename(String str) {
        this.testOutputDatabase = str;
    }

    public void resetTest() {
        this.trainer.resetRecognitionCount();
    }

    public PenRecognitionData getDatabase() {
        return this.trainer.getRecognitionData();
    }

    public String getStatisticsDescription() {
        int totalRecognitionCount = this.trainer.getTotalRecognitionCount();
        int correctRecognitionCount = this.trainer.getCorrectRecognitionCount();
        int correctAsBestRecognitionCount = this.trainer.getCorrectAsBestRecognitionCount();
        return "Total: " + totalRecognitionCount + " Correct: " + correctRecognitionCount + " (" + percent(correctRecognitionCount, totalRecognitionCount) + "%) Best: " + correctAsBestRecognitionCount + " (" + percent(correctAsBestRecognitionCount, totalRecognitionCount) + "%)";
    }

    public int getTotalRecognitionCount() {
        return this.trainer.getTotalRecognitionCount();
    }

    public int getCorrectRecognitionCount() {
        return this.trainer.getCorrectRecognitionCount();
    }

    public int getCorrectAsBestRecognitionCount() {
        return this.trainer.getCorrectAsBestRecognitionCount();
    }

    private int percent(int i, int i2) {
        return (int) ((i / i2) * 100.0f);
    }

    private WmiMathDocumentModel readFile(String str, String str2) {
        System.out.print("Reading file: " + str2);
        PenDocumentModel penDocumentModel = new PenDocumentModel();
        try {
            try {
                InputStreamReader inputStreamReader = new InputStreamReader(ResourceLoader.getResourceAsStream(str + str2 + ".mpn"));
                WmiModelLock.writeLock(penDocumentModel, true);
                if (PenFileOpen.readFromReader(inputStreamReader, penDocumentModel, new PenXMLImportParser())) {
                    System.out.print(" (" + penDocumentModel.getChildCount() + " samples)");
                }
                WmiModelLock.writeUnlock(penDocumentModel);
            } catch (WmiNoReadAccessException e) {
                WmiErrorLog.log(e);
                penDocumentModel = null;
                WmiModelLock.writeUnlock(null);
            } catch (WmiNoWriteAccessException e2) {
                WmiErrorLog.log(e2);
                penDocumentModel = null;
                WmiModelLock.writeUnlock(null);
            }
            System.out.println();
            return penDocumentModel;
        } catch (Throwable th) {
            WmiModelLock.writeUnlock(penDocumentModel);
            throw th;
        }
    }

    private void trainDocument(WmiMathDocumentModel wmiMathDocumentModel) {
        System.out.println("    training ...");
        try {
            try {
                WmiModelLock.writeLock(wmiMathDocumentModel, true);
                long currentTimeMillis = System.currentTimeMillis();
                int childCount = wmiMathDocumentModel.getChildCount();
                for (int i = 0; i < childCount; i++) {
                    if (this.sparsity == 0 || i % this.sparsity != this.startIndex) {
                        WmiModel child = wmiMathDocumentModel.getChild(i);
                        String str = (String) child.getAttributesForRead().getAttribute(PenAttributeConstants.TRAINING_DATA);
                        if (this.excludedNames == null || !this.excludedNames.contains(str)) {
                            PenTrainingCommand.trainModel(child, this.trainer);
                        } else {
                            System.out.print("*** IGNORING: " + str + " (");
                            for (int i2 = 0; i2 < str.length(); i2++) {
                                System.out.print("0x");
                                System.out.print(Integer.toHexString(str.charAt(i2)));
                                if (i2 < str.length() - 1) {
                                    System.out.print(WmiMenu.LIST_DELIMITER);
                                }
                            }
                            System.out.println(")");
                        }
                    }
                    long currentTimeMillis2 = System.currentTimeMillis();
                    if (currentTimeMillis2 - currentTimeMillis >= 15000) {
                        System.out.println("     - " + percent(i, childCount) + "% (" + i + "/" + childCount + ")");
                        currentTimeMillis = currentTimeMillis2;
                    }
                }
            } catch (WmiNoReadAccessException e) {
                WmiErrorLog.log(e);
                WmiModelLock.writeUnlock(wmiMathDocumentModel);
            }
            System.out.println("    ... done");
        } finally {
            WmiModelLock.writeUnlock(wmiMathDocumentModel);
        }
    }

    private void testDocument(WmiMathDocumentModel wmiMathDocumentModel) {
        System.out.println("    testing ...");
        try {
            try {
                WmiModelLock.writeLock(wmiMathDocumentModel, true);
                long currentTimeMillis = System.currentTimeMillis();
                int childCount = wmiMathDocumentModel.getChildCount();
                int i = childCount / this.sparsity;
                int i2 = 0;
                int i3 = this.startIndex;
                while (i3 < childCount) {
                    PenCanvasModel penCanvasModel = (PenCanvasModel) wmiMathDocumentModel.getChild(i3);
                    PenStrokeCollectionModel penStrokeCollectionModel = (PenStrokeCollectionModel) penCanvasModel.getCompositeLayer(2).getChild(0);
                    if (penStrokeCollectionModel.getChildCount() > 0) {
                        try {
                            this.trainer.sendRecognitionRequest(penStrokeCollectionModel);
                            i2++;
                        } catch (Exception e) {
                            System.err.println("Exception caught during recogition: " + e + " at index " + i3);
                            WmiAttributeSet attributesForRead = penCanvasModel.getAttributesForRead();
                            System.err.println("Training data: " + attributesForRead.getAttribute(PenAttributeConstants.TRAINING_DATA));
                            System.err.println("Training data details: " + attributesForRead.getAttribute(PenAttributeConstants.TRAINING_DATA_DETAILS));
                            e.printStackTrace();
                        }
                    }
                    long currentTimeMillis2 = System.currentTimeMillis();
                    if (currentTimeMillis2 - currentTimeMillis >= 15000) {
                        System.out.println("     - " + percent(i3, childCount) + "% (" + i2 + "/" + i + ")");
                        currentTimeMillis = currentTimeMillis2;
                    }
                    i3 += this.sparsity;
                }
            } catch (WmiNoReadAccessException e2) {
                WmiErrorLog.log(e2);
                WmiModelLock.writeUnlock(wmiMathDocumentModel);
            }
            System.out.println("    ... done");
        } finally {
            WmiModelLock.writeUnlock(wmiMathDocumentModel);
        }
    }

    private static void printUsage() {
        System.out.println("Usage: PenCharacterBatchTrainer [-o <output-file>] [-p <input-path>]");
        System.out.println("       PenCharacterBatchTrainer -test [-i <input-file>] [-o <output-file>] [-s <sparsity>]");
    }

    public static void train(String[] strArr) {
        boolean z = false;
        String str = null;
        String str2 = null;
        String str3 = DEFAULT_TRAINING_PATH;
        int i = 5;
        if (strArr != null && strArr.length > 0) {
            int i2 = 0;
            if (strArr[0].equals("-test")) {
                System.out.println("test");
                z = true;
                i2 = 0 + 1;
            }
            int i3 = i2;
            while (i3 < strArr.length) {
                try {
                    if (strArr[i3].equals("-i") && i3 < strArr.length - 1) {
                        i3++;
                        str = strArr[i3];
                    } else if (strArr[i3].equals("-o") && i3 < strArr.length - 1) {
                        i3++;
                        str2 = strArr[i3];
                    } else if (strArr[i3].equals("-s") && i3 < strArr.length - 1) {
                        i3++;
                        i = Integer.parseInt(strArr[i3]);
                    } else if (strArr[i3].equals("-p") && i3 < strArr.length - 1) {
                        i3++;
                        str3 = strArr[i3];
                    }
                    i3++;
                } catch (ArrayIndexOutOfBoundsException e) {
                    printUsage();
                    System.exit(1);
                }
            }
        }
        PenCharacterBatchTrainer penCharacterBatchTrainer = new PenCharacterBatchTrainer();
        if (z) {
            if (str != null) {
                penCharacterBatchTrainer.setTestInputFilename(str);
            } else {
                penCharacterBatchTrainer.setTrainingInputFilePath(str3);
            }
            penCharacterBatchTrainer.setTestOutputFilename(str2);
            penCharacterBatchTrainer.testCharacterRecognizer(i);
        } else {
            penCharacterBatchTrainer.setTrainingInputFilePath(str3);
            penCharacterBatchTrainer.trainCharacterRecognizer();
            PenRecognitionData database = penCharacterBatchTrainer.getDatabase();
            if (str2 == null) {
                str2 = "RecognitionData." + database.getSerializedFileExtension();
            }
            PenRecognitionDataStore.writeToSerializedFile(str2, database);
            System.out.println("Database written to: " + str2);
        }
        System.exit(0);
    }
}
