package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import scala.MatchError;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;

/* compiled from: HuberAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00054Qa\u0004\t\u0001)qA\u0001B\f\u0001\u0003\u0002\u0003\u0006I\u0001\r\u0005\tg\u0001\u0011\t\u0011)A\u0005i!Aq\u0007\u0001B\u0001B\u0003%\u0001\b\u0003\u0005B\u0001\t\u0005\t\u0015!\u0003C\u0011\u0015I\u0005\u0001\"\u0001K\u0011\u001d\u0001\u0006A1A\u0005RECa!\u0016\u0001!\u0002\u0013\u0011\u0006b\u0002,\u0001\u0005\u0004%I!\u0015\u0005\u0007/\u0002\u0001\u000b\u0011\u0002*\t\u000fa\u0003!\u0019!C\u00053\"1!\f\u0001Q\u0001\nQBqa\u0017\u0001C\u0002\u0013%\u0011\f\u0003\u0004]\u0001\u0001\u0006I\u0001\u000e\u0005\u0006;\u0002!\tA\u0018\u0002\u0010\u0011V\u0014WM]!hOJ,w-\u0019;pe*\u0011\u0011CE\u0001\u000bC\u001e<'/Z4bi>\u0014(BA\n\u0015\u0003\u0015y\u0007\u000f^5n\u0015\t)b#\u0001\u0002nY*\u0011q\u0003G\u0001\u0006gB\f'o\u001b\u0006\u00033i\ta!\u00199bG\",'\"A\u000e\u0002\u0007=\u0014xmE\u0002\u0001;\r\u0002\"AH\u0011\u000e\u0003}Q\u0011\u0001I\u0001\u0006g\u000e\fG.Y\u0005\u0003E}\u0011a!\u00118z%\u00164\u0007\u0003\u0002\u0013&O5j\u0011\u0001E\u0005\u0003MA\u0011A\u0004R5gM\u0016\u0014XM\u001c;jC\ndW\rT8tg\u0006;wM]3hCR|'\u000f\u0005\u0002)W5\t\u0011F\u0003\u0002+)\u00059a-Z1ukJ,\u0017B\u0001\u0017*\u0005!Ien\u001d;b]\u000e,\u0007C\u0001\u0013\u0001\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u\u0007\u0001\u0001\"AH\u0019\n\u0005Iz\"a\u0002\"p_2,\u0017M\\\u0001\bKB\u001c\u0018\u000e\\8o!\tqR'\u0003\u00027?\t1Ai\\;cY\u0016\fQBY2GK\u0006$XO]3t'R$\u0007cA\u001d=}5\t!H\u0003\u0002<-\u0005I!M]8bI\u000e\f7\u000f^\u0005\u0003{i\u0012\u0011B\u0011:pC\u0012\u001c\u0017m\u001d;\u0011\u0007yyD'\u0003\u0002A?\t)\u0011I\u001d:bs\u0006a!m\u0019)be\u0006lW\r^3sgB\u0019\u0011\bP\"\u0011\u0005\u0011;U\"A#\u000b\u0005\u0019#\u0012A\u00027j]\u0006dw-\u0003\u0002I\u000b\n1a+Z2u_J\fa\u0001P5oSRtD\u0003B&N\u001d>#\"!\f'\t\u000b\u0005+\u0001\u0019\u0001\"\t\u000b9*\u0001\u0019\u0001\u0019\t\u000bM*\u0001\u0019\u0001\u001b\t\u000b]*\u0001\u0019\u0001\u001d\u0002\u0007\u0011LW.F\u0001S!\tq2+\u0003\u0002U?\t\u0019\u0011J\u001c;\u0002\t\u0011LW\u000eI\u0001\f]Vlg)Z1ukJ,7/\u0001\u0007ok64U-\u0019;ve\u0016\u001c\b%A\u0003tS\u001el\u0017-F\u00015\u0003\u0019\u0019\u0018nZ7bA\u0005I\u0011N\u001c;fe\u000e,\u0007\u000f^\u0001\u000bS:$XM]2faR\u0004\u0013aA1eIR\u0011Qf\u0018\u0005\u0006A:\u0001\raJ\u0001\tS:\u001cH/\u00198dK\u0002")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/HuberAggregator.class */
public class HuberAggregator implements DifferentiableLossAggregator<Instance, HuberAggregator> {
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<double[]> bcFeaturesStd;
    private final Broadcast<Vector> bcParameters;
    private final int dim;
    private final int numFeatures;
    private final double sigma;
    private final double intercept;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile boolean bitmap$0;

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator, org.apache.spark.ml.optim.aggregator.HuberAggregator] */
    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public HuberAggregator merge(HuberAggregator huberAggregator) {
        ?? merge;
        merge = merge(huberAggregator);
        return merge;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public Vector gradient() {
        Vector gradient;
        gradient = gradient();
        return gradient;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weight() {
        double weight;
        weight = weight();
        return weight;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double loss() {
        double loss;
        loss = loss();
        return loss;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weightSum() {
        return this.weightSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void weightSum_$eq(double d) {
        this.weightSum = d;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double lossSum() {
        return this.lossSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void lossSum_$eq(double d) {
        this.lossSum = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.optim.aggregator.HuberAggregator] */
    private double[] gradientSumArray$lzycompute() {
        double[] gradientSumArray;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                gradientSumArray = gradientSumArray();
                this.gradientSumArray = gradientSumArray;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double sigma() {
        return this.sigma;
    }

    private double intercept() {
        return this.intercept;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public HuberAggregator add(Instance instance) {
        if (instance == null) {
            throw new MatchError(instance);
        }
        double label = instance.label();
        double weight = instance.weight();
        Vector features = instance.features();
        Predef$.MODULE$.require(numFeatures() == features.size(), () -> {
            return new StringBuilder(43).append("Dimensions mismatch when adding new sample.").append(new StringBuilder(21).append(" Expecting ").append(this.numFeatures()).append(" but got ").append(features.size()).append(".").toString()).toString();
        });
        Predef$.MODULE$.require(weight >= 0.0d, () -> {
            return new StringBuilder(34).append("instance weight, ").append(weight).append(" has to be >= 0.0").toString();
        });
        if (weight == 0.0d) {
            return this;
        }
        double[] dArr = (double[]) this.bcFeaturesStd.value();
        double[] dArr2 = (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((Vector) this.bcParameters.value()).toArray())).slice(0, numFeatures());
        double[] gradientSumArray = gradientSumArray();
        DoubleRef create = DoubleRef.create(0.0d);
        features.foreachActive((i, d) -> {
            if (dArr[i] == 0.0d || d == 0.0d) {
                return;
            }
            create.elem += dArr2[i] * (d / dArr[i]);
        });
        if (this.fitIntercept) {
            create.elem += intercept();
        }
        double d2 = label - create.elem;
        if (package$.MODULE$.abs(d2) <= sigma() * this.epsilon) {
            lossSum_$eq(lossSum() + (0.5d * weight * (sigma() + (package$.MODULE$.pow(d2, 2.0d) / sigma()))));
            double sigma = d2 / sigma();
            features.foreachActive((i2, d3) -> {
                if (dArr[i2] == 0.0d || d3 == 0.0d) {
                    return;
                }
                gradientSumArray[i2] = gradientSumArray[i2] + ((-1.0d) * weight * sigma * (d3 / dArr[i2]));
            });
            if (this.fitIntercept) {
                int dim = dim() - 2;
                gradientSumArray[dim] = gradientSumArray[dim] + ((-1.0d) * weight * sigma);
            }
            int dim2 = dim() - 1;
            gradientSumArray[dim2] = gradientSumArray[dim2] + (0.5d * weight * (1.0d - package$.MODULE$.pow(sigma, 2.0d)));
        } else {
            double d4 = d2 >= ((double) 0) ? -1.0d : 1.0d;
            lossSum_$eq(lossSum() + (0.5d * weight * ((sigma() + ((2.0d * this.epsilon) * package$.MODULE$.abs(d2))) - ((sigma() * this.epsilon) * this.epsilon))));
            features.foreachActive((i3, d5) -> {
                if (dArr[i3] == 0.0d || d5 == 0.0d) {
                    return;
                }
                gradientSumArray[i3] = gradientSumArray[i3] + (weight * d4 * this.epsilon * (d5 / dArr[i3]));
            });
            if (this.fitIntercept) {
                int dim3 = dim() - 2;
                gradientSumArray[dim3] = gradientSumArray[dim3] + (weight * d4 * this.epsilon);
            }
            int dim4 = dim() - 1;
            gradientSumArray[dim4] = gradientSumArray[dim4] + (0.5d * weight * (1.0d - (this.epsilon * this.epsilon)));
        }
        weightSum_$eq(weightSum() + weight);
        return this;
    }

    public HuberAggregator(boolean z, double d, Broadcast<double[]> broadcast, Broadcast<Vector> broadcast2) {
        this.fitIntercept = z;
        this.epsilon = d;
        this.bcFeaturesStd = broadcast;
        this.bcParameters = broadcast2;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector) broadcast2.value()).size();
        this.numFeatures = z ? dim() - 2 : dim() - 1;
        this.sigma = ((Vector) broadcast2.value()).apply(dim() - 1);
        this.intercept = z ? ((Vector) broadcast2.value()).apply(dim() - 2) : 0.0d;
    }
}
