package aoc.day21 import aoc._ import aoc.direction._ import scala.collection.mutable val board = lines .map(_.toArray) .toArray val startPoint = board.toSeq.zipWithIndex.findMap { (row, idx) => row.toSeq.zipWithIndex.findMap: (cell, col) => if cell == 'S' then Some((idx, col)) else None }.get def inside(x: Int, y: Int) = x >= 0 && x < board.length && y >= 0 && y < board(0).length def walkable(x: Int, y: Int) = inside(x, y) && board(x)(y) != '#' // part 1 def walk(steps: Int, starting: Set[(Int, Int)]): Set[(Int, Int)] = if steps == 0 then starting else // println(s"$steps => ${starting.size}") val ns = starting .flatMap((a, b) => Dir.all.map(_(a, b))) .filter(walkable(_, _)) walk(steps - 1, ns) def part1 = val possible = walk(64, Set(startPoint)) println(possible.size) // part 2 object shortestPathsFrom: val cache = mutable.Map.empty[(Int, Int), Array[Array[Int]]] def apply(x: Int, y: Int) = cache.getOrElseUpdate( (x, y), { val mp = mutable.Map((x, y) -> 0) val queue = mutable.Queue((x, y)) while !queue.isEmpty do val (x, y) = queue.dequeue() val dist = mp((x, y)) Dir.all .map(_(x, y)) .filter(walkable(_, _)) .foreach { pt => if !mp.contains(pt) then mp += pt -> (dist + 1) queue += pt } val (a, b) = mp.toMap.partition((_, dist) => dist % 2 == 0) Iterator(a, b) .map(v => v.values.toSeq.sorted.toArray) .toArray } ) // input is funni val _ = assert(board.head.forall(_ == '.')) assert(board.last.forall(_ == '.')) assert(board.forall(_.head == '.')) assert(board.forall(_.last == '.')) val (x, y) = startPoint assert((0 until board.length).forall(board(_)(y) != '#')) assert((0 until board(0).length).forall(board(x)(_) != '#')) def pointsReachable(steps: Int): Long = val (sx, sy) = startPoint def distToRow(delta: Int) = if delta == 0 then 0L else if delta < 0 then // go up 1L + sx + (delta + 1L).abs * board.size else 0L + (board.size - sx) + (delta - 1L) * board.size def distToCol(delta: Int) = if delta == 0 then 0L else if delta < 0 then // go left 1L + sy + (delta + 1L).abs * board(0).size else 0L + (board(0).size - sy) + (delta - 1L) * board(0).size def startWhere(dx: Int, dy: Int) = val x = if dx == 0 then sx else if dx < 0 then board.size - 1 else 0 val y = if dy == 0 then sy else if dy < 0 then board(0).size - 1 else 0 (x, y) def countWhere(steps: Int, dx: Int, dy: Int) = val dToBoard = distToRow(dx) + distToCol(dy) if dToBoard > steps then 0L else val (stx, sty) = startWhere(dx, dy) val dists = shortestPathsFrom(stx, sty)((steps - dToBoard).toInt % 2) dists.sortedCountUntil(steps - dToBoard.toInt).toLong extension (arr: Array[Int]) def sortedCountUntil(max: Int) = var low = 0 var hi = arr.size while low < hi do val mid = (low + hi) / 2 if arr(mid) <= max then low = mid + 1 else hi = mid low ((0 to steps).toIterator .takeWhile(distToRow(_) <= steps) ++ (-1 to -steps by -1).toIterator.takeWhile(distToRow(_) <= steps)) .map(dx => val dRow = distToRow(dx) val zero = countWhere(steps, dx, 0) val right = val (stx, sty) = startWhere(dx, 1) val baseDistance = dRow + distToCol(1) // dToBoard = baseDistance + (delta-1) * board.size if baseDistance > steps then 0L else val maxRightAllBoard = ((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1 val odd = ((steps - baseDistance) % 2).toInt val a = (1L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)(odd).size val b = (2L until maxRightAllBoard by 2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size val rest = (maxRightAllBoard to steps).toIterator .takeWhile(dRow + distToCol(_) <= steps) .map(countWhere(steps, dx, _)) .sum a + b + rest val left = val (stx, sty) = startWhere(dx, -1) val baseDistance = dRow + distToCol(-1) // dToBoard = baseDistance + (delta-1) * board.size if baseDistance > steps then 0L else val maxLeftAllBoard = ((steps - baseDistance - shortestPathsFrom(stx, sty).map(_.last).max) / board.size).toInt.max(0) + 1 val odd = ((steps - baseDistance) % 2).toInt val a = (-1L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)(odd).size val b = (-2L until -maxLeftAllBoard by -2).size * shortestPathsFrom(stx, sty)((board.size + odd) % 2).size val rest = (-maxLeftAllBoard to -steps by -1).toIterator .takeWhile(dRow + distToCol(_) <= steps) .map(countWhere(steps, dx, _)) .sum a + b + rest zero + right + left ) .sum def part2 = val inputs = Seq(26501365) println(inputs.map(pointsReachable)) @main def Day21(part: Int) = part match case 1 => part1 case 2 => part2