org.tensorflow.Tensor

Here are the examples of the java api org.tensorflow.Tensor taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

130 Examples 7

19 Source : RNTensorflowInference.java
with Apache License 2.0
from reneweb

public void feed(String inputName, Tensor tensor) {
    tfContext.runner.feed(inputName, tensor);
}

19 Source : MRCNNCorpusCallosum.java
with GNU General Public License v3.0
from mstritt

public Detections detectCorpusCallosum(BufferedImage image512) {
    Tensor<Float> input = DLHelpers.convertBufferedImageToTensor(image512, size, size);
    if (input != null) {
        RawDetections rawDetections = DLHelpers.executeInceptionGraph(s, input, size, size, MAX_DETECTIONS, maskWidth, maskHeight);
        Detections detections = processDetections(size, size, rawDetections);
        // BufferedImage outputImage = DLHelpers.augmentDetections(image512, detections);
        // ImageIO.write(outputImage, "jpeg", new File("d:/test-seg.jpg"));
        return detections;
    }
    return null;
}

19 Source : TensorUtil.java
with GNU Affero General Public License v3.0
from jpmml

static public float toFloatScalar(Tensor tensor) {
    try {
        return tensor.floatValue();
    } catch (Exception e) {
        float[] values = toFloatArray(tensor);
        if (values.length != 1) {
            throw new IllegalArgumentException("Expected 1-element array, got " + Arrays.toString(values));
        }
        return values[0];
    }
}

18 Source : ConvNet.java
with Apache License 2.0
from tomwhite

public float[] predict(BufferedImage image) throws IOException {
    try (Tensor imageTensor = executeGraphToNormalizeImage(asJpegByteArray(image))) {
        return executeCnnGraph(imageTensor);
    }
}

18 Source : FaceRecognizer.java
with GNU General Public License v3.0
from sanstorik

/**
 * Running neural network
 *
 * @param image cropped, centralized face
 * @return describing of a face based on 128 float features
 */
private FaceFeatures preplacedImageThroughNeuralNetwork(BufferedImage image, int faceType) {
    FaceFeatures features;
    try (Session session = new Session(graph)) {
        Tensor<Float> feedImage = Tensors.create(imageToMultiDimensionalArray(image));
        long timeResponse = System.currentTimeMillis();
        Tensor<Float> response = session.runner().feed("input", feedImage).feed("phase_train", Tensor.create(false)).fetch("embeddings").run().get(0).expect(Float.clreplaced);
        FileUtils.timeSpent(timeResponse, "RESPONSE");
        final long[] shape = response.shape();
        // first dimension should return 1 as for image with normal size
        // second dimension should give 128 characteristics of face
        if (shape[0] != 1 || shape[1] != 128) {
            throw new IllegalStateException("illegal output values: 1 = " + shape[0] + " 2 = " + shape[1]);
        }
        float[][] featuresHolder = new float[1][128];
        response.copyTo(featuresHolder);
        features = new FaceFeatures(featuresHolder[0], faceType);
        response.close();
    }
    return features;
}

18 Source : TFUtil.java
with Apache License 2.0
from rockyzhengwu

public static Tensor createTensor(List<List<Integer>> idList, int maxLength) {
    int[][] lList = intListToMat(idList, maxLength);
    Tensor tensor = Tensor.create(lList);
    return tensor;
}

18 Source : GraphBuilder.java
with Apache License 2.0
from pravega

public <T> Output<T> constant(String name, Object value, Clreplaced<T> type) {
    try (Tensor<T> t = Tensor.<T>create(value, type)) {
        return graph.opBuilder("Const", name).setAttr("dtype", DataType.fromClreplaced(type)).setAttr("value", t).build().<T>output(0);
    }
}

18 Source : DLHelpers.java
with GNU General Public License v3.0
from mstritt

public clreplaced DLHelpers {

    // public static final int[] RPN_ANCHOR_SCALES = new int[]{8 , 16, 32, 64, 128};
    public static final float[] MEAN_PIXEL = new float[] { 123.7f, 116.8f, 103.9f };

    private static Random random = new Random();

    // anchors 1,65472,4
    public static transient Tensor<Float> anchors = null;

    public static RawDetections executeInceptionGraph(final Session s, final Tensor<Float> input, final int inputWidth, final int inputHeight, final int maxDetections, final int maskWidth, final int maskHeight) {
        // image metas
        // meta = np.array(
        // [image_id] +                  # size=1
        // list(original_image_shape) +  # size=3
        // list(image_shape) +           # size=3
        // list(window) +                # size=4 (y1, x1, y2, x2) in image cooredinates
        // [scale[0]] +                     # size=1 NO LONGER, I dont have time to correct this properly so take only the first element
        // list(active_clreplaced_ids)        # size=num_clreplacedes
        // )
        final FloatBuffer metas = FloatBuffer.wrap(new float[] { 0, inputWidth, inputHeight, 3, inputWidth, inputHeight, 3, 0, 0, inputWidth, inputHeight, 1, 0, 0 });
        final Tensor<Float> meta_data = Tensor.create(new long[] { 1, 14 }, metas);
        List<Tensor<?>> res = s.runner().feed("input_image", input).feed("input_image_meta", meta_data).feed("input_anchors", // dtype float and shape [?,?,4]
        getAnchors(inputWidth)).fetch(// mrcnn_mask/Reshape_1   mrcnn_detection/Reshape_1    mrcnn_bbox/Reshape     mrcnn_clreplaced/Reshape_1
        "mrcnn_detection/Reshape_1").fetch("mrcnn_mask/Reshape_1").run();
        // mrcnn_detection/Reshape_1   -> y1,x1,y2,x2,clreplaced_id,probability (ordered desc)
        float[][][] res_detection = new float[1][maxDetections][6];
        // mrcnn_mask/Reshape_1
        float[][][][][] res_mask = new float[1][maxDetections][maskHeight][maskWidth][2];
        Tensor<Float> mrcnn_detection = res.get(0).expect(Float.clreplaced);
        Tensor<Float> mrcnn_mask = res.get(1).expect(Float.clreplaced);
        mrcnn_detection.copyTo(res_detection);
        mrcnn_mask.copyTo(res_mask);
        RawDetections rawDetections = new RawDetections();
        rawDetections.objectBB = res_detection;
        rawDetections.masks = res_mask;
        return rawDetections;
    }

    public static Tensor<Float> convertBufferedImageToTensor(BufferedImage image, int targetWidth, int targetHeight) {
        // if (image.getWidth()!=DESIRED_SIZE || image.getHeight()!=DESIRED_SIZE)
        {
            // also make it an RGB image
            image = resize(image, targetWidth, targetHeight);
        // image = resize(image,image.getWidth(), image.getHeight());
        }
        int width = image.getWidth();
        int height = image.getHeight();
        Raster r = image.getRaster();
        int[] rgb = new int[3];
        // int[] data = new int[width * height];
        // image.getRGB(0, 0, width, height, data, 0, width);
        float[][][][] rgbArray = new float[1][height][width][3];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                // Color color = new Color(data[i * width + j]);
                // rgbArray[0][i][j][0] = color.getRed() - MEAN_PIXEL[0];
                // rgbArray[0][i][j][1] = color.getGreen() - MEAN_PIXEL[1];
                // rgbArray[0][i][j][2] = color.getBlue() - MEAN_PIXEL[2];
                rgb = r.getPixel(j, i, rgb);
                rgbArray[0][i][j][0] = rgb[0] - MEAN_PIXEL[0];
                rgbArray[0][i][j][1] = rgb[1] - MEAN_PIXEL[1];
                rgbArray[0][i][j][2] = rgb[2] - MEAN_PIXEL[2];
            }
        }
        return Tensor.create(rgbArray, Float.clreplaced);
    }

    public static BufferedImage augmentDetections(BufferedImage image, Detections detections) {
        boolean drawBoundingBox = false;
        boolean drawContour = true;
        BufferedImage outImg = new BufferedImage(image.getWidth(), image.getHeight(), image.getType());
        Graphics2D g = outImg.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
        g.setRenderingHint(RenderingHints.KEY_ALPHA_INTERPOLATION, RenderingHints.VALUE_ALPHA_INTERPOLATION_QUALITY);
        g.drawImage(image, 0, 0, null);
        int numObjects = detections.getBoundingBoxes().size();
        for (int i = 0; i < numObjects; i++) {
            float probability = detections.getProbabilities().get(i);
            RectangleExt rect = detections.getBoundingBoxes().get(i);
            PolygonExt poly = detections.getContours().get(i);
            int x = rect.x;
            int y = rect.y;
            int width = rect.width;
            int height = rect.height;
            if (drawBoundingBox) {
                g.setStroke(new BasicStroke(2));
                g.setColor(Color.yellow);
                g.drawRect(x, y, width, height);
            }
            // draw contour
            if (drawContour) {
                g.setStroke(new BasicStroke(2));
                Color color = Color.getHSBColor(random.nextFloat(), 1f, 1f);
                g.setColor(color);
                g.drawPolygon(poly);
            }
        }
        g.dispose();
        return outImg;
    }

    public static synchronized Tensor<Float> getAnchors(int img_size) {
        if (anchors == null) {
            float[] fArr = MaskRCNNAnchors.GenerateAnchors(img_size);
            anchors = Tensor.create(new long[] { 1, fArr.length / 4, 4 }, FloatBuffer.wrap(fArr));
        }
        return anchors;
    }

    public static BufferedImage resize(BufferedImage img, int width, int height) {
        // int type = img.getType()>0?img.getType():BufferedImage.TYPE_INT_RGB;
        int type = BufferedImage.TYPE_INT_RGB;
        // BufferedImage resizedImage = new BufferedImage(roundP2(width), roundP2(height), type);
        BufferedImage resizedImage = new BufferedImage(width, height, type);
        Graphics2D g = resizedImage.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.drawImage(img, 0, 0, width, height, null);
        g.dispose();
        return resizedImage;
    }

    public static float[][] dilate(final float[][] buf) {
        final float cf = 1;
        final float[][] res = new float[buf.length][buf[0].length];
        for (int x = 0; x < buf.length; x++) for (int y = 0; y < buf[x].length; y++) {
            res[x][y] = buf[x][y];
            if (x > 0 && buf[x - 1][y] == cf)
                res[x][y] = cf;
            else if (x < buf.length - 1 && buf[x + 1][y] == cf)
                res[x][y] = cf;
            else if (y > 0 && buf[x][y - 1] == cf)
                res[x][y] = cf;
            else if (y < buf[x].length - 1 && buf[x][y + 1] == cf)
                res[x][y] = cf;
            // else
            // if (x>0 && y>0 && buf[x-1][y-1]==cf) res[x][y] = cf; else
            // if (x<buf.length-1 && y>0 && buf[x+1][y-1]==cf) res[x][y] = cf; else
            // if (y>0 && x<buf.length-1 && buf[x+1][y-1]==cf) res[x][y] = cf; else
            // if (y<buf[x].length-1 && x<buf.length-1 && buf[x+1][y+1]==cf) res[x][y] = cf;
            if (x == 0 || y == 0 || x == buf.length - 1 || y == buf[x].length - 1)
                // border pixel is always background (needed to 'close' objects)
                res[x][y] = 0;
        }
        return res;
    }

    public static float[][] erode(final float[][] buf) {
        final float cf = 0;
        final float[][] res = new float[buf.length][buf[0].length];
        for (int x = 0; x < buf.length; x++) for (int y = 0; y < buf[x].length; y++) {
            res[x][y] = buf[x][y];
            if (x > 0 && buf[x - 1][y] == cf)
                res[x][y] = cf;
            else if (x < buf.length - 1 && buf[x + 1][y] == cf)
                res[x][y] = cf;
            else if (y > 0 && buf[x][y - 1] == cf)
                res[x][y] = cf;
            else if (y < buf[x].length - 1 && buf[x][y + 1] == cf)
                res[x][y] = cf;
        }
        return res;
    }
}

