#!/usr/bin/python3
"""
Miscellaneous utility functions and transforms.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import numpy as np
import re
from pydantic import BaseModel
from scipy.stats.qmc import Sobol
from typing import Any, Dict, List, Optional, Tuple

from .envs.base import BaseTask


__all__ = ["Rd2Pydantic", "Pydantic2Rd", "initialize_designs", "json_loads"]


def _erfinv(x: np.ndarray) -> np.ndarray:
    """
    Winitzki's approximation of the inverse error function.
    Input:
        x: an array of shape (...,).
    Returns:
        An array of shape (...,).
    """
    a = 0.147
    ln = np.log1p(-1.0 * x * x)
    term = (2.0 / (np.pi * a)) + (ln / 2.0)
    return np.sign(x) * np.sqrt(np.sqrt(term * term - ln / a) - term)


def initialize_designs(
    task: BaseTask, batch_size: int, seed: int = 2025
) -> Tuple[List[BaseModel], Dict[str, int]]:
    """
    Initialize designs for a task using Sobol sequences.
    Input:
        task: the optimization task.
        batch_size: the number of designs to initialize.
        seed: random seed.
    Returns:
        A tuple of initial designs and metadata.
    """
    _sobol = Sobol(task.ndim, scramble=True, rng=seed)
    sobol_samples = _sobol.random_base2(int(np.ceil(np.log2(batch_size))))
    new_xq = sobol_samples[:batch_size]
    _eps = np.finfo(new_xq.dtype).eps
    new_xq = new_xq.clip(min=_eps, max=(1.0 - _eps))
    new_xq = np.squeeze(_erfinv((2.0 * new_xq) - 1.0) * np.sqrt(2.0))
    designs = Rd2Pydantic(task, new_xq, continuous=True)
    if task.task_name == "IWPCWarfarin-v0":
        mu, sigma = task.mu["warfarin_dose"], task.std["warfarin_dose"]
        designs = [
            task.design_schema(
                warfarin_dose=max(
                    0, (getattr(x, "warfarin_dose", 0.0) * sigma) + mu
                )
            )
            for x in designs
        ]
    return designs, {}


def Rd2Pydantic(
    task: BaseTask, x: np.ndarray, continuous: bool
) -> List[BaseModel]:
    """
    Convert a numpy array of shape (N, D) into a list of Pydantic models.
    Input:
        task: the optimization task.
        x: a numpy array of shape (N, D) or (N, D, D_model).
        continuous: whether the input is from a continuous design space.
    Returns:
        A list of Pydantic models.
    """
    if task.task_name == "IWPCWarfarin-v0":
        assert continuous
        return [task.design_schema(warfarin_dose=dose) for dose in x]
    elif task.task_name == "HIVDB-v0":
        assert x.shape[1] == len(task.train.drugs)
        return [
            task.design_schema.model_validate({
                key: val for key, val in zip(
                    task.train.drugs, (x[i] > 0.0).tolist()  # type: ignore
                )
            })
            for i in range(len(x))
        ]
    elif task.task_name == "ToyTask-v0":
        assert continuous
        return [task.design_schema(dim_1=xi[0], dim_2=xi[1]) for xi in x]
    raise NotImplementedError("Implement your own task's method here!")


def Pydantic2Rd(
    task: BaseTask, x: List[BaseModel], continuous: bool, seed: int = 2025
) -> np.ndarray:
    """
    Convert a list of Pydantic models into a numpy array of shape (N, D).
    Input:
        task: the optimization task.
        x: a list of Pydantic models.
        continuous: whether to map to a continuous design space.
        seed: random seed.
    Returns:
        A numpy array of shape (N, D) or (N, D, D_model).
    """
    rng = np.random.default_rng(seed)
    assert all(isinstance(design, task.design_schema) for design in x)
    if task.task_name == "IWPCWarfarin-v0":
        assert continuous
        return np.array([design.warfarin_dose for design in x])  # type: ignore
    elif task.task_name == "HIVDB-v0":
        vec = np.abs(rng.normal(size=(len(x), len(task.train.drugs))))
        for i in range(len(x)):
            values = np.array([
                x[i].model_dump()[key] for key in task.train.drugs
            ])
            vec[i, np.where(np.logical_not(values))] *= -1.0
        if not continuous:
            vec = np.where(vec > 0.0, 1, 0).astype(int)
        return vec
    elif task.task_name == "ToyTask-v0":
        assert continuous
        return np.array([[design.dim_1, design.dim_2] for design in x])
    raise NotImplementedError("Implement your own task's method here!")


def json_loads(s: str) -> Dict[str, Any]:
    """
    Extract the most likely valid JSON substring from a given string.
    Input:
        s: input string potentially containing JSON-like content.
    Returns:
        The loaded longest valid JSON substring.
    """
    if not isinstance(s, str) or not len(s):
        return {}
    try:
        return json.loads(s)
    except json.JSONDecodeError:
        pass

    start_positions = [m.start() for m in re.finditer(r"[\[{]", s)]
    best: Optional[str] = None
    best_len = 0

    for start in start_positions:
        stack = []
        for j in range(start, len(s)):
            if c := s[j] in "{[":
                stack.append(c)
            elif stack and (stack[-1] + c) in ["{}", "[]"]:
                stack.pop()
            if not stack:
                candidate = s[start:(j + 1)]
                try:
                    json.loads(candidate)
                    length = len(candidate)
                    if length > best_len:
                        best = candidate
                        best_len = length
                except json.JSONDecodeError:
                    pass
                break

    if best is None:
        for i in start_positions:
            for j in range(i + 1, len(s) + 1):
                candidate = s[i:j]
                try:
                    json.loads(candidate)
                    length = len(candidate)
                    if length > best_len:
                        best = candidate
                        best_len = length
                except json.JSONDecodeError:
                    continue

    if best is None:
        return json.loads(s)  # Should raise an error.
    return json.loads(best)
