package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ProductFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

/* loaded from: input_file:org/jpmml/sparkml/feature/StandardScalerModelConverter.class */
public class StandardScalerModelConverter extends FeatureConverter<StandardScalerModel> {
    public StandardScalerModelConverter(StandardScalerModel standardScalerModel) {
        super(standardScalerModel);
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public List<Feature> encodeFeatures(SparkMLEncoder sparkMLEncoder) {
        StandardScalerModel standardScalerModel = (StandardScalerModel) getTransformer();
        List<Feature> features = sparkMLEncoder.getFeatures(standardScalerModel.getInputCol());
        Vector mean = standardScalerModel.mean();
        if (standardScalerModel.getWithMean() && mean.size() != features.size()) {
            throw new IllegalArgumentException();
        }
        Vector std = standardScalerModel.std();
        if (standardScalerModel.getWithStd() && std.size() != features.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < features.size(); i++) {
            Feature feature = features.get(i);
            final FieldName formatName = formatName(standardScalerModel, i);
            Apply apply = null;
            if (standardScalerModel.getWithMean()) {
                double apply2 = mean.apply(i);
                if (!ValueUtil.isZero(Double.valueOf(apply2))) {
                    apply = PMMLUtil.createApply("-", new Expression[]{feature.toContinuousFeature().ref(), PMMLUtil.createConstant(Double.valueOf(apply2))});
                }
            }
            if (standardScalerModel.getWithStd()) {
                double apply3 = std.apply(i);
                if (!ValueUtil.isOne(Double.valueOf(apply3))) {
                    Double valueOf = Double.valueOf(1.0d / apply3);
                    if (apply != null) {
                        apply = PMMLUtil.createApply("*", new Expression[]{apply, PMMLUtil.createConstant(valueOf)});
                    } else {
                        feature = new ProductFeature(sparkMLEncoder, feature, valueOf) { // from class: org.jpmml.sparkml.feature.StandardScalerModelConverter.1
                            public ContinuousFeature toContinuousFeature() {
                                PMMLEncoder ensureEncoder = ensureEncoder();
                                DerivedField derivedField = ensureEncoder.getDerivedField(formatName);
                                if (derivedField == null) {
                                    derivedField = ensureEncoder.createDerivedField(formatName, OpType.CONTINUOUS, DataType.DOUBLE, PMMLUtil.createApply("*", new Expression[]{getFeature().toContinuousFeature().ref(), PMMLUtil.createConstant(getFactor())}));
                                }
                                return new ContinuousFeature(ensureEncoder, derivedField);
                            }
                        };
                    }
                }
            }
            if (apply != null) {
                arrayList.add(new ContinuousFeature(sparkMLEncoder, sparkMLEncoder.createDerivedField(formatName, OpType.CONTINUOUS, DataType.DOUBLE, apply)));
            } else {
                arrayList.add(feature);
            }
        }
        return arrayList;
    }
}
