package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.hadoop.fs.Path;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineModel$;
import org.apache.spark.ml.feature.RFormulaBase;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.util.DefaultParamsReader;
import org.apache.spark.ml.util.DefaultParamsReader$;
import org.apache.spark.ml.util.DefaultParamsWriter$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.BooleanType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;

/* compiled from: RFormula.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Mh\u0001B\u0001\u0003\u00015\u0011QB\u0015$pe6,H.Y'pI\u0016d'BA\u0002\u0005\u0003\u001d1W-\u0019;ve\u0016T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M!\u0001A\u0004\u000b\u0018!\ry\u0001CE\u0007\u0002\t%\u0011\u0011\u0003\u0002\u0002\u0006\u001b>$W\r\u001c\t\u0003'\u0001i\u0011A\u0001\t\u0003'UI!A\u0006\u0002\u0003\u0019I3uN]7vY\u0006\u0014\u0015m]3\u0011\u0005aYR\"A\r\u000b\u0005i!\u0011\u0001B;uS2L!\u0001H\r\u0003\u00155cuK]5uC\ndW\r\u0003\u0005\u001f\u0001\t\u0015\r\u0011\"\u0011 \u0003\r)\u0018\u000eZ\u000b\u0002AA\u0011\u0011e\n\b\u0003E\u0015j\u0011a\t\u0006\u0002I\u0005)1oY1mC&\u0011aeI\u0001\u0007!J,G-\u001a4\n\u0005!J#AB*ue&twM\u0003\u0002'G!\u001aQdK\u0019\u0011\u00051zS\"A\u0017\u000b\u000592\u0011AC1o]>$\u0018\r^5p]&\u0011\u0001'\f\u0002\u0006'&t7-Z\u0011\u0002e\u0005)\u0011GL\u001b/a!AA\u0007\u0001B\u0001B\u0003%\u0001%\u0001\u0003vS\u0012\u0004\u0003fA\u001a,c!Iq\u0007\u0001BC\u0002\u0013\u0005A\u0001O\u0001\u0010e\u0016\u001cx\u000e\u001c<fI\u001a{'/\\;mCV\t\u0011\b\u0005\u0002\u0014u%\u00111H\u0001\u0002\u0011%\u0016\u001cx\u000e\u001c<fIJ3uN]7vY\u0006D\u0001\"\u0010\u0001\u0003\u0002\u0003\u0006I!O\u0001\u0011e\u0016\u001cx\u000e\u001c<fI\u001a{'/\\;mC\u0002B\u0011b\u0010\u0001\u0003\u0006\u0004%\t\u0001\u0002!\u0002\u001bAL\u0007/\u001a7j]\u0016lu\u000eZ3m+\u0005\t\u0005CA\bC\u0013\t\u0019EAA\u0007QSB,G.\u001b8f\u001b>$W\r\u001c\u0005\t\u000b\u0002\u0011\t\u0011)A\u0005\u0003\u0006q\u0001/\u001b9fY&tW-T8eK2\u0004\u0003BB$\u0001\t\u0003\u0011\u0001*\u0001\u0004=S:LGO\u0010\u000b\u0005%%[E\nC\u0003\u001f\r\u0002\u0007\u0001\u0005K\u0002JWEBQa\u000e$A\u0002eBQa\u0010$A\u0002\u0005CQA\u0014\u0001\u0005B=\u000b\u0011\u0002\u001e:b]N4wN]7\u0015\u0005A#\u0007CA)b\u001d\t\u0011fL\u0004\u0002T9:\u0011Ak\u0017\b\u0003+js!AV-\u000e\u0003]S!\u0001\u0017\u0007\u0002\rq\u0012xn\u001c;?\u0013\u0005Y\u0011BA\u0005\u000b\u0013\t9\u0001\"\u0003\u0002^\r\u0005\u00191/\u001d7\n\u0005}\u0003\u0017a\u00029bG.\fw-\u001a\u0006\u0003;\u001aI!AY2\u0003\u0013\u0011\u000bG/\u0019$sC6,'BA0a\u0011\u0015)W\n1\u0001g\u0003\u001d!\u0017\r^1tKR\u0004$aZ7\u0011\u0007!L7.D\u0001a\u0013\tQ\u0007MA\u0004ECR\f7/\u001a;\u0011\u00051lG\u0002\u0001\u0003\n]\u0012\f\t\u0011!A\u0003\u0002=\u00141a\u0018\u00133#\t\u00018\u000f\u0005\u0002#c&\u0011!o\t\u0002\b\u001d>$\b.\u001b8h!\t\u0011C/\u0003\u0002vG\t\u0019\u0011I\\=)\u00075[s/I\u0001y\u0003\u0015\u0011d\u0006\r\u00181\u0011\u0015Q\b\u0001\"\u0011|\u0003=!(/\u00198tM>\u0014XnU2iK6\fGc\u0001?\u0002\u0006A\u0019Q0!\u0001\u000e\u0003yT!a 1\u0002\u000bQL\b/Z:\n\u0007\u0005\raP\u0001\u0006TiJ,8\r\u001e+za\u0016Da!a\u0002z\u0001\u0004a\u0018AB:dQ\u0016l\u0017\rK\u0002zWEBq!!\u0004\u0001\t\u0003\ny!\u0001\u0003d_BLHc\u0001\n\u0002\u0012!A\u00111CA\u0006\u0001\u0004\t)\"A\u0003fqR\u0014\u0018\r\u0005\u0003\u0002\u0018\u0005uQBAA\r\u0015\r\tY\u0002B\u0001\u0006a\u0006\u0014\u0018-\\\u0005\u0005\u0003?\tIB\u0001\u0005QCJ\fW.T1qQ\u0011\tYaK\u0019\t\u000f\u0005\u0015\u0002\u0001\"\u0011\u0002(\u0005AAo\\*ue&tw\rF\u0001!Q\u0011\t\u0019cK<\t\u000f\u00055\u0002\u0001\"\u0003\u00020\u0005qAO]1og\u001a|'/\u001c'bE\u0016dGc\u0001)\u00022!9Q-a\u000bA\u0002\u0005M\u0002\u0007BA\u001b\u0003s\u0001B\u0001[5\u00028A\u0019A.!\u000f\u0005\u0017\u0005m\u0012\u0011GA\u0001\u0002\u0003\u0015\ta\u001c\u0002\u0004?\u0012\u001a\u0004bBA \u0001\u0011%\u0011\u0011I\u0001\u0012G\",7m[\"b]R\u0013\u0018M\\:g_JlG\u0003BA\"\u0003\u0013\u00022AIA#\u0013\r\t9e\t\u0002\u0005+:LG\u000fC\u0004\u0002\b\u0005u\u0002\u0019\u0001?\t\u000f\u00055\u0003\u0001\"\u0011\u0002P\u0005)qO]5uKV\u0011\u0011\u0011\u000b\t\u00041\u0005M\u0013bAA+3\tAQ\nT,sSR,'\u000f\u000b\u0003\u0002L-:\bf\u0001\u0001,c!\u001a\u0001!!\u0018\u0011\u00071\ny&C\u0002\u0002b5\u0012A\"\u0012=qKJLW.\u001a8uC2<q!!\u001a\u0003\u0011\u0003\t9'A\u0007S\r>\u0014X.\u001e7b\u001b>$W\r\u001c\t\u0004'\u0005%dAB\u0001\u0003\u0011\u0003\tYg\u0005\u0005\u0002j\u00055\u00141OA=!\r\u0011\u0013qN\u0005\u0004\u0003c\u001a#AB!osJ+g\r\u0005\u0003\u0019\u0003k\u0012\u0012bAA<3\tQQ\n\u0014*fC\u0012\f'\r\\3\u0011\u0007\t\nY(C\u0002\u0002~\r\u0012AbU3sS\u0006d\u0017N_1cY\u0016DqaRA5\t\u0003\t\t\t\u0006\u0002\u0002h!A\u0011QQA5\t\u0003\n9)\u0001\u0003sK\u0006$WCAAE!\u0011A\u00121\u0012\n\n\u0007\u00055\u0015D\u0001\u0005N\u0019J+\u0017\rZ3sQ\u0011\t\u0019iK<\t\u0011\u0005M\u0015\u0011\u000eC!\u0003+\u000bA\u0001\\8bIR\u0019!#a&\t\u000f\u0005e\u0015\u0011\u0013a\u0001A\u0005!\u0001/\u0019;iQ\u0011\t\tjK<\u0007\u0013\u0005}\u0015\u0011\u000e\u0001\u0002j\u0005\u0005&a\u0005*G_JlW\u000f\\1N_\u0012,Gn\u0016:ji\u0016\u00148\u0003BAO\u0003#B!\"!*\u0002\u001e\n\u0005\t\u0015!\u0003\u0013\u0003!Ign\u001d;b]\u000e,\u0007bB$\u0002\u001e\u0012\u0005\u0011\u0011\u0016\u000b\u0005\u0003W\u000by\u000b\u0005\u0003\u0002.\u0006uUBAA5\u0011\u001d\t)+a*A\u0002IA\u0001\"a-\u0002\u001e\u0012E\u0013QW\u0001\tg\u00064X-S7qYR!\u00111IA\\\u0011\u001d\tI*!-A\u0002\u00012q!a/\u0002j\u0011\tiLA\nS\r>\u0014X.\u001e7b\u001b>$W\r\u001c*fC\u0012,'o\u0005\u0003\u0002:\u0006%\u0005bB$\u0002:\u0012\u0005\u0011\u0011\u0019\u000b\u0003\u0003\u0007\u0004B!!,\u0002:\"Q\u0011qYA]\u0005\u0004%I!!3\u0002\u0013\rd\u0017m]:OC6,WCAAf!\u0011\ti-a6\u000e\u0005\u0005='\u0002BAi\u0003'\fA\u0001\\1oO*\u0011\u0011Q[\u0001\u0005U\u00064\u0018-C\u0002)\u0003\u001fD\u0011\"a7\u0002:\u0002\u0006I!a3\u0002\u0015\rd\u0017m]:OC6,\u0007\u0005\u0003\u0005\u0002\u0014\u0006eF\u0011IAp)\r\u0011\u0012\u0011\u001d\u0005\b\u00033\u000bi\u000e1\u0001!\u0011)\t)/!\u001b\u0002\u0002\u0013%\u0011q]\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002jB!\u0011QZAv\u0013\u0011\ti/a4\u0003\r=\u0013'.Z2uQ\u0011\tIgK<)\t\u0005\r4f\u001e")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/RFormulaModel.class */
public class RFormulaModel extends Model<RFormulaModel> implements RFormulaBase, MLWritable {
    private final String uid;
    private final ResolvedRFormula resolvedFormula;
    private final PipelineModel pipelineModel;
    private final Param<String> labelCol;
    private final Param<String> featuresCol;

