package org.apache.spark.ml.evaluation;

import java.io.IOException;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.FractionalType;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: RegressionEvaluator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015e\u0001B\n\u0015\u0005}A\u0001\"\u000e\u0001\u0003\u0006\u0004%\tE\u000e\u0005\t\u001b\u0002\u0011\t\u0011)A\u0005o!)q\n\u0001C\u0001!\")q\n\u0001C\u0001+\"9q\u000b\u0001b\u0001\n\u0003A\u0006B\u00020\u0001A\u0003%\u0011\fC\u0003a\u0001\u0011\u0005a\u0007C\u0003c\u0001\u0011\u00051\rC\u0003i\u0001\u0011\u0005\u0011\u000eC\u0003m\u0001\u0011\u0005Q\u000eC\u0003q\u0001\u0011\u0005\u0013\u000fC\u0004\u0002\u001c\u0001!\t%!\b\t\u000f\u0005\u001d\u0002\u0001\"\u0011\u0002*\u001d9\u0011Q\t\u000b\t\u0002\u0005\u001dcAB\n\u0015\u0011\u0003\tI\u0005\u0003\u0004P\u001f\u0011\u0005\u0011Q\f\u0005\b\u0003?zA\u0011IA1\u0011%\tigDA\u0001\n\u0013\tyGA\nSK\u001e\u0014Xm]:j_:,e/\u00197vCR|'O\u0003\u0002\u0016-\u0005QQM^1mk\u0006$\u0018n\u001c8\u000b\u0005]A\u0012AA7m\u0015\tI\"$A\u0003ta\u0006\u00148N\u0003\u0002\u001c9\u00051\u0011\r]1dQ\u0016T\u0011!H\u0001\u0004_J<7\u0001A\n\u0006\u0001\u0001\"Cf\f\t\u0003C\tj\u0011\u0001F\u0005\u0003GQ\u0011\u0011\"\u0012<bYV\fGo\u001c:\u0011\u0005\u0015RS\"\u0001\u0014\u000b\u0005\u001dB\u0013AB:iCJ,GM\u0003\u0002*-\u0005)\u0001/\u0019:b[&\u00111F\n\u0002\u0011\u0011\u0006\u001c\bK]3eS\u000e$\u0018n\u001c8D_2\u0004\"!J\u0017\n\u000592#a\u0003%bg2\u000b'-\u001a7D_2\u0004\"\u0001M\u001a\u000e\u0003ER!A\r\f\u0002\tU$\u0018\u000e\\\u0005\u0003iE\u0012Q\u0003R3gCVdG\u000fU1sC6\u001cxK]5uC\ndW-A\u0002vS\u0012,\u0012a\u000e\t\u0003q\u0005s!!O \u0011\u0005ijT\"A\u001e\u000b\u0005qr\u0012A\u0002\u001fs_>$hHC\u0001?\u0003\u0015\u00198-\u00197b\u0013\t\u0001U(\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u0005\u000e\u0013aa\u0015;sS:<'B\u0001!>Q\r\tQi\u0013\t\u0003\r&k\u0011a\u0012\u0006\u0003\u0011b\t!\"\u00198o_R\fG/[8o\u0013\tQuIA\u0003TS:\u001cW-I\u0001M\u0003\u0015\td\u0006\u000e\u00181\u0003\u0011)\u0018\u000e\u001a\u0011)\u0007\t)5*\u0001\u0004=S:LGO\u0010\u000b\u0003#J\u0003\"!\t\u0001\t\u000bU\u001a\u0001\u0019A\u001c)\u0007I+5\nK\u0002\u0004\u000b.#\u0012!\u0015\u0015\u0004\t\u0015[\u0015AC7fiJL7MT1nKV\t\u0011\fE\u0002[7^j\u0011\u0001K\u0005\u00039\"\u0012Q\u0001U1sC6D3!B#L\u0003-iW\r\u001e:jG:\u000bW.\u001a\u0011)\u0007\u0019)5*A\u0007hKRlU\r\u001e:jG:\u000bW.\u001a\u0015\u0004\u000f\u0015[\u0015!D:fi6+GO]5d\u001d\u0006lW\r\u0006\u0002eK6\t\u0001\u0001C\u0003g\u0011\u0001\u0007q'A\u0003wC2,X\rK\u0002\t\u000b.\u000b\u0001c]3u!J,G-[2uS>t7i\u001c7\u0015\u0005\u0011T\u0007\"\u00024\n\u0001\u00049\u0004fA\u0005F\u0017\u0006Y1/\u001a;MC\n,GnQ8m)\t!g\u000eC\u0003g\u0015\u0001\u0007q\u0007K\u0002\u000b\u000b.\u000b\u0001\"\u001a<bYV\fG/\u001a\u000b\u0003eZ\u0004\"a\u001d;\u000e\u0003uJ!!^\u001f\u0003\r\u0011{WO\u00197f\u0011\u001598\u00021\u0001y\u0003\u001d!\u0017\r^1tKR\u00044!_A\u0002!\rQXp`\u0007\u0002w*\u0011A\u0010G\u0001\u0004gFd\u0017B\u0001@|\u0005\u001d!\u0015\r^1tKR\u0004B!!\u0001\u0002\u00041\u0001AaCA\u0003m\u0006\u0005\t\u0011!B\u0001\u0003\u000f\u00111a\u0018\u00132#\u0011\tI!a\u0004\u0011\u0007M\fY!C\u0002\u0002\u000eu\u0012qAT8uQ&tw\rE\u0002t\u0003#I1!a\u0005>\u0005\r\te.\u001f\u0015\u0005\u0017\u0015\u000b9\"\t\u0002\u0002\u001a\u0005)!G\f\u0019/a\u0005q\u0011n\u001d'be\u001e,'OQ3ui\u0016\u0014XCAA\u0010!\r\u0019\u0018\u0011E\u0005\u0004\u0003Gi$a\u0002\"p_2,\u0017M\u001c\u0015\u0004\u0019\u0015[\u0015\u0001B2paf$2!UA\u0016\u0011\u001d\ti#\u0004a\u0001\u0003_\tQ!\u001a=ue\u0006\u00042AWA\u0019\u0013\r\t\u0019\u0004\u000b\u0002\t!\u0006\u0014\u0018-\\'ba\"\"Q\"RA\u001cC\t\tI$A\u00032]Ur\u0003\u0007K\u0002\u0001\u0003{\u00012ARA \u0013\r\t\te\u0012\u0002\r\u000bb\u0004XM]5nK:$\u0018\r\u001c\u0015\u0004\u0001\u0015[\u0015a\u0005*fOJ,7o]5p]\u00163\u0018\r\\;bi>\u0014\bCA\u0011\u0010'\u001dy\u00111JA)\u0003/\u00022a]A'\u0013\r\ty%\u0010\u0002\u0007\u0003:L(+\u001a4\u0011\tA\n\u0019&U\u0005\u0004\u0003+\n$!\u0006#fM\u0006,H\u000e\u001e)be\u0006l7OU3bI\u0006\u0014G.\u001a\t\u0004g\u0006e\u0013bAA.{\ta1+\u001a:jC2L'0\u00192mKR\u0011\u0011qI\u0001\u0005Y>\fG\rF\u0002R\u0003GBa!!\u001a\u0012\u0001\u00049\u0014\u0001\u00029bi\"DC!E#\u0002j\u0005\u0012\u00111N\u0001\u0006c92d\u0006M\u0001\fe\u0016\fGMU3t_24X\r\u0006\u0002\u0002rA!\u00111OA?\u001b\t\t)H\u0003\u0003\u0002x\u0005e\u0014\u0001\u00027b]\u001eT!!a\u001f\u0002\t)\fg/Y\u0005\u0005\u0003\u007f\n)H\u0001\u0004PE*,7\r\u001e\u0015\u0005\u001f\u0015\u000bI\u0007\u000b\u0003\u000f\u000b\u0006%\u0004")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/evaluation/RegressionEvaluator.class */
public final class RegressionEvaluator extends Evaluator implements HasPredictionCol, HasLabelCol, DefaultParamsWritable {
    private final String uid;
    private final Param<String> metricName;
    private final Param<String> labelCol;
    private final Param<String> predictionCol;

