import pyboolector as bt
import clause_utils as cl
from btor2circuit import Btor2Circuit


class ProductMachine:
    """
    The product machine of two circuits, each with unrolled copies
    (can be lazily built and added to the solver).

    Constructor has no side effect.
    """
    solver: bt.Boolector
    "The underlying SAT solver."
    suffix_a: str
    "The suffix for the first circuit, e.g. `@A`."
    suffix_b: str
    "The suffix for the second circuit, e.g. `@B`."
    btor2_lines_a: list[str]
    "The lines of the first circuit's BTOR2 file."
    btor2_lines_b: list[str]
    "The lines of the second circuit's BTOR2 file."
    circuit_a: list[Btor2Circuit]
    "The unrolling copies of the first circuit."
    circuit_b: list[Btor2Circuit]
    "The unrolling copies of the second circuit."
    ob_eq: dict[(int, int), bt.BoolectorNode]
    "The ObEq signal, indexed by the unrolling copy indices."

    def __init__(self, solver: bt.Boolector, suffix_a: str, suffix_b: str,
    btor2_lines_a: list[str], btor2_lines_b: list[str]):
        self.solver = solver
        self.suffix_a = suffix_a
        self.suffix_b = suffix_b
        self.btor2_lines_a = btor2_lines_a
        self.btor2_lines_b = btor2_lines_b
        self.circuit_a = []
        self.circuit_b = []
        self.ob_eq = {}

    def _join_states(self, c_curr: Btor2Circuit, c_next: Btor2Circuit):
        "(private method) Joins the next states of `c_curr` with the current states of `c_next`."
        for name in c_curr.state_names():
            self.solver.Assert(c_curr.next_state_by_name(name) == c_next.curr_state_by_name(name))
    
    def grow(self, height_a: int, height_b: int):
        "Grows the product machine to the given heights."
        while len(self.circuit_a) < height_a:
            index = len(self.circuit_a)
            suffix = f"{self.suffix_a}{index}"
            self.circuit_a.append(Btor2Circuit(self.solver, suffix, self.btor2_lines_a))
            if index > 0:
                self._join_states(self.circuit_a[index - 1], self.circuit_a[index])
        while len(self.circuit_b) < height_b:
            index = len(self.circuit_b)
            suffix = f"{self.suffix_b}{index}"
            self.circuit_b.append(Btor2Circuit(self.solver, suffix, self.btor2_lines_b))
            if index > 0:
                self._join_states(self.circuit_b[index - 1], self.circuit_b[index])
            
    def build_ob_eq_signal(self, prime_a: int, prime_b: int):
        "Builds the ObEq signal for the given prime counts."
        a = self.circuit_a[prime_a]
        b = self.circuit_b[prime_b]
        outputs_all_eq = self.solver.Const(True)
        for name in a.output_names():
            output_eq = a.output_by_name(name) == b.output_by_name(name)
            outputs_all_eq = outputs_all_eq & output_eq
        both_invalid = (~a.valid_signal()) & (~b.valid_signal())
        both_valid = a.valid_signal() & b.valid_signal()
        ob_eq = (both_valid & outputs_all_eq) | both_invalid
        self.ob_eq[(prime_a, prime_b)] = ob_eq

    def apply_initial_relation(self):
        "Applies the initial relation to the solver."
        a = self.circuit_a[0]
        b = self.circuit_b[0]
        for name in a.state_names():
            if name.startswith("_"):  # local state
                node = a.init_state_by_name_or_none(name)
                if node is not None:
                    self.solver.Assert(a.curr_state_by_name(name) == node)
            else:  # interface state
                self.solver.Assert(a.curr_state_by_name(name) == b.curr_state_by_name(name))
        for name in b.state_names():
            if name.startswith("_"):  # local state
                node = b.init_state_by_name_or_none(name)
                if node is not None:
                    self.solver.Assert(b.curr_state_by_name(name) == node)
            # skip the interface states, they are already handled.
    
    def get_node_by_number(self, node_id: cl.Node, prime_a: int, prime_b: int) -> bt.BoolectorNode:
        "Returns the internal node for a numeric node."
        copy_id, local_id = node_id
        if copy_id > 0:  # circuit A
            return self.circuit_a[copy_id - 1 + prime_a].id_to_node[local_id]
        if copy_id < 0:  # circuit B
            return self.circuit_b[-copy_id - 1 + prime_b].id_to_node[local_id]
        return self.ob_eq[(prime_a, prime_b)]  # if `copy_id == 0`, then it is the `ObEq` signal

    def get_literal_by_number(self, lit: cl.Literal, prime_a: int, prime_b: int) -> bt.BoolectorNode:
        "Returns the internal node for a numeric literal."
        if len(lit) == 2:
            pos, node_id = lit
            node = self.get_node_by_number(node_id, prime_a, prime_b)
            return node if pos else ~node
        else:
            eq, node_id_1, node_id_2 = lit
            node_1 = self.get_node_by_number(node_id_1, prime_a, prime_b)
            node_2 = self.get_node_by_number(node_id_2, prime_a, prime_b)
            return (node_1 == node_2) if eq else (node_1 != node_2)

    def get_clause_by_number(self, clause: cl.Clause, prime_a: int, prime_b: int) -> bt.BoolectorNode:
        "Returns the internal node for a numeric clause."
        clause_node = self.solver.Const(False)
        for lit in clause:
            clause_node = clause_node | self.get_literal_by_number(lit, prime_a, prime_b)
        return clause_node
