aoc2023/Day20.scala
2023-12-21 18:21:32 +01:00

146 lines
5.1 KiB
Scala

package aoc.day20
import aoc._
import scala.collection.mutable
enum Pulse:
case Low, High
def flip = Pulse.fromOrdinal(1 - ordinal)
import Pulse._
sealed abstract class Module(targets: Set[String]):
def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse])
final def broadcast(p: Pulse) = targets.toIterator.map((_, p)).toMap
case class FlipFlop(state: Pulse, targets: Set[String]) extends Module(targets):
def handle(source: String, pulse: Pulse) = pulse match
case High => (this, Map.empty)
case Low =>
val nw = FlipFlop(state.flip, targets)
(nw, nw.broadcast(nw.state))
object FlipFlop:
def apply(targets: Set[String]) = new FlipFlop(Low, targets)
case class Conjunc(sources: Map[String, Pulse], targets: Set[String]) extends Module(targets):
override def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) =
val nw = Conjunc(sources.updated(source, pulse), targets)
(nw, nw.broadcast(if nw.sources.values.forall(_ == High) then Low else High))
object Conjunc:
def apply(sources: Set[String], targets: Set[String]) = new Conjunc(sources.toIterator.map((_, Low)).toMap, targets)
case class Broadcast(targets: Set[String]) extends Module(targets):
def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) = (this, broadcast(pulse))
enum ModuleTyp { case Broadcast, Flip, Conjunc }
object Parser extends CommonParser:
type Node = (ModuleTyp, String, Set[String])
val name = """[a-zA-Z]+""".r
val moduleTyp: Parser[(ModuleTyp, String)] =
("broadcaster" ^^^ (ModuleTyp.Broadcast, "broadcaster")) | ('%' ~ name ^^ { case _ ~ name =>
(ModuleTyp.Flip, name)
}) |
('&' ~ name ^^ { case _ ~ name => (ModuleTyp.Conjunc, name) })
val node: Parser[Node] = moduleTyp ~ "->" ~ repsep(name, ",") ^^ { case (typ, name) ~ _ ~ targets =>
(typ, name, targets.toSet)
}
val nodes =
val ns = lines.map(Parser.parse(Parser.node, _).get).toSeq
val edges = ns.flatMap((_, from, to) => to.map((from, _)))
def sourcesOf(target: String) = edges.filter(_._2 == target).map(_._1).toSet
ns.map: (typ, name, targets) =>
val mod = typ match
case ModuleTyp.Broadcast => Broadcast(targets)
case ModuleTyp.Flip => FlipFlop(targets)
case ModuleTyp.Conjunc => Conjunc(sourcesOf(name), targets)
(name, mod)
.toMap
extension (m: Map[String, Module])
def sendSig() = m.loop(mutable.Queue(("", "broadcaster", Low)), Seq.empty)
def send(cache: mutable.Map[Map[String, Module], (Map[String, Module], Int, Int)]): (Map[String, Module], Int, Int) =
cache.getOrElseUpdate(
m, {
val (nm, pulses) = m.sendSig()
val (lows, highs) = pulses.partition((_, _, p) => p == Low)
(nm, lows.size, highs.size)
}
)
@scala.annotation.tailrec
def loopUntilTwice(
breakCond: (String, String, Pulse) => Boolean,
first: Seq[Int] = Seq.empty,
current: Int = 0
): (Int, Int) =
println(s"at $current: ${m.hashCode()}")
first match
case Seq(a, b) => (a, b)
case _ =>
val (nm, pulses) = sendSig()
val nfirst =
if pulses.exists(breakCond(_, _, _)) then first :+ current
else first
nm.loopUntilTwice(breakCond, nfirst, current + 1)
@scala.annotation.tailrec
private def loop(
queue: mutable.Queue[(String, String, Pulse)],
pulses: Seq[(String, String, Pulse)]
): (Map[String, Module], Seq[(String, String, Pulse)]) =
if queue.isEmpty then (m, pulses)
else
val (from, to, pulse) = queue.dequeue()
val nm = m.get(to) match
case None => m
case Some(value) =>
val (nmod, nexts) = m(to).handle(from, pulse)
if to == "rs" && pulse == High then println(s"$from -> $to pulse = $pulse | $nmod")
queue ++= nexts.toIterator.map((target, pulse) => (to, target, pulse))
m.updated(to, nmod)
nm.loop(queue, (from, to, pulse) +: pulses)
// part 1
def part1 =
val mp = mutable.Map.empty[Map[String, Module], (Map[String, Module], Int, Int)]
val (_, lows, highs) = (1 to 1000).foldLeft((nodes, 0, 0)):
case ((nodes, lows, highs), _) =>
val (pn, pl, ph) = nodes.send(mp)
(pn, lows + pl, highs + ph)
println(1L * lows * highs)
def part2 =
def sendsTo(target: String) = nodes.filter: (name, node) =>
node match
case Broadcast(targets) => targets.contains(target)
case FlipFlop(_, targets) => targets.contains(target)
case Conjunc(_, targets) => targets.contains(target)
val v = sendsTo("rx").toSeq match // should be at most one
case Seq((n, conj)) =>
assert(conj.isInstanceOf[Conjunc])
n
case ss => throw Exception(s"Unexpected $ss")
val toV = sendsTo(v)
println(toV)
// find cycles to send highs to v
val res = toV.keys
.map(toV => nodes.loopUntilTwice((from, to, pulse) => from == toV && to == v && pulse == High))
.foldLeft(1L) { case (n, (a, b)) => lcm(n, b - a) }
println(res)
@scala.annotation.tailrec
def gcd(a: Long, b: Long): Long = if b == 0 then a else gcd(b, a % b)
def lcm(a: Long, b: Long) = a / gcd(a, b) * b
@main def Day20(part: Int) = part match
case 1 => part1
case 2 => part2