package org.jpmml.sparkml;

import com.google.common.collect.Iterables;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.xml.bind.JAXBException;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.MetroJAXBUtil;

/* loaded from: input_file:org/jpmml/sparkml/PMMLBuilder.class */
public class PMMLBuilder {
    private StructType schema = null;
    private PipelineModel pipelineModel = null;
    private Map<RegexKey, Map<String, Object>> options = new LinkedHashMap();
    private Verification verification = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sparkml.PMMLBuilder$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sparkml/PMMLBuilder$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningField$UsageType = new int[MiningField.UsageType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.PREDICTED.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.TARGET.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningField$UsageType[MiningField.UsageType.ACTIVE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/jpmml/sparkml/PMMLBuilder$Verification.class */
    public static class Verification {
        private Dataset<Row> dataset;
        private Dataset<Row> transformedDataset;
        public Double precision;
        public Double zeroThreshold;

        private Verification(Dataset<Row> dataset, Dataset<Row> dataset2) {
            this.dataset = null;
            this.transformedDataset = null;
            this.precision = null;
            this.zeroThreshold = null;
            setDataset(dataset);
            setTransformedDataset(dataset2);
        }

        public Dataset<Row> getDataset() {
            return this.dataset;
        }

        private Verification setDataset(Dataset<Row> dataset) {
            this.dataset = dataset;
            return this;
        }

        public Dataset<Row> getTransformedDataset() {
            return this.transformedDataset;
        }

        private Verification setTransformedDataset(Dataset<Row> dataset) {
            this.transformedDataset = dataset;
            return this;
        }

        public Double getPrecision() {
            return this.precision;
        }

        public Verification setPrecision(Double d) {
            this.precision = d;
            return this;
        }

        public Double getZeroThreshold() {
            return this.zeroThreshold;
        }

        public Verification setZeroThreshold(Double d) {
            this.zeroThreshold = d;
            return this;
        }
    }

    public PMMLBuilder(StructType structType, PipelineModel pipelineModel) {
        setSchema(structType);
        setPipelineModel(pipelineModel);
    }

    public PMML build() {
        Model miningSchema;
        StructType schema = getSchema();
        PipelineModel pipelineModel = getPipelineModel();
        Map<RegexKey, Map<String, Object>> options = getOptions();
        Verification verification = getVerification();
        ConverterFactory converterFactory = new ConverterFactory(options);
        SparkMLEncoder sparkMLEncoder = new SparkMLEncoder(schema, converterFactory);
        Map derivedFields = sparkMLEncoder.getDerivedFields();
        ArrayList arrayList = new ArrayList();
        ArrayList<String> arrayList2 = new ArrayList();
        ArrayList<String> arrayList3 = new ArrayList();
        List emptyList = Collections.emptyList();
        Iterator<Transformer> it = getTransformers(pipelineModel).iterator();
        while (it.hasNext()) {
            HasPredictionCol hasPredictionCol = (Transformer) it.next();
            TransformerConverter<?> newConverter = converterFactory.newConverter(hasPredictionCol);
            if (newConverter instanceof FeatureConverter) {
                ((FeatureConverter) newConverter).registerFeatures(sparkMLEncoder);
            } else {
                if (!(newConverter instanceof ModelConverter)) {
                    throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + newConverter);
                }
                Model registerModel = ((ModelConverter) newConverter).registerModel(sparkMLEncoder);
                arrayList.add(registerModel);
                if (hasPredictionCol instanceof HasPredictionCol) {
                    HasPredictionCol hasPredictionCol2 = hasPredictionCol;
                    if (!(hasPredictionCol instanceof GeneralizedLinearRegressionModel) || !MiningFunction.CLASSIFICATION.equals(registerModel.getMiningFunction())) {
                        arrayList2.add(hasPredictionCol2.getPredictionCol());
                    }
                }
                if (hasPredictionCol instanceof HasProbabilityCol) {
                    arrayList3.add(((HasProbabilityCol) hasPredictionCol).getProbabilityCol());
                }
                emptyList = new ArrayList(derivedFields.keySet());
            }
        }
        ArrayList<FieldName> arrayList4 = new ArrayList(derivedFields.keySet());
        arrayList4.removeAll(emptyList);
        if (arrayList.size() == 1) {
            miningSchema = (Model) Iterables.getOnlyElement(arrayList);
        } else {
            if (arrayList.size() <= 1) {
                throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
            }
            ArrayList arrayList5 = new ArrayList();
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                for (MiningField miningField : ((Model) it2.next()).getMiningSchema().getMiningFields()) {
                    switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningField$UsageType[miningField.getUsageType().ordinal()]) {
                        case 1:
                        case 2:
                            arrayList5.add(miningField);
                            break;
                    }
                }
            }
            miningSchema = MiningModelUtil.createModelChain(arrayList, new Schema((Label) null, Collections.emptyList())).setMiningSchema(new MiningSchema(arrayList5));
        }
        for (FieldName fieldName : arrayList4) {
            DerivedField derivedField = (DerivedField) derivedFields.get(fieldName);
            sparkMLEncoder.removeDerivedField(fieldName);
            ModelUtil.ensureOutput(miningSchema).addOutputFields(new OutputField[]{new OutputField(derivedField.getName(), derivedField.getDataType()).setOpType(derivedField.getOpType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression())});
        }
        PMML encodePMML = sparkMLEncoder.encodePMML(miningSchema);
        if ((arrayList2.size() > 0 || arrayList3.size() > 0) && verification != null) {
            Dataset<Row> dataset = verification.getDataset();
            Dataset<Row> transformedDataset = verification.getTransformedDataset();
            Double precision = verification.getPrecision();
            Double zeroThreshold = verification.getZeroThreshold();
            ArrayList<String> arrayList6 = new ArrayList();
            for (MiningField miningField2 : miningSchema.getMiningSchema().getMiningFields()) {
                switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningField$UsageType[miningField2.getUsageType().ordinal()]) {
                    case 3:
                        arrayList6.add(miningField2.getName().getValue());
                        break;
                }
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (String str : arrayList6) {
                linkedHashMap.put(ModelUtil.createVerificationField(FieldName.create(str)), getColumn(dataset, str));
            }
            for (String str2 : arrayList2) {
                linkedHashMap.put(ModelUtil.createVerificationField(sparkMLEncoder.getOnlyFeature(str2).getName()).setPrecision(precision).setZeroThreshold(zeroThreshold), getColumn(transformedDataset, str2));
            }
            for (String str3 : arrayList3) {
                List<Feature> features = sparkMLEncoder.getFeatures(str3);
                for (int i = 0; i < features.size(); i++) {
                    linkedHashMap.put(ModelUtil.createVerificationField(features.get(i).getName()).setPrecision(precision).setZeroThreshold(zeroThreshold), getVectorColumn(transformedDataset, str3, i));
                }
            }
            miningSchema.setModelVerification(ModelUtil.createModelVerification(linkedHashMap));
        }
        return encodePMML;
    }

    public byte[] buildByteArray() {
        return buildByteArray(1048576);
    }

    private byte[] buildByteArray(int i) {
        PMML build = build();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(i);
        try {
            MetroJAXBUtil.marshalPMML(build, byteArrayOutputStream);
            return byteArrayOutputStream.toByteArray();
        } catch (JAXBException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public File buildFile(File file) throws IOException {
        PMML build = build();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        try {
            try {
                MetroJAXBUtil.marshalPMML(build, fileOutputStream);
                fileOutputStream.close();
                return file;
            } catch (JAXBException e) {
                throw new RuntimeException((Throwable) e);
            }
        } catch (Throwable th) {
            fileOutputStream.close();
            throw th;
        }
    }

    public PMMLBuilder putOption(String str, Object obj) {
        return putOptions(Collections.singletonMap(str, obj));
    }

    public PMMLBuilder putOptions(Map<String, ?> map) {
        return putOptions(Pattern.compile(".*"), map);
    }

    public PMMLBuilder putOption(PipelineStage pipelineStage, String str, Object obj) {
        return putOptions(pipelineStage, Collections.singletonMap(str, obj));
    }

    public PMMLBuilder putOptions(PipelineStage pipelineStage, Map<String, ?> map) {
        return putOptions(Pattern.compile(pipelineStage.uid(), 16), map);
    }

    public PMMLBuilder putOptions(Pattern pattern, Map<String, ?> map) {
        Map<RegexKey, Map<String, Object>> options = getOptions();
        RegexKey regexKey = new RegexKey(pattern);
        Map<String, Object> map2 = options.get(regexKey);
        if (map2 == null) {
            map2 = new LinkedHashMap();
            options.put(regexKey, map2);
        }
        map2.putAll(map);
        return this;
    }

    public PMMLBuilder verify(Dataset<Row> dataset) {
        return verify(dataset, 1.0E-14d, 1.0E-14d);
    }

    public PMMLBuilder verify(Dataset<Row> dataset, double d, double d2) {
        return setVerification(new Verification(dataset, getPipelineModel().transform(dataset)).setPrecision(Double.valueOf(d)).setZeroThreshold(Double.valueOf(d2)));
    }

    public StructType getSchema() {
        return this.schema;
    }

    public PMMLBuilder setSchema(StructType structType) {
        if (structType == null) {
            throw new IllegalArgumentException();
        }
        this.schema = structType;
        return this;
    }

    public PipelineModel getPipelineModel() {
        return this.pipelineModel;
    }

    public PMMLBuilder setPipelineModel(PipelineModel pipelineModel) {
        if (pipelineModel == null) {
            throw new IllegalArgumentException();
        }
        this.pipelineModel = pipelineModel;
        return this;
    }

    public Map<RegexKey, Map<String, Object>> getOptions() {
        return this.options;
    }

    private PMMLBuilder setOptions(Map<RegexKey, Map<String, Object>> map) {
        if (map == null) {
            throw new IllegalArgumentException();
        }
        this.options = map;
        return this;
    }

    public Verification getVerification() {
        return this.verification;
    }

    private PMMLBuilder setVerification(Verification verification) {
        this.verification = verification;
        return this;
    }

    private static Iterable<Transformer> getTransformers(PipelineModel pipelineModel) {
        boolean z;
        ArrayList arrayList = new ArrayList();
        arrayList.add(pipelineModel);
        Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>() { // from class: org.jpmml.sparkml.PMMLBuilder.1
            @Override // java.util.function.Function
            public List<Transformer> apply(Transformer transformer) {
                if (transformer instanceof PipelineModel) {
                    return Arrays.asList(((PipelineModel) transformer).stages());
                }
                if (transformer instanceof CrossValidatorModel) {
                    return Collections.singletonList(((CrossValidatorModel) transformer).bestModel());
                }
                if (transformer instanceof TrainValidationSplitModel) {
                    return Collections.singletonList(((TrainValidationSplitModel) transformer).bestModel());
                }
                return null;
            }
        };
        do {
            z = false;
            ListIterator listIterator = arrayList.listIterator();
            while (listIterator.hasNext()) {
                List<Transformer> apply = function.apply((Transformer) listIterator.next());
                if (apply != null) {
                    listIterator.remove();
                    Iterator<Transformer> it = apply.iterator();
                    while (it.hasNext()) {
                        listIterator.add(it.next());
                    }
                    z = true;
                }
            }
        } while (z);
        return arrayList;
    }

    private static List<?> getColumn(Dataset<Row> dataset, String str) {
        return (List) dataset.select(str, new String[0]).collectAsList().stream().map(row -> {
            return row.apply(0);
        }).collect(Collectors.toList());
    }

    private static List<?> getVectorColumn(Dataset<Row> dataset, String str, int i) {
        return (List) getColumn(dataset, str).stream().map(vector -> {
            return Double.valueOf(vector.apply(i));
        }).collect(Collectors.toList());
    }

    private static void init() {
        ConverterFactory.checkVersion();
        ConverterFactory.checkApplicationClasspath();
        ConverterFactory.checkNoShading();
    }

    static {
        init();
    }
}
