import cc3d
import numpy as np
from .util import *
from functools import cache

def pack_voxels(voxels, res=100) -> bytes:
    return bytes(np.packbits(voxels.astype(bool)))

def unpack_voxels(voxel_bytes: bytes, res=100) -> np.ndarray:
    packed = np.asarray(list(voxel_bytes), dtype=np.uint8)
    unpacked = np.unpackbits(packed)
    unpacked = unpacked.reshape((res-1, res-1, res-1))
    return unpacked

def load_voxels(f):
    voxels_lines = f.readlines()
    voxels_lines = [l.strip() for l in voxels_lines] if isinstance(voxels_lines[0], str) else [l.decode().strip() for l in voxels_lines]
    voxel_dim = int(voxels_lines[0].split(':')[-1].strip())
    voxels = [int(v) for v in voxels_lines[2:]]
    voxels = np.array(voxels).reshape((voxel_dim,voxel_dim,voxel_dim))
    return voxels

def validate_simulation(sim_data_location):
    with open_file(sim_data_location) as f:
        sim_data = json.load(f)
    valid = True

    # This is our proxy for numerical issues
    sim_E = sim_data['sim_E_VRH'] if 'sim_E_VRH' in sim_data else None
    if sim_E is None:
        valid = False
    elif sim_E > 1:
        valid = False

    try:
        validation_status = validate_material_properties(sim_data)
        if validation_status == PropertyValidationStatus.E_EXCEEDS_1:
            valid = False
        elif validation_status == PropertyValidationStatus.C_ASYMMETRIC:
            valid = False
        elif validation_status == PropertyValidationStatus.S_ASYMMETRIC:
            valid = False
        elif validation_status == PropertyValidationStatus.EGK_MODULUS_IS_ZERO:
            valid = False
        elif validation_status == PropertyValidationStatus.V_IS_ZERO:
            valid = False
    except Exception as e:
        valid = False

    return valid

def check_symmetric(a:np.array, rtol=1e-05, atol=1e-08):
    return np.allclose(a, a.transpose(), rtol=rtol, atol=atol)

class PropertyValidationStatus(Enum):
    VALID = 0
    E_EXCEEDS_1 = 1
    C_ASYMMETRIC = 2
    S_ASYMMETRIC = 3
    S_NON_ORTHOTROPIC = 4
    EGK_MODULUS_IS_ZERO = 5
    V_IS_ZERO = 6

def validate_material_properties(structure_prop_info:dict) -> PropertyValidationStatus:
    def is_zero(x:float, tol:float=1e-7):
        return abs(x) < tol
    if is_zero(structure_prop_info["thickened_occupied_volume_fraction"]):
        return PropertyValidationStatus.V_IS_ZERO

    if structure_prop_info["sim_E_VRH"] > 1:
        return PropertyValidationStatus.E_EXCEEDS_1
    
    if is_zero(structure_prop_info["sim_E_VRH"]) or is_zero(structure_prop_info["sim_G_VRH"]) or is_zero(structure_prop_info["sim_K_VRH"]):
        return PropertyValidationStatus.EGK_MODULUS_IS_ZERO

    C = np.array(structure_prop_info["sim_C_matrix"], dtype=float)
    if not check_symmetric(C, 1e-3, 1e-3):
        # this should always be thrown out, because the simulator should spit out something symmetric; something has gone wrong in this case.
        return PropertyValidationStatus.C_ASYMMETRIC

    S = np.array(structure_prop_info["sim_S_matrix"], dtype=float)
    if not check_symmetric(S, 1e-3, 1e-3):
        # this one shoud also be thrown out, though it's not as critical as C_Asymmetric since it's
        # typically due to numerical instability / poor condition number of C (since S = C^{-1}) --
        # can be checked with np.linalg.cond(C), but I didn't have an actual threshold for that at any point
        return PropertyValidationStatus.S_ASYMMETRIC
    
    for i in range(6):
        for j in range(6):
            allowedToBeNonZero = i==j or (i < 3 and j < 3) # in orthotropic material, top left 3x3 and main diagonal can be nonzero 
            shouldBeZero = not allowedToBeNonZero
            if shouldBeZero:
                if abs(S[i][j]) > 1e-6:
                    # we don't necessarily have to throw these away, just be aware that the material 
                    # properties relying on orthotropic symmetries might not be accurate.
                    return PropertyValidationStatus.S_NON_ORTHOTROPIC

    return PropertyValidationStatus.VALID

def validate_voxels(voxels):
    """
    Validates a voxelized cell by checking if the boundaries
    are completely periodic and at least one connected component
    spans all 3 periodic boundaries.
    
    There are significant complexities in computing validity when
    dealing with interpenetrating lattices. The algorithm we use
    is to tile the base cell in a 3x3 lattice, then check that
    the boundaries are periodic and that at least one connected
    component of this larger base cell reaches all boundaries.
    
    """
    #shape = voxels.shape
    voxels = np.tile(voxels, [3,3,3])
    labels = cc3d.connected_components(voxels, connectivity=6)
    #labels = labels[shape[0]:2*shape[0], shape[1]:2*shape[1], shape[2]:2*shape[2]]
    z_equiv = (voxels[:,:,0] == voxels[:,:,-1]).all()
    y_equiv = (voxels[:,0,:] == voxels[:,-1,:]).all()
    x_equiv = (voxels[0,:,:] == voxels[-1,:,:]).all()
    boundaries_equiv = x_equiv & y_equiv & z_equiv

    component_contiguous_and_periodic = []
    for c in range(1, labels.max()+1):

        component_voxels = (labels == c)

        #component_z_equiv = (component_voxels[:,:,0] == component_voxels[:,:,-1]).all()
        #component_y_equiv = (component_voxels[:,0,:] == component_voxels[:,-1,:]).all()
        #component_x_equiv = (component_voxels[0,:,:] == component_voxels[-1,:,:]).all()
        #component_boundaries_equiv = component_x_equiv & component_y_equiv & component_z_equiv

        on_x_boundary = component_voxels[0,:,:].any()
        on_y_boundary = component_voxels[:,0,:].any()
        on_z_boundary = component_voxels[:,:,0].any()
        on_x_boundary_opp = component_voxels[-1,:,:].any()
        on_y_boundary_opp = component_voxels[:,-1,:].any()
        on_z_boundary_opp = component_voxels[:,:,-1].any()
        on_all_boundaries = on_x_boundary & on_y_boundary & on_z_boundary & on_x_boundary_opp & on_y_boundary_opp & on_z_boundary_opp
        component_contiguous_and_periodic.append(on_all_boundaries) #& component_boundaries_equiv)
    contiguous_and_periodic = boundaries_equiv & any(component_contiguous_and_periodic)
    return bool(contiguous_and_periodic)