from .tile import *
from .pattern import *
from .lifting import *
from .structure import *
from . import skeleton as sk
from . import convex_polytope as cp
from pathlib import Path
from .dsl_translator_formatter import DSLTranslator_Formatter
import math
from queue import LifoQueue, Queue
from .codegen_utils import CodeBlock
from pyaml import yaml

class DSLTranslator:
    def __init__(self, finalStructure:Structure, _formatter:DSLTranslator_Formatter) -> None:
        self.formatter = _formatter
        self.program:CodeBlock = None

        # figure out how many substructures there are, and keep a backtrace so we can roll it forward to build up the bools at the end
        boolOpTrace = LifoQueue()
        non_bool_substructures:list[Structure] = []
        q = Queue()
        q.put(finalStructure)
        while not q.empty():
            curr_structure = q.get()
            assert isinstance(curr_structure, Structure), "Invalid input to ProcMetaTranslator: ensure that the input is a valid Structure object"
            if not isinstance(curr_structure, CSGBoolean):
                non_bool_substructures.append(curr_structure)
                continue
            # we have a boolean
            q.put(curr_structure.A)
            q.put(curr_structure.B) 
            boolOpTrace.put(curr_structure)

        # for each substructure, create the appropriate nodes of the graph
        # TODO: for now, have a separate set of operations (vertices, edge chains etc) for each one. Potentially figure out how to reuse some things later.
        full_content = []
        
        substructure2structureOutName = {}
        structureID = 0
        for structure in non_bool_substructures:
            structure_suffix = f"_s{structureID}"

            bv = structure.tile.bv_template

            substructure_content:list[str | CodeBlock] = []
            lifted_skel_names = []
            lskel_id = 0
            for liftedSkel in structure.tile.liftedSkeletons:
                lskel_suffix = f"_lsk{lskel_id}"

                CPRefdPt2PyName:dict[cp.PointReferencedToCP, str] = {}
                SkelEC2PyName:dict[sk.EdgeChain, str] = {}
                vertex_lines = []
                ec_lines = []
                skel_line = []
                skel_out_name = "skel" + structure_suffix + lskel_suffix
                match liftedSkel.skel:
                    case sk.PointSkeleton():
                        for pt in liftedSkel.skel.get_points():
                            if pt not in CPRefdPt2PyName:
                                    (alias, weights) = bv.get_aliased_spec_from_RelVert(pt)
                                    vline = f"vertex({alias})" if len(weights) == 0 else f"vertex({alias}, {weights})"
                                    vname = "v" + str(len(CPRefdPt2PyName)) + structure_suffix + lskel_suffix
                                    vertex_lines.append(f"{vname} = {vline}")
                                    CPRefdPt2PyName[pt] = vname
                        
                        # create the skeleton
                        named_p_list = [CPRefdPt2PyName[p] for p in liftedSkel.skel.get_points()]
                        named_p_list_str = "[" + ", ".join(named_p_list) + "]"
                        skel_line = f"{skel_out_name} = skeleton({named_p_list_str})"
                    case sk.EdgeSkeleton():
                        ec_list = liftedSkel.skel.get_edge_chains()
                        for ec in ec_list:
                            ordered_vert_list = ec.get_ordered_points_along_chain()
                            # add the points and collect new references
                            for pt in ordered_vert_list:
                                if pt not in CPRefdPt2PyName:
                                    (alias, weights) = bv.get_aliased_spec_from_RelVert(pt)
                                    vline = f"vertex({alias})" if len(weights) == 0 else f"vertex({alias}, {weights})"
                                    vname = "v" + str(len(CPRefdPt2PyName)) + structure_suffix + lskel_suffix
                                    vertex_lines.append(f"{vname} = {vline}")
                                    CPRefdPt2PyName[pt] = vname

                            # add the edge chain creation line (using the vert names)
                            named_vert_list = [CPRefdPt2PyName[p_orig] for p_orig in ordered_vert_list]
                            named_vert_list_str = "[" + ", ".join(named_vert_list) + "]"
                            match ec:
                                case Polyline():
                                    ec_name = f"p{len(SkelEC2PyName)}" + structure_suffix + lskel_suffix
                                    ec_lines.append(f"{ec_name} = Polyline({named_vert_list_str})")
                                    SkelEC2PyName[ec] = ec_name
                                case Curve():
                                    ec_name = f"c{len(SkelEC2PyName)}" + structure_suffix + lskel_suffix
                                    ec_lines.append(f"{ec_name} = Curve({named_vert_list_str})")
                                    SkelEC2PyName[ec] = ec_name
                                case _:
                                    raise Exception("Unsupported edgechain type")
                        # create the skeleton
                        named_ec_list = [SkelEC2PyName[orig_ec] for orig_ec in ec_list]
                        named_ec_list_str = "[" + ", ".join(named_ec_list) + "]"
                        skel_line = f"{skel_out_name} = skeleton({named_ec_list_str})"
                    case _:
                        raise Exception("Unsupported skeleton type")
                    
                    
                # add the vertex and edge chain blocks to the function body
                substructure_content.extend(vertex_lines)
                substructure_content.extend([""])
                substructure_content.extend(ec_lines)
                substructure_content.extend([""])
                substructure_content.extend([skel_line])

                # apply the lifting procedure
                liftLine:str = None
                lift_out_name = "liftedSkel" + structure_suffix + lskel_suffix
                match liftedSkel:
                    case UniformBeams():
                        liftLine = f"{lift_out_name} = UniformBeams({skel_out_name}, {liftedSkel.uniformThicknessValue})"
                    case SpatiallyVaryingBeams():
                        liftLine = f"{lift_out_name} = SpatiallyVaryingBeams({skel_out_name}, {liftedSkel.varyingThicknessProfile.tolist()})"
                    case UniformTPMSShellViaConjugation():
                        liftLine = f"{lift_out_name} = UniformTPMSShellViaConjugation({skel_out_name}, {liftedSkel.uniformThicknessValue})"
                    case UniformDirectShell():
                        liftLine = f"{lift_out_name} = UniformDirectShell({skel_out_name}, {liftedSkel.uniformThicknessValue})"
                    case UniformTPMSShellViaMixedMinimal():
                        liftLine = f"{lift_out_name} = UniformTPMSShellViaMixedMinimal({skel_out_name}, {liftedSkel.uniformThicknessValue})"
                    case Spheres():
                        liftLine = f"{lift_out_name} = Spheres({skel_out_name}, {liftedSkel.uniformThicknessValue})"
                    case _:
                        raise Exception("Unsupported lift type")

                substructure_content.extend([liftLine])
                substructure_content.extend([""])

                lifted_skel_names.append(lift_out_name)
                lskel_id += 1

            # Group together into the tile
            embedding_name = "embedding" + structure_suffix
            embedding_line = f"{embedding_name} = {bv.infer_embed_call_from_corners(structure.tile.bv_corner_positions)}"
            tile_out_name = "tile" + structure_suffix
            lifted_skel_names_str = "[" + ", ".join(lifted_skel_names) + "]"
            tile_line = f"{tile_out_name} = Tile({lifted_skel_names_str}, {embedding_name})"
            substructure_content.extend([embedding_line, tile_line])

            # get the pattern
            pat_name = "pat" + structure_suffix
            match structure.pat:
                case TetFullMirror():
                    pat_line = f"{pat_name} = TetFullMirror()"
                case TriPrismFullMirror():
                    pat_line = f"{pat_name} = TriPrismFullMirror()"
                case CuboidFullMirror():
                    pat_line = f"{pat_name} = CuboidFullMirror()"
                case Identity():
                    pat_line = f"{pat_name} = Identity()"
                case Custom():
                    opQ = structure.pat.ops.opQueue
                    opStr = None
                    while not opQ.empty():
                        op:PatternOp = opQ.get()
                        if isinstance(op, NoOp):
                            break
                        opStr = op.get_op_call_string(opStr)
                    pat_line_start = f"{pat_name} = Custom("
                    # fix the indentation on the nested operations
                    offset = " " * len(pat_line_start)
                    opStrLines = opStr.split("\n")
                    for lid in range(1, len(opStrLines)): # offset all but the first line (which is already offset by the pat_line_start text)
                        opStrLines[lid] = offset + opStrLines[lid]
                    alignedOpLines = "\n".join(opStrLines)
                    # put the full line together
                    pat_line = f"{pat_line_start}{alignedOpLines})"
                case _:
                    raise Exception("unsupported pattern")
            substructure_content.extend([pat_line])
            substructure_content.extend([""])

            structure_name = "obj" + str(structureID) + structure_suffix
            structure_line = structure_name + f" = Structure({tile_out_name}, {pat_name})"
            substructure_content.extend([structure_line])
            substructure2structureOutName[structure] = structure_name

            full_content.extend(substructure_content)
            structureID += 1

        # once all structures have been added independently, 
        # add the CSG operations between them using the backtrace constructed at the beginning
        booleanNameOut = None
        if len(non_bool_substructures) == 1: # no boolean operations involved, just take the output of the final structure
            assert len(substructure2structureOutName) == 1, "There should be exactly one substructure in this graph"
            booleanNameOut = substructure2structureOutName[finalStructure]
        else:
            bool_id = 0
            while not boolOpTrace.empty():
                currOp:CSGBoolean = boolOpTrace.get()
                in1:Structure = substructure2structureOutName[currOp.A]
                in2:Structure = substructure2structureOutName[currOp.B]
                currOpNameOut = None
                match currOp.op_type:
                    case CSGBooleanTypes.UNION:
                        bool_name = "bool" + str(bool_id)
                        bool_line = f"{bool_name} = Union({in1}, {in2})"
                        full_content.extend([bool_line])
                    case CSGBooleanTypes.INTERSECT:
                        bool_name = "bool" + str(bool_id)
                        bool_line = f"{bool_name} = Intersect({in1}, {in2})"
                        full_content.extend([bool_line])
                    case CSGBooleanTypes.DIFFERENCE:
                        bool_name = "bool" + str(bool_id)
                        bool_line = f"{bool_name} = Subtract({in1}, {in2})"
                        full_content.extend([bool_line])
                    case _:
                        print(f"Error: unsupported CSG Boolean type: {currOp.op_type}. Aborting.")
                        return
                currOpNameOut = bool_name
                assert currOpNameOut != None
                substructure2structureOutName[currOp] = currOpNameOut
                # update opOut in case this is the last one
                booleanNameOut = currOpNameOut


        # write out the final structure
        full_content.extend([f"return {booleanNameOut}"])
        
        program_params = _formatter.get_param_signature_string()
        self.program:CodeBlock = CodeBlock(f"def make_structure({program_params}) -> Structure", full_content)


    def boilerplate_filehead(self):
        return f"""
'''
{yaml.safe_dump(self.formatter.get_head_comment(), indent=4, sort_keys=False)}
'''

from metagen import *\n\n
"""
    
    def boilerplate_filetail(self):
        return """
\n\n\n\n
# --- END: structure description ---
# --- BEGIN: scratch space ---

if __name__ == "__main__":
    import os
    outputFolder = Path("data/generated/individual/")

    obj = make_structure()
    pmt = ProcMetaTranslator(obj)
    structure_name = os.path.splitext(os.path.basename(__file__))[0]
    pmt.save(outputFolder / f"{structure_name}.json")
"""

    def get_file_contents(self) -> str:
        s = self.boilerplate_filehead()
        s += str(self.program)
        s += self.boilerplate_filetail()
        return s

    def save(self, filename:Path) -> None:
        fh = open(filename, 'w')
        fh.write(self.get_file_contents())
        fh.close()