    /* compiled from: RFormula.scala */
    /* loaded from: input_file:org/apache/spark/ml/feature/RFormulaModel$RFormulaModelReader.class */
    public static class RFormulaModelReader extends MLReader<RFormulaModel> {
        private final String className = RFormulaModel.class.getName();

        private String className() {
            return this.className;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.spark.ml.util.MLReader
        public RFormulaModel load(String str) {
            DefaultParamsReader.Metadata loadMetadata = DefaultParamsReader$.MODULE$.loadMetadata(str, sc(), className());
            Row row = (Row) sparkSession().read().parquet(new Path(str, "data").toString()).select("label", Predef$.MODULE$.wrapRefArray(new String[]{"terms", "hasIntercept"})).head();
            RFormulaModel rFormulaModel = new RFormulaModel(loadMetadata.uid(), new ResolvedRFormula(row.getString(0), (Seq) row.getAs(1), row.getBoolean(2)), PipelineModel$.MODULE$.load(new Path(str, "pipelineModel").toString()));
            DefaultParamsReader$.MODULE$.getAndSetParams(rFormulaModel, loadMetadata);
            return rFormulaModel;
        }
    }

    /* compiled from: RFormula.scala */
    /* loaded from: input_file:org/apache/spark/ml/feature/RFormulaModel$RFormulaModelWriter.class */
    public static class RFormulaModelWriter extends MLWriter {
        private final RFormulaModel instance;

