import pyboolector as bt
from product_machine import ProductMachine
import clause_utils as cl
from abstract_and_refine import TermsMap, PartitionAssignment, AcexTreeNode, AcexTree, CexSolver
from queue import PriorityQueue, SimpleQueue


class FrameSolver:
    """
    Maintains a solver for queries shaped `SAT(? /\ F[i])` (i.e. the frame is not primed).
    """
    solver: bt.Boolector
    "The SAT solver."
    pm: ProductMachine
    "The product machine used with the solver."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str]):
        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machine:
        self.pm = ProductMachine(self.solver, "@A", "@B", btor2_lines_a, btor2_lines_b)
        self.pm.grow(1, 1)  # initially, involved nodes are states + ObEq only, 1 is enough.
        self.pm.build_ob_eq_signal(0, 0)

    def add_clause_to_solver(self, clause: cl.Clause):
        "Adds a clause (as a fixed constraint) to the solver."
        # clause = cl.expand_clause(clause)  ## experimental feature
        self.solver.Assert(self.pm.get_clause_by_number(clause, 0, 0))

    def sat_query_for_initial(self, terms_map: TermsMap) -> PartitionAssignment:
        "Performs `SAT(I /\ F[i])`, and returns the assignment, or None if UNSAT."
        self.solver.Push()
        self.pm.apply_initial_relation()
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            self.solver.Pop()
            return None
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected solver result: " + res)
        # extract the assignments:
        assignments = {}
        for _, terms in terms_map.m.items():
            for term in terms:
                node = self.pm.get_node_by_number(term, 0, 0)
                assignments[term] = node.assignment
        self.solver.Pop()
        return PartitionAssignment(assignments, terms_map)

    def sat_query_for_region(self, region: PartitionAssignment) -> bool:
        "Performs `SAT(region /\ F[i])`."
        self.solver.Push()
        for literal in region.to_literals(True):
            self.solver.Assert(self.pm.get_literal_by_number(literal, 0, 0))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            self.solver.Pop()
            return False
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected solver result: " + res)
        self.solver.Pop()
        return True


