"""
This code will generate a dataset by solving the holey layer structure
using RCWA method.

The settings for the structures are defined in the `configs/datasets/*.json` files.
"""

import os

os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["KHEPRI_MT_ON"] = "0"

import json
import time
import argparse
import multiprocessing as mp
import uuid
from itertools import product
import h5py
import matplotlib.pyplot as plt
import numpy as np
from khepri.crystal import Crystal
from khepri.tools import block2dense


def interpolate_kp(kp, Nk):
    """
    Interpolates the kp values to construct the band diagram.

    e.g., kp = [[0,0], [1,0], [1,1]]
              --> ks = [[0,0],...,[0.5,0],...,[1,0],...,[1,0.5],...,[1,1]]
    """
    kps = []
    for i in range(len(kp) - 1):
        if i == len(kp) - 1:
            endpoint = True
        else:
            endpoint = False

        _x = np.linspace(kp[i][0], kp[i + 1][0], Nk // (len(kp) - 1), endpoint=endpoint)
        _y = np.linspace(kp[i][1], kp[i + 1][1], Nk // (len(kp) - 1), endpoint=endpoint)

        kps.append(np.array([_x, _y]).T)

    return np.concatenate(kps, axis=0)


def solver(params):
    cl, wl, kp = params
    cl.set_source(wl, 1, 0, kp=kp)
    cl.solve()
    r, t = cl.poynting_flux_end()
    return r, t, np.linalg.det(block2dense(cl.Stot))


def generate_structure(config, num_workers=4, timeit=False):
    # Fixed parameters
    Nx = config["numerical_parameters"]["Nx"]
    Nk = config["numerical_parameters"]["Nk"]
    Nf = config["numerical_parameters"]["Nf"]
    kp = config["physical_parameters"]["kp"]
    pw = config["numerical_parameters"]["pw"]
    default_eps = config["constants"]["default_eps"]
    f_min, f_max = config["physical_parameters"]["f_range"]

    a = 1.22e-6
    c = 299792458
    f0 = c / a * 1e-12

    # Saving
    structure = {
        "pixmap_layers": [],  # List of 2D arrays representing the layers
        "thicknesses": [],  # List of thicknesses for each layer
        "layer_names": [],  # List of layer names (include the configuration)
    }

    cl = Crystal(pw)
    layer_names = []
    for i in range(
        np.random.randint(
            config["variables"]["n_layers"][0], config["variables"]["n_layers"][1] + 1
        )
    ):
        # Random settings
        eps = np.random.uniform(
            config["constants"]["material_eps"][0],
            config["constants"]["material_eps"][1],
        )
        hole_radius = np.random.uniform(
            config["variables"]["hole_radius"][0], config["variables"]["hole_radius"][1]
        )
        thickness = np.random.uniform(
            config["variables"]["layer_thickness"][0],
            config["variables"]["layer_thickness"][1],
        )
        configuration = np.random.choice(config["variables"]["configurations"])

        # Generate the bi-layer
        layer = np.ones((Nx, Nx)) * eps
        x = np.linspace(-0.5, 0.5, Nx)
        y = x.copy()
        X, Y = np.meshgrid(x, y)
        layer[np.sqrt(X**2 + Y**2) < hole_radius] = default_eps

        if configuration == "A":
            pass
        elif configuration == "B":
            layer = np.roll(layer, shift=(Nx // 2), axis=(0))
        elif configuration == "C":
            layer = np.roll(layer, shift=(Nx // 2), axis=(1))
        elif configuration == "D":
            layer = np.roll(layer, shift=(Nx // 2, Nx // 2), axis=(0, 1))
        else:
            raise ValueError("Invalid configuration")

        cl.add_layer_pixmap(f"layer_{i}:holey:{configuration}", layer, thickness)
        cl.add_layer_uniform(f"layer_{i}:uniform", eps, thickness)
        layer_names += [f"layer_{i}:holey:{configuration}", f"layer_{i}:uniform"]

        structure["pixmap_layers"].append(layer)
        structure["pixmap_layers"].append(np.ones((Nx, Nx)) * eps)
        structure["thicknesses"] += [thickness] * 2
        structure["layer_names"] += [
            f"layer_{i}:holey:{configuration}",
            f"layer_{i}:uniform",
        ]

    # Create the crystal
    cl.set_device(layer_names)

    # Solving
    start = time.time()
    fs = np.linspace(f_min, f_max, Nf, endpoint=True)
    kps = interpolate_kp(kp, Nk) * 0.3 * np.pi / 2.0
    RT = []
    params = list(product([cl], 1.0 / fs, kps))
    with mp.Pool(num_workers) as p:
        RT = list(p.imap(solver, params))
    tot_time = time.time() - start
    RT = np.array(RT)
    RT = RT.reshape(Nf, Nk, 3)

    A = 1 - RT[..., 0] - RT[..., 1]
    S = RT[..., 2]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3), sharey=True)

    Sabs = np.abs(S)
    Sabs[np.isnan(Sabs)] = 0
    Sabs /= np.mean(Sabs, axis=0)
    grad = np.gradient(np.log10(Sabs + 1e-9), axis=0)
    grad = np.gradient(grad, axis=0)  # 2nd derivative of S

    structure["d2S"] = grad
    structure["R"] = RT[..., 0]
    structure["T"] = RT[..., 1]
    structure["fs"] = fs
    structure["kps"] = kps

    fmax = np.max(fs) * f0
    fmin = np.min(fs) * f0

    ax1.imshow(
        grad,
        origin="lower",
        aspect="auto",
        cmap="grey",
        extent=(-np.pi, np.pi, fmin, fmax),
    )
    im = ax2.matshow(
        RT[..., 1].real,
        origin="lower",
        aspect="auto",
        vmin=0,
        vmax=1,
        extent=(-np.pi, np.pi, fmin, fmax),
    )

    plt.colorbar(im)

    name = "".join(
        [name.split(":holey:")[-1] for name in layer_names if "holey" in name]
    )
    plt.suptitle(name)
    ax1.set_ylabel("Frequency [c/a]")

    if not timeit:
        return structure, fig
    else:
        return structure, fig, tot_time


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate dataset by solving the holey layer structure using RCWA method."
    )
    parser.add_argument(
        "--config_file",
        type=str,
        default="configs/datasets/holey_layer.json",
        help="Path to the configuration file.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="Number of workers for parallel processing.",
    )
    parser.add_argument(
        "--n_structures",
        type=int,
        default=1,
        help="Number of structures to generate.",
    )
    parser.add_argument(
        "--output_folder",
        type=str,
        default="./structures/",
        help="Folder to save the generated dataset.",
    )
    parser.add_argument(
        "--timeit",
        action="store_true",
        help="Measure the time taken to generate each structure.",
    )
    args = parser.parse_args()

    with open(args.config_file, "r") as f:
        config = json.load(f)

    if not os.path.exists(args.output_folder):
        # Create the output folder if it doesn't exist
        os.makedirs(args.output_folder)
        # Dump the config file to the output folder
        with open(os.path.join(args.output_folder, "config.json"), "w") as f:
            json.dump(config, f, indent=4)

    for i in range(args.n_structures):
        if not args.timeit:
            structure, fig = generate_structure(config, args.num_workers, timeit=False)
        else:
            structure, fig, tot_time = generate_structure(
                config, args.num_workers, timeit=True
            )
            print(f"{tot_time:.5f}")

        # Save the structure in HDF5 format
        structure_id = uuid.uuid4()
        output_path = os.path.join(args.output_folder, f"{structure_id}.h5")
        with h5py.File(output_path, "w") as hf:
            hf.create_dataset("pixmap_layers", data=structure["pixmap_layers"])
            hf.create_dataset("thicknesses", data=structure["thicknesses"])
            hf.create_dataset("layer_names", data=structure["layer_names"])
            hf.create_dataset("fs", data=structure["fs"])
            hf.create_dataset("kps", data=structure["kps"])
            hf.create_dataset("R", data=structure["R"])
            hf.create_dataset("T", data=structure["T"])
            hf.create_dataset("d2S", data=structure["d2S"])

        # Save the figure
        fig_path = os.path.join(args.output_folder, f"{structure_id}.png")
        fig.savefig(fig_path, dpi=100)
        plt.close(fig)
