package org.datavec.api.transform.ndarray;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.transform.Distance;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.metadata.DoubleMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/api/transform/ndarray/NDArrayDistanceTransform.class */
public class NDArrayDistanceTransform implements Transform {
    private String newColumnName;
    private Distance distance;
    private String firstCol;
    private String secondCol;
    private Schema inputSchema;

    public NDArrayDistanceTransform(@NonNull @JsonProperty("newColumnName") String str, @NonNull @JsonProperty("distance") Distance distance, @NonNull @JsonProperty("firstCol") String str2, @NonNull @JsonProperty("secondCol") String str3) {
        if (str == null) {
            throw new NullPointerException("newColumnName");
        }
        if (distance == null) {
            throw new NullPointerException("distance");
        }
        if (str2 == null) {
            throw new NullPointerException("firstCol");
        }
        if (str3 == null) {
            throw new NullPointerException("secondCol");
        }
        this.newColumnName = str;
        this.distance = distance;
        this.firstCol = str2;
        this.secondCol = str3;
    }

    public String toString() {
        return "NDArrayDistanceTransform(newColumnName=\"" + this.newColumnName + "\",distance=" + this.distance + ",firstCol=" + this.firstCol + ",secondCol=" + this.secondCol + ")";
    }

    @Override // org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        if (!schema.hasColumn(this.firstCol)) {
            throw new IllegalStateException("Input schema does not have first column: " + this.firstCol);
        }
        if (!schema.hasColumn(this.secondCol)) {
            throw new IllegalStateException("Input schema does not have first column: " + this.secondCol);
        }
        this.inputSchema = schema;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        double manhattanDistance;
        int indexOfColumn = this.inputSchema.getIndexOfColumn(this.firstCol);
        int indexOfColumn2 = this.inputSchema.getIndexOfColumn(this.secondCol);
        INDArray iNDArray = ((NDArrayWritable) list.get(indexOfColumn)).get();
        INDArray iNDArray2 = ((NDArrayWritable) list.get(indexOfColumn2)).get();
        switch (this.distance) {
            case COSINE:
                manhattanDistance = Transforms.cosineSim(iNDArray, iNDArray2);
                break;
            case EUCLIDEAN:
                manhattanDistance = Transforms.euclideanDistance(iNDArray, iNDArray2);
                break;
            case MANHATTAN:
                manhattanDistance = Transforms.manhattanDistance(iNDArray, iNDArray2);
                break;
            default:
                throw new UnsupportedOperationException("Unknown or not supported distance metric: " + this.distance);
        }
        ArrayList arrayList = new ArrayList(list.size() + 1);
        arrayList.addAll(list);
        arrayList.add(new DoubleWritable(manhattanDistance));
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Writable>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(map(it.next()));
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        ArrayList arrayList = new ArrayList(schema.getColumnMetaData());
        arrayList.add(new DoubleMetaData(this.newColumnName));
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return this.newColumnName;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return new String[]{outputColumnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{this.firstCol, this.secondCol};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return columnNames()[0];
    }

    public String getNewColumnName() {
        return this.newColumnName;
    }

    public Distance getDistance() {
        return this.distance;
    }

    public String getFirstCol() {
        return this.firstCol;
    }

    public String getSecondCol() {
        return this.secondCol;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema getInputSchema() {
        return this.inputSchema;
    }

    public void setNewColumnName(String str) {
        this.newColumnName = str;
    }

    public void setDistance(Distance distance) {
        this.distance = distance;
    }

    public void setFirstCol(String str) {
        this.firstCol = str;
    }

    public void setSecondCol(String str) {
        this.secondCol = str;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof NDArrayDistanceTransform)) {
            return false;
        }
        NDArrayDistanceTransform nDArrayDistanceTransform = (NDArrayDistanceTransform) obj;
        if (!nDArrayDistanceTransform.canEqual(this)) {
            return false;
        }
        String newColumnName = getNewColumnName();
        String newColumnName2 = nDArrayDistanceTransform.getNewColumnName();
        if (newColumnName == null) {
            if (newColumnName2 != null) {
                return false;
            }
        } else if (!newColumnName.equals(newColumnName2)) {
            return false;
        }
        Distance distance = getDistance();
        Distance distance2 = nDArrayDistanceTransform.getDistance();
        if (distance == null) {
            if (distance2 != null) {
                return false;
            }
        } else if (!distance.equals(distance2)) {
            return false;
        }
        String firstCol = getFirstCol();
        String firstCol2 = nDArrayDistanceTransform.getFirstCol();
        if (firstCol == null) {
            if (firstCol2 != null) {
                return false;
            }
        } else if (!firstCol.equals(firstCol2)) {
            return false;
        }
        String secondCol = getSecondCol();
        String secondCol2 = nDArrayDistanceTransform.getSecondCol();
        if (secondCol == null) {
            if (secondCol2 != null) {
                return false;
            }
        } else if (!secondCol.equals(secondCol2)) {
            return false;
        }
        Schema inputSchema = getInputSchema();
        Schema inputSchema2 = nDArrayDistanceTransform.getInputSchema();
        return inputSchema == null ? inputSchema2 == null : inputSchema.equals(inputSchema2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof NDArrayDistanceTransform;
    }

    public int hashCode() {
        String newColumnName = getNewColumnName();
        int hashCode = (1 * 59) + (newColumnName == null ? 43 : newColumnName.hashCode());
        Distance distance = getDistance();
        int hashCode2 = (hashCode * 59) + (distance == null ? 43 : distance.hashCode());
        String firstCol = getFirstCol();
        int hashCode3 = (hashCode2 * 59) + (firstCol == null ? 43 : firstCol.hashCode());
        String secondCol = getSecondCol();
        int hashCode4 = (hashCode3 * 59) + (secondCol == null ? 43 : secondCol.hashCode());
        Schema inputSchema = getInputSchema();
        return (hashCode4 * 59) + (inputSchema == null ? 43 : inputSchema.hashCode());
    }
}
