# add angler to path (not necessary if pip installed)
import sys

import matplotlib.pylab as plt
import numpy as np
import torch
from tqdm import tqdm

from device_shape import (
    mmi_2x2_L_random,
    mmi_3x3_L,
    mmi_3x3_L_random,
    mmi_3x3_L_random_slots,
    mmi_4x4_L,
    mmi_4x4_L_random,
    mmi_4x4_L_random_3pads,
    mmi_5x5_L_random,
    mmi_6x6_L_random,
)

sys.path.append("..")

from itertools import product

# import the main simulation and optimization classes
from angler import Optimization, Simulation
from device_shape import *

# import some structure generators


def generate_simu_data(pol_list=["Hz"]):
    # define the similation constants
    lambda0 = 1.55e-6  # free space wavelength (m)
    c0 = 299792458  # speed of light in vacuum (m/s)
    omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
    dl = 0.1  # grid size (L0) um
    NPML = [10, 10]  # number of pml grid points on x and y borders
    pol = "Hz"  # polarization (either 'Hz' or 'Ez')
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    ext_wg = 10  # length os extended input waveguide (um)
    wg = 0.4  # waveguide width (um)
    port1 = 2.5  # input waveguide center location (um)
    port2 = 5.6  # input waveguide center location (um)
    mode_center_1 = 2.5
    mode_center_2 = 5.7
    mode_width = 0.5
    pads = ((14, 38, 2, 3.5), (14, 38, 4.5, 6))  # pad location (xl,xh,yl,yh) (um)
    eps_range = (11.9, 12.3)  # tuning range of Si permittivity
    # eps_range = (11, 13)            # tuning range of Si permittivity

    # define permittivity of three port system
    # eps_r, design_region = N_port(N, L, H, w, d, l, spc, dl, NPML, eps_m)
    device = torch.device("cuda:0")
    wavelength, epsilon, _ = torch.load("../mmi2x2/processed/training.pt")
    print(wavelength.shape)
    eps_r = epsilon[0, 0].t().numpy().real
    # eps_r = np.kron(eps_r,np.array([[1,1],[1,1]]))
    print(eps_r.shape)
    # print(eps_r[1, :])
    # exit(0)
    ## add waveguide
    eps_si = eps_r.max()
    eps_sio2 = eps_r.min()

    ## binarize permittivities on SiO2 and Si boundary
    # eps_r[eps_r > (eps_si + eps_sio2) / 2] = eps_si
    # eps_r[eps_r <= (eps_si + eps_sio2) / 2] = eps_sio2

    ## extension of input waveguide
    eps_i = np.zeros((int(ext_wg / dl), eps_r.shape[1])) + eps_sio2
    start, end = round((port1 - wg / 2) / dl), round((port1 + wg / 2) / dl) - 1
    print(start, end)
    eps_i[:, round((port1 - wg / 2) / dl) : round((port1 + wg / 2) / dl) - 1] = eps_si
    eps_i[:, [start - 1, end]] = 6.177296

    start, end = round((port2 - wg / 2) / dl) + 1, round((port2 + wg / 2) / dl)
    print(start, end)
    eps_i[:, round((port2 - wg / 2) / dl) + 1 : round((port2 + wg / 2) / dl)] = eps_si
    eps_i[:, [start - 1, end]] = 6.177296
    eps_r = np.concatenate([eps_i, eps_r], axis=0)
    (Nx, Ny) = eps_r.shape

    n_points = 10
    epsilon_list = []
    field_list = []
    n_eps = [np.linspace(0, 1, n_points).tolist()] * len(pads)
    # pol_list = ["Ez", "Hz"]
    # pol_list = ["Hz"]

    for eps_list in tqdm(product(*n_eps)):
        for pol in pol_list:
            epsilon_list_tmp = []
            field_list_tmp = []
            for ny in [mode_center_1, mode_center_2]:
                for e, loc in zip(eps_list, pads):
                    top, bottom = int((ext_wg + loc[0]) / dl), int((ext_wg + loc[1]) / dl)
                    left, right = int(loc[2] / dl), int(loc[3] / dl)
                    eps = e * (eps_range[1] - eps_range[0]) + eps_range[0]
                    eps_r[top - 1 : bottom + 1, left - 1 : right + 1] = (eps + eps_si) / 2
                    eps_r[top:bottom, left:right] = eps

                epsilon_list_tmp.append(eps_r)
                # make a new simulation object
                simulation = Simulation(omega, eps_r, dl, NPML, pol)

                # plot the permittivity distribution
                # simulation.plt_eps(outline=False)
                # plt.savefig("angler_mmi_eps.png", dpi=300)
                # exit(0)

                l = 0.5

                # set the input waveguide modal source
                simulation.add_mode(
                    neff=np.sqrt(eps_si),
                    direction_normal="x",
                    center=[NPML[0] + int(l / 2 / dl), ny / dl],
                    width=int(1 * mode_width / dl),
                    scale=source_amp,
                )
                simulation.setup_modes()

                # set source and solve for electromagnetic fields
                (Ex, Ey, Hz) = simulation.solve_fields()
                if pol == "Hz":
                    field_list_tmp.append(
                        np.stack(
                            [
                                simulation.fields["Ex"],
                                simulation.fields["Ey"],
                                np.zeros_like(simulation.fields["Ey"]),
                                np.zeros_like(simulation.fields["Ey"]),
                                np.zeros_like(simulation.fields["Ey"]),
                                simulation.fields["Hz"],
                            ],
                            axis=0,
                        )
                    )
                else:
                    field_list_tmp.append(
                        np.stack(
                            [
                                np.zeros_like(Ex),
                                np.zeros_like(Ex),
                                simulation.fields["Ez"],
                                simulation.fields["Hx"],
                                simulation.fields["Hy"],
                                np.zeros_like(Ex),
                            ],
                            axis=0,
                        )
                    )
                simulation.plt_re()
                plt.savefig("angler_mmi_simu.png", dpi=300)
                exit(0)
            epsilon_list.append(np.stack(epsilon_list_tmp, axis=0))
            field_list.append(np.stack(field_list_tmp, axis=0))
    epsilon_list = torch.from_numpy(
        np.stack(epsilon_list, axis=0).astype(np.complex64)[:, :, np.newaxis, int(ext_wg / dl) :, :]
    ).transpose(-1, -2)
    field_list = torch.from_numpy(
        np.stack(field_list, axis=0).astype(np.complex64)[..., int(ext_wg / dl) :, :]
    ).transpose(-1, -2)
    print(epsilon_list.shape, field_list.shape)
    grid_step_list = torch.ones(epsilon_list.size(0), epsilon_list.size(1), 2).fill_(dl)
    import os

    if not os.path.isdir("./raw"):
        os.mkdir("./raw")
    torch.save(
        {
            "eps": epsilon_list,
            "fields": field_list,
            "wavelength": wavelength.unsqueeze(1).repeat(field_list.size(0), field_list.size(1), 1),
            "grid_step": grid_step_list,
        },
        f"./raw/{'_'.join(pol_list)}_fields_epsilon.pt",
    )