18 Source : GraphBuilder.java
with GNU General Public License v3.0
from mstritt

<T> Output<T> constant(String name, Object value, Clreplaced<T> type) {
    try (Tensor<T> t = Tensor.<T>create(value, type)) {
        return g.opBuilder("Const", name).setAttr("dtype", DataType.fromClreplaced(type)).setAttr("value", t).build().<T>output(0);
    }
}

18 Source : DLSegment.java
with GNU General Public License v3.0
from mstritt

public static BufferedImage segmentInput(final Tensor<Float> inputTensor, Session s, Color bg, Color fg) {
    Tensor<Long> outputTensor = s.runner().feed("image_batch", inputTensor).fetch("predictions").run().get(0).expect(Long.clreplaced);
    long[] mask = outputTensor.copyTo(new long[outputTensor.numElements()]);
    BufferedImage bufferedImage = decodeLabels(mask, bg, fg);
    return bufferedImage;
}

18 Source : TensorUtil.java
with GNU Affero General Public License v3.0
from jpmml

static public long[] toLongArray(Tensor tensor) {
    LongBuffer longBuffer = LongBuffer.allocate(tensor.numElements());
    tensor.writeTo(longBuffer);
    return longBuffer.array();
}

18 Source : SessionRunner.java
with MIT License
from dhruvrajan

public SessionRunner feed(Tensor[] tensors, Operand[] ops) {
    for (Pair<Tensor, Operand> pairs : (Iterable<Pair<Tensor, Operand>>) () -> Pair.zip(tensors, ops)) {
        this.runner.feed(pairs.second().asOutput(), pairs.first());
    }
    return this;
}

18 Source : ImageClassifier.java
with Apache License 2.0
from baghelamit

public static String clreplacedifyImage(byte[] imageBytes) {
    try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
        float[] labelProbabilities = executeInceptionGraph(graphDef, image);
        int bestLabelIdx = maxIndex(labelProbabilities);
        String match = String.format("%s (%.2f%% likely)", labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f);
        return match;
    }
}

18 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(String[] data) {
    try (Tensor<TString> tensor = TString.vectorOf(data)) {
        return new TfNDArray(this, tensor);
    }
}

18 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(float data) {
    // create scalar tensor with float
    try (Tensor<TFloat32> tensor = TFloat32.scalarOf(data)) {
        return new TfNDArray(this, tensor);
    }
}

17 Source : ConvNet.java
with Apache License 2.0
from tomwhite

private float[] executeCnnGraph(Tensor image) {
    try (Tensor result = session.runner().feed(INPUT_OPERATION_NAME, image).feed(KERAS_LEARNING_PHASE_OPERATION_NAME, Tensor.create(false)).fetch(OUTPUT_OPERATION_NAME).run().get(0)) {
        final long[] rshape = result.shape();
        if (result.numDimensions() != 2 || rshape[0] != 1) {
            throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape)));
        }
        int nlabels = (int) rshape[1];
        return result.copyTo(new float[1][nlabels])[0];
    }
}

17 Source : ConvNet.java
with Apache License 2.0
from tomwhite

private Tensor executeGraphToNormalizeImage(byte[] imageBytes) {
    try (Tensor t = Tensor.create(imageBytes)) {
        return imageNormalizationSession.runner().feed("input", t).fetch(normalizationOutputOperationName).run().get(0);
    }
}

17 Source : Recognizer.java
with Apache License 2.0
from tahaemara

// /
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
    try (Graph g = new Graph()) {
        g.importGraphDef(graphDef);
        try (Session s = new Session(g);
            Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("softmax").run().get(0)) {
            final long[] rshape = result.shape();
            if (result.numDimensions() != 2 || rshape[0] != 1) {
                throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape)));
            }
            int nlabels = (int) rshape[1];
            return result.copyTo(new float[1][nlabels])[0];
        }
    }
}

17 Source : DavisCNNTensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

/**
 * Executes the stored Graph of the CNN.
 *
 * //https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
 * // https://github.com/tensorflow/tensorflow/issues/7149
 * https://stackoverflow.com/questions/44774234/why-tensorflow-uses-channel-last-ordering-instead-of-row-major
 *
 * @param pixbuf the pixel buffer holding the frame, as collected from
 * DVSFramer in DVSFrame.
 *
 * @param width width of image
 * @param height height of image
 * @return activations of output
 */
