package org.graphframes.examples;

import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.EdgeContext;
import org.apache.spark.graphx.Graph;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.graphframes.GraphFrame;
import org.graphframes.GraphFrame$;
import org.graphframes.examples.BeliefPropagation;
import org.graphframes.lib.AggregateMessages$;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Predef$$eq$colon$eq$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: BeliefPropagation.scala */
/* loaded from: input_file:org/graphframes/examples/BeliefPropagation$.class */
public final class BeliefPropagation$ {
    public static BeliefPropagation$ MODULE$;

    static {
        new BeliefPropagation$();
    }

    public void main(String[] strArr) {
        SparkSession orCreate = SparkSession$.MODULE$.builder().appName("BeliefPropagation example").getOrCreate();
        GraphFrame gridIsingModel = Graphs$.MODULE$.gridIsingModel(orCreate.sqlContext(), 3);
        Predef$.MODULE$.println("Original Ising model:");
        gridIsingModel.vertices().show();
        gridIsingModel.edges().show();
        Dataset select = runBPwithGraphX(gridIsingModel, 5).vertices().select("id", Predef$.MODULE$.wrapRefArray(new String[]{"belief"}));
        Predef$.MODULE$.println(new StringBuilder(46).append("Done with BP. Final beliefs after ").append(5).append(" iterations:").toString());
        select.show();
        orCreate.stop();
    }

