package org.psjava.algo.graph;

import org.psjava.algo.graph.dfs.DFSVisitorBase;
import org.psjava.algo.graph.dfs.SingleSourceDFS;
import org.psjava.ds.array.DynamicArray;
import org.psjava.ds.graph.DirectedWeightedEdge;
import org.psjava.ds.graph.RootedTree;
import org.psjava.ds.map.MutableMap;
import org.psjava.ds.map.MutableMapFactory;
import org.psjava.ds.math.BinaryOperator;
import org.psjava.ds.numbersystrem.AddInvert;
import org.psjava.ds.numbersystrem.AddableNumberSystem;
import org.psjava.ds.tree.segmenttree.SegmentTree;
import org.psjava.ds.tree.segmenttree.SegmentTreeFactory;
import org.psjava.util.VisitorStopper;

/* loaded from: input_file:psjava-0.1.19.jar:org/psjava/algo/graph/DistanceCalculatorInRootedTree.class */
public class DistanceCalculatorInRootedTree {
    private final LowestCommonAncestorAlgorithm lca;
    private final SegmentTreeFactory segmentTreeFactory;
    private final MutableMapFactory mapFactory;

    public DistanceCalculatorInRootedTree(LowestCommonAncestorAlgorithm lowestCommonAncestorAlgorithm, SegmentTreeFactory segmentTreeFactory, MutableMapFactory mutableMapFactory) {
        this.lca = lowestCommonAncestorAlgorithm;
        this.segmentTreeFactory = segmentTreeFactory;
        this.mapFactory = mutableMapFactory;
    }

    public <V, W> DistanceCalculatorInRootedTreeResult<V, W> calc(RootedTree<V, DirectedWeightedEdge<V, W>> rootedTree, final AddableNumberSystem<W> addableNumberSystem) {
        final DynamicArray create = DynamicArray.create();
        final MutableMap create2 = this.mapFactory.create();
        final MutableMap create3 = this.mapFactory.create();
        final MutableMap create4 = this.mapFactory.create();
        SingleSourceDFS.traverse(rootedTree.graph, rootedTree.root, new DFSVisitorBase<V, DirectedWeightedEdge<V, W>>() { // from class: org.psjava.algo.graph.DistanceCalculatorInRootedTree.1
            @Override // org.psjava.algo.graph.dfs.DFSVisitorBase, org.psjava.algo.graph.dfs.DFSVisitor
            public void onDiscovered(V v, int i, VisitorStopper visitorStopper) {
                create2.add(v, Integer.valueOf(create.size()));
            }

            @Override // org.psjava.algo.graph.dfs.DFSVisitorBase, org.psjava.algo.graph.dfs.DFSVisitor
            public void onWalkDown(DirectedWeightedEdge<V, W> directedWeightedEdge) {
                create3.add(directedWeightedEdge.to(), Integer.valueOf(create.size()));
                create.addToLast(directedWeightedEdge.weight());
            }

            @Override // org.psjava.algo.graph.dfs.DFSVisitorBase, org.psjava.algo.graph.dfs.DFSVisitor
            public void onWalkUp(DirectedWeightedEdge<V, W> directedWeightedEdge) {
                create4.add(directedWeightedEdge.to(), Integer.valueOf(create.size()));
                create.addToLast(AddInvert.calc(addableNumberSystem, directedWeightedEdge.weight()));
            }
        });
        final SegmentTree create5 = this.segmentTreeFactory.create(create, createAdder(addableNumberSystem));
        final LowestCommonAncestorQuerySession<V> calc = this.lca.calc(rootedTree);
        return new DistanceCalculatorInRootedTreeResult<V, W>() { // from class: org.psjava.algo.graph.DistanceCalculatorInRootedTree.2
            @Override // org.psjava.algo.graph.DistanceCalculatorInRootedTreeResult
            public W getDistance(V v, V v2) {
                Object query = calc.query(v, v2);
                Object zero = addableNumberSystem.getZero();
                if (!v.equals(query)) {
                    zero = addableNumberSystem.add(zero, create5.query(((Integer) create2.get(query)).intValue(), ((Integer) create2.get(v)).intValue()));
                }
                if (!v2.equals(query)) {
                    zero = addableNumberSystem.add(zero, create5.query(((Integer) create2.get(query)).intValue(), ((Integer) create2.get(v2)).intValue()));
                }
                return (W) zero;
            }

            @Override // org.psjava.algo.graph.DistanceCalculatorInRootedTreeResult
            public void modifyDistance(V v, V v2, W w) {
                V v3 = ((Integer) create2.get(v)).intValue() < ((Integer) create2.get(v2)).intValue() ? v2 : v;
                create5.update(((Integer) create3.get(v3)).intValue(), w);
                create5.update(((Integer) create4.get(v3)).intValue(), AddInvert.calc(addableNumberSystem, w));
            }
        };
    }

    private static <T> BinaryOperator<T> createAdder(final AddableNumberSystem<T> addableNumberSystem) {
        return new BinaryOperator<T>() { // from class: org.psjava.algo.graph.DistanceCalculatorInRootedTree.3
            @Override // org.psjava.ds.math.BinaryOperator
            public T calc(T t, T t2) {
                return (T) AddableNumberSystem.this.add(t, t2);
            }
        };
    }
}
