Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions examples/webgpu_compute_cloth.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import * as THREE from 'three/webgpu';

import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, uint, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';
import { Fn, If, Return, instancedArray, instanceIndex, uniform, select, attribute, Loop, float, transformNormalToView, cross, triNoise3D, time } from 'three/tsl';

import { Inspector } from 'three/addons/inspector/Inspector.js';

Expand Down Expand Up @@ -307,14 +307,6 @@
// This shader computes a force for each spring, depending on the distance between the two vertices connected by that spring and the targeted rest length
computeSpringForces = Fn( () => {

If( instanceIndex.greaterThanEqual( uint( springCount ) ), () => {

// compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of springs.
// in that case, return.
Return();

} );

const vertexIds = springVertexIdBuffer.element( instanceIndex );
const restLength = springRestLengthBuffer.element( instanceIndex );

Expand All @@ -335,14 +327,6 @@
// In the end it adds the force to the vertex' position.
computeVertexForces = Fn( () => {

If( instanceIndex.greaterThanEqual( uint( vertexCount ) ), () => {

// compute Shaders are executed in groups of 64, so instanceIndex might be bigger than the amount of vertices.
// in that case, return.
Return();

} );

const params = vertexParamsBuffer.element( instanceIndex ).toVar();
const isFixed = params.x;
const springCount = params.y;
Expand Down
31 changes: 4 additions & 27 deletions examples/webgpu_compute_particles_fluid.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import * as THREE from 'three/webgpu';

import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, uint, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';
import { Fn, If, Return, instancedArray, instanceIndex, uniform, attribute, float, clamp, struct, atomicStore, int, ivec3, array, vec3, atomicAdd, Loop, atomicLoad, max, pow, mat3, vec4, cross, step, storage } from 'three/tsl';

import { Inspector } from 'three/addons/inspector/Inspector.js';

Expand Down Expand Up @@ -132,6 +132,9 @@
gui.add( params, 'particleCount', 4096, maxParticles, 4096 ).onChange( value => {

particleMesh.count = value;
p2g1Kernel.count = value;
p2g2Kernel.count = value;
g2pKernel.count = value;
particleCountUniform.value = value;

} );
Expand Down Expand Up @@ -219,12 +222,6 @@
const cellCount = gridSize.x * gridSize.y * gridSize.z;
clearGridKernel = Fn( () => {

If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {

Return();

} );

atomicStore( cellBuffer.element( instanceIndex ).get( 'x' ), 0 );
atomicStore( cellBuffer.element( instanceIndex ).get( 'y' ), 0 );
atomicStore( cellBuffer.element( instanceIndex ).get( 'z' ), 0 );
Expand All @@ -234,11 +231,6 @@

p2g1Kernel = Fn( () => {

If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {

Return();

} );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
const particleVelocity = particleBuffer.element( instanceIndex ).get( 'velocity' ).toConst( 'particleVelocity' );
const C = particleBuffer.element( instanceIndex ).get( 'C' ).toConst( 'C' );
Expand Down Expand Up @@ -282,11 +274,6 @@

p2g2Kernel = Fn( () => {

If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {

Return();

} );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toConst( 'particlePosition' );
const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();

Expand Down Expand Up @@ -353,11 +340,6 @@

updateGridKernel = Fn( () => {

If( instanceIndex.greaterThanEqual( uint( cellCount ) ), () => {

Return();

} );
const cell = cellBuffer.element( instanceIndex );
const mass = decodeFixedPoint( atomicLoad( cell.get( 'mass' ) ) ).toConst();
If( mass.lessThanEqual( 0 ), () => {
Expand Down Expand Up @@ -412,11 +394,6 @@

g2pKernel = Fn( () => {

If( instanceIndex.greaterThanEqual( particleCountUniform ), () => {

Return();

} );
const particlePosition = particleBuffer.element( instanceIndex ).get( 'position' ).toVar( 'particlePosition' );
const gridPosition = particlePosition.mul( gridSizeUniform ).toVar();
const particleVelocity = vec3( 0 ).toVar();
Expand Down
3 changes: 2 additions & 1 deletion src/nodes/core/IndexNode.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import Node from './Node.js';
import { nodeImmutable, varying } from '../tsl/TSLBase.js';
import { nodeImmutable } from '../tsl/TSLCore.js';
import { varying } from './VaryingNode.js';

/**
* This class represents shader indices of different types. The following predefined node
Expand Down
4 changes: 4 additions & 0 deletions src/nodes/gpgpu/BarrierNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ class BarrierNode extends Node {

this.scope = scope;

this.isBarrierNode = true;

}

generate( builder ) {

const { scope } = this;
const { renderer } = builder;

builder.allowEarlyReturns = false;

if ( renderer.backend.isWebGLBackend === true ) {

builder.addFlowCode( `\t// ${scope}Barrier \n` );
Expand Down
113 changes: 69 additions & 44 deletions src/nodes/gpgpu/ComputeNode.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import Node from '../core/Node.js';
import { instanceIndex } from '../core/IndexNode.js';
import StackTrace from '../core/StackTrace.js';
import { uniform } from '../core/UniformNode.js';
import { NodeUpdateType } from '../core/constants.js';
import { addMethodChaining, nodeObject } from '../tsl/TSLCore.js';
import { warn, error } from '../../utils.js';

/**
* TODO
* Represents a compute shader node.
*
* @augments Node
*/
Expand All @@ -20,8 +22,8 @@ class ComputeNode extends Node {
/**
* Constructs a new compute node.
*
* @param {Node} computeNode - TODO
* @param {Array<number>} workgroupSize - TODO.
* @param {Node} computeNode - The node that defines the compute shader logic.
* @param {Array<number>} workgroupSize - An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
*/
constructor( computeNode, workgroupSize ) {

Expand All @@ -37,30 +39,38 @@ class ComputeNode extends Node {
this.isComputeNode = true;

/**
* TODO
* The node that defines the compute shader logic.
*
* @type {Node}
*/
this.computeNode = computeNode;


/**
* TODO
* An array defining the X, Y, and Z dimensions of the workgroup for compute shader execution.
*
* @type {Array<number>}
* @default [ 64 ]
*/
this.workgroupSize = workgroupSize;

/**
* TODO
* The total number of threads (invocations) to execute. If it is a number, it will be used
* to automatically generate bounds checking against `instanceIndex`.
*
* @type {number|Array<number>}
*/
this.count = null;

/**
* TODO
* The dispatch size for workgroups on X, Y, and Z axes.
* Used directly if `count` is not provided.
*
* @type {number|Array<number>}
*/
this.dispatchSize = null;

/**
* The version of the node.
*
* @type {number}
*/
Expand All @@ -84,36 +94,19 @@ class ComputeNode extends Node {
this.updateBeforeType = NodeUpdateType.OBJECT;

/**
* TODO
* A callback executed when the compute node finishes initialization.
*
* @type {?Function}
*/
this.onInitFunction = null;

}

/**
* TODO
*
* @param {number|Array<number>} count - Array with [ x, y, z ] values for dispatch or a single number for the count
* @return {ComputeNode}
*/
setCount( count ) {

this.count = count;

return this;

}

/**
* TODO
*
* @return {number|Array<number>}
*/
getCount() {

return this.count;
/**
* A uniform node holding the dispatch count for bounds checking.
* Created automatically when `count` is a number.
*
* @type {?UniformNode}
*/
this.countNode = null;

}

Expand Down Expand Up @@ -156,9 +149,9 @@ class ComputeNode extends Node {
}

/**
* TODO
* Sets the callback to run during initialization.
*
* @param {Function} callback - TODO.
* @param {Function} callback - The callback function.
* @return {ComputeNode} A reference to this node.
*/
onInit( callback ) {
Expand All @@ -182,6 +175,12 @@ class ComputeNode extends Node {

setup( builder ) {

if ( this.count !== null && this.countNode === null ) {

this.countNode = uniform( this.count, 'uint' ).onObjectUpdate( () => this.count );

}

const result = this.computeNode.build( builder );

if ( result ) {
Expand Down Expand Up @@ -211,6 +210,16 @@ class ComputeNode extends Node {

}

if ( this.count !== null && builder.allowEarlyReturns === true ) {

const countSnippet = this.countNode.build( builder, 'uint' );
const indexSnippet = instanceIndex.build( builder, 'uint' );

builder.flow.code = `${ builder.tab }if ( ${ indexSnippet } >= ${ countSnippet } ) { return; }\n\n${ builder.flow.code }`;

}


} else {

const properties = builder.getNodeProperties( this );
Expand All @@ -235,9 +244,9 @@ export default ComputeNode;
*
* @tsl
* @function
* @param {Node} node - TODO
* @param {Array<number>} [workgroupSize=[64]] - TODO.
* @returns {AtomicFunctionNode}
* @param {Node} node - The TSL logic for the compute shader.
* @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
* @returns {ComputeNode}
*/
export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {

Expand Down Expand Up @@ -274,12 +283,28 @@ export const computeKernel = ( node, workgroupSize = [ 64 ] ) => {
*
* @tsl
* @function
* @param {Node} node - TODO
* @param {number|Array<number>} count - TODO.
* @param {Array<number>} [workgroupSize=[64]] - TODO.
* @returns {AtomicFunctionNode}
*/
export const compute = ( node, count, workgroupSize ) => computeKernel( node, workgroupSize ).setCount( count );
* @param {Node} node - The TSL logic for the compute shader.
* @param {number|Array<number>} count - The compute count or dispatch size.
* @param {Array<number>} [workgroupSize=[64]] - The workgroup size.
* @returns {ComputeNode}
, */
export const compute = ( node, count, workgroupSize ) => {

const computeNode = computeKernel( node, workgroupSize );

if ( typeof count === 'number' ) {

computeNode.count = count;

} else {

computeNode.dispatchSize = count;

}

return computeNode;

};

addMethodChaining( 'compute', compute );
addMethodChaining( 'computeKernel', computeKernel );
2 changes: 1 addition & 1 deletion src/renderers/common/ComputePipeline.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Pipeline from './Pipeline.js';
class ComputePipeline extends Pipeline {

/**
* Constructs a new render pipeline.
* Constructs a new compute pipeline.
*
* @param {string} cacheKey - The pipeline's cache key.
* @param {ProgrammableStage} computeProgram - The pipeline's compute shader.
Expand Down
4 changes: 2 additions & 2 deletions src/renderers/webgpu/WebGPUBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -1406,13 +1406,13 @@ class WebGPUBackend extends Backend {

if ( dispatchSize === null ) {

dispatchSize = computeNode.count;
dispatchSize = computeNode.dispatchSize || computeNode.count;

}

// When the dispatchSize is set with a StorageBuffer from the GPU.

if ( dispatchSize && typeof dispatchSize === 'object' && dispatchSize.isIndirectStorageBufferAttribute ) {
if ( dispatchSize && dispatchSize.isIndirectStorageBufferAttribute ) {

const dispatchBuffer = this.get( dispatchSize ).buffer;

Expand Down
8 changes: 8 additions & 0 deletions src/renderers/webgpu/nodes/WGSLNodeBuilder.js
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ class WGSLNodeBuilder extends NodeBuilder {
*/
this.scopedArrays = new Map();

/**
* A flag that indicates that early returns are allowed.
*
* @type {boolean}
* @default true
*/
this.allowEarlyReturns = true;

}

/**
Expand Down