package aoc.day20 import aoc._ import scala.collection.mutable enum Pulse: case Low, High def flip = Pulse.fromOrdinal(1 - ordinal) import Pulse._ sealed abstract class Module(targets: Set[String]): def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) final def broadcast(p: Pulse) = targets.toIterator.map((_, p)).toMap case class FlipFlop(state: Pulse, targets: Set[String]) extends Module(targets): def handle(source: String, pulse: Pulse) = pulse match case High => (this, Map.empty) case Low => val nw = FlipFlop(state.flip, targets) (nw, nw.broadcast(nw.state)) object FlipFlop: def apply(targets: Set[String]) = new FlipFlop(Low, targets) case class Conjunc(sources: Map[String, Pulse], targets: Set[String]) extends Module(targets): override def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) = val nw = Conjunc(sources.updated(source, pulse), targets) (nw, nw.broadcast(if nw.sources.values.forall(_ == High) then Low else High)) object Conjunc: def apply(sources: Set[String], targets: Set[String]) = new Conjunc(sources.toIterator.map((_, Low)).toMap, targets) case class Broadcast(targets: Set[String]) extends Module(targets): def handle(source: String, pulse: Pulse): (Module, Map[String, Pulse]) = (this, broadcast(pulse)) enum ModuleTyp { case Broadcast, Flip, Conjunc } object Parser extends CommonParser: type Node = (ModuleTyp, String, Set[String]) val name = """[a-zA-Z]+""".r val moduleTyp: Parser[(ModuleTyp, String)] = ("broadcaster" ^^^ (ModuleTyp.Broadcast, "broadcaster")) | ('%' ~ name ^^ { case _ ~ name => (ModuleTyp.Flip, name) }) | ('&' ~ name ^^ { case _ ~ name => (ModuleTyp.Conjunc, name) }) val node: Parser[Node] = moduleTyp ~ "->" ~ repsep(name, ",") ^^ { case (typ, name) ~ _ ~ targets => (typ, name, targets.toSet) } val nodes = val ns = lines.map(Parser.parse(Parser.node, _).get).toSeq val edges = ns.flatMap((_, from, to) => to.map((from, _))) def sourcesOf(target: String) = edges.filter(_._2 == target).map(_._1).toSet ns.map: (typ, name, targets) => val mod = typ match case ModuleTyp.Broadcast => Broadcast(targets) case ModuleTyp.Flip => FlipFlop(targets) case ModuleTyp.Conjunc => Conjunc(sourcesOf(name), targets) (name, mod) .toMap extension (m: Map[String, Module]) def sendSig() = m.loop(mutable.Queue(("", "broadcaster", Low)), Seq.empty) def send(cache: mutable.Map[Map[String, Module], (Map[String, Module], Int, Int)]): (Map[String, Module], Int, Int) = cache.getOrElseUpdate( m, { val (nm, pulses) = m.sendSig() val (lows, highs) = pulses.partition((_, _, p) => p == Low) (nm, lows.size, highs.size) } ) @scala.annotation.tailrec def loopUntilTwice( breakCond: (String, String, Pulse) => Boolean, first: Seq[Int] = Seq.empty, current: Int = 0 ): (Int, Int) = println(s"at $current: ${m.hashCode()}") first match case Seq(a, b) => (a, b) case _ => val (nm, pulses) = sendSig() val nfirst = if pulses.exists(breakCond(_, _, _)) then first :+ current else first nm.loopUntilTwice(breakCond, nfirst, current + 1) @scala.annotation.tailrec private def loop( queue: mutable.Queue[(String, String, Pulse)], pulses: Seq[(String, String, Pulse)] ): (Map[String, Module], Seq[(String, String, Pulse)]) = if queue.isEmpty then (m, pulses) else val (from, to, pulse) = queue.dequeue() val nm = m.get(to) match case None => m case Some(value) => val (nmod, nexts) = m(to).handle(from, pulse) if to == "rs" && pulse == High then println(s"$from -> $to pulse = $pulse | $nmod") queue ++= nexts.toIterator.map((target, pulse) => (to, target, pulse)) m.updated(to, nmod) nm.loop(queue, (from, to, pulse) +: pulses) // part 1 def part1 = val mp = mutable.Map.empty[Map[String, Module], (Map[String, Module], Int, Int)] val (_, lows, highs) = (1 to 1000).foldLeft((nodes, 0, 0)): case ((nodes, lows, highs), _) => val (pn, pl, ph) = nodes.send(mp) (pn, lows + pl, highs + ph) println(1L * lows * highs) def part2 = def sendsTo(target: String) = nodes.filter: (name, node) => node match case Broadcast(targets) => targets.contains(target) case FlipFlop(_, targets) => targets.contains(target) case Conjunc(_, targets) => targets.contains(target) val v = sendsTo("rx").toSeq match // should be at most one case Seq((n, conj)) => assert(conj.isInstanceOf[Conjunc]) n case ss => throw Exception(s"Unexpected $ss") val toV = sendsTo(v) println(toV) // find cycles to send highs to v val res = toV.keys .map(toV => nodes.loopUntilTwice((from, to, pulse) => from == toV && to == v && pulse == High)) .foldLeft(1L) { case (n, (a, b)) => lcm(n, b - a) } println(res) @scala.annotation.tailrec def gcd(a: Long, b: Long): Long = if b == 0 then a else gcd(b, a % b) def lcm(a: Long, b: Long) = a / gcd(a, b) * b @main def Day20(part: Int) = part match case 1 => part1 case 2 => part2