    private GraphFrame colorGraph(GraphFrame graphFrame) {
        return GraphFrame$.MODULE$.apply(graphFrame.vertices().withColumn("color", functions$.MODULE$.udf((i, i2) -> {
            return (i + i2) % 2;
        }, package$.MODULE$.universe().TypeTag().Int(), package$.MODULE$.universe().TypeTag().Int(), package$.MODULE$.universe().TypeTag().Int()).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col("i"), functions$.MODULE$.col("j")}))), graphFrame.edges());
    }

    public GraphFrame runBPwithGraphX(GraphFrame graphFrame, int i) {
        GraphFrame colorGraph = colorGraph(graphFrame);
        int count = (int) colorGraph.vertices().select("color", Predef$.MODULE$.wrapRefArray(new String[0])).distinct().count();
        Graph<Row, Row> graphX = colorGraph.toGraphX();
        Map<String, Object> vertexColumnMap = colorGraph.vertexColumnMap();
        Map<String, Object> edgeColumnMap = colorGraph.edgeColumnMap();
        Function2 function2 = (obj, row) -> {
            return $anonfun$runBPwithGraphX$1(vertexColumnMap, BoxesRunTime.unboxToLong(obj), row);
        };
        ClassTag apply = ClassTag$.MODULE$.apply(BeliefPropagation.VertexAttr.class);
        graphX.mapVertices$default$3(function2);
        ObjectRef create = ObjectRef.create(graphX.mapVertices(function2, apply, (Predef$.eq.colon.eq) null).mapEdges(edge -> {
            return new BeliefPropagation.EdgeAttr(((Row) edge.attr()).getDouble(BoxesRunTime.unboxToInt(edgeColumnMap.apply("b"))));
        }, ClassTag$.MODULE$.apply(BeliefPropagation.EdgeAttr.class)));
        scala.package$.MODULE$.Range().apply(0, i).foreach$mVc$sp(i2 -> {
            scala.package$.MODULE$.Range().apply(0, count).foreach$mVc$sp(i2 -> {
                Graph graph = (Graph) create.elem;
                create.elem = ((Graph) create.elem).outerJoinVertices(graph.aggregateMessages(edgeContext -> {
                    $anonfun$runBPwithGraphX$5(i2, edgeContext);
                    return BoxedUnit.UNIT;
                }, (d, d2) -> {
                    return d + d2;
                }, graph.aggregateMessages$default$3(), ClassTag$.MODULE$.Double()), (obj2, vertexAttr, option) -> {
                    return $anonfun$runBPwithGraphX$7(i2, BoxesRunTime.unboxToLong(obj2), vertexAttr, option);
                }, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.apply(BeliefPropagation.VertexAttr.class), Predef$$eq$colon$eq$.MODULE$.tpEquals());
            });
        });
        Graph graph = (Graph) create.elem;
        Function2 function22 = (obj2, vertexAttr) -> {
            return BoxesRunTime.boxToDouble($anonfun$runBPwithGraphX$9(BoxesRunTime.unboxToLong(obj2), vertexAttr));
        };
        ClassTag Double = ClassTag$.MODULE$.Double();
        graph.mapVertices$default$3(function22);
        return GraphFrame$.MODULE$.fromGraphX(colorGraph, graph.mapVertices(function22, Double, (Predef$.eq.colon.eq) null).mapEdges(edge2 -> {
            $anonfun$runBPwithGraphX$10(edge2);
            return BoxedUnit.UNIT;
        }, ClassTag$.MODULE$.Unit()), (Seq) new $colon.colon("belief", Nil$.MODULE$), GraphFrame$.MODULE$.fromGraphX$default$4(), package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Unit());
    }

    public GraphFrame runBPwithGraphFrames(GraphFrame graphFrame, int i) {
        GraphFrame colorGraph = colorGraph(graphFrame);
        int count = (int) colorGraph.vertices().select("color", Predef$.MODULE$.wrapRefArray(new String[0])).distinct().count();
        ObjectRef create = ObjectRef.create(GraphFrame$.MODULE$.apply(colorGraph.vertices().withColumn("belief", functions$.MODULE$.lit(BoxesRunTime.boxToDouble(CMAESOptimizer.DEFAULT_STOPFITNESS))), colorGraph.edges()));
        scala.package$.MODULE$.Range().apply(0, i).foreach$mVc$sp(i2 -> {
            scala.package$.MODULE$.Range().apply(0, count).foreach$mVc$sp(i2 -> {
                AggregateMessages$ aggregateMessages$ = AggregateMessages$.MODULE$;
                Column when = functions$.MODULE$.when(aggregateMessages$.src().apply("color").$eq$eq$eq(BoxesRunTime.boxToInteger(i2)), aggregateMessages$.edge().apply("b").$times(aggregateMessages$.dst().apply("belief")));
                Column when2 = functions$.MODULE$.when(aggregateMessages$.dst().apply("color").$eq$eq$eq(BoxesRunTime.boxToInteger(i2)), aggregateMessages$.edge().apply("b").$times(aggregateMessages$.src().apply("belief")));
                UserDefinedFunction udf = functions$.MODULE$.udf(d -> {
                    return scala.math.package$.MODULE$.exp(-MODULE$.log1pExp(-d));
                }, package$.MODULE$.universe().TypeTag().Double(), package$.MODULE$.universe().TypeTag().Double());
                Dataset<Row> agg = ((GraphFrame) create.elem).aggregateMessages().sendToSrc(when).sendToDst(when2).agg(functions$.MODULE$.sum(aggregateMessages$.msg()).as("aggMess"));
                Dataset<Row> vertices = ((GraphFrame) create.elem).vertices();
                create.elem = GraphFrame$.MODULE$.apply(aggregateMessages$.getCachedDataFrame(vertices.join(agg, vertices.apply("id").$eq$eq$eq(agg.apply("id")), "left_outer").drop(agg.apply("id")).withColumn("newBelief", functions$.MODULE$.when(vertices.apply("color").$eq$eq$eq(BoxesRunTime.boxToInteger(i2)).$amp$amp(agg.apply("aggMess").isNotNull()), udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{agg.apply("aggMess").$plus(vertices.apply("a"))}))).otherwise(vertices.apply("belief"))).drop("aggMess").drop("belief").withColumnRenamed("newBelief", "belief")), ((GraphFrame) create.elem).edges());
            });
        });
        return GraphFrame$.MODULE$.apply(((GraphFrame) create.elem).vertices().drop("color"), ((GraphFrame) create.elem).edges());
    }

    private double log1pExp(double d) {
        return d > ((double) 0) ? d + scala.math.package$.MODULE$.log1p(scala.math.package$.MODULE$.exp(-d)) : scala.math.package$.MODULE$.log1p(scala.math.package$.MODULE$.exp(d));
    }

    public static final /* synthetic */ BeliefPropagation.VertexAttr $anonfun$runBPwithGraphX$1(Map map, long j, Row row) {
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToLong(j), row);
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Row row2 = (Row) tuple2._2();
        return new BeliefPropagation.VertexAttr(row2.getDouble(BoxesRunTime.unboxToInt(map.apply("a"))), CMAESOptimizer.DEFAULT_STOPFITNESS, row2.getInt(BoxesRunTime.unboxToInt(map.apply("color"))));
    }

    public static final /* synthetic */ void $anonfun$runBPwithGraphX$5(int i, EdgeContext edgeContext) {
        if (((BeliefPropagation.VertexAttr) edgeContext.dstAttr()).color() == i) {
            double b = ((BeliefPropagation.EdgeAttr) edgeContext.attr()).b() * ((BeliefPropagation.VertexAttr) edgeContext.srcAttr()).belief();
            if (b != 0) {
                edgeContext.sendToDst(BoxesRunTime.boxToDouble(b));
                return;
            }
            return;
        }
        if (((BeliefPropagation.VertexAttr) edgeContext.srcAttr()).color() == i) {
            double b2 = ((BeliefPropagation.EdgeAttr) edgeContext.attr()).b() * ((BeliefPropagation.VertexAttr) edgeContext.dstAttr()).belief();
            if (b2 != 0) {
                edgeContext.sendToSrc(BoxesRunTime.boxToDouble(b2));
            }
        }
    }

    public static final /* synthetic */ BeliefPropagation.VertexAttr $anonfun$runBPwithGraphX$7(int i, long j, BeliefPropagation.VertexAttr vertexAttr, Option option) {
        BeliefPropagation.VertexAttr vertexAttr2;
        Tuple3 tuple3 = new Tuple3(BoxesRunTime.boxToLong(j), vertexAttr, option);
        if (tuple3 == null) {
            throw new MatchError(tuple3);
        }
        BeliefPropagation.VertexAttr vertexAttr3 = (BeliefPropagation.VertexAttr) tuple3._2();
        Option option2 = (Option) tuple3._3();
        if (vertexAttr3.color() == i) {
            vertexAttr2 = new BeliefPropagation.VertexAttr(vertexAttr3.a(), scala.math.package$.MODULE$.exp(-MODULE$.log1pExp(-(vertexAttr3.a() + BoxesRunTime.unboxToDouble(option2.getOrElse(() -> {
                return CMAESOptimizer.DEFAULT_STOPFITNESS;
            }))))), i);
        } else {
            vertexAttr2 = vertexAttr3;
        }
        return vertexAttr2;
    }

    public static final /* synthetic */ double $anonfun$runBPwithGraphX$9(long j, BeliefPropagation.VertexAttr vertexAttr) {
        return vertexAttr.belief();
    }

    public static final /* synthetic */ void $anonfun$runBPwithGraphX$10(Edge edge) {
    }

    private BeliefPropagation$() {
        MODULE$ = this;
    }
}
