feat: switch model to ade20k

This commit is contained in:
max_richter 2021-03-10 15:57:27 +01:00
parent 2a94207c73
commit 41243299af
4 changed files with 25 additions and 11 deletions

View File

@ -1,18 +1,14 @@
import '@tensorflow/tfjs-backend-webgl';
import "@tensorflow/tfjs-backend-cpu"
import * as tfconv from '@tensorflow/tfjs-converter';
import * as deeplab from '@tensorflow-models/deeplab';
import { getLabels, getColormap, getURL, SemanticSegmentation, toSegmentationImage } from '@tensorflow-models/deeplab';
import { loadGraphModel } from '@tensorflow/tfjs-converter';
import { getURL, SemanticSegmentation } from '@tensorflow-models/deeplab';
const base = 'cityscapes'; // set to your preferred model, out of `pascal`,
const base = 'ade20k'; // set to your preferred model, out of `pascal`,`cityscapes` and `ade20k`
const createModel = async () => {
// `cityscapes` and `ade20k`
const quantizationBytes = 2; // either 1, 2 or 4
// use the getURL utility function to get the URL to the pre-trained weights
const modelUrl = getURL(base, quantizationBytes);
const rawModel = await tfconv.loadGraphModel(modelUrl);
const modelName = 'pascal'; // set to your preferred model, out of `pascal`, `cityscapes` and `ade20k`
return new SemanticSegmentation(rawModel, modelName);
const rawModel = await loadGraphModel(modelUrl);
return new SemanticSegmentation(rawModel, base);
};
const model = createModel();

View File

@ -1,4 +1,4 @@
import { Renderer, Camera, Transform, Texture, Sphere, Program, Mesh } from "ogl"
import { Renderer, Camera, Transform, Texture, Raycast, Sphere, Program, Mesh } from "ogl"
import { Orbit } from "./CustomOrbit"
import { bufToImageUrl } from "../../helpers";
@ -87,6 +87,7 @@ export default (image: Image, canvas2D: CanvasRenderingContext2D, canvas3D: HTML
overlay.scale.x = -9.995;
overlay.setParent(scene);
const raycast = new Raycast();
let activeTool = "pan";
let shouldBeRendering = false;
@ -99,6 +100,10 @@ export default (image: Image, canvas2D: CanvasRenderingContext2D, canvas3D: HTML
renderer.render({ scene, camera });
}
function click() {
const res = raycast.intersectSphere(skybox);
console.log(res);
}
function start() {
!hasConfirmed && Toast.confirm("Drawing does not work correctly in 3D at the moment").then(c => {
@ -132,7 +137,7 @@ export default (image: Image, canvas2D: CanvasRenderingContext2D, canvas3D: HTML
}
return {
start, stop, setTool, updateOverlay, setOpacity
start, stop, setTool, updateOverlay, setOpacity, click
}
}

View File

@ -256,6 +256,7 @@
}
function handleResize() {
if (!wrapper) return;
const box = wrapper.getBoundingClientRect();
topLeftX = Math.floor(box.x);
topLeftY = Math.floor(box.y);

12
view/src/ogl.d.ts vendored
View File

@ -396,6 +396,18 @@ declare module 'ogl' {
remove: () => void;
}
export class Raycast {
constructor();
castMouse(camera: Camera, mouse: Vec2);
intersectBounds(meshes: Mesh[], { maxDistance, output = [] } = {})
intersectMeshes(meshes: Mesh[], { cullFace = true, maxDistance, includeUV = true, includeNormal = true, output = [] } = {})
intersectSphere(sphere, origin = this.origin, direction = this.direction)
intersectBox(box, origin = this.origin, direction = this.direction)
intersectTriangle(a, b, c, backfaceCulling = true, origin = this.origin, direction = this.direction, normal = tempVec3g)
getBarycoord(point, a, b, c, target = tempVec3h)
}
export interface MeshOptions {
mode?: number;
geometry: Geometry;