def generate_simu_mmi(N, pol_list=["Hz"]):
    # define the similation constants
    lambda0 = 1.53e-6  # free space wavelength (m)
    # lambda0 = 1.53e-6  # free space wavelength (m)
    # lambda0 = 1.6e-6  # free space wavelength (m)
    c0 = 299792458  # speed of light in vacuum (m/s)
    omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
    dl = 0.1  # grid size (L0) um
    NPML = [15, 15]  # number of pml grid points on x and y borders
    pol = "Hz"  # polarization (either 'Hz' or 'Ez')
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    eps_range = (11.9, 12.3)  # tuning range of Si permittivity
    eps_sio2 = 1.44 ** 2
    eps_si = 3.48 ** 2
    box_size = (31.5, 6.1)
    box_size = (25.9, 6.1)
    # box_size = (33, 6.1)
    # box_size = (16, 4)
    # box_size = (20.66, 4)
    mmi = MMI_NxM(
        N,
        N,
        # box_size=(27, 7),  # box [length, width], um
        # box_size=(12, 3),  # box [length, width], um
        # box_size=(20, 5),  # box [length, width], um
        box_size=box_size,  # box [length, width], um
        # wg_width=(1.55 / 3.48 / 2, 1.55 / 3.48 / 2),  # in/out wavelength width, um
        # wg_width=(0.8, 0.8),  # in/out wavelength width, um
        wg_width=(1.1, 1.1),  # in/out wavelength width, um
        # wg_width=(1.55/2, 1.55/2),  # in/out wavelength width, um
        # port_diff=(7/N, 7/N),  # distance between in/out waveguides. um
        port_diff=(box_size[1] / N, box_size[1] / N),  # distance between in/out waveguides. um
        port_len=3,  # length of in/out waveguide from PML to box. um
        border_width=0.25,  # space between box and PML. um
        # grid_step=0.03,  # isotropic grid step um
        grid_step=0.05,  # isotropic grid step um
        # grid_step=0.1,  # isotropic grid step um
        NPML=(30, 30),  # PML pixel width. pixel
    )
    mmi = mmi_3x3_L()
    mmi.epsilon_map = mmi.set_pad_eps([11.9, 12.3, 12.3])
    epsilon_map = mmi.trim_pml(mmi.epsilon_map)
    print(epsilon_map.shape)
    simulation = Simulation(omega, mmi.epsilon_map, mmi.grid_step, mmi.NPML, pol)
    simulation.plt_eps(outline=False)
    plt.savefig(f"angler_mmi{N}x{N}_eps.png", dpi=300)

    simulation.add_mode(
        neff=np.sqrt(eps_si),
        direction_normal="x",
        center=mmi.in_port_centers_px[1],
        # center=[NPML[0] + int(l / 2 / dl), ny / dl],
        width=int(2 * mmi.in_port_width_px[1]),
        scale=source_amp,
    )
    simulation.setup_modes()

    # set source and solve for electromagnetic fields
    (Ex1, Ey1, Hz1) = simulation.solve_fields()
    # simulation.fields["Hz"] = Hz1
    simulation.plt_re(outline=False)
    plt.savefig(f"angler_mmi{N}x{N}_simu.png", dpi=300)
    exit(0)

    simulation = Simulation(omega, mmi.epsilon_map, mmi.grid_step, mmi.NPML, pol)
    simulation.add_mode(
        neff=np.sqrt(eps_si),
        direction_normal="x",
        center=mmi.in_port_centers_px[-1],
        # center=[NPML[0] + int(l / 2 / dl), ny / dl],
        width=int(1 * mmi.in_port_width_px[1]),
        scale=source_amp,
    )
    simulation.setup_modes()

    # set source and solve for electromagnetic fields
    (Ex2, Ey2, Hz2) = simulation.solve_fields()
    print(np.mean(np.abs((Ex1 - np.fliplr(Ex2)))) / np.mean(np.abs(Ex1)))
    print(np.mean(np.abs((Ey1 - np.fliplr(Ey2)))) / np.mean(np.abs(Ey1)))
    print(np.mean(np.abs((Hz1 - np.fliplr(Hz2)))) / np.mean(np.abs(Hz1)))
    simulation.fields["Hz"] = Ey2
    simulation.plt_re(outline=False)
    plt.savefig(f"angler_mmi{N}x{N}_simu_Ey2.png", dpi=300)


