import argparse
import os
import random
import copy
from importlib.machinery import SourceFileLoader
from pathlib import Path
import xxhash
import base64

from metagen import *

class PerturbationTypes(Enum):
    VERTEX_POSITION = 0
    THICKNESS = 1
    SMOOTHNESS = 2
    LIFT_PROC = 3

class LiftProcs(Enum):
    UNIFORM_BEAMS = 0
    SPATIALLY_VARYING_BEAMS = 1
    UNIFORM_SHELL_CONJUGATE = 2
    UNIFORM_SHELL_DIRECT = 3
    UNIFORM_SHELL_MIXED_MINIMAL = 4
    SPHERES = 5

valid_floats = []
valid_floats.extend([round(i, 1) for i in np.linspace(0.0, 1.0, 11)]) # tenths
valid_floats.extend([round(0.125*i, 3) for i in range(1, 8)])  # inverse powers of 2
valid_floats.sort()

def swap_smoothness(_structure:Structure) -> None:
    for lskelID in range(len(_structure.tile.liftedSkeletons)):
        lskel = _structure.tile.liftedSkeletons[lskelID]
        bv = lskel.parentCP

        if isinstance(lskel.skel, PointSkeleton):
            continue # no smoothness attributes

        ecs:list[EdgeChain] = lskel.skel.get_edge_chains()
        ecs_out = []
        for ec in ecs:
            # decide whether to swap this specific ec smoothness flag. if not, keep the current one and move on
            if random.uniform(0.0, 1.0) > allowedPerturbations[PerturbationTypes.SMOOTHNESS]["probability"]:
                ecs_out.append(ec)
                continue
            # flip smoothness
            orderedPts = [ConvexPolytope.RelativeVert(bv, refdPt) for refdPt in ec.get_ordered_points_along_chain()]
            if isinstance(ec, Polyline):
                ec_out = Curve(orderedPts)
            else:
                ec_out = Polyline(orderedPts)
            ecs_out.append(ec_out)
        if len(ecs_out) == 0: # nothing to change
            raise Exception("This shouldn't happen; starting ec was non empty but ending ec was empty")

        skel_out = skeleton(ecs_out)

        # overwrite the original lifted skeleton with our new one
        lskelOut = None
        match lskel:
            case UniformBeams():
                lskelOut = UniformBeams(skel_out, lskel.uniformThicknessValue)
            case SpatiallyVaryingBeams():
                lskelOut = SpatiallyVaryingBeams(skel_out, lskel.varyingThicknessProfile)
            case UniformTPMSShellViaConjugation():
                lskelOut = UniformTPMSShellViaConjugation(skel_out, lskel.uniformThicknessValue)
            case UniformDirectShell():
                lskelOut = UniformDirectShell(skel_out, lskel.uniformThicknessValue)
            case UniformTPMSShellViaMixedMinimal():
                lskelOut = UniformTPMSShellViaMixedMinimal(skel_out, lskel.uniformThicknessValue)
            case Spheres():
                raise Exception("Can't swap smoothness for a point skeleton")
            case _:
                raise Exception("Unsupported lift procedure")
        _structure.tile.liftedSkeletons[lskelID] = lskelOut


