package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Locale;
import org.antlr.tool.ErrorManager;
import org.antlr.tool.GrammarReport;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.RandomAccessSparseVector;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.class */
public final class TrainAdaptiveLogistic {
    private static String inputFile;
    private static String outputFile;
    private static AdaptiveLogisticModelParameters lmp;
    private static int passes;
    private static boolean showperf;
    private static int skipperfnum = 99;
    private static AdaptiveLogisticRegression model;

    private TrainAdaptiveLogistic() {
    }

    public static void main(String[] strArr) throws Exception {
        mainToOutput(strArr, new PrintWriter((OutputStream) System.out, true));
    }

    static void mainToOutput(String[] strArr, PrintWriter printWriter) throws Exception {
        if (parseArgs(strArr)) {
            CsvRecordFactory csvRecordFactory = lmp.getCsvRecordFactory();
            model = lmp.createAdaptiveLogisticRegression();
            CrossFoldLearner crossFoldLearner = null;
            int i = 0;
            for (int i2 = 0; i2 < passes; i2++) {
                BufferedReader open = open(inputFile);
                csvRecordFactory.firstLine(open.readLine());
                String readLine = open.readLine();
                int i3 = 2;
                while (readLine != null) {
                    RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(lmp.getNumFeatures());
                    int processLine = csvRecordFactory.processLine(readLine, randomAccessSparseVector);
                    model.train(processLine, randomAccessSparseVector);
                    i++;
                    if (showperf && i % (skipperfnum + 1) == 0) {
                        State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = model.getBest();
                        if (best != null) {
                            crossFoldLearner = best.getPayload().getLearner();
                        }
                        if (crossFoldLearner != null) {
                            printWriter.printf("%d\t%.3f\t%.2f\n", Integer.valueOf(i), Double.valueOf(crossFoldLearner.logLikelihood()), Double.valueOf(crossFoldLearner.percentCorrect() * 100.0d));
                        } else {
                            printWriter.printf(Locale.ENGLISH, "%10d %2d %s\n", Integer.valueOf(i), Integer.valueOf(processLine), "AdaptiveLogisticRegression has not found a good model ......");
                        }
                    }
                    readLine = open.readLine();
                    i3++;
                }
                open.close();
            }
            State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best2 = model.getBest();
            if (best2 != null) {
                crossFoldLearner = best2.getPayload().getLearner();
            }
            if (crossFoldLearner == null) {
                printWriter.printf(Locale.ENGLISH, "%s\n", "AdaptiveLogisticRegression has failed to train a model.");
                return;
            }
            FileOutputStream fileOutputStream = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(fileOutputStream);
                fileOutputStream.close();
                OnlineLogisticRegression onlineLogisticRegression = crossFoldLearner.getModels().get(0);
                printWriter.printf(Locale.ENGLISH, "%d\n", Integer.valueOf(lmp.getNumFeatures()));
                printWriter.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
                Object obj = "";
                for (String str : csvRecordFactory.getTraceDictionary().keySet()) {
                    double predictorWeight = predictorWeight(onlineLogisticRegression, 0, csvRecordFactory, str);
                    if (predictorWeight != 0.0d) {
                        printWriter.printf(Locale.ENGLISH, "%s%.3f*%s", obj, Double.valueOf(predictorWeight), str);
                        obj = " + ";
                    }
                }
                printWriter.printf("\n", new Object[0]);
                for (int i4 = 0; i4 < onlineLogisticRegression.getBeta().numRows(); i4++) {
                    for (String str2 : csvRecordFactory.getTraceDictionary().keySet()) {
                        double predictorWeight2 = predictorWeight(onlineLogisticRegression, i4, csvRecordFactory, str2);
                        if (predictorWeight2 != 0.0d) {
                            printWriter.printf(Locale.ENGLISH, "%20s %.5f\n", str2, Double.valueOf(predictorWeight2));
                        }
                    }
                    for (int i5 = 0; i5 < onlineLogisticRegression.getBeta().numCols(); i5++) {
                        printWriter.printf(Locale.ENGLISH, "%15.9f ", Double.valueOf(onlineLogisticRegression.getBeta().get(i4, i5)));
                    }
                    printWriter.println();
                }
            } catch (Throwable th) {
                fileOutputStream.close();
                throw th;
            }
        }
    }

    private static double predictorWeight(OnlineLogisticRegression onlineLogisticRegression, int i, RecordFactory recordFactory, String str) {
        double d = 0.0d;
        Iterator<Integer> it = recordFactory.getTraceDictionary().get(str).iterator();
        while (it.hasNext()) {
            d += onlineLogisticRegression.getBeta().get(i, it.next().intValue());
        }
        return d;
    }

    private static boolean parseArgs(String[] strArr) {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("help").withDescription("print this list").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("quiet").withDescription("be extra quiet").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create3 = defaultOptionBuilder.withLongName("showperf").withDescription("output performance measures during training").create();
        DefaultOption create4 = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("output").withRequired(true).withArgument(argumentBuilder.withName("output").withMaximum(1).create()).withDescription("where to write the model content").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("threads").withArgument(argumentBuilder.withName("threads").withDefault(GrammarReport.Version).create()).withDescription("the number of threads AdaptiveLogisticRegression uses").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("predictors").withRequired(true).withArgument(argumentBuilder.withName("predictors").create()).withDescription("a list of predictor variables").create();
        DefaultOption create8 = defaultOptionBuilder.withLongName("types").withRequired(true).withArgument(argumentBuilder.withName("types").create()).withDescription("a list of predictor variable types (numeric, word, or text)").create();
        DefaultOption create9 = defaultOptionBuilder.withLongName("target").withDescription("the name of the target variable").withRequired(true).withArgument(argumentBuilder.withName("target").withMaximum(1).create()).create();
        DefaultOption create10 = defaultOptionBuilder.withLongName("categories").withDescription("the number of target categories to be considered").withRequired(true).withArgument(argumentBuilder.withName("categories").withMaximum(1).create()).create();
        DefaultOption create11 = defaultOptionBuilder.withLongName("features").withDescription("the number of internal hashed features to use").withArgument(argumentBuilder.withName("numFeatures").withDefault("1000").withMaximum(1).create()).create();
        DefaultOption create12 = defaultOptionBuilder.withLongName("passes").withDescription("the number of times to pass over the input data").withArgument(argumentBuilder.withName("passes").withDefault("2").withMaximum(1).create()).create();
        DefaultOption create13 = defaultOptionBuilder.withLongName("interval").withArgument(argumentBuilder.withName("interval").withDefault("500").create()).withDescription("the interval property of AdaptiveLogisticRegression").create();
        DefaultOption create14 = defaultOptionBuilder.withLongName("window").withArgument(argumentBuilder.withName("window").withDefault("800").create()).withDescription("the average propery of AdaptiveLogisticRegression").create();
        DefaultOption create15 = defaultOptionBuilder.withLongName("skipperfnum").withArgument(argumentBuilder.withName("skipperfnum").withDefault("99").create()).withDescription("show performance measures every (skipperfnum + 1) rows").create();
        DefaultOption create16 = defaultOptionBuilder.withLongName("prior").withArgument(argumentBuilder.withName("prior").withDefault("L1").create()).withDescription("the prior algorithm to use: L1, L2, ebp, tp, up").create();
        DefaultOption create17 = defaultOptionBuilder.withLongName("prioroption").withArgument(argumentBuilder.withName("prioroption").create()).withDescription("constructor parameter for ElasticBandPrior and TPrior").create();
        DefaultOption create18 = defaultOptionBuilder.withLongName("auc").withArgument(argumentBuilder.withName("auc").withDefault("global").create()).withDescription("the auc to use: global or grouped").create();
        Group create19 = new GroupBuilder().withOption(create).withOption(create2).withOption(create4).withOption(create5).withOption(create9).withOption(create10).withOption(create7).withOption(create8).withOption(create12).withOption(create13).withOption(create14).withOption(create6).withOption(create16).withOption(create11).withOption(create3).withOption(create15).withOption(create17).withOption(create18).create();
        Parser parser = new Parser();
        parser.setHelpOption(create);
        parser.setHelpTrigger("--help");
        parser.setGroup(create19);
        parser.setHelpFormatter(new HelpFormatter(" ", "", " ", ErrorManager.MSG_RULE_HAS_NO_ARGS));
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        inputFile = getStringArgument(parseAndHelp, create4);
        outputFile = getStringArgument(parseAndHelp, create5);
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = parseAndHelp.getValues(create8).iterator();
        while (it.hasNext()) {
            newArrayList.add(it.next().toString());
        }
        ArrayList newArrayList2 = Lists.newArrayList();
        Iterator it2 = parseAndHelp.getValues(create7).iterator();
        while (it2.hasNext()) {
            newArrayList2.add(it2.next().toString());
        }
        lmp = new AdaptiveLogisticModelParameters();
        lmp.setTargetVariable(getStringArgument(parseAndHelp, create9));
        lmp.setMaxTargetCategories(getIntegerArgument(parseAndHelp, create10));
        lmp.setNumFeatures(getIntegerArgument(parseAndHelp, create11));
        lmp.setInterval(getIntegerArgument(parseAndHelp, create13));
        lmp.setAverageWindow(getIntegerArgument(parseAndHelp, create14));
        lmp.setThreads(getIntegerArgument(parseAndHelp, create6));
        lmp.setAuc(getStringArgument(parseAndHelp, create18));
        lmp.setPrior(getStringArgument(parseAndHelp, create16));
        if (parseAndHelp.getValue(create17) != null) {
            lmp.setPriorOption(getDoubleArgument(parseAndHelp, create17));
        }
        lmp.setTypeMap(newArrayList2, newArrayList);
        showperf = getBooleanArgument(parseAndHelp, create3);
        skipperfnum = getIntegerArgument(parseAndHelp, create15);
        passes = getIntegerArgument(parseAndHelp, create12);
        lmp.checkParameters();
        return true;
    }

    private static String getStringArgument(CommandLine commandLine, Option option) {
        return (String) commandLine.getValue(option);
    }

    private static boolean getBooleanArgument(CommandLine commandLine, Option option) {
        return commandLine.hasOption(option);
    }

    private static int getIntegerArgument(CommandLine commandLine, Option option) {
        return Integer.parseInt((String) commandLine.getValue(option));
    }

    private static double getDoubleArgument(CommandLine commandLine, Option option) {
        return Double.parseDouble((String) commandLine.getValue(option));
    }

    public static AdaptiveLogisticRegression getModel() {
        return model;
    }

    public static LogisticModelParameters getParameters() {
        return lmp;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.io.InputStream] */
    public static BufferedReader open(String str) throws IOException {
        FileInputStream fileInputStream;
        try {
            fileInputStream = Resources.getResource(str).openStream();
        } catch (IllegalArgumentException e) {
            fileInputStream = new FileInputStream(new File(str));
        }
        return new BufferedReader(new InputStreamReader(fileInputStream, Charsets.UTF_8));
    }
}
