package ai.djl.nn.norm;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;

/* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/nn/norm/LayerNorm.class */
public class LayerNorm extends AbstractBlock {
    protected float epsilon;
    protected Shape normalizedShape;
    protected boolean center;
    protected boolean scale;
    protected int[] axis;
    protected Parameter gamma;
    protected Parameter beta;

    /* loaded from: input_file:META-INF/bundled-dependencies/api-0.22.1.jar:ai/djl/nn/norm/LayerNorm$Builder.class */
    public static class Builder {
        private float epsilon = 1.0E-5f;
        private boolean scale = true;
        private boolean center = true;
        private int[] axis;

        protected Builder() {
        }

        public Builder axis(int... iArr) {
            this.axis = iArr;
            return this;
        }

        public Builder optCenter(boolean z) {
            this.center = z;
            return this;
        }

        public Builder optScale(boolean z) {
            this.scale = z;
            return this;
        }

        public Builder optEpsilon(float f) {
            this.epsilon = f;
            return this;
        }

        public LayerNorm build() {
            return new LayerNorm(this);
        }
    }

    protected LayerNorm(Builder builder) {
        this.epsilon = builder.epsilon;
        this.scale = builder.scale;
        this.center = builder.center;
        this.axis = builder.axis;
        this.gamma = addParameter(Parameter.builder().setName("gamma").setType(Parameter.Type.GAMMA).optRequiresGrad(this.scale).build());
        this.beta = addParameter(Parameter.builder().setName("beta").setType(Parameter.Type.BETA).optRequiresGrad(this.center).build());
    }

    public static NDList layerNorm(NDArray nDArray, Shape shape, NDArray nDArray2, NDArray nDArray3, float f) {
        return nDArray.getNDArrayInternal().layerNorm(nDArray, shape, nDArray2, nDArray3, f);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        return layerNorm(singletonOrThrow, this.normalizedShape, parameterStore.getValue(this.gamma, device, z), parameterStore.getValue(this.beta, device, z), this.epsilon);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.AbstractBaseBlock
    public void beforeInitialize(Shape... shapeArr) {
        super.beforeInitialize(shapeArr);
        this.normalizedShape = this.axis == null ? shapeArr[0].slice(1) : new Shape(Arrays.stream(this.axis).mapToLong(i -> {
            return shapeArr[0].get(i);
        }).toArray());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void prepare(Shape[] shapeArr) {
        this.gamma.setShape(this.normalizedShape);
        this.beta.setShape(this.normalizedShape);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected void saveMetadata(DataOutputStream dataOutputStream) throws IOException {
        saveInputShapes(dataOutputStream);
        dataOutputStream.write(this.normalizedShape.getEncoded());
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b != this.version) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
        readInputShapes(dataInputStream);
        this.normalizedShape = Shape.decode(dataInputStream);
    }
}
