import shape_primitives
import ShapeOps
import plot_utils
import ShapeCorpusGen
import yaml
import numpy as np
import matplotlib.pyplot as plt

def get_circle_mask(res=300):
    rho = np.linspace(0, 1, res)
    theta = np.linspace(0, 2 * np.pi, res)
    r, t = np.meshgrid(rho, theta)

    # Convert polar to cartesian
    x = r * np.cos(t)
    y = r * np.sin(t)

    # Define circle within unit disk
    circle_mask = (x**2 + y**2) <= 1.

    return circle_mask

def get_square_mask(res=300):
    # Resolution
    # Polar grid
    rho = np.linspace(0, 1, res)
    theta = np.linspace(0, 2 * np.pi, res)
    r, t = np.meshgrid(rho, theta)

    # Convert polar to cartesian
    x = r * np.cos(t)
    y = r * np.sin(t)

    # Define square within unit disk
    # Side length sqrt(2), i.e. half length = 1/√2 ~ 0.707 to fit in unit disk
    half_len = 1 / np.sqrt(2)
    square_mask = (np.abs(x) <= half_len) & (np.abs(y) <= half_len)
    
    return square_mask

def main_shape():
    ps = shape_primitives.PrimitiveShapes(canvas_size=256)
    binary_op = ShapeOps.BinaryShapeOp()
    unary_op = ShapeOps.UnaryShapeOp()
    unary_transform = ShapeOps.UnaryShapeTransform()
    # Example shapes
    circ = ps.circle()
    square = ps.square()
    rect = ps.rectangle()
    tri = ps.triangle()
    ell = ps.ellipse()
    diamond = ps.diamond()
    star = ps.star()
    # star1 = ps.star()
    sector = ps.sector()
    pentagon = ps.pentagon()


    plot_utils.plot_one_shape(circ, title='Circle', save_path='circle.pdf')
    plot_utils.plot_one_shape(square, title='Square', save_path='square.pdf')
    plot_utils.plot_one_shape(rect, title='Rectangle', save_path='rectangle.pdf')
    plot_utils.plot_one_shape(tri, title='Triangle', save_path='triangle.pdf')
    plot_utils.plot_one_shape(ell, title='Ellipse', save_path='ellipse.pdf')
    plot_utils.plot_one_shape(diamond, title='Diamond', save_path='diamond.pdf')
    plot_utils.plot_one_shape(star, title='Star', save_path='star.pdf')
    plot_utils.plot_one_shape(sector, title='Sector', save_path='sector.pdf')
    plot_utils.plot_one_shape(pentagon, title='Pentagon', save_path='pentagon.pdf')

    # star = unary_transform.run_transform(star, 'translate', x=10, y=50)
    circle_small = ps.circle(radius_ratio=0.20)
    circle_large = ps.circle(radius_ratio=0.40)

    ring = binary_op.run_op(circle_large, circle_small, 'subtract')

    plot_utils.plot_one_shape(ring, title='Ring', save_path='ring.pdf')
    start_ring_union = binary_op.run_op(star, ring, 'xor')

    plot_utils.plot_shapes_geoms(
        [circle_large, pentagon, ring, tri, ell, diamond, rect, sector, star, star1, square, start_ring_union],
        titles=['Circle', 'Pentagon', 'Ring', 'Triangle', 'Ellipse', 'Diamond', 'Rectangle', 'Sector', 'Star', 'Star1', 'Square', 'Star-Ring Union'],
        save_path='shapes_vue.svg'
    )

def generate_shapes():
    config = yaml.safe_load(open('config.yaml', 'r'))
    complex_shape_gen = ShapeCorpusGen.ComplexShapeConstructor(config)
    complex_shape_gen.create_shapes()

def main():
    res = 300
    n_max = 25
    circle_mask = get_circle_mask(res=res)
    breakpoint()
    zernike_corpus = ZernikeBasisCorpus_bak.ZernikeBasisCorpus(n_max=n_max, res=res)

    #plot the basis
    # orders = [(0,0), (1,1), (2,0), (2,2), (3,1), (4,2), (5,1), (5,3), (6,4)]
    # basis_list = list()
    # for n, m in orders:
    #     zernike_basis = zernike_corpus.get_one_zernike_basis(n, m)
    #     basis_list.append(zernike_basis)
    # zernike_corpus.vis_multi_zernike_basis(basis_list, vis_real=True, 
    #                                        fig_savename='zernike_basis_real_new.pdf')

    # zernike_corpus = ZernikeBasisCorpus.ZernikeBasisCorpusGPTLargeRes(n_max=n_max, res=res)
    square_mask = get_square_mask(res=res)
    circle_mask = get_circle_mask(res=res)
    breakpoint()
    # plt.imshow(square_mask, cmap='gray')
    # plt.savefig('squaremask.png')
    coeffs_square = zernike_corpus.compute_zernike_coeffs(square_mask.astype(np.float32))
    coeffs_circle = zernike_corpus.compute_zernike_coeffs(circle_mask.astype(np.float32))

    square_recon = zernike_corpus.reconstruct_shape(coeffs_square)
    circle_recon = zernike_corpus.reconstruct_shape(coeffs_circle)

    avg_recon_error_square = np.mean((square_mask - square_recon)**2)
    avg_recon_error_circle = np.mean((circle_mask - circle_recon)**2)

    print(f"Average Reconstruction Error (Square): {avg_recon_error_square}")
    print(f"Average Reconstruction Error (Circle): {avg_recon_error_circle}")

    breakpoint()

if __name__ == "__main__":
    # generate_shapes()
    main()