private float[] executeDvsFrameGraph(FloatBuffer pixbuf, int width, int height) {
    // final float mean = processor.getImageMean(), scale = processor.getImageScale();
    final int numChannels = processor.isMakeRGBFrames() ? 3 : 1;
    // TODO hack since we don't know the input size yet until network runs
    inputLayer = new InputLayer(width, height, numChannels);
    // TODO super hack brute force to flip image vertically because tobi cannot see how to flip an image in TensorFlow.
    // Also, make RGB frame from gray dvs image by cloning the gray value to each channel in WHC order
    final float[] origarray = pixbuf.array();
    FloatBuffer flipped = FloatBuffer.allocate(pixbuf.limit() * numChannels);
    final float[] flippedarray = flipped.array();
    // prepare rgb scaling factors to make RGB channels from grayscale. each channel has different weighting
    float[] rgb = null;
    if (processor.isMakeRGBFrames()) {
        rgb = new float[] { 1, 1, 1 };
    } else {
        rgb = new float[] { 1 };
    }
    for (int y = 0; y < height; y++) {
        for (int x = 0; x < width; x++) {
            final int origIdx = x + (width * y);
            for (int c = 0; c < numChannels; c++) {
                final int newIdx = c + (numChannels * (x + (width * (height - y - 1))));
                flippedarray[newIdx] = ((origarray[origIdx] * rgb[c]));
            }
        }
    }
    flipped = FloatBuffer.wrap(flippedarray);
    try (Tensor<Float> imageTensor = Tensor.create(new long[] { 1, height, width, numChannels }, flipped)) {
        // use NHWC order according to last post above
        // int numElements = imageTensor.numElements();
        // long[] shape = imageTensor.shape();
        float[] output = TensorFlow.executeGraph(executionGraph, imageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
        outputLayer = new OutputLayer(output);
        if (isSoftMaxOutput()) {
            computeSoftMax();
        }
        getSupport().firePropertyChange(EVENT_MADE_DECISION, null, this);
        return output;
    } catch (IllegalArgumentException ex) {
        String exhtml = ex.toString().replaceAll("<", "<").replaceAll(">", ">").replaceAll("&", "&").replaceAll("\n", "<br>");
        final StringBuilder msg = new StringBuilder("<html>Caught exception <p>" + exhtml + "</p>");
        msg.append("<br> Did you set <i>inputLayerName</i> and <i>outputLayerName</i>?");
        msg.append("<br>The IO layer names could be as follows (the string inside the single quotes): <ul> ");
        for (String s : ioLayers) {
            msg.append("<li>" + (s.replaceAll("<", "").replaceAll(">", "")) + "</li>");
        }
        msg.append("</ul></html>");
        log.warning(msg.toString());
        SwingUtilities.invokeLater(new Runnable() {

            @Override
            public void run() {
                JOptionPane.showMessageDialog(processor.getChip().getAeViewer(), msg.toString(), "Error computing network", JOptionPane.WARNING_MESSAGE);
            }
        });
        throw new IllegalArgumentException(ex.getCause());
    }
}

17 Source : DavisCNNTensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

@Override
public void processAPSDVSFrameArray(APSDVSFrame frame, long[][][] array) {
    // frame.NUM_CHANNELS;
    final int numChannels = 3;
    final int sx = frame.getWidth(), sy = frame.getHeight();
    FloatBuffer fb = FloatBuffer.allocate(sx * sy * numChannels);
    float[][][][] buf = new float[1][90][120][3];
    float nbNulPix = 0;
    for (int y = 0; y < sy; y++) {
        for (int x = 0; x < sx; x++) {
            for (int c = 0; c < numChannels; c++) {
                if (c == 2) {
                    buf[0][y][x][c] = 0;
                } else {
                    buf[0][y][x][c] = frame.getValue(c, x, y) * 255;
                // if( c==1 && frame.getValue(c,x,y)== 0)
                // nbNulPix++;
                }
            }
        // nbNulPix+=buf[0][y][x][0];
        }
    }
    System.out.println(Float.toString(nbNulPix / (90 * 120)));
    fb.rewind();
    // Tensor<Float> inputImageTensor = Tensor.create(new long[]{1, sy, sx, numChannels}, fb);
    Tensor<Float> inputImageTensor = Tensor.create(buf, Float.clreplaced);
    Boolean b = false;
    Tensor<Boolean> t = Tensor.create(b, Boolean.clreplaced);
    // executionGraph.opBuilder("MaxPoolWithArgmax", "MyMaxPoolWithArgmax").setAttr("dtype", inputImageTensor.dataType()).setAttr("value", inputImageTensor).build();
    // Tensor results = TensorFlow.executeGraphAndReturnTensor(executionGraph, inputImageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
    TensorFlow.executeGraphAndReturnTensorWithBooleanArray(array, executionGraph, inputImageTensor, processor.getInputLayerName(), t, "phase_train", processor.getOutputLayerName());
    getSupport().firePropertyChange(EVENT_MADE_DECISION, null, this);
}

17 Source : DavisCNNTensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

@Override
public Tensor processAPSDVSFrame(APSDVSFrame frame) {
    // frame.NUM_CHANNELS;
    final int numChannels = 3;
    final int sx = frame.getWidth(), sy = frame.getHeight();
    FloatBuffer fb = FloatBuffer.allocate(sx * sy * numChannels);
    for (int y = 0; y < sy; y++) {
        for (int x = 0; x < sx; x++) {
            for (int c = 0; c < numChannels; c++) {
                final int newIdx = c + (numChannels * (x + (sx * (sy - y - 1))));
                if (c == 2) {
                    fb.put(newIdx, 0);
                } else {
                    fb.put(newIdx, frame.getValue(c, x, y));
                }
            }
        }
    }
    fb.rewind();
    Tensor<Float> inputImageTensor = Tensor.create(new long[] { 1, sy, sx, numChannels }, fb);
    Boolean b = false;
    Tensor<Boolean> t = Tensor.create(b, Boolean.clreplaced);
    // executionGraph.opBuilder("MaxPoolWithArgmax", "MyMaxPoolWithArgmax").setAttr("dtype", inputImageTensor.dataType()).setAttr("value", inputImageTensor).build();
    // Tensor results = TensorFlow.executeGraphAndReturnTensor(executionGraph, inputImageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
    Tensor results = TensorFlow.executeGraphAndReturnTensorWithBoolean(executionGraph, inputImageTensor, processor.getInputLayerName(), t, "phase_train", processor.getOutputLayerName());
    getSupport().firePropertyChange(EVENT_MADE_DECISION, null, this);
    return results;
}

17 Source : TensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

// execute graph just one time to get the shape of the network output
public static int[] executeGraphStartup(Graph graph, Tensor<Float> image, String inputLayerName, String outputLayerName) {
    if (s == null) {
        s = new Session(graph);
    }
    try (Tensor<Float> result = s.runner().feed(inputLayerName, image).fetch(outputLayerName).run().get(0).expect(Float.clreplaced)) {
        long[] outputShapeLong = result.shape();
        // cast the output shape of the network to integers.
        int[] outputShape = { (int) outputShapeLong[1], (int) outputShapeLong[2], (int) outputShapeLong[3] };
        return outputShape;
    }
}

17 Source : DavisCNNTensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

/**
 * Executes the stored Graph of the CNN.
 *
 * //https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
 * // https://github.com/tensorflow/tensorflow/issues/7149
 * https://stackoverflow.com/questions/44774234/why-tensorflow-uses-channel-last-ordering-instead-of-row-major
 *
 * @param pixbuf the pixel buffer holding the frame, as collected from
 * DVSFramer in DVSFrame.
 *
 * @param width width of image
 * @param height height of image
 * @return activations of output
 */
private float[] executeDvsFrameGraph(FloatBuffer pixbuf, int width, int height) {
    // final float mean = processor.getImageMean(), scale = processor.getImageScale();
    final int numChannels = processor.isMakeRGBFrames() ? 3 : 1;
    // TODO hack since we don't know the input size yet until network runs
    inputLayer = new InputLayer(width, height, numChannels);
    // TODO super hack brute force to flip image vertically because tobi cannot see how to flip an image in TensorFlow.
    // Also, make RGB frame from gray dvs image by cloning the gray value to each channel in WHC order
    final float[] origarray = pixbuf.array();
    FloatBuffer flipped = FloatBuffer.allocate(pixbuf.limit() * numChannels);
    final float[] flippedarray = flipped.array();
    // prepare rgb scaling factors to make RGB channels from grayscale. each channel has different weighting
    float[] rgb = null;
    if (processor.isMakeRGBFrames()) {
        rgb = new float[] { 1, 1, 1 };
    } else {
        rgb = new float[] { 1 };
    }
    for (int y = 0; y < height; y++) {
        for (int x = 0; x < width; x++) {
            final int origIdx = x + (width * y);
            for (int c = 0; c < numChannels; c++) {
                final int newIdx = c + (numChannels * (x + (width * (height - y - 1))));
                flippedarray[newIdx] = ((origarray[origIdx] * rgb[c]));
            }
        }
    }
    flipped = FloatBuffer.wrap(flippedarray);
    try (Tensor<Float> imageTensor = Tensor.create(new long[] { 1, height, width, numChannels }, flipped)) {
        // use NHWC order according to last post above
        // int numElements = imageTensor.numElements();
        // long[] shape = imageTensor.shape();
        if (outShape == null) {
            outShape = TensorFlow.executeGraphStartup(executionGraph, imageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
        }
        long startTimeexecuteGraph = System.nanoTime();
        float[] output = TensorFlow.executeGraph(executionGraph, imageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
        long dtNs_executeGraph = (System.nanoTime() - startTimeexecuteGraph);
        log.info("executeGraph took " + (dtNs_executeGraph * 1e-6f) + " ms");
        // TIMING
        long startTime = System.nanoTime();
        outputLayer = new OutputLayer(output);
        long dtNs_outputLayer = (System.nanoTime() - startTime);
        log.info("outputLayer took " + (dtNs_outputLayer * 1e-6f) + " ms");
        // if (isSoftMaxOutput()) {
        // computeSoftMax();
        // throw new UnsupportedOperationException("Removed implementation.");
        // }
        getSupport().firePropertyChange(EVENT_MADE_DECISION, null, this);
        return output;
    } catch (IllegalArgumentException ex) {
        String exhtml = ex.toString().replaceAll("<", "<").replaceAll(">", ">").replaceAll("&", "&").replaceAll("\n", "<br>");
        final StringBuilder msg = new StringBuilder("<html>Caught exception <p>" + exhtml + "</p>");
        msg.append("<br> Did you set <i>inputLayerName</i> and <i>outputLayerName</i>?");
        msg.append("<br>The IO layer names could be as follows (the string inside the single quotes): <ul> ");
        for (String s : ioLayers) {
            msg.append("<li>" + (s.replaceAll("<", "").replaceAll(">", "")) + "</li>");
        }
        msg.append("</ul></html>");
        log.warning(msg.toString());
        SwingUtilities.invokeLater(new Runnable() {

            @Override
            public void run() {
                JOptionPane.showMessageDialog(processor.getChip().getAeViewer(), msg.toString(), "Error computing network", JOptionPane.WARNING_MESSAGE);
            }
        });
        throw new IllegalArgumentException(ex.getCause());
    }
}

17 Source : DavisCNNTensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

@Override
public void processAPSDVSFrameArray(APSDVSFrame frame, float[] array) {
    // frame.NUM_CHANNELS;
    final int numChannels = 3;
    final int sx = frame.getWidth(), sy = frame.getHeight();
    FloatBuffer fb = FloatBuffer.allocate(sx * sy * numChannels);
    // TODO: this is hardcoded, btw processAPSDVSFrameArray function is never used.
    float[][][][] buf = new float[1][260][344][1];
    float nbNulPix = 0;
    for (int y = 0; y < sy; y++) {
        for (int x = 0; x < sx; x++) {
            for (int c = 0; c < numChannels; c++) {
                if (c == 2) {
                    buf[0][y][x][c] = 0;
                } else {
                    buf[0][y][x][c] = frame.getValue(c, x, y) * 255;
                // if( c==1 && frame.getValue(c,x,y)== 0)
                // nbNulPix++;
                }
            }
        // nbNulPix+=buf[0][y][x][0];
        }
    }
    System.out.println(Float.toString(nbNulPix / (90 * 120)));
    fb.rewind();
    // Tensor<Float> inputImageTensor = Tensor.create(new long[]{1, sy, sx, numChannels}, fb);
    Tensor<Float> inputImageTensor = Tensor.create(buf, Float.clreplaced);
    Boolean b = false;
    Tensor<Boolean> t = Tensor.create(b, Boolean.clreplaced);
    // executionGraph.opBuilder("MaxPoolWithArgmax", "MyMaxPoolWithArgmax").setAttr("dtype", inputImageTensor.dataType()).setAttr("value", inputImageTensor).build();
    // Tensor results = TensorFlow.executeGraphAndReturnTensor(executionGraph, inputImageTensor, processor.getInputLayerName(), processor.getOutputLayerName());
    // TensorFlow.executeGraphAndReturnTensorWithBooleanArray(array, executionGraph, inputImageTensor, processor.getInputLayerName(),t,"phase_train", processor.getOutputLayerName());
    getSupport().firePropertyChange(EVENT_MADE_DECISION, null, this);
}

17 Source : TFPredictor.java
with Apache License 2.0
from rockyzhengwu

public clreplaced TFPredictor {

    protected String modelPath = "";

    protected int numClreplaced = 0;

    protected Session session;

    protected Graph graph = new Graph();

    protected Tensor dropT = Tensor.create(1.0f);

    protected boolean isInit = false;

    protected float[][] tranValue;

    protected Session.Runner runner;

    public TFPredictor() {
    }

    public TFPredictor(String modelPath, int numClreplaced) {
        this.initModel(modelPath, numClreplaced);
    }

    public TFPredictor(byte[] modelByte, int numClreplaced) {
        initModel(modelByte, numClreplaced);
    }

    public void initModel(byte[] modelByte, int numClreplaced) {
        if (isInit) {
            return;
        }
        this.numClreplaced = numClreplaced;
        this.graph.importGraphDef(modelByte);
        this.session = new Session(graph);
        isInit = true;
    }

    public void initModel(String modelPath, int numClreplaced) {
        if (isInit) {
            return;
        }
        this.numClreplaced = numClreplaced;
        this.modelPath = modelPath;
        byte[] modelByte = null;
        Path path = Paths.get(modelPath);
        try {
            modelByte = Files.readAllBytes(path);
            this.graph.importGraphDef(modelByte);
            this.session = new Session(graph);
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println(e.getMessage());
            System.out.println("load model " + modelPath.toString() + " error ");
        }
        isInit = true;
    }

    public List<List<Integer>> predict(List<List<Integer>> sents) {
        List<List<Integer>> path = new ArrayList();
        int sentLength = sents.size();
        int maxSentLength = 0;
        int[] sentLengths = new int[sentLength];
        for (int i = 0; i < sentLength; i++) {
            int s = sents.get(i).size();
            sentLengths[i] = s;
            if (maxSentLength < s) {
                maxSentLength = s;
            }
        }
        Tensor inputx = TFUtil.createTensor(sents, maxSentLength);
        Tensor lengths = Tensor.create(sentLengths);
        Tensor logits = session.runner().feed("char_inputs", inputx).feed("dropout", dropT).feed("lengths", lengths).fetch("project/logits").run().get(0);
        if (tranValue == null) {
            tranValue = new float[numClreplaced + 1][numClreplaced + 1];
            Tensor trans = session.runner().feed("char_inputs", inputx).feed("dropout", dropT).feed("lengths", lengths).fetch("crf_loss/transitions").run().get(0);
            trans.copyTo(tranValue);
        }
        float[][][] flog = new float[sentLength][maxSentLength][this.numClreplaced];
        logits.copyTo(flog);
        for (int i = 0; i < flog.length; i++) {
            List<Integer> sentPath = decode(flog[i], tranValue, sentLengths[i]);
            path.add(sentPath);
        }
        return path;
    }

    public List<Integer> decode(float[][] logits, float[][] trans, int sentLength) {
        List<Integer> path = new ArrayList<>();
        float[][] scores = new float[sentLength + 1][numClreplaced + 1];
        int[][] paths = new int[sentLength + 1][numClreplaced + 1];
        float[][] logitsAppend = new float[sentLength + 1][numClreplaced + 1];
        for (int i = 0; i < sentLength + 1; i++) {
            for (int j = 0; j < numClreplaced + 1; j++) {
                if (i == 0) {
                    logitsAppend[i][j] = -1000;
                } else if (j == numClreplaced) {
                    logitsAppend[i][j] = -1000;
                } else {
                    logitsAppend[i][j] = logits[i - 1][j];
                }
            }
        }
        logitsAppend[0][numClreplaced] = 0;
        // System.out.println("trans");
        // for(int i=0; i<numClreplaced + 1; i++){
        // for(int j=0; j<numClreplaced + 1;j++){
        // System.out.print(trans[i][j]);
        // System.out.print(" ");
        // }
        // System.out.println("");
        // }
        // 
        // System.out.println("logits:");
        // for (int i = 0; i < sentLength + 1; i++) {
        // for (int j = 0; j < numClreplaced + 1; j++) {
        // System.out.print(logitsAppend[i][j]);
        // System.out.print(" ");
        // }
        // System.out.println(" ");
        // }
        for (int i = 0; i < numClreplaced + 1; i++) {
            scores[0][i] = logitsAppend[0][i];
        }
        for (int i = 1; i < sentLength + 1; i++) {
            for (int j = 0; j < numClreplaced + 1; j++) {
                float maxS = -10000;
                for (int t = 0; t < numClreplaced + 1; t++) {
                    float ss = scores[i - 1][t] + trans[t][j];
                    if (ss > maxS) {
                        maxS = ss;
                        paths[i][j] = t;
                    }
                }
                scores[i][j] = maxS + logitsAppend[i][j];
            }
        }
        // System.out.println("scores");
        // for (int i = 0; i < sentLength + 1; i++) {
        // for (int j = 0; j < numClreplaced + 1; j++) {
        // System.out.print(scores[i][j]);
        // System.out.print(" ");
        // }
        // System.out.println("");
        // }
        // 
        // for (int i = 0; i < sentLength + 1; i++) {
        // for (int j = 0; j < numClreplaced + 1; j++) {
        // System.out.print(paths[i][j]);
        // System.out.print(" ");
        // }
        // System.out.println(" ");
        // }
        // back path
        float maxvalue = -10000;
        int maxPath = 0;
        for (int j = 0; j < numClreplaced; j++) {
            if (scores[sentLength][j] > maxvalue) {
                maxvalue = scores[sentLength][j];
                maxPath = j;
            }
        }
        path.add(maxPath);
        for (int i = sentLength; i > 1; i--) {
            maxPath = paths[i][maxPath];
            path.add(maxPath);
        }
        Collections.reverse(path);
        return path;
    }
}

17 Source : MRCNNBrainDetector.java
with GNU General Public License v3.0
from mstritt

public List<DetectorResult> detectBrains(final BufferedImage smallImage, BufferedImage image512) throws IOException {
    List<DetectorResult> resList = new ArrayList<>();
    Tensor<Float> input = DLHelpers.convertBufferedImageToTensor(image512, 512, 512);
    if (input != null) {
        RawDetections rawDetections = DLHelpers.executeInceptionGraph(s, input, 512, 512, MAX_DETECTIONS, 28, 28);
        Detections detections = processDetections(512, 512, rawDetections);
        double scaleW = smallImage.getWidth() / (double) image512.getWidth();
        double scaleH = smallImage.getHeight() / (double) image512.getHeight();
        Rectangle bb = detections.getContours().get(0).getBounds();
        int pad = 20;
        bb = new Rectangle(bb.x - pad, bb.y - pad, bb.width + pad * 2, bb.height + pad * 2);
        Rectangle bbScaled = new Rectangle((int) (bb.x * scaleW), (int) (bb.y * scaleH), (int) (bb.width * scaleW), (int) (bb.height * scaleH));
        bbScaled = new Rectangle(smallImage.getMinX(), smallImage.getMinY(), smallImage.getWidth(), smallImage.getHeight()).intersection(bbScaled);
        BufferedImage roiImage = smallImage.getSubimage(bbScaled.x, bbScaled.y, (int) bbScaled.getWidth(), (int) bbScaled.getHeight());
        // roiImage = DLHelpers.resize(roiImage,DLHelpers.DESIRED_SIZE,DLHelpers.DESIRED_SIZE);
        roiImage = DLHelpers.resize(roiImage, size, size);
        DetectorResult result = new DetectorResult(roiImage, bb.x, bb.y, bb.x + bb.width, bb.y + bb.height);
        resList.add(result);
    }
    return resList;
}

17 Source : InstSegMaskRCNN.java
with GNU General Public License v3.0
from mstritt

public clreplaced InstSegMaskRCNN {

    private static final Logger logger = LoggerFactory.getLogger(InstSegMaskRCNN.clreplaced);

    private static final String MODEL_DIR = "C:\\git\\python\\DSB_2018_DEEPRETINA\\logs\\final";

    private static final String MODEL_NAME = "deepretina_final.pb";

    // private static final String INPUT_IMAGE = "D:\\NoBackup\\databowl2018\\stage1_test\\da6c593410340b19bb212b9f6d274f95b08c0fc8f2570cd66bc5ed42c560acab\\images\\da6c593410340b19bb212b9f6d274f95b08c0fc8f2570cd66bc5ed42c560acab.png";
    private static final String INPUT_IMAGE = "D:\\NoBackup\\databowl2018\\in\\sample512\\images\\sample512.png";

    private static final String OUTPUT_IMAGE = "D:\\NoBackup\\databowl2018\\out\\out.jpg";

    private static final int DESIRED_SIZE = 512;

    private static final int[] RPN_ANCHOR_SCALES = new int[] { 8, 16, 32, 64, 128 };

    private static final float[] MEAN_PIXEL = new float[] { 123.7f, 116.8f, 103.9f };

    // anchors 1,65472,4
    private static transient Tensor<Float> anchors = null;

    private Random random = new Random();

    public static void main2(String[] args) throws IOException {
        Date startDate = new Date();
        final InstSegMaskRCNN maskRCNN = new InstSegMaskRCNN();
        final byte[] graphDef = Files.readAllBytes(Paths.get(MODEL_DIR, MODEL_NAME));
        final Graph g = maskRCNN.loadGraph(graphDef);
        final Session s = maskRCNN.createSession(g);
        try {
            BufferedImage originalImage = ImageIO.read(new File(INPUT_IMAGE));
            long startt = System.currentTimeMillis();
            List<Callable<BufferedImage>> tasks = new ArrayList<>();
            for (int i = 0; i < 100; i++) {
                tasks.add(new Callable<BufferedImage>() {

                    @Override
                    public BufferedImage call() throws Exception {
                        Tensor<Float> input = maskRCNN.convertBufferedImageToTensor(originalImage);
                        if (input != null) {
                            RawDetections rawDetections = maskRCNN.executeInceptionGraph(s, input);
                            input.close();
                            Detections detections = maskRCNN.processDetections(512, 512, rawDetections);
                            BufferedImage outputImage = maskRCNN.augmentDetections(originalImage, detections);
                            // ImageIO.write(outputImage, "jpg", new File(OUTPUT_IMAGE));
                            return outputImage;
                        }
                        return null;
                    }
                });
            }
            ExecutorService executor = Executors.newFixedThreadPool(1);
            try {
                executor.invokeAll(tasks);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                executor.shutdownNow();
            }
            long used = System.currentTimeMillis() - startt;
            System.out.println("time used: " + used / 1000d);
        } finally {
            s.close();
            g.close();
        }
        long elapsedTimeInSec = (new Date().getTime() - startDate.getTime()) / 1000;
        System.out.println(String.format("Ended in %ds .", elapsedTimeInSec));
    }

    public static void main(String[] args) throws IOException {
        Date startDate = new Date();
        InstSegMaskRCNN maskRCNN = new InstSegMaskRCNN();
        byte[] graphDef = Files.readAllBytes(Paths.get(MODEL_DIR, MODEL_NAME));
        try (Graph g = maskRCNN.loadGraph(graphDef);
            Session s = maskRCNN.createSession(g)) {
            BufferedImage originalImage = ImageIO.read(new File(INPUT_IMAGE));
            long startt = System.currentTimeMillis();
            for (int i = 0; i < 100; i++) {
                Tensor<Float> input = maskRCNN.convertBufferedImageToTensor(originalImage);
                if (input != null) {
                    RawDetections rawDetections = maskRCNN.executeInceptionGraph(s, input);
                    Detections detections = maskRCNN.processDetections(512, 512, rawDetections);
                    BufferedImage outputImage = maskRCNN.augmentDetections(originalImage, detections);
                    ImageIO.write(outputImage, "jpg", new File(OUTPUT_IMAGE));
                }
            }
            long used = System.currentTimeMillis() - startt;
            System.out.println("time used: " + used / 1000d);
        }
        long elapsedTimeInSec = (new Date().getTime() - startDate.getTime()) / 1000;
        System.out.println(String.format("Ended in %ds .", elapsedTimeInSec));
    }

    public Detections processDetections(int imgWidth, int imgHeight, RawDetections rawDetections) {
        Detections detections = new Detections();
        detections.setBoundingBoxes(new ArrayList<>());
        detections.setContours(new ArrayList<>());
        detections.setProbabilities(new ArrayList<>());
        // only one image      y1,x1,y2,x2,clreplaced_id,probability (ordered desc)
        float[][] objects = rawDetections.objectBB[0];
        for (int i = 0; i < objects.length; i++) {
            // y1,x1,y2,x2,clreplaced_id,probability
            float[] bb = objects[i];
            if (bb[5] > 0.1 && ((bb[3] - bb[1]) * (bb[2] - bb[0]) > 1E-5)) {
                // probability > 0.8 and area > 1E-5
                float probability = bb[5];
                int x = (int) (bb[1] * imgWidth);
                int y = (int) (bb[0] * imgHeight);
                int width = (int) ((bb[3] - bb[1]) * imgWidth);
                int height = (int) ((bb[2] - bb[0]) * imgHeight);
                RectangleExt boundingBox = new RectangleExt(x, y, width, height);
                // masks    [1][512][28][28][2]
                float[][][] mask = rawDetections.masks[0][i];
                int rw = mask[0].length;
                int rh = mask.length;
                float scaleW = (float) width / (float) rw;
                float scaleH = (float) height / (float) rh;
                BufferedImage roi = new BufferedImage(rw, rh, BufferedImage.TYPE_INT_ARGB);
                for (int xr = 0; xr < rw; xr++) {
                    for (int yr = 0; yr < rh; yr++) {
                        // if (mask[yr][xr][1]>0.5f)
                        if (mask[yr][xr][0] < mask[yr][xr][1]) // if (mask[yr][xr][0]<mask[yr][xr][1] && mask[yr][xr][1]>0.5f)
                        {
                            roi.setRGB(xr, yr, Color.MAGENTA.getRGB());
                        }
                    }
                }
                // scale mask to bb size
                BufferedImage roiScaled = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
                Graphics2D roiG = roiScaled.createGraphics();
                roiG.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
                roiG.drawImage(roi, 0, 0, width, height, null);
                // draw contour
                int pad = 10;
                int[] argb = new int[4];
                Raster raster = roiScaled.getRaster();
                float[][] area = new float[roiScaled.getWidth() + pad * 2][roiScaled.getHeight() + pad * 2];
                for (int x1 = 0; x1 < roiScaled.getWidth(); x1++) {
                    for (int y1 = 0; y1 < roiScaled.getHeight(); y1++) {
                        argb = raster.getPixel(x1, y1, argb);
                        long sum = argb[1] + argb[2] + argb[3];
                        area[x1 + pad][y1 + pad] = sum > 0l ? 1 : 0;
                    }
                }
                // fill holes
                area = dilate(area);
                area = erode(area);
                final ArrayList<Point2D> contour = new ArrayList<>();
                if (MarchingSquares.calculateContour(contour, area, 1, 0.5f)) {
                    int[] xpoints = new int[contour.size()];
                    int[] ypoints = new int[contour.size()];
                    for (int j = 0; j < contour.size(); j++) {
                        xpoints[j] = (int) (x + ((contour.get(j).getX() - pad) * 1f));
                        ypoints[j] = (int) (y + ((contour.get(j).getY() - pad) * 1f));
                    }
                    PolygonExt polygon = new PolygonExt(new Polygon(xpoints, ypoints, xpoints.length));
                    detections.getProbabilities().add(probability);
                    detections.getBoundingBoxes().add(boundingBox);
                    detections.getContours().add(polygon);
                }
                roiG.dispose();
            }
        }
        return detections;
    }

    private BufferedImage augmentDetections(BufferedImage image, Detections detections) {
        boolean drawBoundingBox = false;
        boolean drawContour = true;
        BufferedImage outImg = new BufferedImage(image.getWidth(), image.getHeight(), image.getType());
        Graphics2D g = outImg.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
        g.setRenderingHint(RenderingHints.KEY_ALPHA_INTERPOLATION, RenderingHints.VALUE_ALPHA_INTERPOLATION_QUALITY);
        g.drawImage(image, 0, 0, null);
        int numObjects = detections.boundingBoxes.size();
        for (int i = 0; i < numObjects; i++) {
            float probability = detections.getProbabilities().get(i);
            RectangleExt rect = detections.getBoundingBoxes().get(i);
            PolygonExt poly = detections.getContours().get(i);
            int x = rect.x;
            int y = rect.y;
            int width = rect.width;
            int height = rect.height;
            if (drawBoundingBox) {
                g.setStroke(new BasicStroke(2));
                g.setColor(Color.yellow);
                g.drawRect(x, y, width, height);
            }
            // draw contour
            if (drawContour) {
                g.setStroke(new BasicStroke(2));
                Color color = Color.getHSBColor(random.nextFloat(), 1f, 1f);
                g.setColor(color);
                g.drawPolygon(poly);
            }
        }
        g.dispose();
        return outImg;
    }

    private float[][] dilate(final float[][] buf) {
        final float cf = 1;
        final float[][] res = new float[buf.length][buf[0].length];
        for (int x = 0; x < buf.length; x++) for (int y = 0; y < buf[x].length; y++) {
            res[x][y] = buf[x][y];
            if (x > 0 && buf[x - 1][y] == cf)
                res[x][y] = cf;
            else if (x < buf.length - 1 && buf[x + 1][y] == cf)
                res[x][y] = cf;
            else if (y > 0 && buf[x][y - 1] == cf)
                res[x][y] = cf;
            else if (y < buf[x].length - 1 && buf[x][y + 1] == cf)
                res[x][y] = cf;
            // else
            // if (x>0 && y>0 && buf[x-1][y-1]==cf) res[x][y] = cf; else
            // if (x<buf.length-1 && y>0 && buf[x+1][y-1]==cf) res[x][y] = cf; else
            // if (y>0 && x<buf.length-1 && buf[x+1][y-1]==cf) res[x][y] = cf; else
            // if (y<buf[x].length-1 && x<buf.length-1 && buf[x+1][y+1]==cf) res[x][y] = cf;
            if (x == 0 || y == 0 || x == buf.length - 1 || y == buf[x].length - 1)
                // border pixel is always background (needed to 'close' objects)
                res[x][y] = 0;
        }
        return res;
    }

    private float[][] erode(final float[][] buf) {
        final float cf = 0;
        final float[][] res = new float[buf.length][buf[0].length];
        for (int x = 0; x < buf.length; x++) for (int y = 0; y < buf[x].length; y++) {
            res[x][y] = buf[x][y];
            if (x > 0 && buf[x - 1][y] == cf)
                res[x][y] = cf;
            else if (x < buf.length - 1 && buf[x + 1][y] == cf)
                res[x][y] = cf;
            else if (y > 0 && buf[x][y - 1] == cf)
                res[x][y] = cf;
            else if (y < buf[x].length - 1 && buf[x][y + 1] == cf)
                res[x][y] = cf;
        }
        return res;
    }

    private synchronized Tensor<Float> getAnchors() {
        if (anchors == null) {
            float[] fArr = MaskRCNNAnchors.GenerateAnchors(DESIRED_SIZE);
            anchors = Tensor.create(new long[] { 1, fArr.length / 4, 4 }, FloatBuffer.wrap(fArr));
        }
        return anchors;
    }

    public Tensor<Float> convertBufferedImageToTensor(BufferedImage image) {
        // if (image.getWidth()!=DESIRED_SIZE || image.getHeight()!=DESIRED_SIZE)
        {
            // also make it an RGB image
            image = resize(image, DESIRED_SIZE, DESIRED_SIZE);
        // image = resize(image,image.getWidth(), image.getHeight());
        }
        int width = image.getWidth();
        int height = image.getHeight();
        Raster r = image.getRaster();
        int[] rgb = new int[3];
        // int[] data = new int[width * height];
        // image.getRGB(0, 0, width, height, data, 0, width);
        float[][][][] rgbArray = new float[1][height][width][3];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                // Color color = new Color(data[i * width + j]);
                // rgbArray[0][i][j][0] = color.getRed() - MEAN_PIXEL[0];
                // rgbArray[0][i][j][1] = color.getGreen() - MEAN_PIXEL[1];
                // rgbArray[0][i][j][2] = color.getBlue() - MEAN_PIXEL[2];
                rgb = r.getPixel(j, i, rgb);
                rgbArray[0][i][j][0] = rgb[0] - MEAN_PIXEL[0];
                rgbArray[0][i][j][1] = rgb[1] - MEAN_PIXEL[1];
                rgbArray[0][i][j][2] = rgb[2] - MEAN_PIXEL[2];
            }
        }
        return Tensor.create(rgbArray, Float.clreplaced);
    }

    public BufferedImage resize(BufferedImage img, int width, int height) {
        // int type = img.getType()>0?img.getType():BufferedImage.TYPE_INT_RGB;
        int type = BufferedImage.TYPE_INT_RGB;
        // BufferedImage resizedImage = new BufferedImage(roundP2(width), roundP2(height), type);
        BufferedImage resizedImage = new BufferedImage(width, height, type);
        Graphics2D g = resizedImage.createGraphics();
        g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        g.drawImage(img, 0, 0, width, height, null);
        g.dispose();
        return resizedImage;
    }

    private int roundP2(int x) {
        return (int) Math.pow(2, Math.ceil(Math.log(x) / Math.log(2)));
    }

    public Graph loadGraph(byte[] graphDef) {
        logger.info("TF version " + TensorFlow.version());
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        return g;
    }

    public Session createSession(Graph g) {
        // output all node names
        // Iterator<Operation> ops = g.operations();
        // while (ops.hasNext()) {
        // Operation op = ops.next();
        // System.out.println(op.name());
        // }
        // System.out.println("finished");
        Session s = new Session(g);
        logger.trace("TF session created");
        return s;
    }

    public RawDetections executeInceptionGraph(Session s, Tensor<Float> input) {
        // image metas
        // meta = np.array(
        // [image_id] +                  # size=1
        // list(original_image_shape) +  # size=3
        // list(image_shape) +           # size=3
        // list(window) +                # size=4 (y1, x1, y2, x2) in image cooredinates
        // [scale[0]] +                     # size=1 NO LONGER, I dont have time to correct this properly so take only the first element
        // list(active_clreplaced_ids)        # size=num_clreplacedes
        // )
        final FloatBuffer metas = FloatBuffer.wrap(new float[] { 0, 512, 512, 3, 512, 512, 3, 0, 0, 512, 512, 1, 0, 0 });
        final Tensor<Float> meta_data = Tensor.create(new long[] { 1, 14 }, metas);
        List<Tensor<?>> res = s.runner().feed("input_image", input).feed("input_image_meta", meta_data).feed("input_anchors", // dtype float and shape [?,?,4]
        getAnchors()).fetch(// mrcnn_mask/Reshape_1   mrcnn_detection/Reshape_1    mrcnn_bbox/Reshape     mrcnn_clreplaced/Reshape_1
        "mrcnn_detection/Reshape_1").fetch("mrcnn_mask/Reshape_1").run();
        // mrcnn_detection/Reshape_1   -> y1,x1,y2,x2,clreplaced_id,probability (ordered desc)
        float[][][] res_detection = new float[1][512][6];
        // mrcnn_mask/Reshape_1
        float[][][][][] res_mask = new float[1][512][28][28][2];
        Tensor<Float> mrcnn_detection = res.get(0).expect(Float.clreplaced);
        Tensor<Float> mrcnn_mask = res.get(1).expect(Float.clreplaced);
        mrcnn_detection.copyTo(res_detection);
        mrcnn_mask.copyTo(res_mask);
        RawDetections rawDetections = new RawDetections();
        rawDetections.objectBB = res_detection;
        rawDetections.masks = res_mask;
        return rawDetections;
    }

    public clreplaced RawDetections {

        // y1,x1,y2,x2,clreplaced_id,probability (ordered desc)
        float[][][] objectBB;

        // float[1][512][28][28][2] max 512 instances, for each a 28x28 mask x probability foreground/background
        float[][][][][] masks;
    }

    public clreplaced Detections {

        private List<PolygonExt> contours;

        private List<RectangleExt> boundingBoxes;

        private List<Float> probabilities;

        public List<PolygonExt> getContours() {
            return contours;
        }

        public void setContours(List<PolygonExt> contours) {
            this.contours = contours;
        }

        public List<RectangleExt> getBoundingBoxes() {
            return boundingBoxes;
        }

        public void setBoundingBoxes(List<RectangleExt> boundingBoxes) {
            this.boundingBoxes = boundingBoxes;
        }

        public List<Float> getProbabilities() {
            return probabilities;
        }

        public void setProbabilities(List<Float> probabilities) {
            this.probabilities = probabilities;
        }
    }
}

17 Source : InstSegMaskRCNN.java
with GNU General Public License v3.0
from mstritt

public RawDetections executeInceptionGraph(Session s, Tensor<Float> input) {
    // image metas
    // meta = np.array(
    // [image_id] +                  # size=1
    // list(original_image_shape) +  # size=3
    // list(image_shape) +           # size=3
    // list(window) +                # size=4 (y1, x1, y2, x2) in image cooredinates
    // [scale[0]] +                     # size=1 NO LONGER, I dont have time to correct this properly so take only the first element
    // list(active_clreplaced_ids)        # size=num_clreplacedes
    // )
    final FloatBuffer metas = FloatBuffer.wrap(new float[] { 0, 512, 512, 3, 512, 512, 3, 0, 0, 512, 512, 1, 0, 0 });
    final Tensor<Float> meta_data = Tensor.create(new long[] { 1, 14 }, metas);
    List<Tensor<?>> res = s.runner().feed("input_image", input).feed("input_image_meta", meta_data).feed("input_anchors", // dtype float and shape [?,?,4]
    getAnchors()).fetch(// mrcnn_mask/Reshape_1   mrcnn_detection/Reshape_1    mrcnn_bbox/Reshape     mrcnn_clreplaced/Reshape_1
    "mrcnn_detection/Reshape_1").fetch("mrcnn_mask/Reshape_1").run();
    // mrcnn_detection/Reshape_1   -> y1,x1,y2,x2,clreplaced_id,probability (ordered desc)
    float[][][] res_detection = new float[1][512][6];
    // mrcnn_mask/Reshape_1
    float[][][][][] res_mask = new float[1][512][28][28][2];
    Tensor<Float> mrcnn_detection = res.get(0).expect(Float.clreplaced);
    Tensor<Float> mrcnn_mask = res.get(1).expect(Float.clreplaced);
    mrcnn_detection.copyTo(res_detection);
    mrcnn_mask.copyTo(res_mask);
    RawDetections rawDetections = new RawDetections();
    rawDetections.objectBB = res_detection;
    rawDetections.masks = res_mask;
    return rawDetections;
}

17 Source : DLSegment.java
with GNU General Public License v3.0
from mstritt

private static long[] executeInceptionGraph(byte[] graphDef, Tensor<Float> input) {
    try (Graph g = new Graph()) {
        g.importGraphDef(graphDef);
        try (Session s = new Session(g);
            Tensor<Long> result = s.runner().feed("image_batch", input).fetch("predictions").run().get(0).expect(Long.clreplaced)) {
            return result.copyTo(new long[result.numElements()]);
        }
    }
}

17 Source : Kafka_Streams_TensorFlow_Image_Recognition_Example_IntegrationTest.java
with Apache License 2.0
from kaiwaehner

private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
    try (Graph g = new Graph()) {
        g.importGraphDef(graphDef);
        try (Session s = new Session(g);
            Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
            final long[] rshape = result.shape();
            if (result.numDimensions() != 2 || rshape[0] != 1) {
                throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape)));
            }
            int nlabels = (int) rshape[1];
            return result.copyTo(new float[1][nlabels])[0];
        }
    }
}