def generate_mmi_data(configs, name):
    c0 = 299792458  # speed of light in vacuum (m/s)
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    neff_si = 3.48
    epsilon_list = []
    field_list = []
    grid_step_list = []
    wavelength_list = []
    input_len_list = []
    for idx, config in enumerate(configs):
        print(f"Generating data with config:\n\t{config} ({idx:4d}/{len(configs):4d})")
        pol, device, eps_range, n_points, wavelengths, size = config
        eps_min, eps_max = eps_range

        n_eps = [(np.linspace(0, 1, n_points) * (eps_max - eps_min) + eps_min).tolist()] * device.num_in_ports
        # for eps in tqdm(product(*n_eps)):
        for eps in tqdm(product(*n_eps)):
            for wavelength in wavelengths:
                lambda0 = wavelength / 1e6  # free space wavelength (m)
                omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
                eps_map = device.set_pad_eps(eps)
                epsilon_list_tmp = []
                field_list_tmp = []
                grid_step_list_tmp = []
                wavelength_list_tmp = []
                input_len_list_tmp = []
                for i in range(device.num_in_ports):
                    simulation = Simulation(omega, eps_map, device.grid_step, device.NPML, pol)
                    simulation.add_mode(
                        neff=neff_si,
                        direction_normal="x",
                        center=device.in_port_centers_px[i],
                        width=int(2 * device.in_port_width_px[i]),
                        scale=source_amp,
                    )
                    simulation.setup_modes()
                    (Ex, Ey, Hz) = simulation.solve_fields()

                    if pol == "Hz":
                        field = np.stack(
                            [
                                simulation.fields["Ex"],
                                simulation.fields["Ey"],
                                simulation.fields["Hz"],
                            ],
                            axis=0,
                        )
                    else:
                        field = np.stack(
                            [
                                simulation.fields["Hx"],
                                simulation.fields["Hy"],
                                simulation.fields["Ez"],
                            ],
                            axis=0,
                        )

                    eps_map_resize, grid_step = device.resize(
                        device.trim_pml(eps_map), size=size, mode="bilinear"
                    )
                    # simulation.eps_r = eps_map_resize
                    # simulation.plt_eps(outline=False)
                    # plt.savefig(f"angler_gen_mmi_eps.png", dpi=300)
                    # exit(0)
                    epsilon_list_tmp.append(eps_map_resize)
                    field, _ = device.resize(device.trim_pml(field), size=size, mode="bilinear")
                    # simulation.fields["Hz"] = field[2]
                    # simulation.plt_re(outline=False)
                    # plt.savefig(f"angler_gen_mmi_simu.png", dpi=300)
                    # exit(0)
                    field_list_tmp.append(field)
                    wavelength_list_tmp.append(np.array([wavelength]))
                    grid_step_list_tmp.append(np.array(grid_step))
                    input_len_list_tmp.append(np.array([int(device.port_len / grid_step[0])]))
                    # print(grid_step)
                epsilon_list.append(np.stack(epsilon_list_tmp, axis=0))
                field_list.append(np.stack(field_list_tmp, axis=0))
                wavelength_list.append(np.stack(wavelength_list_tmp, axis=0))
                grid_step_list.append(np.stack(grid_step_list_tmp, axis=0))
                input_len_list.append(np.stack(input_len_list_tmp, axis=0))

    epsilon_list = torch.from_numpy(
        np.stack(epsilon_list, axis=0).astype(np.complex64)[:, :, np.newaxis, :, :]
    ).transpose(-1, -2)
    field_list = torch.from_numpy(np.stack(field_list, axis=0).astype(np.complex64)).transpose(-1, -2)
    grid_step_list = torch.from_numpy(np.stack(grid_step_list, axis=0).astype(np.float32))
    wavelength_list = torch.from_numpy(np.stack(wavelength_list, axis=0).astype(np.float32))
    input_len_list = torch.from_numpy(np.stack(input_len_list, axis=0).astype(np.int32))
    print(
        epsilon_list.shape,
        field_list.shape,
        grid_step_list.shape,
        wavelength_list.shape,
        input_len_list.shape,
    )
    import os

    if not os.path.isdir("./raw"):
        os.mkdir("./raw")
    torch.save(
        {
            "eps": epsilon_list,
            "fields": field_list,
            "wavelength": wavelength_list,
            "grid_step": grid_step_list,
            "input_len": input_len_list,
        },
        f"./raw/{name}.pt",
    )
    print(f"Saved simulation data ./raw/{name}.pt")