    public static RegressionEvaluator load(String str) {
        return RegressionEvaluator$.MODULE$.load(str);
    }

    public static MLReader<RegressionEvaluator> read() {
        return RegressionEvaluator$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.util.DefaultParamsWritable, org.apache.spark.ml.util.MLWritable
    public MLWriter write() {
        MLWriter write;
        write = write();
        return write;
    }

    @Override // org.apache.spark.ml.util.MLWritable
    public void save(String str) throws IOException {
        save(str);
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final String getLabelCol() {
        String labelCol;
        labelCol = getLabelCol();
        return labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final String getPredictionCol() {
        String predictionCol;
        predictionCol = getPredictionCol();
        return predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasLabelCol
    public final void org$apache$spark$ml$param$shared$HasLabelCol$_setter_$labelCol_$eq(Param<String> param) {
        this.labelCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasPredictionCol
    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param<String> param) {
        this.predictionCol = param;
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public Param<String> metricName() {
        return this.metricName;
    }

    public String getMetricName() {
        return (String) $(metricName());
    }

    public RegressionEvaluator setMetricName(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) metricName(), (Param<String>) str);
    }

    public RegressionEvaluator setPredictionCol(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) predictionCol(), (Param<String>) str);
    }

    public RegressionEvaluator setLabelCol(String str) {
        return (RegressionEvaluator) set((Param<Param<String>>) labelCol(), (Param<String>) str);
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public double evaluate(Dataset<?> dataset) {
        double meanAbsoluteError;
        StructType schema = dataset.schema();
        SchemaUtils$.MODULE$.checkColumnTypes(schema, (String) $(predictionCol()), (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new FractionalType[]{DoubleType$.MODULE$, FloatType$.MODULE$})), SchemaUtils$.MODULE$.checkColumnTypes$default$4());
        SchemaUtils$.MODULE$.checkNumericType(schema, (String) $(labelCol()), SchemaUtils$.MODULE$.checkNumericType$default$3());
        RegressionMetrics regressionMetrics = new RegressionMetrics((RDD<Tuple2<Object, Object>>) dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col((String) $(predictionCol())).cast(DoubleType$.MODULE$), functions$.MODULE$.col((String) $(labelCol())).cast(DoubleType$.MODULE$)})).rdd().map(row -> {
            Some unapplySeq = Row$.MODULE$.unapplySeq(row);
            if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqLike) unapplySeq.get()).lengthCompare(2) == 0) {
                Object apply = ((SeqLike) unapplySeq.get()).apply(0);
                Object apply2 = ((SeqLike) unapplySeq.get()).apply(1);
                if (apply instanceof Double) {
                    double unboxToDouble = BoxesRunTime.unboxToDouble(apply);
                    if (apply2 instanceof Double) {
                        return new Tuple2.mcDD.sp(unboxToDouble, BoxesRunTime.unboxToDouble(apply2));
                    }
                }
            }
            throw new MatchError(row);
        }, ClassTag$.MODULE$.apply(Tuple2.class)));
        String str = (String) $(metricName());
        if ("rmse".equals(str)) {
            meanAbsoluteError = regressionMetrics.rootMeanSquaredError();
        } else if ("mse".equals(str)) {
            meanAbsoluteError = regressionMetrics.meanSquaredError();
        } else if ("r2".equals(str)) {
            meanAbsoluteError = regressionMetrics.r2();
        } else {
            if (!"mae".equals(str)) {
                throw new MatchError(str);
            }
            meanAbsoluteError = regressionMetrics.meanAbsoluteError();
        }
        return meanAbsoluteError;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator
    public boolean isLargerBetter() {
        boolean z;
        String str = (String) $(metricName());
        if ("rmse".equals(str)) {
            z = false;
        } else if ("mse".equals(str)) {
            z = false;
        } else if ("r2".equals(str)) {
            z = true;
        } else {
            if (!"mae".equals(str)) {
                throw new MatchError(str);
            }
            z = false;
        }
        return z;
    }

    @Override // org.apache.spark.ml.evaluation.Evaluator, org.apache.spark.ml.param.Params
    public RegressionEvaluator copy(ParamMap paramMap) {
        return (RegressionEvaluator) defaultCopy(paramMap);
    }

    public RegressionEvaluator(String str) {
        this.uid = str;
        HasPredictionCol.$init$((HasPredictionCol) this);
        HasLabelCol.$init$((HasLabelCol) this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$((DefaultParamsWritable) this);
        this.metricName = new Param<>(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", ParamValidators$.MODULE$.inArray(new String[]{"mse", "rmse", "r2", "mae"}));
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{metricName().$minus$greater("rmse")}));
    }

    public RegressionEvaluator() {
        this(Identifiable$.MODULE$.randomUID("regEval"));
    }
}
