aoc2023/Day21.scala

163 lines
5.1 KiB
Scala
Raw Normal View History

2023-12-21 17:21:38 +00:00
package aoc.day21
import aoc._
import aoc.direction._
import scala.collection.mutable
val board =
lines
.map(_.toArray)
.toArray
val startPoint =
board.toSeq.zipWithIndex.findMap { (row, idx) =>
row.toSeq.zipWithIndex.findMap: (cell, col) =>
if cell == 'S' then Some((idx, col)) else None
}.get
def inside(x: Int, y: Int) =
x >= 0 && x < board.length && y >= 0 && y < board(0).length
def walkable(x: Int, y: Int) = inside(x, y) && board(x)(y) != '#'
// part 1
def walk(steps: Int, starting: Set[(Int, Int)]): Set[(Int, Int)] =
if steps == 0 then starting
else
// println(s"$steps => ${starting.size}")
val ns =
starting
.flatMap((a, b) => Dir.all.map(_(a, b)))
2023-12-22 14:04:30 +00:00
.filter(walkable(_, _))
2023-12-21 17:21:38 +00:00
walk(steps - 1, ns)
def part1 =
2023-12-22 14:04:30 +00:00
val possible = walk(64, Set(startPoint))
2023-12-21 17:21:38 +00:00
println(possible.size)
// part 2
object shortestPathsFrom:
val cache = mutable.Map.empty[(Int, Int), Array[Array[Int]]]
def apply(x: Int, y: Int) =
cache.getOrElseUpdate(
(x, y), {
val mp = mutable.Map((x, y) -> 0)
val queue = mutable.Queue((x, y))
while !queue.isEmpty do
val (x, y) = queue.dequeue()
val dist = mp((x, y))
Dir.all
.map(_(x, y))
.filter(walkable(_, _))
.foreach { pt =>
if !mp.contains(pt) then
mp += pt -> (dist + 1)
queue += pt
}
val (a, b) = mp.toMap.partition((_, dist) => dist % 2 == 0)
Iterator(a, b)
.map(v => v.values.toSeq.sorted.toArray)
.toArray
}
)
// input is funni
val _ =
assert(board.head.forall(_ == '.'))
assert(board.last.forall(_ == '.'))
assert(board.forall(_.head == '.'))
assert(board.forall(_.last == '.'))
val (x, y) = startPoint
assert((0 until board.length).forall(board(_)(y) != '#'))
assert((0 until board(0).length).forall(board(x)(_) != '#'))
def pointsReachable(steps: Int): Long =
val (sx, sy) = startPoint
def distToRow(delta: Int) =
if delta == 0 then 0L
else if delta < 0 then
// go up
1L + sx + (delta + 1L).abs * board.size
else 0L + (board.size - sx) + (delta - 1L) * board.size
def distToCol(delta: Int) =
if delta == 0 then 0L
else if delta < 0 then
// go left
1L + sy + (delta + 1L).abs * board(0).size
else 0L + (board(0).size - sy) + (delta - 1L) * board(0).size
def startWhere(dx: Int, dy: Int) =
val x = if dx == 0 then sx else if dx < 0 then board.size - 1 else 0
val y = if dy == 0 then sy else if dy < 0 then board(0).size - 1 else 0
(x, y)
def countWhere(steps: Int, dx: Int, dy: Int) =
val dToBoard = distToRow(dx) + distToCol(dy)
if dToBoard > steps then 0L
else
val (stx, sty) = startWhere(dx, dy)
val dists = shortestPathsFrom(stx, sty)((steps - dToBoard).toInt % 2)
2023-12-22 14:04:30 +00:00
dists.sortedCountUntil(steps - dToBoard.toInt).toLong
2023-12-21 17:21:38 +00:00
extension (arr: Array[Int])
def sortedCountUntil(max: Int) =
var low = 0
var hi = arr.size
while low < hi do
val mid = (low + hi) / 2
if arr(mid) <= max then low = mid + 1 else hi = mid
low
((0 to steps).toIterator
.takeWhile(distToRow(_) <= steps)
++ (-1 to -steps by -1).toIterator.takeWhile(distToRow(_) <= steps))
.map(dx =>
val dRow = distToRow(dx)
val zero = countWhere(steps, dx, 0)
val right =
val (stx, sty) = startWhere(dx, 1)
val baseDistance = dRow + distToCol(1) // dToBoard = baseDistance + (delta-1) * board.size
if baseDistance > steps then 0L
else
val maxRightAllBoard =
((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1
val odd = ((steps - baseDistance) % 2).toInt
val a = (1L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)(odd).size
val b = (2L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size
val rest = (maxRightAllBoard to steps).toIterator
.takeWhile(dRow + distToCol(_) <= steps)
.map(countWhere(steps, dx, _))
.sum
a + b + rest
val left =
val (stx, sty) = startWhere(dx, -1)
val baseDistance = dRow + distToCol(-1) // dToBoard = baseDistance + (delta-1) * board.size
if baseDistance > steps then 0L
else
val maxLeftAllBoard =
((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1
val odd = ((steps - baseDistance) % 2).toInt
val a = (-1L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)(odd).size
val b = (-2L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size
val rest = (-maxLeftAllBoard to -steps by -1).toIterator
.takeWhile(dRow + distToCol(_) <= steps)
.map(countWhere(steps, dx, _))
.sum
a + b + rest
zero + right + left
)
.sum
def part2 =
val inputs = Seq(26501365)
println(inputs.map(pointsReachable))
@main def Day21(part: Int) = part match
case 1 => part1
case 2 => part2