package ai.djl.nn;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Pair;
import java.util.Iterator;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import jnr.ffi.provider.jffi.JNINativeInterface;
import org.slf4j.Marker;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/nn/Blocks.class */
public final class Blocks {
    private Blocks() {
    }

    public static NDArray batchFlatten(NDArray nDArray) {
        long size = nDArray.size(0);
        return size == 0 ? nDArray.reshape(size, nDArray.getShape().slice(1).size()) : nDArray.reshape(size, -1);
    }

    public static NDArray batchFlatten(NDArray nDArray, long j) {
        return nDArray.reshape(-1, j);
    }

    public static Block batchFlattenBlock() {
        return LambdaBlock.singleton(Blocks::batchFlatten, "batchFlatten");
    }

    public static Block batchFlattenBlock(long j) {
        return LambdaBlock.singleton(nDArray -> {
            return batchFlatten(nDArray, j);
        }, "batchFlatten");
    }

    public static Block identityBlock() {
        return new LambdaBlock(nDList -> {
            return nDList;
        }, "identity");
    }

    public static String describe(Block block, String str, int i) {
        Shape[] inputShapes = block.isInitialized() ? block.getInputShapes() : null;
        Shape[] outputShapes = inputShapes != null ? block.getOutputShapes(inputShapes) : null;
        StringBuilder sb = new StringBuilder(JNINativeInterface.GetByteArrayRegion);
        if ((block instanceof LambdaBlock) && !LambdaBlock.DEFAULT_NAME.equals(((LambdaBlock) block).getName())) {
            sb.append(((LambdaBlock) block).getName());
        } else if (str != null) {
            sb.append(str);
        } else {
            sb.append(block.getClass().getSimpleName());
        }
        if (inputShapes != null) {
            sb.append((String) Stream.of((Object[]) inputShapes).map(shape -> {
                return shape.slice(i).toString();
            }).collect(Collectors.joining(Marker.ANY_NON_NULL_MARKER)));
        }
        if (!block.getChildren().isEmpty()) {
            sb.append(" {\n");
            Iterator<Pair<String, Block>> it = block.getChildren().iterator();
            while (it.hasNext()) {
                Pair<String, Block> next = it.next();
                sb.append(describe(next.getValue(), next.getKey().substring(2), i).replaceAll("(?m)^", "\t")).append('\n');
            }
            sb.append('}');
        }
        if (outputShapes != null) {
            sb.append(" -> ");
            sb.append((String) Stream.of((Object[]) outputShapes).map(shape2 -> {
                return shape2.slice(i).toString();
            }).collect(Collectors.joining(Marker.ANY_NON_NULL_MARKER)));
        }
        return sb.toString();
    }
}
