package org.apache.spark.ml.feature;

import java.io.IOException;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.collection.OpenHashMap;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Seq;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Long$;
import scala.math.Ordering$String$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: CountVectorizer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ue\u0001\u0002\u000b\u0016\u0001\u0001B\u0001B\r\u0001\u0003\u0006\u0004%\te\r\u0005\t\u0015\u0002\u0011\t\u0011)A\u0005i!)A\n\u0001C\u0001\u001b\")A\n\u0001C\u0001%\")A\u000b\u0001C\u0001+\")!\f\u0001C\u00017\")a\f\u0001C\u0001?\")a\r\u0001C\u0001O\")Q\u000e\u0001C\u0001]\")1\u000f\u0001C\u0001i\")q\u000f\u0001C\u0001q\"9\u0011\u0011\u0001\u0001\u0005B\u0005\r\u0001bBA\u0018\u0001\u0011\u0005\u0013\u0011\u0007\u0005\b\u0003\u000b\u0002A\u0011IA$\u000f\u001d\ti&\u0006E\u0001\u0003?2a\u0001F\u000b\t\u0002\u0005\u0005\u0004B\u0002'\u0011\t\u0003\t)\bC\u0004\u0002xA!\t%!\u001f\t\u0013\u0005\u0015\u0005#!A\u0005\n\u0005\u001d%aD\"pk:$h+Z2u_JL'0\u001a:\u000b\u0005Y9\u0012a\u00024fCR,(/\u001a\u0006\u00031e\t!!\u001c7\u000b\u0005iY\u0012!B:qCJ\\'B\u0001\u000f\u001e\u0003\u0019\t\u0007/Y2iK*\ta$A\u0002pe\u001e\u001c\u0001a\u0005\u0003\u0001C%b\u0003c\u0001\u0012$K5\tq#\u0003\u0002%/\tIQi\u001d;j[\u0006$xN\u001d\t\u0003M\u001dj\u0011!F\u0005\u0003QU\u0011AcQ8v]R4Vm\u0019;pe&TXM]'pI\u0016d\u0007C\u0001\u0014+\u0013\tYSCA\u000bD_VtGOV3di>\u0014\u0018N_3s!\u0006\u0014\u0018-\\:\u0011\u00055\u0002T\"\u0001\u0018\u000b\u0005=:\u0012\u0001B;uS2L!!\r\u0018\u0003+\u0011+g-Y;miB\u000b'/Y7t/JLG/\u00192mK\u0006\u0019Q/\u001b3\u0016\u0003Q\u0002\"!\u000e \u000f\u0005Yb\u0004CA\u001c;\u001b\u0005A$BA\u001d \u0003\u0019a$o\\8u})\t1(A\u0003tG\u0006d\u0017-\u0003\u0002>u\u00051\u0001K]3eK\u001aL!a\u0010!\u0003\rM#(/\u001b8h\u0015\ti$\bK\u0002\u0002\u0005\"\u0003\"a\u0011$\u000e\u0003\u0011S!!R\r\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002H\t\n)1+\u001b8dK\u0006\n\u0011*A\u00032]Ur\u0003'\u0001\u0003vS\u0012\u0004\u0003f\u0001\u0002C\u0011\u00061A(\u001b8jiz\"\"AT(\u0011\u0005\u0019\u0002\u0001\"\u0002\u001a\u0004\u0001\u0004!\u0004fA(C\u0011\"\u001a1A\u0011%\u0015\u00039C3\u0001\u0002\"I\u0003-\u0019X\r^%oaV$8i\u001c7\u0015\u0005Y;V\"\u0001\u0001\t\u000ba+\u0001\u0019\u0001\u001b\u0002\u000bY\fG.^3)\u0007\u0015\u0011\u0005*\u0001\u0007tKR|U\u000f\u001e9vi\u000e{G\u000e\u0006\u0002W9\")\u0001L\u0002a\u0001i!\u001aaA\u0011%\u0002\u0019M,GOV8dC\n\u001c\u0016N_3\u0015\u0005Y\u0003\u0007\"\u0002-\b\u0001\u0004\t\u0007C\u00012d\u001b\u0005Q\u0014B\u00013;\u0005\rIe\u000e\u001e\u0015\u0004\u000f\tC\u0015\u0001C:fi6Kg\u000e\u0012$\u0015\u0005YC\u0007\"\u0002-\t\u0001\u0004I\u0007C\u00012k\u0013\tY'H\u0001\u0004E_V\u0014G.\u001a\u0015\u0004\u0011\tC\u0015\u0001C:fi6\u000b\u0007\u0010\u0012$\u0015\u0005Y{\u0007\"\u0002-\n\u0001\u0004I\u0007fA\u0005Cc\u0006\n!/A\u00033]Qr\u0003'\u0001\u0005tKRl\u0015N\u001c+G)\t1V\u000fC\u0003Y\u0015\u0001\u0007\u0011\u000eK\u0002\u000b\u0005\"\u000b\u0011b]3u\u0005&t\u0017M]=\u0015\u0005YK\b\"\u0002-\f\u0001\u0004Q\bC\u00012|\u0013\ta(HA\u0004C_>dW-\u00198)\u0007-\u0011e0I\u0001��\u0003\u0015\u0011d\u0006\r\u00181\u0003\r1\u0017\u000e\u001e\u000b\u0004K\u0005\u0015\u0001bBA\u0004\u0019\u0001\u0007\u0011\u0011B\u0001\bI\u0006$\u0018m]3ua\u0011\tY!a\u0007\u0011\r\u00055\u00111CA\f\u001b\t\tyAC\u0002\u0002\u0012e\t1a]9m\u0013\u0011\t)\"a\u0004\u0003\u000f\u0011\u000bG/Y:fiB!\u0011\u0011DA\u000e\u0019\u0001!A\"!\b\u0002\u0006\u0005\u0005\t\u0011!B\u0001\u0003?\u00111a\u0018\u00132#\u0011\t\t#a\n\u0011\u0007\t\f\u0019#C\u0002\u0002&i\u0012qAT8uQ&tw\rE\u0002c\u0003SI1!a\u000b;\u0005\r\te.\u001f\u0015\u0004\u0019\ts\u0018a\u0004;sC:\u001chm\u001c:n'\u000eDW-\\1\u0015\t\u0005M\u0012q\b\t\u0005\u0003k\tY$\u0004\u0002\u00028)!\u0011\u0011HA\b\u0003\u0015!\u0018\u0010]3t\u0013\u0011\ti$a\u000e\u0003\u0015M#(/^2u)f\u0004X\rC\u0004\u0002B5\u0001\r!a\r\u0002\rM\u001c\u0007.Z7bQ\ri!\tS\u0001\u0005G>\u0004\u0018\u0010F\u0002O\u0003\u0013Bq!a\u0013\u000f\u0001\u0004\ti%A\u0003fqR\u0014\u0018\r\u0005\u0003\u0002P\u0005USBAA)\u0015\r\t\u0019fF\u0001\u0006a\u0006\u0014\u0018-\\\u0005\u0005\u0003/\n\tF\u0001\u0005QCJ\fW.T1qQ\rq!\t\u0013\u0015\u0004\u0001\tC\u0015aD\"pk:$h+Z2u_JL'0\u001a:\u0011\u0005\u0019\u00022c\u0002\t\u0002d\u0005%\u0014q\u000e\t\u0004E\u0006\u0015\u0014bAA4u\t1\u0011I\\=SK\u001a\u0004B!LA6\u001d&\u0019\u0011Q\u000e\u0018\u0003+\u0011+g-Y;miB\u000b'/Y7t%\u0016\fG-\u00192mKB\u0019!-!\u001d\n\u0007\u0005M$H\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002`\u0005!An\\1e)\rq\u00151\u0010\u0005\u0007\u0003{\u0012\u0002\u0019\u0001\u001b\u0002\tA\fG\u000f\u001b\u0015\u0005%\t\u000b\t)\t\u0002\u0002\u0004\u0006)\u0011G\f\u001c/a\u0005Y!/Z1e%\u0016\u001cx\u000e\u001c<f)\t\tI\t\u0005\u0003\u0002\f\u0006UUBAAG\u0015\u0011\ty)!%\u0002\t1\fgn\u001a\u0006\u0003\u0003'\u000bAA[1wC&!\u0011qSAG\u0005\u0019y%M[3di\"\"\u0001CQAAQ\u0011y!)!!")
/* loaded from: input_file:org/apache/spark/ml/feature/CountVectorizer.class */
public class CountVectorizer extends Estimator<CountVectorizerModel> implements CountVectorizerParams, DefaultParamsWritable {
    private final String uid;
    private final IntParam vocabSize;
    private final DoubleParam minDF;
    private final DoubleParam maxDF;
    private final DoubleParam minTF;
    private final BooleanParam binary;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public static CountVectorizer load(String str) {
        return CountVectorizer$.MODULE$.load(str);
    }

