163 lines
5.1 KiB
Scala
163 lines
5.1 KiB
Scala
package aoc.day21
|
|
|
|
import aoc._
|
|
import aoc.direction._
|
|
import scala.collection.mutable
|
|
|
|
val board =
|
|
lines
|
|
.map(_.toArray)
|
|
.toArray
|
|
|
|
val startPoint =
|
|
board.toSeq.zipWithIndex.findMap { (row, idx) =>
|
|
row.toSeq.zipWithIndex.findMap: (cell, col) =>
|
|
if cell == 'S' then Some((idx, col)) else None
|
|
}.get
|
|
|
|
def inside(x: Int, y: Int) =
|
|
x >= 0 && x < board.length && y >= 0 && y < board(0).length
|
|
|
|
def walkable(x: Int, y: Int) = inside(x, y) && board(x)(y) != '#'
|
|
|
|
// part 1
|
|
|
|
def walk(steps: Int, starting: Set[(Int, Int)]): Set[(Int, Int)] =
|
|
if steps == 0 then starting
|
|
else
|
|
// println(s"$steps => ${starting.size}")
|
|
val ns =
|
|
starting
|
|
.flatMap((a, b) => Dir.all.map(_(a, b)))
|
|
.filter(walkable(_, _))
|
|
walk(steps - 1, ns)
|
|
|
|
def part1 =
|
|
val possible = walk(64, Set(startPoint))
|
|
println(possible.size)
|
|
|
|
// part 2
|
|
|
|
object shortestPathsFrom:
|
|
val cache = mutable.Map.empty[(Int, Int), Array[Array[Int]]]
|
|
|
|
def apply(x: Int, y: Int) =
|
|
cache.getOrElseUpdate(
|
|
(x, y), {
|
|
val mp = mutable.Map((x, y) -> 0)
|
|
val queue = mutable.Queue((x, y))
|
|
while !queue.isEmpty do
|
|
val (x, y) = queue.dequeue()
|
|
val dist = mp((x, y))
|
|
Dir.all
|
|
.map(_(x, y))
|
|
.filter(walkable(_, _))
|
|
.foreach { pt =>
|
|
if !mp.contains(pt) then
|
|
mp += pt -> (dist + 1)
|
|
queue += pt
|
|
}
|
|
val (a, b) = mp.toMap.partition((_, dist) => dist % 2 == 0)
|
|
Iterator(a, b)
|
|
.map(v => v.values.toSeq.sorted.toArray)
|
|
.toArray
|
|
}
|
|
)
|
|
|
|
// input is funni
|
|
|
|
val _ =
|
|
assert(board.head.forall(_ == '.'))
|
|
assert(board.last.forall(_ == '.'))
|
|
assert(board.forall(_.head == '.'))
|
|
assert(board.forall(_.last == '.'))
|
|
|
|
val (x, y) = startPoint
|
|
assert((0 until board.length).forall(board(_)(y) != '#'))
|
|
assert((0 until board(0).length).forall(board(x)(_) != '#'))
|
|
|
|
def pointsReachable(steps: Int): Long =
|
|
val (sx, sy) = startPoint
|
|
def distToRow(delta: Int) =
|
|
if delta == 0 then 0L
|
|
else if delta < 0 then
|
|
// go up
|
|
1L + sx + (delta + 1L).abs * board.size
|
|
else 0L + (board.size - sx) + (delta - 1L) * board.size
|
|
def distToCol(delta: Int) =
|
|
if delta == 0 then 0L
|
|
else if delta < 0 then
|
|
// go left
|
|
1L + sy + (delta + 1L).abs * board(0).size
|
|
else 0L + (board(0).size - sy) + (delta - 1L) * board(0).size
|
|
|
|
def startWhere(dx: Int, dy: Int) =
|
|
val x = if dx == 0 then sx else if dx < 0 then board.size - 1 else 0
|
|
val y = if dy == 0 then sy else if dy < 0 then board(0).size - 1 else 0
|
|
(x, y)
|
|
|
|
def countWhere(steps: Int, dx: Int, dy: Int) =
|
|
val dToBoard = distToRow(dx) + distToCol(dy)
|
|
if dToBoard > steps then 0L
|
|
else
|
|
val (stx, sty) = startWhere(dx, dy)
|
|
val dists = shortestPathsFrom(stx, sty)((steps - dToBoard).toInt % 2)
|
|
dists.sortedCountUntil(steps - dToBoard.toInt).toLong
|
|
|
|
extension (arr: Array[Int])
|
|
def sortedCountUntil(max: Int) =
|
|
var low = 0
|
|
var hi = arr.size
|
|
while low < hi do
|
|
val mid = (low + hi) / 2
|
|
if arr(mid) <= max then low = mid + 1 else hi = mid
|
|
low
|
|
|
|
((0 to steps).toIterator
|
|
.takeWhile(distToRow(_) <= steps)
|
|
++ (-1 to -steps by -1).toIterator.takeWhile(distToRow(_) <= steps))
|
|
.map(dx =>
|
|
val dRow = distToRow(dx)
|
|
val zero = countWhere(steps, dx, 0)
|
|
val right =
|
|
val (stx, sty) = startWhere(dx, 1)
|
|
val baseDistance = dRow + distToCol(1) // dToBoard = baseDistance + (delta-1) * board.size
|
|
if baseDistance > steps then 0L
|
|
else
|
|
val maxRightAllBoard =
|
|
((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1
|
|
val odd = ((steps - baseDistance) % 2).toInt
|
|
val a = (1L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)(odd).size
|
|
val b = (2L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size
|
|
val rest = (maxRightAllBoard to steps).toIterator
|
|
.takeWhile(dRow + distToCol(_) <= steps)
|
|
.map(countWhere(steps, dx, _))
|
|
.sum
|
|
a + b + rest
|
|
val left =
|
|
val (stx, sty) = startWhere(dx, -1)
|
|
val baseDistance = dRow + distToCol(-1) // dToBoard = baseDistance + (delta-1) * board.size
|
|
if baseDistance > steps then 0L
|
|
else
|
|
val maxLeftAllBoard =
|
|
((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1
|
|
val odd = ((steps - baseDistance) % 2).toInt
|
|
val a = (-1L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)(odd).size
|
|
val b = (-2L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size
|
|
val rest = (-maxLeftAllBoard to -steps by -1).toIterator
|
|
.takeWhile(dRow + distToCol(_) <= steps)
|
|
.map(countWhere(steps, dx, _))
|
|
.sum
|
|
a + b + rest
|
|
zero + right + left
|
|
)
|
|
.sum
|
|
|
|
def part2 =
|
|
val inputs = Seq(26501365)
|
|
println(inputs.map(pointsReachable))
|
|
|
|
@main def Day21(part: Int) = part match
|
|
case 1 => part1
|
|
case 2 => part2
|