import numpy as np
import requests
from flax import struct
from omegaconf import DictConfig, OmegaConf


def non_pytree(*args, **kwargs):
    return struct.field(*args, pytree_node=False, **kwargs)


def run_opt(init_params, fun, opt, max_iter, tol):
    import jax
    import optax

    value_and_grad_fun = optax.value_and_grad_from_state(fun)

    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, state=state)
        updates, state = opt.update(grad, state, params, value=value, grad=grad, value_fn=fun)
        params = optax.apply_updates(params, updates)
        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = optax.tree.get(state, "count")
        grad = optax.tree.get(state, "grad")
        err = optax.tree_utils.tree_norm(grad)
        return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

    init_carry = (init_params, opt.init(init_params))
    final_params, final_state = jax.lax.while_loop(continuing_criterion, step, init_carry)
    return final_params, final_state


def is_power_of_two(n):
    power_of_2 = bool(n > 0 and (n & (n - 1)) == 0)
    return power_of_2


def get_closest_square(full_size):
    target = int(np.sqrt(full_size))

    def closest_square(size, dir="split"):
        if size < 1 or size >= full_size:
            return np.inf
        if full_size % size == 0:
            return size
        else:
            if dir == "split":
                closest_down = closest_square(size - 1, "down")
                closest_up = closest_square(size + 1, "up")
                if np.abs(closest_down - target) < np.abs(closest_up - target):
                    return closest_down
                else:
                    return closest_up
            elif dir == "down":
                return closest_square(size - 1, "down")
            elif dir == "up":
                return closest_square(size + 1, "up")

    dim1 = closest_square(target)
    return (dim1, full_size // dim1)


def mplfig_to_npimage(fig):
    """Converts a matplotlib figure to a RGB frame after updating the canvas"""
    from matplotlib.backends.backend_agg import FigureCanvasAgg

    canvas = FigureCanvasAgg(fig)
    canvas.draw()  # update/draw the elements

    # get the width and the height to resize the matrix
    l, b, w, h = canvas.figure.bbox.bounds
    w, h = int(w), int(h)

    # Get the RGBA buffer and convert to RGB
    buf = np.asarray(canvas.buffer_rgba())
    # Remove alpha channel
    image = buf[:, :, :3]
    return image


def submit_to_api(cfg: DictConfig):
    from jadex.global_configs.api_token import SERVER_URL, SUBMIT_TOKEN

    OmegaConf.resolve(cfg)
    cfg_dict = OmegaConf.to_container(cfg)

    response = requests.post(
        f"{SERVER_URL}/submit_job", headers={"X-Submit-Token": SUBMIT_TOKEN}, json=cfg_dict
    )
    print("Status code:", response.status_code)

    try:
        print("Response JSON:", response.json())
    except ValueError:
        print("Response text:", response.text)


# Function to update the progress on the server
def update_ablation_progress(uuid, progress):
    from jadex.global_configs.api_token import SERVER_URL, SUBMIT_TOKEN

    """Update the progress of the job on the server."""
    data = {"uuid": uuid, "progress": progress}
    response = requests.post(
        f"{SERVER_URL}/update_progress", headers={"X-Submit-Token": SUBMIT_TOKEN}, json=data
    )
    if response.status_code == 200:
        print(f"Progress for {uuid} updated to {progress}.")
    else:
        print(f"Failed to update progress for {uuid}. Error: {response.text}")


def submit_job(fn, cfg: DictConfig):
    import submitit

    cluster_cfg = cfg.cluster

    if cfg.job.get("submit_to_api", False):
        submit_to_api(cfg)

    if cfg.job.get("update_ablation_progress", False):
        update_ablation_progress(cfg.job.ablation_uuid, "SUBMITTED")

    if cluster_cfg is not None:
        try:
            executor = submitit.SlurmExecutor(folder="slurm_out/%j")
            found_slurm = True
        except Exception as e:
            found_slurm = False
            if not isinstance(e, RuntimeError):
                print(e)

        if found_slurm:
            cluster_cfg_dict = dict(cluster_cfg)
            cluster_cfg_dict.pop("submitted", None)
            cluster_cfg_dict.pop("name", None)
            cluster_cfg_dict.pop("id", None)
            executor.update_parameters(**cluster_cfg_dict)
            job = executor.submit(fn, cfg)
            print(f"Submitted job {job.job_id}")
            return job.job_id

    fn(cfg)
