#![feature(portable_simd)]
use std::simd::{LaneCount, Simd, StdFloat, SupportedLaneCount};
// Structure of Arrays (SoA) for points (x, y, z, w)
#[derive(Debug)]
struct Points {
x: Vec<f32>,
y: Vec<f32>,
z: Vec<f32>,
w: Vec<f32>,
}
impl Points {
fn new(n: usize) -> Self {
Self {
x: vec![0.0; n],
y: vec![0.0; n],
z: vec![0.0; n],
w: vec![0.0; n],
}
}
fn len(&self) -> usize {
self.x.len()
}
}
// Apply a 4×4 *row-major* transform to all points in-place:
// [x', y', z', w']^T = M * [x, y, z, w]^T
//
// Works on any SIMD width supported by the platform (e.g., 4, 8, 16 lanes).
fn transform_points<const LANES: usize>(pts: &mut Points, m: [[f32; 4]; 4])
where
LaneCount<LANES>: SupportedLaneCount,
{
let n = pts.len();
assert_eq!(pts.y.len(), n);
assert_eq!(pts.z.len(), n);
assert_eq!(pts.w.len(), n);
// Broadcast matrix rows’ coefficients
let m00 = Simd::<f32, LANES>::splat(m[0][0]);
let m01 = Simd::<f32, LANES>::splat(m[0][1]);
let m02 = Simd::<f32, LANES>::splat(m[0][2]);
let m03 = Simd::<f32, LANES>::splat(m[0][3]);
let m10 = Simd::<f32, LANES>::splat(m[1][0]);
let m11 = Simd::<f32, LANES>::splat(m[1][1]);
let m12 = Simd::<f32, LANES>::splat(m[1][2]);
let m13 = Simd::<f32, LANES>::splat(m[1][3]);
let m20 = Simd::<f32, LANES>::splat(m[2][0]);
let m21 = Simd::<f32, LANES>::splat(m[2][1]);
let m22 = Simd::<f32, LANES>::splat(m[2][2]);
let m23 = Simd::<f32, LANES>::splat(m[2][3]);
let m30 = Simd::<f32, LANES>::splat(m[3][0]);
let m31 = Simd::<f32, LANES>::splat(m[3][1]);
let m32 = Simd::<f32, LANES>::splat(m[3][2]);
let m33 = Simd::<f32, LANES>::splat(m[3][3]);
// Process full SIMD chunks
let simd_chunks = n / LANES;
for chunk in 0..simd_chunks {
let i = chunk * LANES;
// Load SoA lanes
let vx = Simd::<f32, LANES>::from_slice(&pts.x[i..i + LANES]);
let vy = Simd::<f32, LANES>::from_slice(&pts.y[i..i + LANES]);
let vz = Simd::<f32, LANES>::from_slice(&pts.z[i..i + LANES]);
let vw = Simd::<f32, LANES>::from_slice(&pts.w[i..i + LANES]);
// Row 0: x' = m00*x + m01*y + m02*z + m03*w
let x_prime = vx.mul_add(m00, vy.mul_add(m01, vz.mul_add(m02, vw * m03)));
// Row 1: y'
let y_prime = vx.mul_add(m10, vy.mul_add(m11, vz.mul_add(m12, vw * m13)));
// Row 2: z'
let z_prime = vx.mul_add(m20, vy.mul_add(m21, vz.mul_add(m22, vw * m23)));
// Row 3: w'
let w_prime = vx.mul_add(m30, vy.mul_add(m31, vz.mul_add(m32, vw * m33)));
// Store back
x_prime.copy_to_slice(&mut pts.x[i..i + LANES]);
y_prime.copy_to_slice(&mut pts.y[i..i + LANES]);
z_prime.copy_to_slice(&mut pts.z[i..i + LANES]);
w_prime.copy_to_slice(&mut pts.w[i..i + LANES]);
}
// Scalar tail (if n is not a multiple of LANES)
let tail_start = simd_chunks * LANES;
for i in tail_start..n {
let x = pts.x[i];
let y = pts.y[i];
let z = pts.z[i];
let w = pts.w[i];
let x_p = m[0][0] * x + m[0][1] * y + m[0][2] * z + m[0][3] * w;
let y_p = m[1][0] * x + m[1][1] * y + m[1][2] * z + m[1][3] * w;
let z_p = m[2][0] * x + m[2][1] * y + m[2][2] * z + m[2][3] * w;
let w_p = m[3][0] * x + m[3][1] * y + m[3][2] * z + m[3][3] * w;
pts.x[i] = x_p;
pts.y[i] = y_p;
pts.z[i] = z_p;
pts.w[i] = w_p;
}
}
fn main() {
// Choose a SIMD width supported on your target. 8 or 16 is typically a good default on desktops.
const LANES: usize = 16;
// Make some points
let n = 1000;
let mut pts = Points::new(n);
for i in 0..n {
pts.x[i] = i as f32;
pts.y[i] = i as f32 + 1.0;
pts.z[i] = i as f32 + 2.0;
pts.w[i] = 1.0;
}
// Example matrix (identity)
let m = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
transform_points::<LANES>(&mut pts, m);
// Print a couple of results
for i in 0..3 {
println!("({:.1}, {:.1}, {:.1}, {:.1})", pts.x[i], pts.y[i], pts.z[i], pts.w[i]);
}
}