class FramePrimedSolver:
    """
    Maintains a solver for queries shaped `SAT(? /\ T /\ F[i]')` (i.e. the frame is primed).
    """
    solver: bt.Boolector
    "The SAT solver."
    pm: ProductMachine
    "The product machine used with the solver."
    fast_slow_mode: bool
    "If `True`, circuit A is assumed to be faster than circuit B."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool):
        self.fast_slow_mode = fast_slow_mode
        
        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_MODEL_GEN, True)
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machine:
        self.pm = ProductMachine(self.solver, "@A", "@B", btor2_lines_a, btor2_lines_b)
        self.pm.grow(2, 2)  # initially, involved nodes are states + ObEq only, 2 is enough.
        self.pm.build_ob_eq_signal(0, 0)
        self.pm.build_ob_eq_signal(1, 1)  # sync
        self.pm.build_ob_eq_signal(0, 1)  # fast stutter
        if not self.fast_slow_mode:
            self.pm.build_ob_eq_signal(1, 0)  # slow stutter

    def add_clause_to_solver(self, clause: cl.Clause):
        "Adds a clause (as a fixed constraint) to the frame solver."
        # clause = cl.expand_clause(clause)  ## experimental feature
        self.solver.Assert(self.pm.get_clause_by_number(clause, 1, 1))
        self.solver.Assert(self.pm.get_clause_by_number(clause, 0, 1))
        if not self.fast_slow_mode:
            self.solver.Assert(self.pm.get_clause_by_number(clause, 1, 0))

    def sat_query_for_blocking_phase(self, region: PartitionAssignment, terms_map: TermsMap
    ) -> tuple[list[PartitionAssignment], list[cl.Literal]]:
        """
        Performs `SAT(F[i]' /\ ~region' /\ T; region)`.
        - If SAT, the tuple's first is `[SYNC, A_STUTTER, B_STUTTER/None]`; the second is None.
        - If UNSAT, the tuple's first is None; the second is the UNSAT core (`~ObEq` always included).
        """
        self.solver.Push()
        clause = cl.make_clause(region.to_literals(False))
        self.solver.Assert(self.pm.get_clause_by_number(clause, 1, 1))
        self.solver.Assert(self.pm.get_clause_by_number(clause, 0, 1))
        if not self.fast_slow_mode:
            self.solver.Assert(self.pm.get_clause_by_number(clause, 1, 0))

        # literals = region.to_literals(True)
        literals = cl.neg_of_clause(cl.expand_clause(clause))
        lit_nodes = [self.pm.get_literal_by_number(literal, 0, 0) for literal in literals]
        self.solver.Assume(*lit_nodes)

        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            bs: list[bool] = self.solver.Failed(*lit_nodes)
            core = [literals[i] for i, b in enumerate(bs) if (b or literals[i] == cl.LIT_OB_EQ)]
            self.solver.Pop()
            return None, core

        if res != self.solver.SAT:
            raise RuntimeError("Unexpected solver result: " + res)
        prime_regions = [None, None, None]
        prime_options = [(1, 1), (0, 1)] + ([] if self.fast_slow_mode else [(1, 0)])
        for (i, (a_prime, b_prime)) in enumerate(prime_options):
            assignments = {}
            for _, terms in terms_map.m.items():
                for term in terms:
                    node = self.pm.get_node_by_number(term, a_prime, b_prime)
                    assignments[term] = node.assignment
            prime_regions[i] = PartitionAssignment(assignments, terms_map)
        self.solver.Pop()
        return prime_regions, None

    def sat_query_for_propagation_phase(self, clause: cl.Clause) -> bool:
        "Performs `SAT(F[i]' /\ T /\ ~clause)`."
        self.solver.Push()
        literals = cl.neg_of_clause(clause)
        lit_nodes = [self.pm.get_literal_by_number(literal, 0, 0) for literal in literals]
        self.solver.Assert(*lit_nodes)

        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            self.solver.Pop()
            return False
        if res != self.solver.SAT:
            raise RuntimeError("Unexpected solver result: " + res)
        self.solver.Pop()
        return True


class Frame:
    """
    A frame as in IC3, essentially a set of clauses.
    F[0] is the innermost frame, always the negation of ObEq.
    Other frames are initialized as top and incrementally refined.
    """
    delta_clauses: set[cl.Clause]
    "The clause difference between this frame `F[i]` and its outer frame `F[i+1]`."
    unprimed_solver: FrameSolver
    "The solver for queries shaped `SAT(? /\ F[i])`."
    primed_solver: FramePrimedSolver
    "The solver for queries shaped `SAT(? /\ T /\ F[i]')`."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool):
        self.delta_clauses = set()
        self.unprimed_solver = FrameSolver(btor2_lines_a, btor2_lines_b)
        self.primed_solver = FramePrimedSolver(btor2_lines_a, btor2_lines_b, fast_slow_mode)


class SubsumptionSolver:
    """
    Maintains a solver for clause simplification by getting rid of subsumptions.
    """
    solver: bt.Boolector
    "The SAT solver."
    pm: ProductMachine
    "The product machine used with the solver."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str]):
        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machine:
        self.pm = ProductMachine(self.solver, "@A", "@B", btor2_lines_a, btor2_lines_b)
        self.pm.grow(1, 1)  # initially, involved nodes are states + ObEq only, 1 is enough.
        self.pm.build_ob_eq_signal(0, 0)

    def implies(self, clause_1: cl.Clause, clause_2: cl.Clause) -> bool:
        "Returns whether `clause_1` implies `clause_2`: `~SAT(clause_1 /\ ~clause_2)`."
        self.solver.Push()
        self.solver.Assert(self.pm.get_clause_by_number(clause_1, 0, 0))
        for literal in cl.neg_of_clause(clause_2):
            self.solver.Assert(self.pm.get_literal_by_number(literal, 0, 0))
        res = self.solver.Sat()
        self.solver.Pop()
        return res == self.solver.UNSAT