def swap_lifting_procedure(_structure:Structure) -> None:
    for lskelID in range(len(_structure.tile.liftedSkeletons)):
        # decide whether to change this specific lskel
        if random.uniform(0.0, 1.0) > allowedPerturbations[PerturbationTypes.LIFT_PROC]["probability"]:
            continue

        lskel = _structure.tile.liftedSkeletons[lskelID]
        origWasUniform = True
        valid_swaps = []
        match lskel:
            case UniformBeams():
                valid_swaps = [LiftProcs.SPATIALLY_VARYING_BEAMS]
                if lskel.skel.is_single_connected_component() and lskel.skel.has_connected_component_type([ConnectedComponentType.SIMPLE_CLOSED_LOOP]):
                    valid_swaps.append(LiftProcs.UNIFORM_SHELL_DIRECT)
            case SpatiallyVaryingBeams():
                origWasUniform = False
                valid_swaps = [LiftProcs.UNIFORM_BEAMS]
                if lskel.skel.is_single_connected_component() and lskel.skel.has_connected_component_type([ConnectedComponentType.SIMPLE_CLOSED_LOOP]):
                    valid_swaps.append(LiftProcs.UNIFORM_SHELL_DIRECT)
            case UniformTPMSShellViaConjugation():
                valid_swaps = [LiftProcs.UNIFORM_SHELL_MIXED_MINIMAL, LiftProcs.UNIFORM_SHELL_DIRECT, LiftProcs.UNIFORM_BEAMS, LiftProcs.SPATIALLY_VARYING_BEAMS]
            case UniformDirectShell():
                valid_swaps = [LiftProcs.UNIFORM_BEAMS, LiftProcs.SPATIALLY_VARYING_BEAMS]
            case UniformTPMSShellViaMixedMinimal():
                valid_swaps = [LiftProcs.UNIFORM_SHELL_DIRECT, LiftProcs.UNIFORM_BEAMS, LiftProcs.SPATIALLY_VARYING_BEAMS]
            case Spheres():
                valid_swaps = []
            case _:
                raise Exception("Unsupported lift procedure")
        if len(valid_swaps) == 0:
            continue

        targLiftProc = valid_swaps[random.randint(0, len(valid_swaps)-1)]

        ## a bit of a hack -- this should be separate but this is quicker to implement
        updateThickness = True if PerturbationTypes.THICKNESS and \
                                random.uniform(0.0, 1.0) <= allowedPerturbations[PerturbationTypes.THICKNESS]["probability"] \
                                else False

        if origWasUniform and targLiftProc == LiftProcs.SPATIALLY_VARYING_BEAMS:
            # populate the thickness profile
            if updateThickness:
                valid_thickness_floats = [v for v in valid_floats if v > 0.03 and v <= 0.2]
                val = valid_thickness_floats[random.randint(0, len(valid_thickness_floats)-1)]
            else:
                val = lskel.uniformThicknessValue

            if random.uniform(0.0, 1.0) <= 0.7: # use the pentamode style thickening
                prof = np.array([[0.0, val], [0.5, val*2], [1.0, val]]) 
            else: # randomize this profile a bit more
                valid_prof_floats = [v for v in valid_floats if v > 0.1 and v < 0.9]
                n_inner_pts = random.randint(1,3)
                prof_pts = [0.0] + random.sample(valid_prof_floats, n_inner_pts) + [1.0]
                prof_pts = sorted(prof_pts)
                valid_thickness_floats = [v for v in valid_floats if v > 0.03 and v <= 0.2]
                prof_vals = [valid_thickness_floats[random.randint(0, len(valid_thickness_floats)-1)] for _ in range(len(prof_pts))]
                assert(len(prof_pts) == len(prof_vals))
                prof = [[prof_pts[i], prof_vals[i]] for i in range(len(prof_pts))]
            lskel.varyingThicknessProfile = prof
        elif not origWasUniform and targLiftProc != LiftProcs.SPATIALLY_VARYING_BEAMS:
            if updateThickness:
                valid_thickness_floats = [v for v in valid_floats if v > 0.03 and v <= 0.4]
                val = valid_thickness_floats[random.randint(0, len(valid_thickness_floats)-1)]
            else:
                val = min(0.03, lskel.varyingThicknessProfile[0][1]) # get the thickness used at the starting point of the edge chain
            # populate the uniform value
            lskel.uniformThicknessValue = val

        lskelOut = None
        match targLiftProc:
            case LiftProcs.UNIFORM_BEAMS:
                lskelOut = UniformBeams(lskel.skel, lskel.uniformThicknessValue)
            case LiftProcs.SPATIALLY_VARYING_BEAMS:
                lskelOut = SpatiallyVaryingBeams(lskel.skel, lskel.varyingThicknessProfile)
            case LiftProcs.UNIFORM_SHELL_CONJUGATE:
                lskelOut = UniformTPMSShellViaConjugation(lskel.skel, lskel.uniformThicknessValue)
            case LiftProcs.UNIFORM_SHELL_DIRECT:
                lskelOut = UniformDirectShell(lskel.skel, lskel.uniformThicknessValue)
            case LiftProcs.UNIFORM_SHELL_MIXED_MINIMAL:
                lskelOut = UniformTPMSShellViaMixedMinimal(lskel.skel, lskel.uniformThicknessValue)
            case LiftProcs.SPHERES:
                lskelOut = Spheres(lskel.skel, lskel.uniformThicknessValue)
            case _:
                raise Exception("Unsupported lift procedure")
        if lskelOut == None:
            raise Exception("Not initialized")
        
        # overwrite the original lifted skeleton with our new one
        _structure.tile.liftedSkeletons[lskelID] = lskelOut

