package org.apache.mahout.classifier.sgd;

import java.io.BufferedReader;
import java.io.File;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.Locale;
import org.antlr.tool.ErrorManager;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.evaluation.Auc;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SequentialAccessSparseVector;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/RunLogistic.class */
public final class RunLogistic {
    private static String inputFile;
    private static String modelFile;
    private static boolean showAuc;
    private static boolean showScores;
    private static boolean showConfusion;

    private RunLogistic() {
    }

    public static void main(String[] strArr) throws Exception {
        mainToOutput(strArr, new PrintWriter((OutputStream) System.out, true));
    }

    static void mainToOutput(String[] strArr, PrintWriter printWriter) throws Exception {
        if (parseArgs(strArr)) {
            if (!showAuc && !showConfusion && !showScores) {
                showAuc = true;
                showConfusion = true;
            }
            Auc auc = new Auc();
            LogisticModelParameters loadFrom = LogisticModelParameters.loadFrom(new File(modelFile));
            CsvRecordFactory csvRecordFactory = loadFrom.getCsvRecordFactory();
            OnlineLogisticRegression createRegression = loadFrom.createRegression();
            BufferedReader open = TrainLogistic.open(inputFile);
            csvRecordFactory.firstLine(open.readLine());
            if (showScores) {
                printWriter.printf(Locale.ENGLISH, "\"%s\",\"%s\",\"%s\"\n", "target", "model-output", "log-likelihood");
            }
            for (String readLine = open.readLine(); readLine != null; readLine = open.readLine()) {
                SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(loadFrom.getNumFeatures());
                int processLine = csvRecordFactory.processLine(readLine, sequentialAccessSparseVector);
                double classifyScalar = createRegression.classifyScalar(sequentialAccessSparseVector);
                if (showScores) {
                    printWriter.printf(Locale.ENGLISH, "%d,%.3f,%.6f\n", Integer.valueOf(processLine), Double.valueOf(classifyScalar), Double.valueOf(createRegression.logLikelihood(processLine, sequentialAccessSparseVector)));
                }
                auc.add(processLine, classifyScalar);
            }
            if (showAuc) {
                printWriter.printf(Locale.ENGLISH, "AUC = %.2f\n", Double.valueOf(auc.auc()));
            }
            if (showConfusion) {
                Matrix confusion = auc.confusion();
                printWriter.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n", Double.valueOf(confusion.get(0, 0)), Double.valueOf(confusion.get(1, 0)), Double.valueOf(confusion.get(0, 1)), Double.valueOf(confusion.get(1, 1)));
                Matrix entropy = auc.entropy();
                printWriter.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n", Double.valueOf(entropy.get(0, 0)), Double.valueOf(entropy.get(1, 0)), Double.valueOf(entropy.get(0, 1)), Double.valueOf(entropy.get(1, 1)));
            }
        }
    }

    private static boolean parseArgs(String[] strArr) {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("help").withDescription("print this list").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("quiet").withDescription("be extra quiet").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("auc").withDescription("print AUC").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("confusion").withDescription("print confusion matrix").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("scores").withDescription("print scores").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create6 = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("model").withRequired(true).withArgument(argumentBuilder.withName("model").withMaximum(1).create()).withDescription("where to get a model").create();
        Group create8 = new GroupBuilder().withOption(create).withOption(create2).withOption(create3).withOption(create5).withOption(create4).withOption(create6).withOption(create7).create();
        Parser parser = new Parser();
        parser.setHelpOption(create);
        parser.setHelpTrigger("--help");
        parser.setGroup(create8);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", ErrorManager.MSG_RULE_HAS_NO_ARGS));
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        inputFile = getStringArgument(parseAndHelp, create6);
        modelFile = getStringArgument(parseAndHelp, create7);
        showAuc = getBooleanArgument(parseAndHelp, create3);
        showScores = getBooleanArgument(parseAndHelp, create5);
        showConfusion = getBooleanArgument(parseAndHelp, create4);
        return true;
    }

    private static boolean getBooleanArgument(CommandLine commandLine, Option option) {
        return commandLine.hasOption(option);
    }

    private static String getStringArgument(CommandLine commandLine, Option option) {
        return (String) commandLine.getValue(option);
    }
}
