Skip to content

Commit 84d6558

Browse files
authored
Add Cuda support (#108)
1 parent 13e7e0f commit 84d6558

File tree

17 files changed

+590
-142
lines changed

17 files changed

+590
-142
lines changed

Cargo.lock

Lines changed: 47 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ default = ["wayland"]
1212
wayland = ["processing_render/wayland"]
1313
x11 = ["processing_render/x11"]
1414
webcam = ["dep:processing_webcam"]
15+
cuda = ["dep:processing_cuda"]
1516

1617
[workspace]
1718
resolver = "3"
@@ -24,12 +25,14 @@ too_many_arguments = "allow"
2425
[workspace.dependencies]
2526
bevy = { git = "https://github.com/bevyengine/bevy", branch = "main", features = ["file_watcher", "shader_format_wesl", "free_camera", "pan_camera"] }
2627
bevy_naga_reflect = { git = "https://github.com/tychedelia/bevy_naga_reflect" }
28+
bevy_cuda = { git = "https://github.com/tychedelia/bevy_cuda" }
2729
naga = { version = "29", features = ["wgsl-in"] }
2830
wesl = { version = "0.3", default-features = false }
2931
pyo3 = { git = "https://github.com/PyO3/pyo3", branch = "main" }
3032
pyo3-introspection = { git = "https://github.com/PyO3/pyo3", branch = "main" }
3133
processing = { path = "." }
3234
processing_core = { path = "crates/processing_core" }
35+
processing_cuda = { path = "crates/processing_cuda" }
3336
processing_pyo3 = { path = "crates/processing_pyo3" }
3437
processing_render = { path = "crates/processing_render" }
3538
processing_midi = { path = "crates/processing_midi" }
@@ -44,6 +47,7 @@ processing_render = { workspace = true }
4447
processing_midi = { workspace = true }
4548
processing_input = { workspace = true }
4649
processing_webcam = { workspace = true, optional = true }
50+
processing_cuda = { workspace = true, optional = true }
4751
tracing = "0.1"
4852
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
4953

crates/processing_core/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,6 @@ pub enum ProcessingError {
4444
ShaderNotFound,
4545
#[error("MIDI port {0} not found")]
4646
MidiPortNotFound(usize),
47+
#[error("CUDA error: {0}")]
48+
CudaError(String),
4749
}

crates/processing_cuda/Cargo.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[package]
2+
name = "processing_cuda"
3+
version = "0.1.0"
4+
edition = "2024"
5+
6+
[lints]
7+
workspace = true
8+
9+
[features]
10+
default = ["cuda-11040"]
11+
cuda-11040 = ["bevy_cuda/cuda-11040"]
12+
13+
[dependencies]
14+
bevy = { workspace = true }
15+
bevy_cuda = { workspace = true }
16+
processing_core = { workspace = true }
17+
processing_render = { workspace = true }

crates/processing_cuda/src/lib.rs

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
use bevy::prelude::*;
2+
use bevy::render::RenderApp;
3+
use bevy::render::render_resource::{Texture, TextureFormat};
4+
use bevy::render::renderer::RenderDevice;
5+
use bevy_cuda::{CudaBuffer, CudaContext};
6+
use processing_core::app_mut;
7+
use processing_core::error::{ProcessingError, Result};
8+
use processing_render::graphics::view_target;
9+
use processing_render::image::{Image, gpu_image, pixel_size};
10+
11+
#[derive(Component)]
12+
pub struct CudaImageBuffer {
13+
pub buffer: CudaBuffer,
14+
pub width: u32,
15+
pub height: u32,
16+
pub texture_format: TextureFormat,
17+
}
18+
19+
pub struct CudaPlugin;
20+
21+
impl Plugin for CudaPlugin {
22+
fn build(&self, _app: &mut App) {}
23+
24+
fn finish(&self, app: &mut App) {
25+
let render_app = app.sub_app(RenderApp);
26+
let render_device = render_app.world().resource::<RenderDevice>();
27+
let wgpu_device = render_device.wgpu_device();
28+
match CudaContext::new(wgpu_device, 0) {
29+
Ok(ctx) => {
30+
app.insert_resource(ctx);
31+
}
32+
Err(e) => {
33+
warn!("CUDA not available, GPU interop disabled: {e}");
34+
}
35+
}
36+
}
37+
}
38+
39+
fn cuda_ctx(world: &World) -> Result<&CudaContext> {
40+
world
41+
.get_resource::<CudaContext>()
42+
.ok_or(ProcessingError::CudaError("CUDA not available".into()))
43+
}
44+
45+
fn resolve_texture(app: &mut App, entity: Entity) -> Result<(Texture, TextureFormat, u32, u32)> {
46+
if app.world().get::<Image>(entity).is_some() {
47+
let texture = gpu_image(app, entity)?.texture.clone();
48+
let p_image = app.world().get::<Image>(entity).unwrap();
49+
return Ok((
50+
texture,
51+
p_image.texture_format,
52+
p_image.size.width,
53+
p_image.size.height,
54+
));
55+
}
56+
if let Ok(vt) = view_target(app, entity) {
57+
let texture = vt.main_texture().clone();
58+
let fmt = vt.main_texture_format();
59+
let size = texture.size();
60+
return Ok((texture, fmt, size.width, size.height));
61+
}
62+
Err(ProcessingError::ImageNotFound)
63+
}
64+
65+
pub fn cuda_export(entity: Entity) -> Result<()> {
66+
app_mut(|app| {
67+
let (texture, texture_format, width, height) = resolve_texture(app, entity)?;
68+
69+
let px_size = pixel_size(texture_format)?;
70+
let buffer_size = (width as u64) * (height as u64) * (px_size as u64);
71+
72+
let existing = app.world().get::<CudaImageBuffer>(entity);
73+
let needs_alloc = existing.is_none_or(|buf| buf.buffer.size() != buffer_size);
74+
75+
if needs_alloc {
76+
let cuda_ctx = cuda_ctx(app.world())?;
77+
let buffer = cuda_ctx
78+
.create_buffer(buffer_size)
79+
.map_err(|e| ProcessingError::CudaError(format!("Buffer creation failed: {e}")))?;
80+
app.world_mut().entity_mut(entity).insert(CudaImageBuffer {
81+
buffer,
82+
width,
83+
height,
84+
texture_format,
85+
});
86+
}
87+
88+
let world = app.world();
89+
let cuda_buf = world.get::<CudaImageBuffer>(entity).unwrap();
90+
let cuda_ctx = cuda_ctx(world)?;
91+
92+
cuda_ctx
93+
.copy_texture_to_buffer(&texture, &cuda_buf.buffer, width, height, texture_format)
94+
.map_err(|e| {
95+
ProcessingError::CudaError(format!("Texture-to-buffer copy failed: {e}"))
96+
})?;
97+
98+
Ok(())
99+
})
100+
}
101+
102+
pub fn cuda_import(entity: Entity, src_device_ptr: u64, byte_size: u64) -> Result<()> {
103+
app_mut(|app| {
104+
let (texture, texture_format, width, height) = resolve_texture(app, entity)?;
105+
106+
let existing = app.world().get::<CudaImageBuffer>(entity);
107+
let needs_alloc = existing.is_none_or(|buf| buf.buffer.size() != byte_size);
108+
109+
if needs_alloc {
110+
let cuda_ctx = cuda_ctx(app.world())?;
111+
let buffer = cuda_ctx
112+
.create_buffer(byte_size)
113+
.map_err(|e| ProcessingError::CudaError(format!("Buffer creation failed: {e}")))?;
114+
app.world_mut().entity_mut(entity).insert(CudaImageBuffer {
115+
buffer,
116+
width,
117+
height,
118+
texture_format,
119+
});
120+
}
121+
122+
let world = app.world();
123+
let cuda_buf = world.get::<CudaImageBuffer>(entity).unwrap();
124+
let cuda_ctx = cuda_ctx(world)?;
125+
126+
// wait for work (i.e. python) to be done with the buffer before we read from it
127+
cuda_ctx
128+
.synchronize()
129+
.map_err(|e| ProcessingError::CudaError(format!("synchronize failed: {e}")))?;
130+
131+
cuda_buf
132+
.buffer
133+
.copy_from_device_ptr(src_device_ptr, byte_size)
134+
.map_err(|e| ProcessingError::CudaError(format!("memcpy_dtod failed: {e}")))?;
135+
136+
cuda_ctx
137+
.copy_buffer_to_texture(&cuda_buf.buffer, &texture, width, height, texture_format)
138+
.map_err(|e| {
139+
ProcessingError::CudaError(format!("Buffer-to-texture copy failed: {e}"))
140+
})?;
141+
142+
Ok(())
143+
})
144+
}
145+
146+
pub fn cuda_write_back(entity: Entity) -> Result<()> {
147+
app_mut(|app| {
148+
let (texture, _, _, _) = resolve_texture(app, entity)?;
149+
150+
let cuda_buf = app
151+
.world()
152+
.get::<CudaImageBuffer>(entity)
153+
.ok_or(ProcessingError::ImageNotFound)?;
154+
155+
let cuda_ctx = cuda_ctx(app.world())?;
156+
157+
cuda_ctx
158+
.copy_buffer_to_texture(
159+
&cuda_buf.buffer,
160+
&texture,
161+
cuda_buf.width,
162+
cuda_buf.height,
163+
cuda_buf.texture_format,
164+
)
165+
.map_err(|e| {
166+
ProcessingError::CudaError(format!("Buffer-to-texture copy failed: {e}"))
167+
})?;
168+
169+
Ok(())
170+
})
171+
}
172+
173+
pub struct CudaBufferInfo {
174+
pub device_ptr: u64,
175+
pub width: u32,
176+
pub height: u32,
177+
pub texture_format: TextureFormat,
178+
}
179+
180+
pub fn cuda_buffer(entity: Entity) -> Result<CudaBufferInfo> {
181+
app_mut(|app| {
182+
let cuda_buf = app
183+
.world()
184+
.get::<CudaImageBuffer>(entity)
185+
.ok_or(ProcessingError::ImageNotFound)?;
186+
Ok(CudaBufferInfo {
187+
device_ptr: cuda_buf.buffer.device_ptr(),
188+
width: cuda_buf.width,
189+
height: cuda_buf.height,
190+
texture_format: cuda_buf.texture_format,
191+
})
192+
})
193+
}
194+
195+
pub fn typestr_for_format(format: TextureFormat) -> Result<&'static str> {
196+
match format {
197+
TextureFormat::Rgba8Unorm | TextureFormat::Rgba8UnormSrgb => Ok("|u1"),
198+
TextureFormat::Rgba16Float => Ok("<f2"),
199+
TextureFormat::Rgba32Float => Ok("<f4"),
200+
_ => Err(ProcessingError::UnsupportedTextureFormat),
201+
}
202+
}
203+
204+
pub fn elem_size_for_typestr(typestr: &str) -> Result<usize> {
205+
match typestr {
206+
"|u1" => Ok(1),
207+
"<f2" => Ok(2),
208+
"<f4" => Ok(4),
209+
_ => Err(ProcessingError::CudaError(format!(
210+
"unsupported typestr: {typestr}"
211+
))),
212+
}
213+
}

crates/processing_pyo3/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ wayland = ["processing/wayland", "processing_glfw/wayland"]
1616
static-link = ["processing_glfw/static-link"]
1717
x11 = ["processing/x11"]
1818
webcam = ["processing/webcam", "dep:processing_webcam"]
19+
cuda = ["dep:processing_cuda", "processing/cuda"]
1920

2021
[dependencies]
21-
pyo3 = { workspace = true, features = ["experimental-inspect"] }
22+
pyo3 = { workspace = true, features = ["experimental-inspect", "multiple-pymethods"] }
2223
processing = { workspace = true }
2324
processing_webcam = { workspace = true, optional = true }
2425
processing_glfw = { workspace = true }
2526
bevy = { workspace = true, features = ["file_watcher"] }
2627
png = "0.18"
28+
processing_cuda = { workspace = true, optional = true }

0 commit comments

Comments
 (0)