def generate_mmi_random_data(configs, name):
    # each epsilon combination randomly sample an MMI box size, treat them as unified permittivies distribution
    c0 = 299792458  # speed of light in vacuum (m/s)
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    neff_si = 3.48
    epsilon_list = []
    field_list = []
    grid_step_list = []
    wavelength_list = []
    input_len_list = []

    for idx, config in enumerate(configs):
        print(f"Generating data with config:\n\t{config} ({idx:4d}/{len(configs):4d})")
        pol, device_fn, eps_range, n_points, wavelengths, size, random_state = config
        eps_min, eps_max = eps_range
        device = device_fn()
        n_eps = [(np.linspace(0, 1, n_points) * (eps_max - eps_min) + eps_min).tolist()] * len(
            device.pad_regions
        )
        # for eps in tqdm(product(*n_eps)):
        device_id = 0
        for eps in tqdm(product(*n_eps)):
            np.random.seed(random_state + device_id)
            device = device_fn()  # re-sample device shape
            device_id += 1
            for wavelength in wavelengths:
                lambda0 = wavelength / 1e6  # free space wavelength (m)
                omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
                eps_map = device.set_pad_eps(eps)
                epsilon_list_tmp = []
                field_list_tmp = []
                grid_step_list_tmp = []
                wavelength_list_tmp = []
                input_len_list_tmp = []
                for i in range(device.num_in_ports):
                    simulation = Simulation(omega, eps_map, device.grid_step, device.NPML, pol)
                    simulation.add_mode(
                        neff=neff_si,
                        direction_normal="x",
                        center=device.in_port_centers_px[i],
                        width=int(2 * device.in_port_width_px[i]),
                        scale=source_amp,
                    )
                    simulation.setup_modes()
                    (Ex, Ey, Hz) = simulation.solve_fields()

                    if pol == "Hz":
                        field = np.stack(
                            [
                                simulation.fields["Ex"],
                                simulation.fields["Ey"],
                                simulation.fields["Hz"],
                            ],
                            axis=0,
                        )
                    else:
                        field = np.stack(
                            [
                                simulation.fields["Hx"],
                                simulation.fields["Hy"],
                                simulation.fields["Ez"],
                            ],
                            axis=0,
                        )

                    eps_map_resize, grid_step = device.resize(
                        device.trim_pml(eps_map), size=size, mode="bilinear"
                    )
                    # simulation.eps_r = eps_map_resize
                    # simulation.plt_eps(outline=False)
                    # plt.savefig(f"angler_gen_smmi_eps.png", dpi=300)
                    # exit(0)
                    epsilon_list_tmp.append(eps_map_resize)
                    field, _ = device.resize(device.trim_pml(field), size=size, mode="bilinear")
                    # simulation.fields["Hz"] = field[2]
                    # simulation.plt_re(outline=False)
                    # plt.savefig(f"angler_gen_mmi_simu.png", dpi=300)
                    # exit(0)
                    field_list_tmp.append(field)
                    wavelength_list_tmp.append(np.array([wavelength]))
                    grid_step_list_tmp.append(np.array(grid_step))
                    input_len_list_tmp.append(np.array([int(device.port_len / grid_step[0])]))
                    # print(grid_step)
                epsilon_list.append(np.stack(epsilon_list_tmp, axis=0))
                field_list.append(np.stack(field_list_tmp, axis=0))
                wavelength_list.append(np.stack(wavelength_list_tmp, axis=0))
                grid_step_list.append(np.stack(grid_step_list_tmp, axis=0))
                input_len_list.append(np.stack(input_len_list_tmp, axis=0))

    epsilon_list = torch.from_numpy(
        np.stack(epsilon_list, axis=0).astype(np.complex64)[:, :, np.newaxis, :, :]
    ).transpose(-1, -2)
    field_list = torch.from_numpy(np.stack(field_list, axis=0).astype(np.complex64)).transpose(-1, -2)
    grid_step_list = torch.from_numpy(np.stack(grid_step_list, axis=0).astype(np.float32))
    wavelength_list = torch.from_numpy(np.stack(wavelength_list, axis=0).astype(np.float32))
    input_len_list = torch.from_numpy(np.stack(input_len_list, axis=0).astype(np.int32))
    print(
        epsilon_list.shape,
        field_list.shape,
        grid_step_list.shape,
        wavelength_list.shape,
        input_len_list.shape,
    )
    import os

    if not os.path.isdir("./raw"):
        os.mkdir("./raw")
    torch.save(
        {
            "eps": epsilon_list,
            "fields": field_list,
            "wavelength": wavelength_list,
            "grid_step": grid_step_list,
            "input_len": input_len_list,
        },
        f"./raw/{name}.pt",
    )
    print(f"Saved simulation data ./raw/{name}.pt")


