146 lines
5.1 KiB
Scala
146 lines
5.1 KiB
Scala
package aoc.day20
|
|
|
|
import aoc._
|
|
import scala.collection.mutable
|
|
|
|
enum Pulse:
|
|
case Low, High
|
|
|
|
def flip = Pulse.fromOrdinal(1 - ordinal)
|
|
import Pulse._
|
|
|
|
sealed abstract class Module(targets: Set[String]):
|
|
def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse])
|
|
|
|
final def broadcast(p: Pulse) = targets.toIterator.map((_, p)).toMap
|
|
|
|
case class FlipFlop(state: Pulse, targets: Set[String]) extends Module(targets):
|
|
def handle(source: String, pulse: Pulse) = pulse match
|
|
case High => (this, Map.empty)
|
|
case Low =>
|
|
val nw = FlipFlop(state.flip, targets)
|
|
(nw, nw.broadcast(nw.state))
|
|
|
|
object FlipFlop:
|
|
def apply(targets: Set[String]) = new FlipFlop(Low, targets)
|
|
|
|
case class Conjunc(sources: Map[String, Pulse], targets: Set[String]) extends Module(targets):
|
|
override def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) =
|
|
val nw = Conjunc(sources.updated(source, pulse), targets)
|
|
(nw, nw.broadcast(if nw.sources.values.forall(_ == High) then Low else High))
|
|
|
|
object Conjunc:
|
|
def apply(sources: Set[String], targets: Set[String]) = new Conjunc(sources.toIterator.map((_, Low)).toMap, targets)
|
|
|
|
case class Broadcast(targets: Set[String]) extends Module(targets):
|
|
def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) = (this, broadcast(pulse))
|
|
|
|
enum ModuleTyp { case Broadcast, Flip, Conjunc }
|
|
object Parser extends CommonParser:
|
|
|
|
type Node = (ModuleTyp, String, Set[String])
|
|
|
|
val name = """[a-zA-Z]+""".r
|
|
|
|
val moduleTyp: Parser[(ModuleTyp, String)] =
|
|
("broadcaster" ^^^ (ModuleTyp.Broadcast, "broadcaster")) | ('%' ~ name ^^ { case _ ~ name =>
|
|
(ModuleTyp.Flip, name)
|
|
}) |
|
|
('&' ~ name ^^ { case _ ~ name => (ModuleTyp.Conjunc, name) })
|
|
val node: Parser[Node] = moduleTyp ~ "->" ~ repsep(name, ",") ^^ { case (typ, name) ~ _ ~ targets =>
|
|
(typ, name, targets.toSet)
|
|
}
|
|
|
|
val nodes =
|
|
val ns = lines.map(Parser.parse(Parser.node, _).get).toSeq
|
|
val edges = ns.flatMap((_, from, to) => to.map((from, _)))
|
|
def sourcesOf(target: String) = edges.filter(_._2 == target).map(_._1).toSet
|
|
|
|
ns.map: (typ, name, targets) =>
|
|
val mod = typ match
|
|
case ModuleTyp.Broadcast => Broadcast(targets)
|
|
case ModuleTyp.Flip => FlipFlop(targets)
|
|
case ModuleTyp.Conjunc => Conjunc(sourcesOf(name), targets)
|
|
(name, mod)
|
|
.toMap
|
|
|
|
extension (m: Map[String, Module])
|
|
def sendSig() = m.loop(mutable.Queue(("", "broadcaster", Low)), Seq.empty)
|
|
def send(cache: mutable.Map[Map[String, Module], (Map[String, Module], Int, Int)]): (Map[String, Module], Int, Int) =
|
|
cache.getOrElseUpdate(
|
|
m, {
|
|
val (nm, pulses) = m.sendSig()
|
|
val (lows, highs) = pulses.partition((_, _, p) => p == Low)
|
|
(nm, lows.size, highs.size)
|
|
}
|
|
)
|
|
|
|
@scala.annotation.tailrec
|
|
def loopUntilTwice(
|
|
breakCond: (String, String, Pulse) => Boolean,
|
|
first: Seq[Int] = Seq.empty,
|
|
current: Int = 0
|
|
): (Int, Int) =
|
|
println(s"at $current: ${m.hashCode()}")
|
|
first match
|
|
case Seq(a, b) => (a, b)
|
|
case _ =>
|
|
val (nm, pulses) = sendSig()
|
|
val nfirst =
|
|
if pulses.exists(breakCond(_, _, _)) then first :+ current
|
|
else first
|
|
nm.loopUntilTwice(breakCond, nfirst, current + 1)
|
|
|
|
@scala.annotation.tailrec
|
|
private def loop(
|
|
queue: mutable.Queue[(String, String, Pulse)],
|
|
pulses: Seq[(String, String, Pulse)]
|
|
): (Map[String, Module], Seq[(String, String, Pulse)]) =
|
|
if queue.isEmpty then (m, pulses)
|
|
else
|
|
val (from, to, pulse) = queue.dequeue()
|
|
val nm = m.get(to) match
|
|
case None => m
|
|
case Some(value) =>
|
|
val (nmod, nexts) = m(to).handle(from, pulse)
|
|
if to == "rs" && pulse == High then println(s"$from -> $to pulse = $pulse | $nmod")
|
|
queue ++= nexts.toIterator.map((target, pulse) => (to, target, pulse))
|
|
m.updated(to, nmod)
|
|
nm.loop(queue, (from, to, pulse) +: pulses)
|
|
|
|
// part 1
|
|
def part1 =
|
|
val mp = mutable.Map.empty[Map[String, Module], (Map[String, Module], Int, Int)]
|
|
val (_, lows, highs) = (1 to 1000).foldLeft((nodes, 0, 0)):
|
|
case ((nodes, lows, highs), _) =>
|
|
val (pn, pl, ph) = nodes.send(mp)
|
|
(pn, lows + pl, highs + ph)
|
|
println(1L * lows * highs)
|
|
|
|
def part2 =
|
|
def sendsTo(target: String) = nodes.filter: (name, node) =>
|
|
node match
|
|
case Broadcast(targets) => targets.contains(target)
|
|
case FlipFlop(_, targets) => targets.contains(target)
|
|
case Conjunc(_, targets) => targets.contains(target)
|
|
val v = sendsTo("rx").toSeq match // should be at most one
|
|
case Seq((n, conj)) =>
|
|
assert(conj.isInstanceOf[Conjunc])
|
|
n
|
|
case ss => throw Exception(s"Unexpected $ss")
|
|
val toV = sendsTo(v)
|
|
println(toV)
|
|
// find cycles to send highs to v
|
|
val res = toV.keys
|
|
.map(toV => nodes.loopUntilTwice((from, to, pulse) => from == toV && to == v && pulse == High))
|
|
.foldLeft(1L) { case (n, (a, b)) => lcm(n, b - a) }
|
|
println(res)
|
|
|
|
@scala.annotation.tailrec
|
|
def gcd(a: Long, b: Long): Long = if b == 0 then a else gcd(b, a % b)
|
|
def lcm(a: Long, b: Long) = a / gcd(a, b) * b
|
|
|
|
@main def Day20(part: Int) = part match
|
|
case 1 => part1
|
|
case 2 => part2
|