17 Source : Kafka_Streams_TensorFlow_Image_Recognition_Example.java
with Apache License 2.0
from kaiwaehner

private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
    try (Graph g = new Graph()) {
        // Model loading: Using Graph.importGraphDef() to load a pre-trained Inception
        // model.
        g.importGraphDef(graphDef);
        // Graph execution: Using a Session to execute the graphs and find the best
        // label for an image.
        try (Session s = new Session(g);
            Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
            final long[] rshape = result.shape();
            if (result.numDimensions() != 2 || rshape[0] != 1) {
                throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape)));
            }
            int nlabels = (int) rshape[1];
            return result.copyTo(new float[1][nlabels])[0];
        }
    }
}

17 Source : TensorUtil.java
with GNU Affero General Public License v3.0
from jpmml

static public double[] toDoubleArray(Tensor tensor) {
    DoubleBuffer doubleBuffer = DoubleBuffer.allocate(tensor.numElements());
    tensor.writeTo(doubleBuffer);
    return doubleBuffer.array();
}

17 Source : InProcessClassification.java
with Apache License 2.0
from hazelcast

private static Tuple2<String, Float> clreplacedify(String review, SavedModelBundle model, WordIndex wordIndex) {
    try (Tensor<Float> input = Tensors.create(wordIndex.createTensorInput(review));
        Tensor<?> output = model.session().runner().feed("embedding_input:0", input).fetch("dense_1/Sigmoid:0").run().get(0)) {
        float[][] result = new float[1][1];
        output.copyTo(result);
        return tuple2(review, result[0][0]);
    }
}