def generate_mmi_random_spectra_data(configs, name):
    # each epsilon combination randomly sample an MMI box size, treat them as unified permittivies distribution
    c0 = 299792458  # speed of light in vacuum (m/s)
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    neff_si = 3.48
    epsilon_list = []
    field_list = []
    grid_step_list = []
    wavelength_list = []
    input_len_list = []

    for idx, config in enumerate(configs):
        print(f"Generating data with config:\n\t{config} ({idx:4d}/{len(configs):4d})")
        pol, device_fn, eps_range, n_points, wavelengths, size, random_state = config
        eps_min, eps_max = eps_range
        device = device_fn()

        device_id = 0
        for _ in tqdm(range(n_points)):
            np.random.seed(random_state + device_id)
            device = device_fn()  # re-sample device shape
            device_id += 1
            eps = (
                np.round(np.random.uniform(0, 7, size=[len(device.pad_regions)])) / 7 * (eps_max - eps_min)
                + eps_min
            ).tolist()
            for wavelength in wavelengths:
                lambda0 = wavelength / 1e6  # free space wavelength (m)
                omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
                eps_map = device.set_pad_eps(eps)
                epsilon_list_tmp = []
                field_list_tmp = []
                grid_step_list_tmp = []
                wavelength_list_tmp = []
                input_len_list_tmp = []
                for i in range(device.num_in_ports):
                    simulation = Simulation(omega, eps_map, device.grid_step, device.NPML, pol)
                    simulation.add_mode(
                        neff=neff_si,
                        direction_normal="x",
                        center=device.in_port_centers_px[i],
                        width=int(2 * device.in_port_width_px[i]),
                        scale=source_amp,
                    )
                    simulation.setup_modes()
                    (Ex, Ey, Hz) = simulation.solve_fields()

                    if pol == "Hz":
                        field = np.stack(
                            [
                                simulation.fields["Ex"],
                                simulation.fields["Ey"],
                                simulation.fields["Hz"],
                            ],
                            axis=0,
                        )
                    else:
                        field = np.stack(
                            [
                                simulation.fields["Hx"],
                                simulation.fields["Hy"],
                                simulation.fields["Ez"],
                            ],
                            axis=0,
                        )

                    eps_map_resize, grid_step = device.resize(
                        device.trim_pml(eps_map), size=size, mode="bilinear"
                    )
                    # simulation.eps_r = eps_map_resize
                    # simulation.plt_eps(outline=False)
                    # plt.savefig(f"angler_gen_smmi_eps.png", dpi=300)
                    # exit(0)
                    epsilon_list_tmp.append(eps_map_resize)
                    field, _ = device.resize(device.trim_pml(field), size=size, mode="bilinear")
                    ## draw fields
                    # simulation.fields["Hz"] = field[2]
                    # simulation.plt_re(outline=False)
                    # plt.savefig(f"angler_gen_mmi{device.num_in_ports}x{device.num_in_ports}_simu.png", dpi=400)
                    # exit(0)
                    field_list_tmp.append(field)
                    wavelength_list_tmp.append(np.array([wavelength]))
                    grid_step_list_tmp.append(np.array(grid_step))
                    input_len_list_tmp.append(np.array([int(device.port_len / grid_step[0])]))
                    # print(grid_step)
                epsilon_list.append(np.stack(epsilon_list_tmp, axis=0))
                field_list.append(np.stack(field_list_tmp, axis=0))
                wavelength_list.append(np.stack(wavelength_list_tmp, axis=0))
                grid_step_list.append(np.stack(grid_step_list_tmp, axis=0))
                input_len_list.append(np.stack(input_len_list_tmp, axis=0))

    epsilon_list = torch.from_numpy(
        np.stack(epsilon_list, axis=0).astype(np.complex64)[:, :, np.newaxis, :, :]
    ).transpose(-1, -2)
    field_list = torch.from_numpy(np.stack(field_list, axis=0).astype(np.complex64)).transpose(-1, -2)
    grid_step_list = torch.from_numpy(np.stack(grid_step_list, axis=0).astype(np.float32))
    wavelength_list = torch.from_numpy(np.stack(wavelength_list, axis=0).astype(np.float32))
    input_len_list = torch.from_numpy(np.stack(input_len_list, axis=0).astype(np.int32))
    print(
        epsilon_list.shape,
        field_list.shape,
        grid_step_list.shape,
        wavelength_list.shape,
        input_len_list.shape,
    )
    import os

    if not os.path.isdir("./raw"):
        os.mkdir("./raw")
    torch.save(
        {
            "eps": epsilon_list,
            "fields": field_list,
            "wavelength": wavelength_list,
            "grid_step": grid_step_list,
            "input_len": input_len_list,
        },
        f"./raw/{name}.pt",
    )
    print(f"Saved simulation data ./raw/{name}.pt")