    public static MLReader<CountVectorizer> read() {
        return CountVectorizer$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        MLWriter write;
        write = write();
        return write;
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        save(str);
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public int getVocabSize() {
        return getVocabSize();
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public double getMinDF() {
        return getMinDF();
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public double getMaxDF() {
        return getMaxDF();
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public StructType validateAndTransformSchema(StructType structType) {
        return validateAndTransformSchema(structType);
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public double getMinTF() {
        return getMinTF();
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public boolean getBinary() {
        return getBinary();
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final String getOutputCol() {
        String outputCol;
        outputCol = getOutputCol();
        return outputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        String inputCol;
        inputCol = getInputCol();
        return inputCol;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public IntParam vocabSize() {
        return this.vocabSize;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public DoubleParam minDF() {
        return this.minDF;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public DoubleParam maxDF() {
        return this.maxDF;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public DoubleParam minTF() {
        return this.minTF;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public BooleanParam binary() {
        return this.binary;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public void org$apache$spark$ml$feature$CountVectorizerParams$_setter_$vocabSize_$eq(IntParam intParam) {
        this.vocabSize = intParam;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public void org$apache$spark$ml$feature$CountVectorizerParams$_setter_$minDF_$eq(DoubleParam doubleParam) {
        this.minDF = doubleParam;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public void org$apache$spark$ml$feature$CountVectorizerParams$_setter_$maxDF_$eq(DoubleParam doubleParam) {
        this.maxDF = doubleParam;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public void org$apache$spark$ml$feature$CountVectorizerParams$_setter_$minTF_$eq(DoubleParam doubleParam) {
        this.minTF = doubleParam;
    }

    @Override // org.apache.spark.ml.feature.CountVectorizerParams
    public void org$apache$spark$ml$feature$CountVectorizerParams$_setter_$binary_$eq(BooleanParam booleanParam) {
        this.binary = booleanParam;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param<String> param) {
        this.outputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param<String> param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public CountVectorizer setInputCol(String str) {
        return (CountVectorizer) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    public CountVectorizer setOutputCol(String str) {
        return (CountVectorizer) set((Param<Param<String>>) outputCol(), (Param<String>) str);
    }

    public CountVectorizer setVocabSize(int i) {
        return (CountVectorizer) set((Param<IntParam>) vocabSize(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public CountVectorizer setMinDF(double d) {
        return (CountVectorizer) set((Param<DoubleParam>) minDF(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public CountVectorizer setMaxDF(double d) {
        return (CountVectorizer) set((Param<DoubleParam>) maxDF(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public CountVectorizer setMinTF(double d) {
        return (CountVectorizer) set((Param<DoubleParam>) minTF(), (DoubleParam) BoxesRunTime.boxToDouble(d));
    }

    public CountVectorizer setBinary(boolean z) {
        return (CountVectorizer) set((Param<BooleanParam>) binary(), (BooleanParam) BoxesRunTime.boxToBoolean(z));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.Estimator
    public CountVectorizerModel fit(Dataset<?> dataset) {
        transformSchema(dataset.schema(), true);
        int unboxToInt = BoxesRunTime.unboxToInt($(vocabSize()));
        RDD map = dataset.select((String) $(inputCol()), Predef$.MODULE$.wrapRefArray(new String[0])).rdd().map(row -> {
            return (Seq) row.getAs(0);
        }, ClassTag$.MODULE$.apply(Seq.class));
        boolean z = BoxesRunTime.unboxToDouble($(minDF())) < 1.0d || BoxesRunTime.unboxToDouble($(maxDF())) < 1.0d;
        Some some = z ? new Some(BoxesRunTime.boxToLong(map.cache().count())) : None$.MODULE$;
        double unboxToDouble = BoxesRunTime.unboxToDouble($(minDF())) >= 1.0d ? BoxesRunTime.unboxToDouble($(minDF())) : BoxesRunTime.unboxToDouble($(minDF())) * BoxesRunTime.unboxToLong(some.get());
        double unboxToDouble2 = BoxesRunTime.unboxToDouble($(maxDF())) >= 1.0d ? BoxesRunTime.unboxToDouble($(maxDF())) : BoxesRunTime.unboxToDouble($(maxDF())) * BoxesRunTime.unboxToLong(some.get());
        Predef$.MODULE$.require(unboxToDouble2 >= unboxToDouble, () -> {
            return "maxDF must be >= minDF.";
        });
        RDD reduceByKey = RDD$.MODULE$.rddToPairRDDFunctions(map.flatMap(seq -> {
            OpenHashMap.mcJ.sp spVar = new OpenHashMap.mcJ.sp(ClassTag$.MODULE$.apply(String.class), ClassTag$.MODULE$.Long());
            seq.foreach(str -> {
                return BoxesRunTime.boxToLong($anonfun$fit$4(spVar, str));
            });
            return (Iterable) spVar.map(tuple2 -> {
                if (tuple2 != null) {
                    return new Tuple2((String) tuple2._1(), new Tuple2.mcJI.sp(tuple2._2$mcJ$sp(), 1));
                }
                throw new MatchError(tuple2);
            }, Iterable$.MODULE$.canBuildFrom());
        }, ClassTag$.MODULE$.apply(Tuple2.class)), ClassTag$.MODULE$.apply(String.class), ClassTag$.MODULE$.apply(Tuple2.class), Ordering$String$.MODULE$).reduceByKey((tuple2, tuple22) -> {
            Tuple2 tuple2 = new Tuple2(tuple2, tuple22);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2) tuple2._1();
                Tuple2 tuple23 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    long _1$mcJ$sp = tuple22._1$mcJ$sp();
                    int _2$mcI$sp = tuple22._2$mcI$sp();
                    if (tuple23 != null) {
                        return new Tuple2.mcJI.sp(_1$mcJ$sp + tuple23._1$mcJ$sp(), _2$mcI$sp + tuple23._2$mcI$sp());
                    }
                }
            }
            throw new MatchError(tuple2);
        });
        RDD cache = (isSet(minDF()) || isSet(maxDF()) ? reduceByKey.filter(tuple23 -> {
            return BoxesRunTime.boxToBoolean($anonfun$fit$9(unboxToDouble, unboxToDouble2, tuple23));
        }) : reduceByKey).map(tuple24 -> {
            if (tuple24 != null) {
                String str = (String) tuple24._1();
                Tuple2 tuple24 = (Tuple2) tuple24._2();
                if (tuple24 != null) {
                    return new Tuple2(str, BoxesRunTime.boxToLong(tuple24._1$mcJ$sp()));
                }
            }
            throw new MatchError(tuple24);
        }, ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        if (z) {
            map.unpersist(map.unpersist$default$1());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        String[] strArr = (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) cache.top((int) scala.math.package$.MODULE$.min(cache.count(), unboxToInt), scala.package$.MODULE$.Ordering().by(tuple25 -> {
            return BoxesRunTime.boxToLong(tuple25._2$mcJ$sp());
        }, Ordering$Long$.MODULE$)))).map(tuple26 -> {
            return (String) tuple26._1();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
        Predef$.MODULE$.require(strArr.length > 0, () -> {
            return "The vocabulary size should be > 0. Lower minDF as necessary.";
        });
        return (CountVectorizerModel) copyValues(new CountVectorizerModel(uid(), strArr).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        return validateAndTransformSchema(structType);
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public CountVectorizer copy(ParamMap paramMap) {
        return (CountVectorizer) defaultCopy(paramMap);
    }

    @Override // org.apache.spark.ml.Estimator
    public /* bridge */ /* synthetic */ CountVectorizerModel fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public static final /* synthetic */ long $anonfun$fit$4(OpenHashMap openHashMap, String str) {
        return openHashMap.changeValue$mcJ$sp(str, () -> {
            return 1L;
        }, j -> {
            return j + 1;
        });
    }

    public static final /* synthetic */ boolean $anonfun$fit$9(double d, double d2, Tuple2 tuple2) {
        Tuple2 tuple22;
        if (tuple2 == null || (tuple22 = (Tuple2) tuple2._2()) == null) {
            throw new MatchError(tuple2);
        }
        int _2$mcI$sp = tuple22._2$mcI$sp();
        return ((double) _2$mcI$sp) >= d && ((double) _2$mcI$sp) <= d2;
    }

    public CountVectorizer(String str) {
        this.uid = str;
        org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(new Param<>(this, "inputCol", "input column name"));
        HasOutputCol.$init$((HasOutputCol) this);
        CountVectorizerParams.$init$((CountVectorizerParams) this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$((DefaultParamsWritable) this);
    }

    public CountVectorizer() {
        this(Identifiable$.MODULE$.randomUID("cntVec"));
    }
}
