from abc import ABC, abstractmethod
from typing import Dict, Literal, Set

# A-Normal Form (ANF) is a way of writing programs where every subexpression is
# a variable or a function call. This is useful for compilers because it makes
# it easier to reason about the program and to perform optimizations.


# the kind of immediate values
ImmKind = Literal["int", "bool", "id"]
# interference graph is a graph where each node is a variable and each edge
# represents a conflict between two variables.
InterfGraph = Dict[str, Set[str]]


class AST(ABC):
    """
    Abstract syntax tree (AST) is a tree representation of the abstract syntactic
    structure of source code written in a programming language.
    """
    @abstractmethod
    def free_vars(self) -> Set[str]:
        """
        Returns the set of free variables in this AST.
        """
        pass

    def interfere(self, live: Set[str], remove: Set[str]) -> InterfGraph:
        """
        Returns the interference graph of this AST, setting all variables in
        `remove` to be removed at the first Let and adding all variables in
        `live` to be live at the first Let.
        """
        return {}


class AExpr(AST):
    pass


class CExpr(AST):
    pass


def merge_graphs(g1: InterfGraph, g2: InterfGraph) -> InterfGraph:
    g1 = g1.copy()
    for node in g2:
        if node in g1:
            g1[node] |= g2[node]
        else:
            g1[node] = g2[node]
    return g1


def add_node(g: InterfGraph, name: str) -> InterfGraph:
    if name in g:
        return g
    else:
        g = g.copy()
        g[name] = set()
        return g


def add_directed_edge(g: InterfGraph, n1: str, n2: str) -> InterfGraph:
    g = g.copy()
    g = add_node(g, n1)
    g = add_node(g, n2)
    neighbors = g[n1]
    neighbors.add(n2)
    return g


def add_edge(g: InterfGraph, n1: str, n2: str) -> InterfGraph:
    g = add_directed_edge(g, n1, n2)
    g = add_directed_edge(g, n2, n1)
    return g


class ImmExpr:
    def __init__(self, value, kind: ImmKind):
        self.value = value
        self.kind = kind

    def free_vars(self) -> Set[str]:
        if self.kind == "id":
            return {self.value}
        else:
            return set()


class CIf(CExpr):
    def __init__(self, cond: ImmExpr, then: AExpr, els: AExpr):
        self.cond = cond
        self.then = then
        self.els = els

    def interfere(self, live: Set[str], remove: Set[str]) -> InterfGraph:
        return merge_graphs(self.then.interfere(live, remove), self.els.interfere(live, remove))

    def free_vars(self):
        return self.cond.free_vars() | self.then.free_vars() | self.els.free_vars()


class CPrim(CExpr):
    def __init__(self, op: Literal["+", "-", "*", "/"], left: ImmExpr, right: ImmExpr):
        self.op = op
        self.left = left
        self.right = right

    def free_vars(self):
        return self.left.free_vars() | self.right.free_vars()


class CApp(CExpr):
    def __init__(self, func: ImmExpr, args: list[ImmExpr]):
        self.func = func
        self.args = args

    def free_vars(self):
        return self.func.free_vars() | set.union(*map(lambda arg: arg.free_vars(), self.args))


class CImmExpr(CExpr):
    def __init__(self, expr: ImmExpr):
        self.expr = expr

    def free_vars(self):
        return self.expr.free_vars()


class CLambda(CExpr):
    def __init__(self, params: list[str], body: AExpr):
        self.params = params
        self.body = body

    def free_vars(self):
        return self.body.free_vars() - set(self.params)


class ALet(AExpr):
    def __init__(self, name, value: CExpr, body: AExpr):
        self.name = name
        self.value = value
        self.body = body

    def free_vars(self):
        return self.value.free_vars() | (self.body.free_vars() - {self.name})

    def interfere(self, live: Set[str], remove: Set[str]) -> InterfGraph:
        fvs = self.free_vars()
        interf = (fvs - remove) | live
        g = add_node(self.value.interfere(live, remove), self.name)
        for fv in interf:
            g = add_edge(g, self.name, fv)
        return merge_graphs(g, self.body.interfere(live | {self.name}, remove))


class ASeq(AExpr):
    def __init__(self, expr1: CExpr, expr2: AExpr):
        self.expr1 = expr1
        self.expr2 = expr2

    def free_vars(self):
        return self.expr1.free_vars() | self.expr2.free_vars()

    def interfere(self, live: Set[str], remove: Set[str]) -> InterfGraph:
        return merge_graphs(self.expr1.interfere(live, remove), self.expr2.interfere(live, remove))


class ACExpr(AExpr):
    def __init__(self, expr: CExpr):
        self.expr = expr

    def free_vars(self):
        return self.expr.free_vars()

    def interfere(self, live: Set[str], remove: Set[str]) -> InterfGraph:
        return self.expr.interfere(live, remove)