def generate_slot_mmi_random_data(configs, name):
    # each epsilon combination randomly sample an MMI box size, treat them as unified permittivies distribution
    c0 = 299792458  # speed of light in vacuum (m/s)
    source_amp = 1e-9  # amplitude of modal source (make around 1 for nonlinear effects)
    neff_si = 3.48
    epsilon_list = []
    field_list = []
    grid_step_list = []
    wavelength_list = []
    input_len_list = []

    for idx, config in enumerate(configs):
        print(f"Generating data with config:\n\t{config} ({idx:4d}/{len(configs):4d})")
        pol, device_fn, eps_val, n_points, wavelengths, size, random_state = config
        # n_points: how many samples for each wavelength

        # for eps in tqdm(product(*n_eps)):
        device_id = 0
        for _ in tqdm(range(n_points)):
            np.random.seed(random_state + device_id)
            device = device_fn()  # re-sample device shape
            device_id += 1
            for wavelength in wavelengths:
                lambda0 = wavelength / 1e6  # free space wavelength (m)
                omega = 2 * np.pi * c0 / lambda0  # angular frequency (2pi/s)
                eps_map = device.set_pad_eps(np.zeros([len(device.pad_regions)]) + eps_val)
                epsilon_list_tmp = []
                field_list_tmp = []
                grid_step_list_tmp = []
                wavelength_list_tmp = []
                input_len_list_tmp = []
                for i in range(device.num_in_ports):
                    simulation = Simulation(omega, eps_map, device.grid_step, device.NPML, pol)
                    simulation.add_mode(
                        neff=neff_si,
                        direction_normal="x",
                        center=device.in_port_centers_px[i],
                        width=int(2 * device.in_port_width_px[i]),
                        scale=source_amp,
                    )
                    simulation.setup_modes()
                    (Ex, Ey, Hz) = simulation.solve_fields()

                    if pol == "Hz":
                        field = np.stack(
                            [
                                simulation.fields["Ex"],
                                simulation.fields["Ey"],
                                simulation.fields["Hz"],
                            ],
                            axis=0,
                        )
                    else:
                        field = np.stack(
                            [
                                simulation.fields["Hx"],
                                simulation.fields["Hy"],
                                simulation.fields["Ez"],
                            ],
                            axis=0,
                        )

                    eps_map_resize, grid_step = device.resize(
                        device.trim_pml(eps_map), size=size, mode="bilinear"
                    )
                    # simulation.eps_r = eps_map_resize
                    # simulation.plt_eps(outline=False)
                    # plt.savefig(f"angler_gen_smmi_eps.png", dpi=300)
                    # exit(0)
                    epsilon_list_tmp.append(eps_map_resize)
                    field, _ = device.resize(device.trim_pml(field), size=size, mode="bilinear")
                    # simulation.fields["Hz"] = field[2]
                    # simulation.plt_re(outline=False)
                    # plt.savefig(f"angler_gen_smmi_simu.png", dpi=300)
                    # exit(0)
                    field_list_tmp.append(field)
                    wavelength_list_tmp.append(np.array([wavelength]))
                    grid_step_list_tmp.append(np.array(grid_step))
                    input_len_list_tmp.append(np.array([int(device.port_len / grid_step[0])]))
                    # print(grid_step)
                epsilon_list.append(np.stack(epsilon_list_tmp, axis=0))
                field_list.append(np.stack(field_list_tmp, axis=0))
                wavelength_list.append(np.stack(wavelength_list_tmp, axis=0))
                grid_step_list.append(np.stack(grid_step_list_tmp, axis=0))
                input_len_list.append(np.stack(input_len_list_tmp, axis=0))

    epsilon_list = torch.from_numpy(
        np.stack(epsilon_list, axis=0).astype(np.complex64)[:, :, np.newaxis, :, :]
    ).transpose(-1, -2)
    field_list = torch.from_numpy(np.stack(field_list, axis=0).astype(np.complex64)).transpose(-1, -2)
    grid_step_list = torch.from_numpy(np.stack(grid_step_list, axis=0).astype(np.float32))
    wavelength_list = torch.from_numpy(np.stack(wavelength_list, axis=0).astype(np.float32))
    input_len_list = torch.from_numpy(np.stack(input_len_list, axis=0).astype(np.int32))
    print(
        epsilon_list.shape,
        field_list.shape,
        grid_step_list.shape,
        wavelength_list.shape,
        input_len_list.shape,
    )
    import os

    if not os.path.isdir("./raw"):
        os.mkdir("./raw")
    torch.save(
        {
            "eps": epsilon_list,
            "fields": field_list,
            "wavelength": wavelength_list,
            "grid_step": grid_step_list,
            "input_len": input_len_list,
        },
        f"./raw/{name}.pt",
    )
    print(f"Saved simulation data ./raw/{name}.pt")


def postprocess(name, epsilon_min=1, epsilon_max=12.3):
    data = torch.load(f"./raw/{name}.pt")
    epsilon = data["eps"]  # [bs, N, h, w] complex
    fields = data["fields"]  # [bs, N, 3, h, w] complex
    wavelengths = data["wavelength"]  # [bs, N, 1] real
    grid_steps = data["grid_step"]  # [bs, N, 2] real
    input_lens = data["input_len"]  # [bs, N, 1] int

    # normalize epsilon
    # epsilon_min, epsilon_max = epsilon.abs().min().item(), epsilon.abs().max().item()
    epsilon = (epsilon - 1) / (epsilon_max - 1)
    print(epsilon.shape, epsilon_min, epsilon_max)
    # print(grid_steps.shape, grid_steps.tolist()[:10])
    # # exit(0)
    # fig, axes = plt.subplots(1,10)
    # for i, ax in enumerate(axes):
    #     ax.imshow(epsilon[i,0,0].real)
    # # axes[1].imshow(epsilon[1,0,0].real)
    # # axes[2].imshow(epsilon[2,0,0].real)
    # fig.savefig("./slot_mmi_eps_1.png", dpi=200)
    # exit(0)
    # normalize fields
    fields = fields[:, :, 2:3]  # Hz or Ez only
    mag = fields.abs()
    mag_mean = mag.mean(dim=(0, 1, 3, 4), keepdim=True)
    for i in range(mag_mean.shape[2]):
        if mag_mean[:, :, i].mean() > 1e-18:
            mag_std = mag[:, :, i : i + 1].std()
            fields[:, :, i : i + 1] /= mag_std * 2
    print(fields.shape, fields.abs().max(), fields.abs().std())

    # append input mode
    input_mode = fields.clone()
    input_mask = torch.zeros(
        input_mode.shape[0], input_mode.shape[1], 1, 1, input_mode.shape[-1]
    )  # [bs, N, 1, 1, 1, w]
    for i in range(input_mode.shape[0]):
        input_mask[i, ..., : int(input_lens[i, 0, 0])].fill_(1)
    input_mode.mul_(input_mask)

    # make data
    data = torch.cat([epsilon, input_mode], dim=2)
    print(data.shape, fields.shape, wavelengths.shape, grid_steps.shape)
    print(f"postprocessed {name}")
    torch.save(
        {
            "eps": data,
            "fields": fields,
            "wavelength": wavelengths,
            "grid_step": grid_steps,
            "eps_min": torch.tensor([1.0]),
            "eps_max": torch.tensor([epsilon_max]),
        },
        f"./raw/{name}_mode.pt",
    )