        @Override // org.apache.spark.ml.util.MLWriter
        public void saveImpl(String str) {
            DefaultParamsWriter$.MODULE$.saveMetadata(this.instance, str, sc(), DefaultParamsWriter$.MODULE$.saveMetadata$default$4(), DefaultParamsWriter$.MODULE$.saveMetadata$default$5());
            sparkSession().createDataFrame(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ResolvedRFormula[]{this.instance.resolvedFormula()})), scala.reflect.runtime.package$.MODULE$.universe().TypeTag().apply(scala.reflect.runtime.package$.MODULE$.universe().runtimeMirror(RFormulaModelWriter.class.getClassLoader()), new TypeCreator(this) { // from class: org.apache.spark.ml.feature.RFormulaModel$RFormulaModelWriter$$typecreator1$1
                public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                    mirror.universe();
                    return mirror.staticClass("org.apache.spark.ml.feature.ResolvedRFormula").asType().toTypeConstructor();
                }
            })).repartition(1).write().parquet(new Path(str, "data").toString());
            this.instance.pipelineModel().save(new Path(str, "pipelineModel").toString());
        }

        public RFormulaModelWriter(RFormulaModel rFormulaModel) {
            this.instance = rFormulaModel;
        }
    }

    public static RFormulaModel load(String str) {
        return RFormulaModel$.MODULE$.load(str);
    }

    public static MLReader<RFormulaModel> read() {
        return RFormulaModel$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        MLWritable.Cclass.save(this, str);
    }

    @Override // org.apache.spark.ml.feature.RFormulaBase
    public boolean hasLabelCol(StructType structType) {
        return RFormulaBase.Cclass.hasLabelCol(this, structType);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final void org$apache$spark$ml$param$shared$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasFeaturesCol
    public final String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public ResolvedRFormula resolvedFormula() {
        return this.resolvedFormula;
    }

    public PipelineModel pipelineModel() {
        return this.pipelineModel;
    }

    @Override // org.apache.spark.ml.Transformer
    public Dataset<Row> transform(Dataset<?> dataset) {
        checkCanTransform(dataset.schema());
        return transformLabel(pipelineModel().transform(dataset));
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        boolean z;
        checkCanTransform(structType);
        StructType transformSchema = pipelineModel().transformSchema(structType);
        if (resolvedFormula().label().isEmpty() || hasLabelCol(transformSchema)) {
            return transformSchema;
        }
        if (!structType.exists(new RFormulaModel$$anonfun$transformSchema$2(this))) {
            return transformSchema;
        }
        DataType dataType = structType.apply(resolvedFormula().label()).dataType();
        if (dataType instanceof NumericType) {
            z = true;
        } else {
            BooleanType$ booleanType$ = BooleanType$.MODULE$;
            z = booleanType$ != null ? booleanType$.equals(dataType) : dataType == null;
        }
        return new StructType((StructField[]) Predef$.MODULE$.refArrayOps(transformSchema.fields()).$colon$plus(new StructField((String) $(labelCol()), DoubleType$.MODULE$, !z, StructField$.MODULE$.apply$default$4()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class))));
    }

    @Override // org.apache.spark.ml.Model, org.apache.spark.ml.Transformer, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public RFormulaModel copy(ParamMap paramMap) {
        return (RFormulaModel) copyValues(new RFormulaModel(uid(), resolvedFormula(), pipelineModel()).setParent(parent()), paramMap);
    }

    @Override // org.apache.spark.ml.PipelineStage, org.apache.spark.ml.util.Identifiable
    public String toString() {
        return new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"RFormulaModel(", ") (uid=", ")"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{resolvedFormula(), uid()}));
    }

    private Dataset<Row> transformLabel(Dataset<?> dataset) {
        boolean z;
        String label = resolvedFormula().label();
        if (label.isEmpty() || hasLabelCol(dataset.schema())) {
            return dataset.toDF();
        }
        if (!dataset.schema().exists(new RFormulaModel$$anonfun$transformLabel$1(this, label))) {
            return dataset.toDF();
        }
        DataType dataType = dataset.schema().apply(label).dataType();
        if (dataType instanceof NumericType) {
            z = true;
        } else {
            BooleanType$ booleanType$ = BooleanType$.MODULE$;
            z = booleanType$ != null ? booleanType$.equals(dataType) : dataType == null;
        }
        if (z) {
            return dataset.withColumn((String) $(labelCol()), dataset.apply(label).cast(DoubleType$.MODULE$));
        }
        throw new IllegalArgumentException(new StringBuilder().append("Unsupported type for label: ").append(dataType).toString());
    }

    private void checkCanTransform(StructType structType) {
        Seq seq = (Seq) structType.map(new RFormulaModel$$anonfun$3(this), Seq$.MODULE$.canBuildFrom());
        Predef$.MODULE$.require(!seq.contains($(featuresCol())), new RFormulaModel$$anonfun$checkCanTransform$1(this));
        Predef$.MODULE$.require(!seq.contains($(labelCol())) || (structType.apply((String) $(labelCol())).dataType() instanceof NumericType), new RFormulaModel$$anonfun$checkCanTransform$2(this));
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        return new RFormulaModelWriter(this);
    }

    public RFormulaModel(String str, ResolvedRFormula resolvedRFormula, PipelineModel pipelineModel) {
        this.uid = str;
        this.resolvedFormula = resolvedRFormula;
        this.pipelineModel = pipelineModel;
        HasFeaturesCol.Cclass.$init$(this);
        HasLabelCol.Cclass.$init$(this);
        RFormulaBase.Cclass.$init$(this);
        MLWritable.Cclass.$init$(this);
    }
}