17 Source : ObjectDetector.java
with MIT License
from chen0040

public List<DetectedObj> detectObjects(BufferedImage img) throws IOException {
    logger.info("begin detecting objects from image ...");
    List<DetectedObj> result = new ArrayList<>();
    List<Tensor<?>> outputs = null;
    try (Tensor<UInt8> input = TensorUtils.makeImageTensor(img)) {
        outputs = model.session().runner().feed("image_tensor", input).fetch("detection_scores").fetch("detection_clreplacedes").fetch("detection_boxes").run();
    }
    try (Tensor<Float> scoresT = outputs.get(0).expect(Float.clreplaced);
        Tensor<Float> clreplacedesT = outputs.get(1).expect(Float.clreplaced);
        Tensor<Float> boxesT = outputs.get(2).expect(Float.clreplaced)) {
        // All these tensors have:
        // - 1 as the first dimension
        // - maxObjects as the second dimension
        // While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
        // This can be verified by looking at scoresT.shape() etc.
        int maxObjects = (int) scoresT.shape()[1];
        float[] scores = scoresT.copyTo(new float[1][maxObjects])[0];
        float[] clreplacedes = clreplacedesT.copyTo(new float[1][maxObjects])[0];
        float[][] boxes = boxesT.copyTo(new float[1][maxObjects][4])[0];
        for (int i = 0; i < scores.length; ++i) {
            if (scores[i] < 0.5) {
                continue;
            }
            String label = labels[(int) clreplacedes[i]];
            float score = scores[i];
            float[] box = boxes[i];
            DetectedObj detectedObj = new DetectedObj(label, score, box);
            result.add(detectedObj);
        }
    }
    logger.info("object detection completed on image");
    return result;
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(Shape shape, DataType dataType) {
    if (shape.dimension() == 0) {
        // TensorFlow does not support empty scalar(emtpy NDArray with 0 dimension)
        // initialize with scalar 0
        return create(0f).toType(dataType, false);
    }
    Tensor<?> tensor = Tensor.of(TfDataType.toTf(dataType), TfNDArray.toTfShape(shape));
    return new TfNDArray(this, tensor);
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(int data) {
    // create scalar tensor with int
    try (Tensor<TInt32> tensor = TInt32.scalarOf(data)) {
        return new TfNDArray(this, tensor);
    }
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray create(String data) {
    try (Tensor<TString> tensor = TString.scalarOf(data)) {
        return new TfNDArray(this, tensor);
    }
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

public NDArray fill(Shape shape, Number value, DataType dataType) {
    switch(dataType) {
        case INT32:
            try (Tensor<?> tensor = tf.fill(tf.constant(shape.getShape()), tf.constant(value.intValue())).asTensor()) {
                return new TfNDArray(this, tensor);
            }
        case INT64:
            try (Tensor<?> tensor = tf.fill(tf.constant(shape.getShape()).asOutput(), tf.constant(value.longValue())).asTensor()) {
                return new TfNDArray(this, tensor);
            }
        case FLOAT16:
            try (Tensor<?> tensor = tf.fill(tf.constant(shape.getShape()).asOutput(), tf.constant(value.shortValue())).asTensor()) {
                return new TfNDArray(this, tensor);
            }
        case FLOAT64:
            try (Tensor<?> tensor = tf.fill(tf.constant(shape.getShape()).asOutput(), tf.constant(value.doubleValue())).asTensor()) {
                return new TfNDArray(this, tensor);
            }
        default:
            try (Tensor<?> tensor = tf.fill(tf.constant(shape.getShape()).asOutput(), tf.constant(value.floatValue())).asTensor()) {
                return new TfNDArray(this, tensor);
            }
    }
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray linspace(float start, float stop, int num, boolean endpoint) {
    if (num < 0) {
        throw new IllegalArgumentException("number of samples must be non-negative.");
    }
    if (num == 0) {
        return create(new Shape(0));
    }
    if (endpoint) {
        try (Tensor<?> tensor = org.tensorflow.op.core.LinSpace.create(tf.scope(), tf.constant(start), tf.constant(stop), tf.constant(num)).asTensor()) {
            return new TfNDArray(this, tensor);
        }
    }
    try (Tensor<?> tensor = org.tensorflow.op.core.LinSpace.create(tf.scope(), tf.constant(start), tf.constant(stop), tf.constant(num + 1)).asTensor()) {
        return new TfNDArray(this, tensor).get(new NDIndex(":-1"));
    }
}

17 Source : TfNDManager.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray arange(float start, float stop, float step, DataType dataType) {
    if (stop <= start && step > 0) {
        return create(new Shape(0), dataType);
    }
    try (Tensor<?> tensor = tf.range(toConstant(start, dataType), toConstant(stop, dataType), toConstant(step, dataType)).asTensor()) {
        return new TfNDArray(this, tensor);
    }
}

17 Source : TfNDArrayEx.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray sigmoid() {
    try (Tensor<?> tensor = tf.math.sigmoid(array.getOperand()).asTensor()) {
        return new TfNDArray(manager, tensor);
    }
}

17 Source : TfNDArrayEx.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray softSign() {
    try (Tensor<?> tensor = tf.nn.softsign(array.getOperand()).asTensor()) {
        return new TfNDArray(manager, tensor);
    }
}

17 Source : TfNDArrayEx.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray selu() {
    try (Tensor<?> tensor = tf.nn.selu(array.getOperand()).asTensor()) {
        return new TfNDArray(manager, tensor);
    }
}

17 Source : TfNDArrayEx.java
with Apache License 2.0
from awslabs

/**
 * {@inheritDoc}
 */
@Override
public NDArray relu() {
    try (Tensor<?> tensor = tf.nn.relu(array.getOperand()).asTensor()) {
        return new TfNDArray(manager, tensor);
    }
}

16 Source : TensorflowProcessorConfiguration.java
with Apache License 2.0
from tzolov

@ServiceActivator(inputChannel = Processor.INPUT, outputChannel = Processor.OUTPUT)
public Message<?> evaluate(Message<?> input) {
    Map<String, Object> processorContext = new ConcurrentHashMap<>();
    Map<String, Object> inputData = tensorflowInputConverter.convert(input, processorContext);
    Tensor outputTensor = tensorFlowService.evaluate(inputData, properties.getOutputName(), properties.getOutputIndex());
    Object outputData = tensorflowOutputConverter.convert(outputTensor, processorContext);
    if (properties.isSaveOutputInHeader()) {
        // Add the result to the message header
        return MessageBuilder.withPayload(input.getPayload()).copyHeadersIfAbsent(input.getHeaders()).setHeaderIfAbsent(TF_OUTPUT_HEADER, outputData).build();
    }
    // Add the outputData as part of the message payload
    Message<?> outputMessage = MessageBuilder.withPayload(outputData).copyHeadersIfAbsent(input.getHeaders()).build();
    return outputMessage;
}

16 Source : LabelImageTensorflowInputConverter.java
with Apache License 2.0
from tzolov

@Override
public Map<String, Object> convert(Message<?> input, Map<String, Object> processorContext) {
    Object payload = input.getPayload();
    if (payload instanceof byte[]) {
        Tensor inputImageTensor = constructAndExecuteGraphToNormalizeImage3((byte[]) payload);
        Map<String, Object> inputMap = new HashMap<>();
        inputMap.put("input", inputImageTensor);
        return inputMap;
    }
    throw new IllegalArgumentException("Unsupported payload type:" + input.getPayload());
}

16 Source : Recognizer.java
with Apache License 2.0
from tahaemara

@Override
public void actionPerformed(ActionEvent e) {
    if (e.getSource() == incep) {
        int returnVal = incepch.showOpenDialog(this);
        if (returnVal == JFileChooser.APPROVE_OPTION) {
            File file = incepch.getSelectedFile();
            modelpath = file.getAbsolutePath();
            modelpth.setText(modelpath);
            System.out.println("Opening: " + file.getAbsolutePath());
            modelselected = true;
            graphDef = readAllBytesOrExit(Paths.get(modelpath, "tensorflow_inception_graph.pb"));
            labels = readAllLinesOrExit(Paths.get(modelpath, "imagenet_comp_graph_label_strings.txt"));
        } else {
            System.out.println("Process was cancelled by user.");
        }
    } else if (e.getSource() == img) {
        int returnVal = imgch.showOpenDialog(Recognizer.this);
        if (returnVal == JFileChooser.APPROVE_OPTION) {
            try {
                File file = imgch.getSelectedFile();
                imagepath = file.getAbsolutePath();
                imgpth.setText(imagepath);
                System.out.println("Image Path: " + imagepath);
                Image img = ImageIO.read(file);
                viewer.setIcon(new ImageIcon(img.getScaledInstance(200, 200, 200)));
                if (modelselected) {
                    predict.setEnabled(true);
                }
            } catch (IOException ex) {
                Logger.getLogger(Recognizer.clreplaced.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            System.out.println("Process was cancelled by user.");
        }
    } else if (e.getSource() == predict) {
        byte[] imageBytes = readAllBytesOrExit(Paths.get(imagepath));
        try (Tensor image = Tensor.create(imageBytes)) {
            float[] labelProbabilities = executeInceptionGraph(graphDef, image);
            int bestLabelIdx = maxIndex(labelProbabilities);
            result.setText("");
            result.setText(String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
            System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
        }
    }
}

16 Source : TensorFlowGraphModelTest.java
with Apache License 2.0
from spotify

@Test
public void testDummyLoadOfTensorFlowGraph() throws Exception {
    final Path graphFile = createADummyTFGraph();
    try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null);
        final Session session = model.instance();
        final Tensor<TFloat64> double3 = TFloat64.scalarOf(3.0D)) {
        List<Tensor<?>> result = null;
        try {
            result = session.runner().fetch(mulResult).feed(inputOpName, double3).run();
            replacedertEquals(result.get(0).data(), NdArrays.scalarOf(6.0D));
        } finally {
            if (result != null) {
                result.forEach(Tensor::close);
            }
        }
    }
}

16 Source : TensorFlowGraphModelTest.java
with Apache License 2.0
from spotify

@Test
public void testDummyLoadOfTensorFlowGraphWithPrefix() throws Exception {
    final String prefix = "test";
    final Path graphFile = createADummyTFGraph();
    try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, prefix);
        final Session session = model.instance();
        final Tensor<TFloat64> double3 = TFloat64.scalarOf(3.0D)) {
        List<Tensor<?>> result = null;
        try {
            result = session.runner().fetch(prefix + "/" + mulResult).feed(prefix + "/" + inputOpName, double3).run();
            replacedertEquals(result.get(0).rawData().asDoubles().getDouble(0), 6.0D, Double.MIN_VALUE);
        } finally {
            if (result != null) {
                result.forEach(Tensor::close);
            }
        }
    }
}

16 Source : TensorFlow.java
with GNU Lesser General Public License v2.1
from SensorsINI

public static float[] executeGraph(Graph graph, Tensor<Float> image, String inputLayerName, String outputLayerName) {
    // try (Graph g=graph) {
    try {
        if (session == null) {
            session = new Session(graph);
        }
        Tensor<Float> result = session.runner().feed(inputLayerName, image).fetch(outputLayerName).run().get(0).expect(Float.clreplaced);
        final long[] rshape = result.shape();
        if (result.numDimensions() != 2 || rshape[0] != 1) {
            throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", Arrays.toString(rshape)));
        }
        int nlabels = (int) rshape[1];
        return result.copyTo(new float[1][nlabels])[0];
    } catch (Exception e) {
        log.log(Level.SEVERE, "Exception running network: " + e.toString(), e.getCause());
        if (session != null) {
            session.close();
        }
        return null;
    }
}

See More Examples