def append_input_mode(pol_list):
    pol_list = sorted(pol_list)
    data = torch.load(f"./raw/{'_'.join(pol_list)}_fields_epsilon.pt")
    epsilon = data["eps"][:, :, 0:1]
    epsilon_min, epsilon_max = epsilon.abs().min().item(), epsilon.abs().max().item()
    epsilon = (epsilon - 1) / (epsilon_max - 1)
    print(epsilon.shape, epsilon_min, epsilon_max)

    wavelength = data["wavelength"]
    if pol_list == ["Hz"]:
        fields = data["fields"][:, :, 0:2]  # Ex, Ey, Ez only
    elif pol_list == ["Ez"]:
        fields = data["fields"][:, :, 2:3]  # Ex, Ey, Ez only
    else:
        fields = data["fields"][:, :, 0:3]  # Ex, Ey, Ez only
    print(fields.shape)
    # exit(0)
    # fields *= 1e11
    mag = fields.abs()
    mag_mean = mag.mean(dim=(0, 1, 3, 4), keepdim=True)
    for i in range(mag_mean.shape[2]):
        if mag_mean[:, :, i].mean() > 1e-18:
            mag_std = mag[:, :, i : i + 1].std()
            fields[:, :, i : i + 1] /= mag_std * 2

    # print(1/mag_std)
    # fields /= mag_std * 2
    print(fields.abs().max(), fields.abs().std())
    input_mode = fields.clone()
    input_taper = 9.9
    dl = 0.1
    # print([fields[i,0,0,0] for i in range(fields.shape[0])])
    input_mode[..., int(input_taper / dl) :].fill_(0)  # only know input mode
    epsilon = torch.cat([epsilon, input_mode], dim=2)
    grid_step = data["grid_step"]
    print(epsilon.shape, fields.shape, wavelength.shape, grid_step.shape)
    # torch.save(
    #     {
    #         "eps": epsilon,
    #         "fields": fields,
    #         "wavelength": wavelength,
    #         "grid_step": grid_step,
    #         "eps_min": torch.tensor([1.0]),
    #         "eps_max": torch.tensor([epsilon_max]),
    #         "port": torch.tensor([[22, 26], [54, 58]]),
    #     },
    #     f"./raw/{'_'.join(pol_list)}_fields_epsilon_mode.pt",
    # )


def launch_rHz_data_generation():
    pol = "Hz"
    device_list = [mmi_3x3_L_random]
    points_per_port = [8]
    eps_range = [11.9, 12.3]

    size = (384, 80)
    # np.random.seed(42)  # set random seed
    # for i, wavelength in enumerate(wavelengths[:3]):
    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    tasks = [
        (0, wavelengths[0]),
        (1, wavelengths[1]),
        (2, wavelengths[2]),  # eda15
        (3, wavelengths[3]),
        (4, wavelengths[4]),  # eda14
    ]
    # for i, wavelength in enumerate(wavelengths[3:]):

    wavelengths = np.arange(1.535, 1.576, 0.01).tolist()
    tasks = [
        # (5, wavelengths[0]),
        # (6, wavelengths[1]),
        # (7, wavelengths[2]), # eda05
        (8, wavelengths[3]),
        (9, wavelengths[4]),  # eda13
    ]

    for i, wavelength in tasks:
        name = f"rHz_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, int(10000 + i * 2000))
            for device, n_points in zip(device_list, points_per_port)
        ]
        generate_mmi_random_data(configs, name=name)
        postprocess(name)


def launch_rHz_data_spectra_generation():
    pol = "Hz"
    device_list = [mmi_3x3_L_random]
    points_per_wavelength = [32]
    eps_range = [11.9, 12.3]

    size = (384, 80)

    wavelengths = np.arange(1.53, 1.565, 0.002).tolist()
    tasks = list(enumerate(wavelengths))
    import os

    machine = os.uname()[1]
    tasks = {
        "eda05": tasks[0:4],
        "eda13": tasks[4:8],
        "eda14": tasks[8:12],
        "eda15": tasks[12:],
    }

    for i, wavelength in tasks[machine]:
        name = f"cband_rHz_mmi3x3_3pads_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, 40000)
            for device, n_points in zip(device_list, points_per_wavelength)
        ]
        generate_mmi_random_spectra_data(configs, name=name)
        postprocess(name)