class FinalizationSolver:
    """
    Maintains a solver for finalization by getting rid of subsumptions.
    """
    solver: bt.Boolector
    "The SAT solver."
    pm: ProductMachine
    "The product machine used with the solver."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], unroll_a: int, unroll_b: int):
        # create and configure the solver:
        self.solver = bt.Boolector()
        self.solver.Set_opt(bt.BTOR_OPT_INCREMENTAL, True)

        # set up the product machine:
        self.pm = ProductMachine(self.solver, "@A", "@B", btor2_lines_a, btor2_lines_b)
        self.pm.grow(unroll_a, unroll_b)
        self.pm.build_ob_eq_signal(0, 0)

    def add_effective_clause(self, clause: cl.Clause) -> bool:
        "Adds clause if not yet implied. Returns whether clause was added."
        for literal in cl.neg_of_clause(clause):
            self.solver.Assume(self.pm.get_literal_by_number(literal, 0, 0))
        res = self.solver.Sat()
        if res == self.solver.UNSAT:
            return False
        self.solver.Assert(self.pm.get_clause_by_number(clause, 0, 0))
        return True


class Checker:
    """
    Checks whether two circuits are stuttering equivalent.
    - When equivalent, prints the invariant.
    - When non-equivalent, prints the cex tree.
    """
    btor2_lines_a: list[str]
    "The lines of the first circuit's BTOR2 file."
    btor2_lines_b: list[str]
    "The lines of the second circuit's BTOR2 file."
    fast_slow_mode: bool
    "If `True`, circuit A is assumed to be faster than circuit B."
    frames: list[Frame]
    "A list of frames as in IC3."
    subsumption_solver: SubsumptionSolver
    "The solver used for clause simplification."
    terms_map: TermsMap
    "A map from bit width to a list of involved terms of that width (for ET abstraction)."
    acex_tree: AcexTree
    "The current abstract cex tree."
    printer: cl.Printer
    "The printer used to print the invariant or cex tree."

    def __init__(self, btor2_lines_a: list[str], btor2_lines_b: list[str], fast_slow_mode: bool):
        self.btor2_lines_a = btor2_lines_a
        self.btor2_lines_b = btor2_lines_b
        self.fast_slow_mode = fast_slow_mode
        self.frames = []
        self.subsumption_solver = SubsumptionSolver(btor2_lines_a, btor2_lines_b)
        self.terms_map = TermsMap(btor2_lines_a, btor2_lines_b)
        self.acex_tree = None
        self.printer = cl.Printer(btor2_lines_a, btor2_lines_b)

    def add_frame(self):
        "Adds a new frame to the checker."
        new_frame = Frame(self.btor2_lines_a, self.btor2_lines_b, self.fast_slow_mode)
        a_height, b_height = self.terms_map.unroll_height
        new_frame.unprimed_solver.pm.grow(a_height, b_height)
        new_frame.primed_solver.pm.grow(a_height + 1, b_height + 1)
        self.frames.append(new_frame)
        print(f"Current frame count: {len(self.frames)}")

    def grow_all_frames(self):
        "Grows all the frames (plus the subsumption solver) to the current unroll height."
        a_height, b_height = self.terms_map.unroll_height
        for frame in self.frames:
            frame.unprimed_solver.pm.grow(a_height, b_height)
            frame.primed_solver.pm.grow(a_height + 1, b_height + 1)
        self.subsumption_solver.pm.grow(a_height, b_height)
    
    def k(self) -> int:
        "Returns the maximal frame index."
        return len(self.frames) - 1

    def add_clause(self, index: int, clause: cl.Clause):
        """
        Systematically adds a clause to F[i] and all its inner frames,
        both to the solver and the bookkeeper.
        Avoids clause subsumptions.
        """
        i = index
        while i >= 0:
            frame = self.frames[i]
            frame.unprimed_solver.add_clause_to_solver(clause)
            frame.primed_solver.add_clause_to_solver(clause)
            if i == index:  # Only do it for the current frame, otherwise overhead is too high.
                # frame.delta_clauses = {c for c in frame.delta_clauses if not self.subsumption_solver.implies(clause, c)}
                frame.delta_clauses = {c for c in frame.delta_clauses if not cl.is_strict_subclause(clause, c)}
                frame.delta_clauses.add(clause)
            else:
                frame.delta_clauses.discard(clause)  # Relatively cheap operation.
            i -= 1

    def propagate_clause(self, index: int, clause: cl.Clause):
        "Propagates a clause from F[i] to F[i+1]."
        self.frames[index].delta_clauses.remove(clause)
        self.frames[index + 1].delta_clauses = {c for c in self.frames[index + 1].delta_clauses
        #     if not self.subsumption_solver.implies(clause, c)}
            if not cl.is_strict_subclause(clause, c)}
        self.frames[index + 1].delta_clauses.add(clause)
        self.frames[index + 1].unprimed_solver.add_clause_to_solver(clause)
        self.frames[index + 1].primed_solver.add_clause_to_solver(clause)

    def in_frame_i(self, i: int, region: PartitionAssignment) -> bool:
        "Checks whether `region` is inside F[i]."
        return self.frames[i].unprimed_solver.sat_query_for_region(region)

    def get_frame_index(self, max_index: int, region: PartitionAssignment) -> int:
        "Returns the minimal frame index `i` such that `region` is inside `F[i]`."
        if region in self.acex_tree.value_to_node:  # avoild re-calculating
            return self.acex_tree.value_to_node[region].frame_index
        if region.is_bad:
            return 0
        i = max_index - 1
        while self.in_frame_i(i, region):
            i -= 1
        return i + 1


    def can_block(self, root_assignment: PartitionAssignment) -> bool:
        "Tries to find an acex from root, or block it out."
        self.acex_tree = AcexTree(AcexTreeNode(root_assignment, self.k()))
        queue: PriorityQueue[tuple[int, AcexTreeNode]] = PriorityQueue()
        queue.put((-self.k(), self.acex_tree.root))  # priority == -frame_index

        while not queue.empty():
            print(",", end="", flush=True)
            _, tree_node = queue.get()
            if (tree_node is not self.acex_tree.root) and tree_node.is_orphan():
                continue  # skip orphan nodes (p.s. priority queues cannot remove items like a set)

            invalidation_queue: SimpleQueue[AcexTreeNode] = SimpleQueue()
            if not self.in_frame_i(tree_node.frame_index, tree_node.value):  # outdated position due to shrinking frames
                tree_node.frame_index += 1  # update tree node position
                queue.put((-tree_node.frame_index, tree_node))  # re-insert the tree node
                for parent in tree_node.parents:
                    if parent.frame_index <= tree_node.frame_index:  # invalidate the parent
                        invalidation_queue.put(parent)
                        queue.put((-parent.frame_index, parent))  # re-insert the parent
            else:  # correct position
                regions, core = self.frames[tree_node.frame_index - 1].primed_solver.sat_query_for_blocking_phase(
                    tree_node.value, self.terms_map)  # TODO: core seems weak... More aggressive generalization?
                if regions is None:  # UNSAT case:
                    self.add_clause(tree_node.frame_index, cl.neg_of_cube(core))  # refine frame / push tree node
                    if tree_node is self.acex_tree.root:
                        return True  # root is blocked, return
                    tree_node.frame_index += 1  # else, update tree node position
                    queue.put((-tree_node.frame_index, tree_node))  # re-insert the tree node
                    for parent in tree_node.parents:
                        if parent.frame_index <= tree_node.frame_index:  # invalidate the parent
                            invalidation_queue.put(parent)
                            queue.put((-parent.frame_index, parent))  # re-insert the parent
                else:  # SAT case (put new nodes into the queue):
                    sync, a_stutter, b_stutter = regions
                    ## settle `sync`:
                    sync_frame_index = self.get_frame_index(tree_node.frame_index - 1, sync)
                    if self.acex_tree.add_child_sync(tree_node, sync, sync_frame_index) and not sync.is_bad:
                        queue.put((-sync_frame_index, self.acex_tree.value_to_node[sync]))
                    ## settle `a_stutter`:
                    a_frame_index = self.get_frame_index(tree_node.frame_index - 1, a_stutter)
                    if self.acex_tree.add_child_a_stutter(tree_node, a_stutter, a_frame_index) and not a_stutter.is_bad:
                        queue.put((-a_frame_index, self.acex_tree.value_to_node[a_stutter]))
                    ## settle `b_stutter` (if not in fast-slow mode):
                    if b_stutter is not None:
                        b_frame_index = self.get_frame_index(tree_node.frame_index - 1, b_stutter)
                        if self.acex_tree.add_child_b_stutter(tree_node, b_stutter, b_frame_index) and not b_stutter.is_bad:
                            queue.put((-b_frame_index, self.acex_tree.value_to_node[b_stutter]))
            # parent invalidation:
            while not invalidation_queue.empty():
                parent = invalidation_queue.get()
                children = self.acex_tree.give_up_children(parent)
                for child in children:
                    if child.is_orphan():
                        self.acex_tree.remove_orphan(child)
                        invalidation_queue.put(child)  # iteratively process grand-children

        # Now that queue is empty, acex is found:
        return False


    def check(self) -> bool:
        """
        The Main Procedure.
        TODO: dump important info to some output path.
        """
        print("Current Terms:", self.terms_map.print(self.printer))
        # Initial checking:
        self.add_frame()  # the initial frame F[0]
        self.add_clause(0, cl.make_clause([cl.LIT_NEG_OB_EQ]))  # F[0] = {~ObEq}
        print(";", end="", flush=True)
        assignment = self.frames[0].unprimed_solver.sat_query_for_initial(self.terms_map)
        if assignment is not None:  # 0-step counterexample
            print("0-step abstract counterexample found.")
            self.acex_tree = AcexTree(AcexTreeNode(assignment, 0))
            cex_solver = CexSolver(self.btor2_lines_a, self.btor2_lines_b, self.fast_slow_mode,
                self.acex_tree, self.terms_map)
            assert cex_solver.build_cex_tree()
            print("0-step concrete counterexample found:")
            print("\n".join(cex_solver.cex_tree.print_tree(self.printer, self.fast_slow_mode)))
            return False

        # Main loop:
        self.add_frame()  # the first refinement frame F[1]
        while True:
            ## Blocking Phase:
            while True:
                print(";", end="", flush=True)
                assignment = self.frames[self.k()].unprimed_solver.sat_query_for_initial(self.terms_map)
                if assignment is None:
                    break
                if not self.can_block(assignment):  # counterexample found
                    print(f"\n{self.k()}-step abstract counterexample found.")
                    cex_solver = CexSolver(self.btor2_lines_a, self.btor2_lines_b, self.fast_slow_mode,
                        self.acex_tree, self.terms_map)
                    if cex_solver.build_cex_tree():
                        print(f"{self.k()}-step concrete counterexample found:")
                        print("\n".join(cex_solver.cex_tree.print_tree(self.printer, self.fast_slow_mode)))
                        return False
                    print("Counterexample is spurious, refining with more terms...")
                    print("Current Terms:", self.terms_map.print(self.printer))
                    self.grow_all_frames()

            ## Propagation Phase:
            self.add_frame()
            for i in range(0, self.k()):
            # for i in reversed(range(0, self.k())):
                clauses = {clause for clause in self.frames[i].delta_clauses}  # copy before mutate
                for clause in clauses:
                    print(".", end="", flush=True)
                    if not self.frames[i].primed_solver.sat_query_for_propagation_phase(clause):
                        self.propagate_clause(i, clause)
                if len(self.frames[i].delta_clauses) == 0:  # found inductive invariant!
                    print("\nInductive invariant found:")
                    net_clauses = []
                    unroll_a, unroll_b = self.terms_map.unroll_height
                    finalizer = FinalizationSolver(self.btor2_lines_a, self.btor2_lines_b, unroll_a, unroll_b)
                    for j in range(i + 1, self.k() + 1):
                        for clause in self.frames[j].delta_clauses:
                            if finalizer.add_effective_clause(clause):
                                net_clauses.append(clause)
                    print("\n".join(self.printer.invariant(net_clauses)))
                    return True
            ## Continue for next iteration...