def shift_vertex_positions(_structure:Structure) -> None:
    for lskelID in range(len(_structure.tile.liftedSkeletons)):
        skel = _structure.tile.liftedSkeletons[lskelID].skel
        parentCP = skel.parentCP
        match skel:
            case sk.PointSkeleton():
                pass
                # raise NotImplementedError()
            case sk.EdgeSkeleton():
                ec_list = skel.get_edge_chains()
                for ec in ec_list:
                    ordered_vert_list = ec.get_ordered_points_along_chain()
                    # add the points and collect new references
                    for pt in ordered_vert_list:
                        if isinstance(pt, PointOnCPInterior):
                            # pick a new weighted vector of all the corner vertices
                            num_bv_corners = len(pt.weights)
                            newFloats = [0.0]*num_bv_corners
                            idOrder = list(range(num_bv_corners))
                            random.shuffle(idOrder) # shuffle so we aren't biased toward the early elements 
                            maxRemaining = 1.0
                            for idx in range(num_bv_corners-1):
                                cid = idOrder[idx]
                                targ = random.uniform(0.0, maxRemaining)
                                # get the closest valid_floats value that's smaller than the random target
                                vfid = 0
                                while vfid < len(valid_floats)-1 and valid_floats[vfid+1] < targ:
                                    vfid+=1
                                newFloats[cid] = valid_floats[vfid]
                                maxRemaining -= newFloats[cid]
                            newFloats[idOrder[num_bv_corners-1]] = round(maxRemaining, 3) # the last coordinate has to be set as 1-sum(others)

                            pt.weights = newFloats # TODO: this might cause issues if e.g. the point is suddenly reclassified as a non-interior point
                        elif isinstance(pt, PointOnCPCorner):
                            pass # no continuous offset possible
                        elif isinstance(pt, PointOnCPEdge):
                            edgeID = pt.incidentEdges[0].idx
                            newval = valid_floats[random.randint(0, len(valid_floats)-1)]
                            pt.weights = parentCP.get_weights_of_edge_point(edgeID, newval) # TODO: this might cause issues if e.g. the point is suddenly reclassified as a non-edge point
                        elif isinstance(pt, PointOnCPFace):
                            print("Not able to move vertices on face -- to implement")
                            pass
                            # raise NotImplementedError()
                        else:
                            raise Exception("Should never get here")

