package smile.neighbor;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.Metric;
import smile.sort.DoubleHeapSelect;
import smile.util.DoubleArrayList;

/* loaded from: input_file:smile/neighbor/CoverTree.class */
public class CoverTree<K, V> implements KNNSearch<K, V>, RNNSearch<K, V>, Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(CoverTree.class);
    private final List<K> keys;
    private final List<V> data;
    private final Metric<K> distance;
    private CoverTree<K, V>.Node root;
    private final double base;
    private final double invLogBase;

    /* loaded from: input_file:smile/neighbor/CoverTree$DistanceNode.class */
    class DistanceNode implements Comparable<CoverTree<K, V>.DistanceNode> {
        double dist;
        CoverTree<K, V>.Node node;

        DistanceNode(double d, CoverTree<K, V>.Node node) {
            this.dist = d;
            this.node = node;
        }

        @Override // java.lang.Comparable
        public int compareTo(CoverTree<K, V>.DistanceNode distanceNode) {
            return Double.compare(this.dist, distanceNode.dist);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/neighbor/CoverTree$DistanceSet.class */
    public class DistanceSet {
        int idx;
        DoubleArrayList dist = new DoubleArrayList();

        DistanceSet(int i) {
            this.idx = i;
        }

        K getKey() {
            return (K) CoverTree.this.keys.get(this.idx);
        }

        V getValue() {
            return (V) CoverTree.this.data.get(this.idx);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/neighbor/CoverTree$Node.class */
    public class Node implements Serializable {
        int idx;
        double maxDist;
        double parentDist;
        ArrayList<CoverTree<K, V>.Node> children;
        int scale;

        Node(int i) {
            this.idx = i;
        }

        Node(int i, double d, double d2, ArrayList<CoverTree<K, V>.Node> arrayList, int i2) {
            this.idx = i;
            this.maxDist = d;
            this.parentDist = d2;
            this.children = arrayList;
            this.scale = i2;
        }

        K getKey() {
            return (K) CoverTree.this.keys.get(this.idx);
        }

        V getValue() {
            return (V) CoverTree.this.data.get(this.idx);
        }

        boolean isLeaf() {
            return this.children == null;
        }
    }

    public CoverTree(K[] kArr, V[] vArr, Metric<K> metric) {
        this(kArr, vArr, metric, 1.3d);
    }

    public CoverTree(List<K> list, List<V> list2, Metric<K> metric) {
        this(list, list2, metric, 1.3d);
    }

    public CoverTree(K[] kArr, V[] vArr, Metric<K> metric, double d) {
        this(Arrays.asList(kArr), Arrays.asList(vArr), metric, d);
    }

    public CoverTree(List<K> list, List<V> list2, Metric<K> metric, double d) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Different size of keys and data objects");
        }
        this.keys = list;
        this.data = list2;
        this.distance = metric;
        this.base = d;
        this.invLogBase = 1.0d / Math.log(d);
        buildCoverTree();
    }

    public static <T> CoverTree<T, T> of(T[] tArr, Metric<T> metric) {
        return new CoverTree<>(tArr, tArr, metric);
    }

    public static <T> CoverTree<T, T> of(T[] tArr, Metric<T> metric, double d) {
        return new CoverTree<>(tArr, tArr, metric, d);
    }

    public static <T> CoverTree<T, T> of(List<T> list, Metric<T> metric) {
        return new CoverTree<>(list, list, metric);
    }

    public static <T> CoverTree<T, T> of(List<T> list, Metric<T> metric, double d) {
        return new CoverTree<>(list, list, metric, d);
    }

    public String toString() {
        return String.format("Cover Tree (%s)", this.distance);
    }

    private void buildCoverTree() {
        ArrayList<CoverTree<K, V>.DistanceSet> arrayList = new ArrayList<>();
        ArrayList<CoverTree<K, V>.DistanceSet> arrayList2 = new ArrayList<>();
        K k = this.keys.get(0);
        int size = this.keys.size();
        double d = -1.0d;
        for (int i = 1; i < size; i++) {
            CoverTree<K, V>.DistanceSet distanceSet = new DistanceSet(i);
            double d2 = this.distance.d(k, this.keys.get(i));
            distanceSet.dist.add(d2);
            arrayList.add(distanceSet);
            if (d2 > d) {
                d = d2;
            }
        }
        this.root = batchInsert(0, getScale(d), getScale(d), arrayList, arrayList2);
    }

    private CoverTree<K, V>.Node batchInsert(int i, int i2, int i3, ArrayList<CoverTree<K, V>.DistanceSet> arrayList, ArrayList<CoverTree<K, V>.DistanceSet> arrayList2) {
        if (arrayList.isEmpty()) {
            return newLeaf(i);
        }
        int min = Math.min(i2 - 1, getScale(max(arrayList)));
        if (min == Integer.MIN_VALUE) {
            ArrayList<CoverTree<K, V>.Node> arrayList3 = new ArrayList<>();
            arrayList3.add(newLeaf(i));
            while (!arrayList.isEmpty()) {
                CoverTree<K, V>.DistanceSet distanceSet = arrayList.get(arrayList.size() - 1);
                arrayList.remove(arrayList.size() - 1);
                arrayList3.add(newLeaf(distanceSet.idx));
                arrayList2.add(distanceSet);
            }
            CoverTree<K, V>.Node node = new Node(i);
            node.scale = 100;
            node.maxDist = 0.0d;
            node.children = arrayList3;
            return node;
        }
        ArrayList<CoverTree<K, V>.DistanceSet> arrayList4 = new ArrayList<>();
        split(arrayList, arrayList4, i2);
        CoverTree<K, V>.Node batchInsert = batchInsert(i, min, i3, arrayList, arrayList2);
        if (arrayList.isEmpty()) {
            arrayList.addAll(arrayList4);
            return batchInsert;
        }
        ArrayList<CoverTree<K, V>.Node> arrayList5 = new ArrayList<>();
        arrayList5.add(batchInsert);
        ArrayList<CoverTree<K, V>.DistanceSet> arrayList6 = new ArrayList<>();
        ArrayList<CoverTree<K, V>.DistanceSet> arrayList7 = new ArrayList<>();
        while (!arrayList.isEmpty()) {
            CoverTree<K, V>.DistanceSet distanceSet2 = arrayList.get(arrayList.size() - 1);
            arrayList.remove(arrayList.size() - 1);
            double d = distanceSet2.dist.get(distanceSet2.dist.size() - 1);
            arrayList2.add(distanceSet2);
            distSplit(arrayList, arrayList6, distanceSet2.getKey(), i2);
            distSplit(arrayList4, arrayList6, distanceSet2.getKey(), i2);
            CoverTree<K, V>.Node batchInsert2 = batchInsert(distanceSet2.idx, min, i3, arrayList6, arrayList7);
            batchInsert2.parentDist = d;
            arrayList5.add(batchInsert2);
            double coverRadius = getCoverRadius(i2);
            Iterator<CoverTree<K, V>.DistanceSet> it = arrayList6.iterator();
            while (it.hasNext()) {
                CoverTree<K, V>.DistanceSet next = it.next();
                next.dist.remove(next.dist.size() - 1);
                if (next.dist.get(next.dist.size() - 1) <= coverRadius) {
                    arrayList.add(next);
                } else {
                    arrayList4.add(next);
                }
            }
            Iterator<CoverTree<K, V>.DistanceSet> it2 = arrayList7.iterator();
            while (it2.hasNext()) {
                CoverTree<K, V>.DistanceSet next2 = it2.next();
                next2.dist.remove(next2.dist.size() - 1);
                arrayList2.add(next2);
            }
            arrayList6.clear();
            arrayList7.clear();
        }
        arrayList.addAll(arrayList4);
        CoverTree<K, V>.Node node2 = new Node(i);
        node2.scale = i3 - i2;
        node2.maxDist = max(arrayList2);
        node2.children = arrayList5;
        return node2;
    }

    private double getCoverRadius(int i) {
        return Math.pow(this.base, i);
    }

    private int getScale(double d) {
        return (int) Math.ceil(this.invLogBase * Math.log(d));
    }

    private CoverTree<K, V>.Node newLeaf(int i) {
        return new Node(i, 0.0d, 0.0d, null, 100);
    }

    private double max(ArrayList<CoverTree<K, V>.DistanceSet> arrayList) {
        double d = 0.0d;
        Iterator<CoverTree<K, V>.DistanceSet> it = arrayList.iterator();
        while (it.hasNext()) {
            CoverTree<K, V>.DistanceSet next = it.next();
            if (d < next.dist.get(next.dist.size() - 1)) {
                d = next.dist.get(next.dist.size() - 1);
            }
        }
        return d;
    }

    private void split(ArrayList<CoverTree<K, V>.DistanceSet> arrayList, ArrayList<CoverTree<K, V>.DistanceSet> arrayList2, int i) {
        double coverRadius = getCoverRadius(i);
        ArrayList arrayList3 = new ArrayList();
        Iterator<CoverTree<K, V>.DistanceSet> it = arrayList.iterator();
        while (it.hasNext()) {
            CoverTree<K, V>.DistanceSet next = it.next();
            if (next.dist.get(next.dist.size() - 1) <= coverRadius) {
                arrayList3.add(next);
            } else {
                arrayList2.add(next);
            }
        }
        arrayList.clear();
        arrayList.addAll(arrayList3);
    }

    private void distSplit(ArrayList<CoverTree<K, V>.DistanceSet> arrayList, ArrayList<CoverTree<K, V>.DistanceSet> arrayList2, K k, int i) {
        double coverRadius = getCoverRadius(i);
        ArrayList arrayList3 = new ArrayList();
        Iterator<CoverTree<K, V>.DistanceSet> it = arrayList.iterator();
        while (it.hasNext()) {
            CoverTree<K, V>.DistanceSet next = it.next();
            double d = this.distance.d(k, next.getKey());
            if (d <= coverRadius) {
                next.dist.add(d);
                arrayList2.add(next);
            } else {
                arrayList3.add(next);
            }
        }
        arrayList.clear();
        arrayList.addAll(arrayList3);
    }

    @Override // smile.neighbor.KNNSearch
    public Neighbor<K, V>[] search(K k, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid k: " + i);
        }
        if (i > this.data.size()) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        K key = this.root.getKey();
        double d = this.distance.d(key, k);
        Neighbor<K, V> neighbor = new Neighbor<>(key, this.root.getValue(), this.root.idx, d);
        Neighbor<K, V>[] neighborArr = (Neighbor[]) Array.newInstance(neighbor.getClass(), 1);
        if (this.root.children == null) {
            neighborArr[0] = neighbor;
            return neighborArr;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(new DistanceNode(d, this.root));
        DoubleHeapSelect doubleHeapSelect = new DoubleHeapSelect(i);
        doubleHeapSelect.add(Double.MAX_VALUE);
        boolean z = true;
        if (this.root.getKey() != k) {
            doubleHeapSelect.add(d);
            z = false;
        }
        while (!arrayList.isEmpty()) {
            ArrayList arrayList3 = new ArrayList();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                DistanceNode distanceNode = (DistanceNode) it.next();
                CoverTree<K, V>.Node node = distanceNode.node;
                int i2 = 0;
                while (i2 < node.children.size()) {
                    CoverTree<K, V>.Node node2 = node.children.get(i2);
                    double d2 = i2 == 0 ? distanceNode.dist : this.distance.d(node2.getKey(), k);
                    double peek = z ? Double.MAX_VALUE : doubleHeapSelect.peek();
                    if (d2 <= peek + node2.maxDist) {
                        if (i2 > 0 && d2 < peek && node2.getKey() != k) {
                            doubleHeapSelect.add(d2);
                        }
                        if (node2.children != null) {
                            arrayList3.add(new DistanceNode(d2, node2));
                        } else if (d2 <= peek) {
                            arrayList2.add(new DistanceNode(d2, node2));
                        }
                    }
                    i2++;
                }
            }
            arrayList = arrayList3;
        }
        ArrayList arrayList4 = new ArrayList();
        double peek2 = doubleHeapSelect.peek();
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            DistanceNode distanceNode2 = (DistanceNode) it2.next();
            if (distanceNode2.dist <= peek2 && distanceNode2.node.getKey() != k) {
                arrayList4.add(new Neighbor(distanceNode2.node.getKey(), distanceNode2.node.getValue(), distanceNode2.node.idx, distanceNode2.dist));
            }
        }
        Neighbor<K, V>[] neighborArr2 = (Neighbor[]) arrayList4.toArray(neighborArr);
        if (neighborArr2.length < i) {
            logger.warn(String.format("CoverTree.knn(%d) returns only %d neighbors", Integer.valueOf(i), Integer.valueOf(neighborArr2.length)));
        }
        Arrays.sort(neighborArr2);
        if (neighborArr2.length > i) {
            neighborArr2 = (Neighbor[]) Arrays.copyOf(neighborArr2, i);
        }
        MathEx.reverse(neighborArr2);
        return neighborArr2;
    }

    @Override // smile.neighbor.RNNSearch
    public void search(K k, double d, List<Neighbor<K, V>> list) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid radius: " + d);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(new DistanceNode(this.distance.d(this.root.getKey(), k), this.root));
        while (!arrayList.isEmpty()) {
            ArrayList arrayList3 = new ArrayList();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                DistanceNode distanceNode = (DistanceNode) it.next();
                CoverTree<K, V>.Node node = distanceNode.node;
                int i = 0;
                while (i < node.children.size()) {
                    CoverTree<K, V>.Node node2 = node.children.get(i);
                    double d2 = i == 0 ? distanceNode.dist : this.distance.d(node2.getKey(), k);
                    if (d2 <= d + node2.maxDist) {
                        if (node2.children != null) {
                            arrayList3.add(new DistanceNode(d2, node2));
                        } else if (d2 <= d) {
                            arrayList2.add(new DistanceNode(d2, node2));
                        }
                    }
                    i++;
                }
            }
            arrayList = arrayList3;
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            DistanceNode distanceNode2 = (DistanceNode) it2.next();
            if (distanceNode2.node.getKey() != k) {
                list.add(new Neighbor<>(distanceNode2.node.getKey(), distanceNode2.node.getValue(), distanceNode2.node.idx, distanceNode2.dist));
            }
        }
    }
}