def launch_rHz_data_mmi2x2_generation():
    pol = "Hz"
    device_list = [mmi_2x2_L_random]
    points_per_wavelength = [128]
    eps_range = [11.9, 12.3]

    size = (384, 80)

    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    tasks = list(enumerate(wavelengths))
    import os

    machine = os.uname()[1]
    tasks = {
        "eda05": tasks[0:1],
        "eda13": tasks[1:2],
        "eda14": tasks[2:3],
        "eda15": tasks[3:],
    }

    for i, wavelength in tasks[machine]:
        name = f"rHz_mmi2x2_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, 30000 + 2000 * i)
            for device, n_points in zip(device_list, points_per_wavelength)
        ]
        generate_mmi_random_spectra_data(configs, name=name)
        postprocess(name)


def launch_rHz_data_mmi4x4_generation():
    pol = "Hz"
    device_list = [mmi_4x4_L_random]
    points_per_wavelength = [256]
    eps_range = [11.9, 12.3]

    size = (384, 80)

    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    tasks = list(enumerate(wavelengths))
    import os

    machine = os.uname()[1]
    tasks = {
        "eda05": tasks[0:1],
        "eda13": tasks[1:2],
        "eda14": tasks[2:3],
        "eda15": tasks[3:],
    }

    for i, wavelength in tasks[machine]:
        name = f"rHz_mmi4x4_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, 30000 + 2000 * i)
            for device, n_points in zip(device_list, points_per_wavelength)
        ]
        generate_mmi_random_spectra_data(configs, name=name)
        postprocess(name)


def launch_rHz_data_mmi5x5_generation():
    pol = "Hz"
    device_list = [mmi_5x5_L_random]
    points_per_wavelength = [256]
    eps_range = [11.9, 12.3]

    size = (384, 80)

    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    tasks = list(enumerate(wavelengths))
    import os

    machine = os.uname()[1]
    tasks = {
        "eda05": tasks[0:1],
        "eda13": tasks[1:2],
        "eda14": tasks[2:3],
        "eda15": tasks[3:],
    }

    for i, wavelength in tasks[machine]:
        name = f"rHz_mmi5x5_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, 30000 + 2000 * i)
            for device, n_points in zip(device_list, points_per_wavelength)
        ]
        generate_mmi_random_spectra_data(configs, name=name)
        postprocess(name)


def launch_rHz_data_mmi6x6_generation():
    pol = "Hz"
    device_list = [mmi_6x6_L_random]
    points_per_wavelength = [170]
    eps_range = [11.9, 12.3]

    size = (384, 80)

    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    tasks = list(enumerate(wavelengths))
    import os

    machine = os.uname()[1]
    tasks = {
        "eda05": tasks[0:1],
        "eda13": tasks[1:2],
        "eda14": tasks[2:3],
        "eda15": tasks[3:],
    }

    for i, wavelength in tasks[machine]:
        name = f"rHz_mmi6x6_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_range, n_points, [wavelength], size, 30000 + 2000 * i)
            for device, n_points in zip(device_list, points_per_wavelength)
        ]
        generate_mmi_random_spectra_data(configs, name=name)
        postprocess(name)


def launch_slot_rHz_data_generation():
    pol = "Hz"
    device_list = [mmi_3x3_L_random_slots]
    points_per_port = [512]
    eps_val = 1.44 ** 2
    wavelengths = np.arange(1.53, 1.571, 0.01).tolist()
    size = (384, 80)
    # np.random.seed(42)  # set random seed
    # for i, wavelength in enumerate(wavelengths[:3]):
    tasks = [
        # (0, wavelengths[0]),
        # (1, wavelengths[1]),
        # (2, wavelengths[2]),  # eda05
        (3, wavelengths[3]),
        (4, wavelengths[4]),  # ead13
    ]

    wavelengths = np.arange(1.535, 1.576, 0.01).tolist()
    tasks = [
        # (5, wavelengths[0]),
        # (6, wavelengths[1]), # ead05
        # (7, wavelengths[2]), # eda13
        # (8, wavelengths[3]), # eda14
        (9, wavelengths[4]),  # eda15
    ]

    # for i, wavelength in enumerate(wavelengths[3:]):
    for i, wavelength in tasks:
        name = f"slot_rHz_{i}_fields_epsilon"
        configs = [
            (pol, device, eps_val, n_points, [wavelength], size, int(10000 + i * 2000))
            for device, n_points in zip(device_list, points_per_port)
        ]
        generate_slot_mmi_random_data(configs, name=name)
        postprocess(name)


if __name__ == "__main__":
    pol_list = ["Hz"]
 
    launch_rHz_data_generation()
    # launch_slot_rHz_data_generation()
    # launch_rHz_data_spectra_generation()
    # launch_rHz_data_mmi2x2_generation()
    # launch_rHz_data_mmi4x4_generation()
    # launch_rHz_data_mmi5x5_generation()
    # launch_rHz_data_mmi6x6_generation()