def make_variation_from_structure(structureIn:Structure, allowedPerturbations:dict) -> Structure:
    allowedPerturbationNames = list(allowedPerturbations.keys())

    # s = copy.deepcopy(structureIn) # TODO: figure out why this fails
    s = structureIn

    # figure out how many substructures there are, and keep a backtrace so we can roll it forward to build up the bools at the end
    boolOpTrace = LifoQueue()
    non_bool_substructures:list[Structure] = []
    q = Queue()
    q.put(structureIn)
    while not q.empty():
        curr_structure = q.get()
        assert isinstance(curr_structure, Structure), "Invalid input to ProcMetaTranslator: ensure that the input is a valid Structure object"
        if not isinstance(curr_structure, CSGBoolean):
            non_bool_substructures.append(curr_structure)
            continue
        # we have a boolean
        q.put(curr_structure.A)
        q.put(curr_structure.B) 
        boolOpTrace.put(curr_structure)

    # update the individual substructures
    subs2subsOut = {}
    for s_sub in non_bool_substructures:
        if PerturbationTypes.VERTEX_POSITION in allowedPerturbationNames \
                and random.uniform(0.0, 1.0) <= allowedPerturbations[PerturbationTypes.VERTEX_POSITION]["probability"]:
            shift_vertex_positions(s_sub)
        if PerturbationTypes.SMOOTHNESS in allowedPerturbationNames \
                and random.uniform(0.0, 1.0) <= allowedPerturbations[PerturbationTypes.SMOOTHNESS]["probability"]:
            swap_smoothness(s_sub)
        if PerturbationTypes.LIFT_PROC in allowedPerturbationNames \
                and random.uniform(0.0, 1.0) <= allowedPerturbations[PerturbationTypes.LIFT_PROC]["probability"]:
            swap_lifting_procedure(s_sub)
        subs2subsOut[s_sub] = s_sub

    # this might not be necessary, maybe it's enough to just update the (sub)structures that the boolean ops are pointing to
    # once all structures have been added independently, add the CSG operations between them using the backtrace contstructed at the beginning
    booleanOpOut = None
    if len(non_bool_substructures) == 1: # no boolean operations involved, just take the output of the final structure
        assert len(subs2subsOut) == 1, "There should be exactly one substructure in this graph"
        booleanOpOut = subs2subsOut[structureIn]
    else:
        while not boolOpTrace.empty():
            currOp:CSGBoolean = boolOpTrace.get()
            in1:Structure = subs2subsOut[currOp.A]
            in2:Structure = subs2subsOut[currOp.B]
            currOpPmnOut = None
            match currOp.op_type:
                case CSGBooleanTypes.UNION:
                    currOpPmnOut = Union(in1, in2)
                case CSGBooleanTypes.INTERSECT:
                    currOpPmnOut = Intersect(in1, in2)
                case CSGBooleanTypes.DIFFERENCE:
                    currOpPmnOut = Subtract(in1, in2)
                case _:
                    print(f"Error: unsupported CSG Boolean type: {currOp.op_type}. Aborting.")
                    return
            assert currOpPmnOut != None
            subs2subsOut[currOp] = currOpPmnOut
            # update opOut in case this is the last one
            booleanOpOut = currOpPmnOut

    return booleanOpOut

def make_variations(inputFilename:str, outPath:str, allowedPerturbations:dict, numVariants:int=1) -> None:
    fname = Path(inputFilename)
    variant_topfolder = Path(outPath)
    parent_structure_name = fname.parent.stem
    variantFolder = variant_topfolder / f"{parent_structure_name}__autogen_variants/"
    if not os.path.exists(variantFolder):
        os.makedirs(variantFolder, exist_ok=True)
    
    # load the file, create the structure in memory
    loadedModule = SourceFileLoader("structureSpec", str(fname)).load_module()

    for _ in range(numVariants):
        # create the structure in memory, could do this outside the loop if we had a deep copy but we don't
        structureOrig = loadedModule.make_structure()

        structureOut = make_variation_from_structure(structureOrig, allowedPerturbations)
        if not structureOut:
            continue

        # write the new structure to a py file using the dsl_translator
        fmt = DSLTranslator_Formatter()
        generatorPath = str(__file__).partition("/data")[-1]
        fmt.set_file_info('1.1.0', generatorPath, {})
        fmt.add_related_structure(parent_structure_name, ["parent"])
        translator = DSLTranslator(structureOut, fmt)

        # get variant id based on the string of the structure's to-be-written file contents
        x = xxhash.xxh32()
        x.update(translator.get_file_contents().encode())
        h = x.digest()
        variantID = base64.b32encode(h).decode(encoding='utf-8').strip('=')

        outpath = variantFolder/ f"{fname.stem}_variant{(variantID):04}.py"
        translator.save(outpath)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str)
    parser.add_argument('-o', '--output', type=str)
    parser.add_argument('-n', '--numvariants', type=int, default=0) 

    args = parser.parse_args()

    allowedPerturbations = {
                            PerturbationTypes.VERTEX_POSITION:  {"probability":0.9}, 
                            PerturbationTypes.THICKNESS:        {"probability":0.98}, 
                            PerturbationTypes.SMOOTHNESS:       {"probability":0.7}, 
                            PerturbationTypes.LIFT_PROC:        {"probability":0.7},
                            }

    make_variations(args.input, args.output, allowedPerturbations, args.numvariants)

if __name__ == "__main__":
    main()