import jdk.incubator.vector.{ShortVector, VectorMask, VectorOperators, ByteVector => BV}
import java.nio.file.{Files, Path}
import scala.concurrent.duration.DurationLong
import scala.util.chaining.scalaUtilChainingOps
object Day17:
val input = Files.readString(Path.of("17.txt")).split('\n')
extension (left: BV)
def + (right: BV) = left.add(right)
val spec = BV.SPECIES_PREFERRED
val inputHeight = input.length
val inputWidth = input.head.length
val cycles = 6
val xDim = (cycles * 2 + inputWidth + spec.length() - 1) / spec.length() * spec.length() + 2
val yDim = cycles * 2 + 2 + inputHeight
val zDim = cycles * 2 + 2 + 1
val wDim = cycles * 2 + 2 + 1
def populate1(bs: Array[Byte]): Unit =
for {
y <- 0 until inputHeight
x <- 0 until inputWidth
if input(y)(x) == '#'
} bs((cycles + 1) * xDim * yDim + (cycles + 1 + y) * xDim + (cycles + 1) + x) = 1
def populate2(bs: Array[Byte]): Unit =
for {
y <- 0 until inputHeight
x <- 0 until inputWidth
if input(y)(x) == '#'
} bs((cycles + 1) * zDim * yDim * xDim + (cycles + 1) * xDim * yDim + (cycles + 1 + y) * xDim + (cycles + 1) + x) = 1
def loop1(board: Array[Byte], board2: Array[Byte], cycle: Int = 0): Unit =
val start = cycles - cycle
var z = start
while (z < zDim - start) do
val z1 = (z - 1) * xDim * yDim
val z2 = z * xDim * yDim
val z3 = (z + 1) * xDim * yDim
var y = start
while (y < yDim - start) do
val y1 = (y - 1) * xDim
val y2 = y * xDim
val y3 = (y + 1) * xDim
var x = 1
while x < xDim - spec.length() do
val x1 = (x - 1)
val x2 = x
val x3 = (x + 1)
val sum =
BV.fromArray(spec, board, z1 + y1 + x1) +
BV.fromArray(spec, board, z1 + y1 + x2) +
BV.fromArray(spec, board, z1 + y1 + x3) +
BV.fromArray(spec, board, z1 + y2 + x1) +
BV.fromArray(spec, board, z1 + y2 + x2) +
BV.fromArray(spec, board, z1 + y2 + x3) +
BV.fromArray(spec, board, z1 + y3 + x1) +
BV.fromArray(spec, board, z1 + y3 + x2) +
BV.fromArray(spec, board, z1 + y3 + x3) +
BV.fromArray(spec, board, z2 + y1 + x1) +
BV.fromArray(spec, board, z2 + y1 + x2) +
BV.fromArray(spec, board, z2 + y1 + x3) +
BV.fromArray(spec, board, z2 + y2 + x1) + //
BV.fromArray(spec, board, z2 + y2 + x3) + //
BV.fromArray(spec, board, z2 + y3 + x1) +
BV.fromArray(spec, board, z2 + y3 + x2) +
BV.fromArray(spec, board, z2 + y3 + x3) +
BV.fromArray(spec, board, z3 + y1 + x1) +
BV.fromArray(spec, board, z3 + y1 + x2) +
BV.fromArray(spec, board, z3 + y1 + x3) +
BV.fromArray(spec, board, z3 + y2 + x1) +
BV.fromArray(spec, board, z3 + y2 + x2) +
BV.fromArray(spec, board, z3 + y2 + x3) +
BV.fromArray(spec, board, z3 + y3 + x1) +
BV.fromArray(spec, board, z3 + y3 + x2) +
BV.fromArray(spec, board, z3 + y3 + x3)
val act = BV.fromArray(spec, board, z2 + y2 + x2).eq(1.toByte)
val m3 = sum.eq(3.toByte)
val m23 = sum.eq(2.toByte).or(m3)
val next = BV.broadcast(spec, 0).blend(BV.broadcast(spec, 1), (m23.and(act)).or(act.not.and(m3)))
next.intoArray(board2, z2 + y2 + x)
x += spec.length()
y += 1
z += 1
if cycle < (cycles - 1) then loop1(board2, board, cycle + 1)
def loop2(board: Array[Byte], board2: Array[Byte], cycle: Int = 0): Unit =
val start = cycles - cycle
var w = start
while (w < wDim - start) do
val w1 = (w - 1) * xDim * yDim * zDim
val w2 = w * xDim * yDim * zDim
val w3 = (w + 1) * xDim * yDim * zDim
var z = start
while (z < zDim - start) do
val z1 = (z - 1) * xDim * yDim
val z2 = z * xDim * yDim
val z3 = (z + 1) * xDim * yDim
var y = start
while (y < yDim - start) do
val y1 = (y - 1) * xDim
val y2 = y * xDim
val y3 = (y + 1) * xDim
var x = 1
while x < xDim - spec.length() do
val x1 = (x - 1)
val x2 = x
val x3 = (x + 1)
val sum =
BV.fromArray(spec, board, w1 + z1 + y1 + x1) +
BV.fromArray(spec, board, w1 + z1 + y1 + x2) +
BV.fromArray(spec, board, w1 + z1 + y1 + x3) +
BV.fromArray(spec, board, w1 + z1 + y2 + x1) +
BV.fromArray(spec, board, w1 + z1 + y2 + x2) +
BV.fromArray(spec, board, w1 + z1 + y2 + x3) +
BV.fromArray(spec, board, w1 + z1 + y3 + x1) +
BV.fromArray(spec, board, w1 + z1 + y3 + x2) +
BV.fromArray(spec, board, w1 + z1 + y3 + x3) +
BV.fromArray(spec, board, w1 + z2 + y1 + x1) +
BV.fromArray(spec, board, w1 + z2 + y1 + x2) +
BV.fromArray(spec, board, w1 + z2 + y1 + x3) +
BV.fromArray(spec, board, w1 + z2 + y2 + x1) +
BV.fromArray(spec, board, w1 + z2 + y2 + x2) +
BV.fromArray(spec, board, w1 + z2 + y2 + x3) +
BV.fromArray(spec, board, w1 + z2 + y3 + x1) +
BV.fromArray(spec, board, w1 + z2 + y3 + x2) +
BV.fromArray(spec, board, w1 + z2 + y3 + x3) +
BV.fromArray(spec, board, w1 + z3 + y1 + x1) +
BV.fromArray(spec, board, w1 + z3 + y1 + x2) +
BV.fromArray(spec, board, w1 + z3 + y1 + x3) +
BV.fromArray(spec, board, w1 + z3 + y2 + x1) +
BV.fromArray(spec, board, w1 + z3 + y2 + x2) +
BV.fromArray(spec, board, w1 + z3 + y2 + x3) +
BV.fromArray(spec, board, w1 + z3 + y3 + x1) +
BV.fromArray(spec, board, w1 + z3 + y3 + x2) +
BV.fromArray(spec, board, w1 + z3 + y3 + x3) +
BV.fromArray(spec, board, w2 + z1 + y1 + x1) + //
BV.fromArray(spec, board, w2 + z1 + y1 + x2) +
BV.fromArray(spec, board, w2 + z1 + y1 + x3) +
BV.fromArray(spec, board, w2 + z1 + y2 + x1) +
BV.fromArray(spec, board, w2 + z1 + y2 + x2) +
BV.fromArray(spec, board, w2 + z1 + y2 + x3) +
BV.fromArray(spec, board, w2 + z1 + y3 + x1) +
BV.fromArray(spec, board, w2 + z1 + y3 + x2) +
BV.fromArray(spec, board, w2 + z1 + y3 + x3) +
BV.fromArray(spec, board, w2 + z2 + y1 + x1) +
BV.fromArray(spec, board, w2 + z2 + y1 + x2) +
BV.fromArray(spec, board, w2 + z2 + y1 + x3) +
BV.fromArray(spec, board, w2 + z2 + y2 + x1) + //
BV.fromArray(spec, board, w2 + z2 + y2 + x3) + //
BV.fromArray(spec, board, w2 + z2 + y3 + x1) +
BV.fromArray(spec, board, w2 + z2 + y3 + x2) +
BV.fromArray(spec, board, w2 + z2 + y3 + x3) +
BV.fromArray(spec, board, w2 + z3 + y1 + x1) +
BV.fromArray(spec, board, w2 + z3 + y1 + x2) +
BV.fromArray(spec, board, w2 + z3 + y1 + x3) +
BV.fromArray(spec, board, w2 + z3 + y2 + x1) +
BV.fromArray(spec, board, w2 + z3 + y2 + x2) +
BV.fromArray(spec, board, w2 + z3 + y2 + x3) +
BV.fromArray(spec, board, w2 + z3 + y3 + x1) +
BV.fromArray(spec, board, w2 + z3 + y3 + x2) +
BV.fromArray(spec, board, w2 + z3 + y3 + x3) +
BV.fromArray(spec, board, w3 + z1 + y1 + x1) + //
BV.fromArray(spec, board, w3 + z1 + y1 + x2) +
BV.fromArray(spec, board, w3 + z1 + y1 + x3) +
BV.fromArray(spec, board, w3 + z1 + y2 + x1) +
BV.fromArray(spec, board, w3 + z1 + y2 + x2) +
BV.fromArray(spec, board, w3 + z1 + y2 + x3) +
BV.fromArray(spec, board, w3 + z1 + y3 + x1) +
BV.fromArray(spec, board, w3 + z1 + y3 + x2) +
BV.fromArray(spec, board, w3 + z1 + y3 + x3) +
BV.fromArray(spec, board, w3 + z2 + y1 + x1) +
BV.fromArray(spec, board, w3 + z2 + y1 + x2) +
BV.fromArray(spec, board, w3 + z2 + y1 + x3) +
BV.fromArray(spec, board, w3 + z2 + y2 + x1) +
BV.fromArray(spec, board, w3 + z2 + y2 + x2) +
BV.fromArray(spec, board, w3 + z2 + y2 + x3) +
BV.fromArray(spec, board, w3 + z2 + y3 + x1) +
BV.fromArray(spec, board, w3 + z2 + y3 + x2) +
BV.fromArray(spec, board, w3 + z2 + y3 + x3) +
BV.fromArray(spec, board, w3 + z3 + y1 + x1) +
BV.fromArray(spec, board, w3 + z3 + y1 + x2) +
BV.fromArray(spec, board, w3 + z3 + y1 + x3) +
BV.fromArray(spec, board, w3 + z3 + y2 + x1) +
BV.fromArray(spec, board, w3 + z3 + y2 + x2) +
BV.fromArray(spec, board, w3 + z3 + y2 + x3) +
BV.fromArray(spec, board, w3 + z3 + y3 + x1) +
BV.fromArray(spec, board, w3 + z3 + y3 + x2) +
BV.fromArray(spec, board, w3 + z3 + y3 + x3)
val act = BV.fromArray(spec, board, w2 + z2 + y2 + x2).eq(1.toByte)
val m3 = sum.eq(3.toByte)
val m23 = sum.eq(2.toByte).or(m3)
val next = BV.broadcast(spec, 0).blend(BV.broadcast(spec, 1), (m23.and(act)).or(act.not.and(m3)))
next.intoArray(board2, w2 + z2 + y2 + x2)
x += spec.length()
y += 1
z += 1
w += 1
if cycle < (cycles - 1) then loop2(board2, board, cycle + 1)
def part1: Int =
val board = Array.ofDim[Byte](zDim * yDim * xDim)
val board2 = Array.ofDim[Byte](zDim * yDim * xDim)
populate1(board)
loop1(board, board2)
sumBytes(board)
def part2: Int =
val board = Array.ofDim[Byte](wDim * zDim * yDim * xDim)
val board2 = Array.ofDim[Byte](wDim * zDim * yDim * xDim)
populate2(board)
loop2(board, board2)
sumBytes(board)
def sumBytes(xs: Array[Byte]): Int =
val specS = spec.withLanes(java.lang.Short.TYPE)
val mask = ShortVector.broadcast(specS, 0x00FF)
var acc = ShortVector.broadcast(specS, 0)
var i = 0
while i < spec.loopBound(xs.length) do
var bv = BV.fromArray(spec, xs, i)
var sv = bv.reinterpretAsShorts()
acc = acc.add(sv.and(mask))
acc = acc.add(sv.lanewise(VectorOperators.LSHR, 8).and(mask))
i += spec.length()
var accTail = 0
while i < xs.length do
accTail += xs(i) & 0xFF
i += 1
acc.reduceLanes(VectorOperators.ADD) + accTail
def main(args: Array[String]): Unit =
part1.pipe(println)
part2.pipe(println)